"""Utilities for collecting objects based on "is" comparison."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import collections
import weakref


# LINT.IfChange
class _ObjectIdentityWrapper:
    """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.

    Since __eq__ is based on object identity, it's safe to also define __hash__
    based on object ids. This lets us add unhashable types like trackable
    _ListWrapper objects to object-identity collections.
    """

    __slots__ = ["_wrapped", "__weakref__"]

    def __init__(self, wrapped):
        self._wrapped = wrapped

    @property
    def unwrapped(self):
        return self._wrapped

    def _assert_type(self, other):
        if not isinstance(other, _ObjectIdentityWrapper):
            raise TypeError(
                "Cannot compare wrapped object with unwrapped object. "
                f"Expect the object to be `_ObjectIdentityWrapper`. "
                f"Got: {other}"
            )

    def __lt__(self, other):
        self._assert_type(other)
        return id(self._wrapped) < id(other._wrapped)

    def __gt__(self, other):
        self._assert_type(other)
        return id(self._wrapped) > id(other._wrapped)

    def __eq__(self, other):
        if other is None:
            return False
        self._assert_type(other)
        return self._wrapped is other._wrapped

    def __ne__(self, other):
        return not self.__eq__(other)

    def __hash__(self):
        # Wrapper id() is also fine for weakrefs. In fact, we rely on
        # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
        # weakref.ref(a) in _WeakObjectIdentityWrapper.
        return id(self._wrapped)

    def __repr__(self):
        return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)


class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):

    __slots__ = ()

    def __init__(self, wrapped):
        super().__init__(weakref.ref(wrapped))

    @property
    def unwrapped(self):
        return self._wrapped()


class Reference(_ObjectIdentityWrapper):
    """Reference that refers an object.

    ```python
    x = [1]
    y = [1]

    x_ref1 = Reference(x)
    x_ref2 = Reference(x)
    y_ref2 = Reference(y)

    print(x_ref1 == x_ref2)
    ==> True

    print(x_ref1 == y)
    ==> False
    ```
    """

    __slots__ = ()

    # Disabling super class' unwrapped field.
    unwrapped = property()

    def deref(self):
        """Returns the referenced object.

        ```python
        x_ref = Reference(x)
        print(x is x_ref.deref())
        ==> True
        ```
        """
        return self._wrapped


class ObjectIdentityDictionary(collections.abc.MutableMapping):
    """A mutable mapping data structure which compares using "is".

    This is necessary because we have trackable objects (_ListWrapper) which
    have behavior identical to built-in Python lists (including being unhashable
    and comparing based on the equality of their contents by default).
    """

    __slots__ = ["_storage"]

    def __init__(self):
        self._storage = {}

    def _wrap_key(self, key):
        return _ObjectIdentityWrapper(key)

    def __getitem__(self, key):
        return self._storage[self._wrap_key(key)]

    def __setitem__(self, key, value):
        self._storage[self._wrap_key(key)] = value

    def __delitem__(self, key):
        del self._storage[self._wrap_key(key)]

    def __len__(self):
        return len(self._storage)

    def __iter__(self):
        for key in self._storage:
            yield key.unwrapped

    def __repr__(self):
        return "ObjectIdentityDictionary(%s)" % repr(self._storage)


class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
    """Like weakref.WeakKeyDictionary, but compares objects with "is"."""

    __slots__ = ["__weakref__"]

    def _wrap_key(self, key):
        return _WeakObjectIdentityWrapper(key)

    def __len__(self):
        # Iterate, discarding old weak refs
        return len(list(self._storage))

    def __iter__(self):
        keys = self._storage.keys()
        for key in keys:
            unwrapped = key.unwrapped
            if unwrapped is None:
                del self[key]
            else:
                yield unwrapped


class ObjectIdentitySet(collections.abc.MutableSet):
    """Like the built-in set, but compares objects with "is"."""

    __slots__ = ["_storage", "__weakref__"]

    def __init__(self, *args):
        self._storage = set(self._wrap_key(obj) for obj in list(*args))

    @staticmethod
    def _from_storage(storage):
        result = ObjectIdentitySet()
        result._storage = storage
        return result

    def _wrap_key(self, key):
        return _ObjectIdentityWrapper(key)

    def __contains__(self, key):
        return self._wrap_key(key) in self._storage

    def discard(self, key):
        self._storage.discard(self._wrap_key(key))

    def add(self, key):
        self._storage.add(self._wrap_key(key))

    def update(self, items):
        self._storage.update([self._wrap_key(item) for item in items])

    def clear(self):
        self._storage.clear()

    def intersection(self, items):
        return self._storage.intersection(
            [self._wrap_key(item) for item in items]
        )

    def difference(self, items):
        return ObjectIdentitySet._from_storage(
            self._storage.difference([self._wrap_key(item) for item in items])
        )

    def __len__(self):
        return len(self._storage)

    def __iter__(self):
        keys = list(self._storage)
        for key in keys:
            yield key.unwrapped


class ObjectIdentityWeakSet(ObjectIdentitySet):
    """Like weakref.WeakSet, but compares objects with "is"."""

    __slots__ = ()

    def _wrap_key(self, key):
        return _WeakObjectIdentityWrapper(key)

    def __len__(self):
        # Iterate, discarding old weak refs
        return len([_ for _ in self])

    def __iter__(self):
        keys = list(self._storage)
        for key in keys:
            unwrapped = key.unwrapped
            if unwrapped is None:
                self.discard(key)
            else:
                yield unwrapped


# LINT.ThenChange(//tensorflow/python/util/object_identity.py)
