diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 093cfcdd045..6099964be5b 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -96,6 +96,7 @@ py_library( "//tensorflow/python/distribute:multi_worker_util", "//tensorflow/python/keras/engine:keras_tensor", "//tensorflow/python/keras/utils:control_flow_util", + "//tensorflow/python/keras/utils:object_identity", "//tensorflow/python/keras/utils:tf_contextlib", "//tensorflow/python/keras/utils:tf_inspect", ], diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index fbb5d9e31ab..e6b2b65b27b 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -55,6 +55,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.keras import backend_config from tensorflow.python.keras.engine import keras_tensor from tensorflow.python.keras.utils import control_flow_util +from tensorflow.python.keras.utils import object_identity from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.ops import array_ops @@ -84,7 +85,6 @@ from tensorflow.python.training.tracking import util as tracking_util from tensorflow.python.util import dispatch from tensorflow.python.util import keras_deps from tensorflow.python.util import nest -from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls diff --git a/tensorflow/python/keras/engine/BUILD b/tensorflow/python/keras/engine/BUILD index 7b830353a7f..e2605b98652 100644 --- a/tensorflow/python/keras/engine/BUILD +++ b/tensorflow/python/keras/engine/BUILD @@ -159,6 +159,7 @@ py_library( "//tensorflow/python/keras/saving", "//tensorflow/python/keras/utils:generic_utils", "//tensorflow/python/keras/utils:layer_utils", + "//tensorflow/python/keras/utils:object_identity", "//tensorflow/python/keras/utils:tf_utils", "//tensorflow/python/keras/utils:version_utils", "//tensorflow/python/module", @@ -206,6 +207,7 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:lib", "//tensorflow/python:tensor_spec", + "//tensorflow/python/keras/utils:object_identity", "@six_archive//:six", ], ) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index af04591a556..32b53f61d5b 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -62,6 +62,7 @@ from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import layer_serialization from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils import object_identity from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import version_utils @@ -82,7 +83,6 @@ from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import tracking from tensorflow.python.util import compat from tensorflow.python.util import nest -from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import get_canonical_name_for_symbol from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index 4e8859ca238..ddb9d53120c 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -52,6 +52,7 @@ from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.keras.saving.saved_model import layer_serialization from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.keras.utils import object_identity from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_utils # A module that only depends on `keras.layers` import these from here. @@ -68,7 +69,6 @@ from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import tracking from tensorflow.python.util import nest -from tensorflow.python.util import object_identity from tensorflow.tools.docs import doc_controls diff --git a/tensorflow/python/keras/engine/keras_tensor.py b/tensorflow/python/keras/engine/keras_tensor.py index 88cde689b5b..fbefbe73fb1 100644 --- a/tensorflow/python/keras/engine/keras_tensor.py +++ b/tensorflow/python/keras/engine/keras_tensor.py @@ -25,11 +25,11 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import type_spec as type_spec_module +from tensorflow.python.keras.utils import object_identity from tensorflow.python.ops import array_ops from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.util import nest -from tensorflow.python.util import object_identity # pylint: disable=g-classes-have-attributes diff --git a/tensorflow/python/keras/utils/BUILD b/tensorflow/python/keras/utils/BUILD index 59fd7235869..65f03eb6c23 100644 --- a/tensorflow/python/keras/utils/BUILD +++ b/tensorflow/python/keras/utils/BUILD @@ -92,6 +92,7 @@ py_library( srcs = ["tf_utils.py"], srcs_version = "PY2AND3", deps = [ + ":object_identity", "//tensorflow/python:composite_tensor", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", @@ -210,6 +211,13 @@ py_library( ], ) +py_library( + name = "object_identity", + srcs = ["object_identity.py"], + srcs_version = "PY2AND3", + deps = [], +) + py_library( name = "tf_contextlib", srcs = ["tf_contextlib.py"], diff --git a/tensorflow/python/keras/utils/object_identity.py b/tensorflow/python/keras/utils/object_identity.py new file mode 100644 index 00000000000..d87f6aaeb53 --- /dev/null +++ b/tensorflow/python/keras/utils/object_identity.py @@ -0,0 +1,247 @@ +"""Utilities for collecting objects based on "is" comparison.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections.abc as collections_abc +import weakref + + +# LINT.IfChange +class _ObjectIdentityWrapper(object): + """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. + + Since __eq__ is based on object identity, it's safe to also define __hash__ + based on object ids. This lets us add unhashable types like trackable + _ListWrapper objects to object-identity collections. + """ + + __slots__ = ["_wrapped", "__weakref__"] + + def __init__(self, wrapped): + self._wrapped = wrapped + + @property + def unwrapped(self): + return self._wrapped + + def _assert_type(self, other): + if not isinstance(other, _ObjectIdentityWrapper): + raise TypeError("Cannot compare wrapped object with unwrapped object") + + def __lt__(self, other): + self._assert_type(other) + return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access + + def __gt__(self, other): + self._assert_type(other) + return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access + + def __eq__(self, other): + if other is None: + return False + self._assert_type(other) + return self._wrapped is other._wrapped # pylint: disable=protected-access + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + # Wrapper id() is also fine for weakrefs. In fact, we rely on + # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is + # weakref.ref(a) in _WeakObjectIdentityWrapper. + return id(self._wrapped) + + def __repr__(self): + return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped) + + +class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper): + + __slots__ = () + + def __init__(self, wrapped): + super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped)) + + @property + def unwrapped(self): + return self._wrapped() + + +class Reference(_ObjectIdentityWrapper): + """Reference that refers an object. + + ```python + x = [1] + y = [1] + + x_ref1 = Reference(x) + x_ref2 = Reference(x) + y_ref2 = Reference(y) + + print(x_ref1 == x_ref2) + ==> True + + print(x_ref1 == y) + ==> False + ``` + """ + + __slots__ = () + + # Disabling super class' unwrapped field. + unwrapped = property() + + def deref(self): + """Returns the referenced object. + + ```python + x_ref = Reference(x) + print(x is x_ref.deref()) + ==> True + ``` + """ + return self._wrapped + + +class ObjectIdentityDictionary(collections_abc.MutableMapping): + """A mutable mapping data structure which compares using "is". + + This is necessary because we have trackable objects (_ListWrapper) which + have behavior identical to built-in Python lists (including being unhashable + and comparing based on the equality of their contents by default). + """ + + __slots__ = ["_storage"] + + def __init__(self): + self._storage = {} + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __getitem__(self, key): + return self._storage[self._wrap_key(key)] + + def __setitem__(self, key, value): + self._storage[self._wrap_key(key)] = value + + def __delitem__(self, key): + del self._storage[self._wrap_key(key)] + + def __len__(self): + return len(self._storage) + + def __iter__(self): + for key in self._storage: + yield key.unwrapped + + def __repr__(self): + return "ObjectIdentityDictionary(%s)" % repr(self._storage) + + +class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary): + """Like weakref.WeakKeyDictionary, but compares objects with "is".""" + + __slots__ = ["__weakref__"] + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len(list(self._storage)) + + def __iter__(self): + keys = self._storage.keys() + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + del self[key] + else: + yield unwrapped + + +class ObjectIdentitySet(collections_abc.MutableSet): + """Like the built-in set, but compares objects with "is".""" + + __slots__ = ["_storage", "__weakref__"] + + def __init__(self, *args): + self._storage = set(self._wrap_key(obj) for obj in list(*args)) + + @staticmethod + def _from_storage(storage): + result = ObjectIdentitySet() + result._storage = storage # pylint: disable=protected-access + return result + + def _wrap_key(self, key): + return _ObjectIdentityWrapper(key) + + def __contains__(self, key): + return self._wrap_key(key) in self._storage + + def discard(self, key): + self._storage.discard(self._wrap_key(key)) + + def add(self, key): + self._storage.add(self._wrap_key(key)) + + def update(self, items): + self._storage.update([self._wrap_key(item) for item in items]) + + def clear(self): + self._storage.clear() + + def intersection(self, items): + return self._storage.intersection([self._wrap_key(item) for item in items]) + + def difference(self, items): + return ObjectIdentitySet._from_storage( + self._storage.difference([self._wrap_key(item) for item in items])) + + def __len__(self): + return len(self._storage) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + yield key.unwrapped + + +class ObjectIdentityWeakSet(ObjectIdentitySet): + """Like weakref.WeakSet, but compares objects with "is".""" + + __slots__ = () + + def _wrap_key(self, key): + return _WeakObjectIdentityWrapper(key) + + def __len__(self): + # Iterate, discarding old weak refs + return len([_ for _ in self]) + + def __iter__(self): + keys = list(self._storage) + for key in keys: + unwrapped = key.unwrapped + if unwrapped is None: + self.discard(key) + else: + yield unwrapped +# LINT.ThenChange(//tensorflow/python/util/object_identity.py) diff --git a/tensorflow/python/keras/utils/tf_utils.py b/tensorflow/python/keras/utils/tf_utils.py index 82e0e3c7ea6..02974467b97 100644 --- a/tensorflow/python/keras/utils/tf_utils.py +++ b/tensorflow/python/keras/utils/tf_utils.py @@ -32,13 +32,13 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.framework import type_spec from tensorflow.python.keras import backend as K from tensorflow.python.keras.engine import keras_tensor +from tensorflow.python.keras.utils import object_identity from tensorflow.python.keras.utils import tf_contextlib from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util import nest -from tensorflow.python.util import object_identity def is_tensor_or_tensor_list(v): diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index b4f4c63bc96..d3a704e1415 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -22,6 +22,7 @@ import weakref from tensorflow.python.util.compat import collections_abc +# LINT.IfChange class _ObjectIdentityWrapper(object): """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped. @@ -244,3 +245,4 @@ class ObjectIdentityWeakSet(ObjectIdentitySet): self.discard(key) else: yield unwrapped +# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py)