Fork the object_identity related class to Keras.
PiperOrigin-RevId: 339948109 Change-Id: Ia2612efa66b48ac5784d2a53cdeb0ce9001f4025
This commit is contained in:
parent
662c0c2ece
commit
f4280211b2
@ -96,6 +96,7 @@ py_library(
|
|||||||
"//tensorflow/python/distribute:multi_worker_util",
|
"//tensorflow/python/distribute:multi_worker_util",
|
||||||
"//tensorflow/python/keras/engine:keras_tensor",
|
"//tensorflow/python/keras/engine:keras_tensor",
|
||||||
"//tensorflow/python/keras/utils:control_flow_util",
|
"//tensorflow/python/keras/utils:control_flow_util",
|
||||||
|
"//tensorflow/python/keras/utils:object_identity",
|
||||||
"//tensorflow/python/keras/utils:tf_contextlib",
|
"//tensorflow/python/keras/utils:tf_contextlib",
|
||||||
"//tensorflow/python/keras/utils:tf_inspect",
|
"//tensorflow/python/keras/utils:tf_inspect",
|
||||||
],
|
],
|
||||||
|
@ -55,6 +55,7 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.keras import backend_config
|
from tensorflow.python.keras import backend_config
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
from tensorflow.python.keras.utils import control_flow_util
|
from tensorflow.python.keras.utils import control_flow_util
|
||||||
|
from tensorflow.python.keras.utils import object_identity
|
||||||
from tensorflow.python.keras.utils import tf_contextlib
|
from tensorflow.python.keras.utils import tf_contextlib
|
||||||
from tensorflow.python.keras.utils import tf_inspect
|
from tensorflow.python.keras.utils import tf_inspect
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -84,7 +85,6 @@ from tensorflow.python.training.tracking import util as tracking_util
|
|||||||
from tensorflow.python.util import dispatch
|
from tensorflow.python.util import dispatch
|
||||||
from tensorflow.python.util import keras_deps
|
from tensorflow.python.util import keras_deps
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
|
@ -159,6 +159,7 @@ py_library(
|
|||||||
"//tensorflow/python/keras/saving",
|
"//tensorflow/python/keras/saving",
|
||||||
"//tensorflow/python/keras/utils:generic_utils",
|
"//tensorflow/python/keras/utils:generic_utils",
|
||||||
"//tensorflow/python/keras/utils:layer_utils",
|
"//tensorflow/python/keras/utils:layer_utils",
|
||||||
|
"//tensorflow/python/keras/utils:object_identity",
|
||||||
"//tensorflow/python/keras/utils:tf_utils",
|
"//tensorflow/python/keras/utils:tf_utils",
|
||||||
"//tensorflow/python/keras/utils:version_utils",
|
"//tensorflow/python/keras/utils:version_utils",
|
||||||
"//tensorflow/python/module",
|
"//tensorflow/python/module",
|
||||||
@ -206,6 +207,7 @@ py_library(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
"//tensorflow/python:tensor_spec",
|
"//tensorflow/python:tensor_spec",
|
||||||
|
"//tensorflow/python/keras/utils:object_identity",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -62,6 +62,7 @@ from tensorflow.python.keras.mixed_precision import policy
|
|||||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
|
from tensorflow.python.keras.utils import object_identity
|
||||||
from tensorflow.python.keras.utils import tf_inspect
|
from tensorflow.python.keras.utils import tf_inspect
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.keras.utils import version_utils
|
from tensorflow.python.keras.utils import version_utils
|
||||||
@ -82,7 +83,6 @@ from tensorflow.python.training.tracking import data_structures
|
|||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
from tensorflow.python.util.tf_export import get_canonical_name_for_symbol
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
@ -52,6 +52,7 @@ from tensorflow.python.keras.mixed_precision import policy
|
|||||||
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
from tensorflow.python.keras.saving.saved_model import layer_serialization
|
||||||
from tensorflow.python.keras.utils import generic_utils
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.keras.utils import layer_utils
|
from tensorflow.python.keras.utils import layer_utils
|
||||||
|
from tensorflow.python.keras.utils import object_identity
|
||||||
from tensorflow.python.keras.utils import tf_inspect
|
from tensorflow.python.keras.utils import tf_inspect
|
||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
# A module that only depends on `keras.layers` import these from here.
|
# A module that only depends on `keras.layers` import these from here.
|
||||||
@ -68,7 +69,6 @@ from tensorflow.python.training.tracking import base as trackable
|
|||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,11 +25,11 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import type_spec as type_spec_module
|
from tensorflow.python.framework import type_spec as type_spec_module
|
||||||
|
from tensorflow.python.keras.utils import object_identity
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import
|
from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
|
|
||||||
# pylint: disable=g-classes-have-attributes
|
# pylint: disable=g-classes-have-attributes
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ py_library(
|
|||||||
srcs = ["tf_utils.py"],
|
srcs = ["tf_utils.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":object_identity",
|
||||||
"//tensorflow/python:composite_tensor",
|
"//tensorflow/python:composite_tensor",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
@ -210,6 +211,13 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "object_identity",
|
||||||
|
srcs = ["object_identity.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "tf_contextlib",
|
name = "tf_contextlib",
|
||||||
srcs = ["tf_contextlib.py"],
|
srcs = ["tf_contextlib.py"],
|
||||||
|
247
tensorflow/python/keras/utils/object_identity.py
Normal file
247
tensorflow/python/keras/utils/object_identity.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
"""Utilities for collecting objects based on "is" comparison."""
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections.abc as collections_abc
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
|
||||||
|
# LINT.IfChange
|
||||||
|
class _ObjectIdentityWrapper(object):
|
||||||
|
"""Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
|
||||||
|
|
||||||
|
Since __eq__ is based on object identity, it's safe to also define __hash__
|
||||||
|
based on object ids. This lets us add unhashable types like trackable
|
||||||
|
_ListWrapper objects to object-identity collections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_wrapped", "__weakref__"]
|
||||||
|
|
||||||
|
def __init__(self, wrapped):
|
||||||
|
self._wrapped = wrapped
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped(self):
|
||||||
|
return self._wrapped
|
||||||
|
|
||||||
|
def _assert_type(self, other):
|
||||||
|
if not isinstance(other, _ObjectIdentityWrapper):
|
||||||
|
raise TypeError("Cannot compare wrapped object with unwrapped object")
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
self._assert_type(other)
|
||||||
|
return id(self._wrapped) < id(other._wrapped) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
self._assert_type(other)
|
||||||
|
return id(self._wrapped) > id(other._wrapped) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if other is None:
|
||||||
|
return False
|
||||||
|
self._assert_type(other)
|
||||||
|
return self._wrapped is other._wrapped # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
# Wrapper id() is also fine for weakrefs. In fact, we rely on
|
||||||
|
# id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
|
||||||
|
# weakref.ref(a) in _WeakObjectIdentityWrapper.
|
||||||
|
return id(self._wrapped)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<{} wrapping {!r}>".format(type(self).__name__, self._wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def __init__(self, wrapped):
|
||||||
|
super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unwrapped(self):
|
||||||
|
return self._wrapped()
|
||||||
|
|
||||||
|
|
||||||
|
class Reference(_ObjectIdentityWrapper):
|
||||||
|
"""Reference that refers an object.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = [1]
|
||||||
|
y = [1]
|
||||||
|
|
||||||
|
x_ref1 = Reference(x)
|
||||||
|
x_ref2 = Reference(x)
|
||||||
|
y_ref2 = Reference(y)
|
||||||
|
|
||||||
|
print(x_ref1 == x_ref2)
|
||||||
|
==> True
|
||||||
|
|
||||||
|
print(x_ref1 == y)
|
||||||
|
==> False
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
# Disabling super class' unwrapped field.
|
||||||
|
unwrapped = property()
|
||||||
|
|
||||||
|
def deref(self):
|
||||||
|
"""Returns the referenced object.
|
||||||
|
|
||||||
|
```python
|
||||||
|
x_ref = Reference(x)
|
||||||
|
print(x is x_ref.deref())
|
||||||
|
==> True
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
return self._wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectIdentityDictionary(collections_abc.MutableMapping):
|
||||||
|
"""A mutable mapping data structure which compares using "is".
|
||||||
|
|
||||||
|
This is necessary because we have trackable objects (_ListWrapper) which
|
||||||
|
have behavior identical to built-in Python lists (including being unhashable
|
||||||
|
and comparing based on the equality of their contents by default).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ["_storage"]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._storage = {}
|
||||||
|
|
||||||
|
def _wrap_key(self, key):
|
||||||
|
return _ObjectIdentityWrapper(key)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._storage[self._wrap_key(key)]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self._storage[self._wrap_key(key)] = value
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self._storage[self._wrap_key(key)]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._storage)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for key in self._storage:
|
||||||
|
yield key.unwrapped
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "ObjectIdentityDictionary(%s)" % repr(self._storage)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
|
||||||
|
"""Like weakref.WeakKeyDictionary, but compares objects with "is"."""
|
||||||
|
|
||||||
|
__slots__ = ["__weakref__"]
|
||||||
|
|
||||||
|
def _wrap_key(self, key):
|
||||||
|
return _WeakObjectIdentityWrapper(key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Iterate, discarding old weak refs
|
||||||
|
return len(list(self._storage))
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
keys = self._storage.keys()
|
||||||
|
for key in keys:
|
||||||
|
unwrapped = key.unwrapped
|
||||||
|
if unwrapped is None:
|
||||||
|
del self[key]
|
||||||
|
else:
|
||||||
|
yield unwrapped
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectIdentitySet(collections_abc.MutableSet):
|
||||||
|
"""Like the built-in set, but compares objects with "is"."""
|
||||||
|
|
||||||
|
__slots__ = ["_storage", "__weakref__"]
|
||||||
|
|
||||||
|
def __init__(self, *args):
|
||||||
|
self._storage = set(self._wrap_key(obj) for obj in list(*args))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _from_storage(storage):
|
||||||
|
result = ObjectIdentitySet()
|
||||||
|
result._storage = storage # pylint: disable=protected-access
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _wrap_key(self, key):
|
||||||
|
return _ObjectIdentityWrapper(key)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return self._wrap_key(key) in self._storage
|
||||||
|
|
||||||
|
def discard(self, key):
|
||||||
|
self._storage.discard(self._wrap_key(key))
|
||||||
|
|
||||||
|
def add(self, key):
|
||||||
|
self._storage.add(self._wrap_key(key))
|
||||||
|
|
||||||
|
def update(self, items):
|
||||||
|
self._storage.update([self._wrap_key(item) for item in items])
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self._storage.clear()
|
||||||
|
|
||||||
|
def intersection(self, items):
|
||||||
|
return self._storage.intersection([self._wrap_key(item) for item in items])
|
||||||
|
|
||||||
|
def difference(self, items):
|
||||||
|
return ObjectIdentitySet._from_storage(
|
||||||
|
self._storage.difference([self._wrap_key(item) for item in items]))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._storage)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
keys = list(self._storage)
|
||||||
|
for key in keys:
|
||||||
|
yield key.unwrapped
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectIdentityWeakSet(ObjectIdentitySet):
|
||||||
|
"""Like weakref.WeakSet, but compares objects with "is"."""
|
||||||
|
|
||||||
|
__slots__ = ()
|
||||||
|
|
||||||
|
def _wrap_key(self, key):
|
||||||
|
return _WeakObjectIdentityWrapper(key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Iterate, discarding old weak refs
|
||||||
|
return len([_ for _ in self])
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
keys = list(self._storage)
|
||||||
|
for key in keys:
|
||||||
|
unwrapped = key.unwrapped
|
||||||
|
if unwrapped is None:
|
||||||
|
self.discard(key)
|
||||||
|
else:
|
||||||
|
yield unwrapped
|
||||||
|
# LINT.ThenChange(//tensorflow/python/util/object_identity.py)
|
@ -32,13 +32,13 @@ from tensorflow.python.framework import tensor_util
|
|||||||
from tensorflow.python.framework import type_spec
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.keras import backend as K
|
from tensorflow.python.keras import backend as K
|
||||||
from tensorflow.python.keras.engine import keras_tensor
|
from tensorflow.python.keras.engine import keras_tensor
|
||||||
|
from tensorflow.python.keras.utils import object_identity
|
||||||
from tensorflow.python.keras.utils import tf_contextlib
|
from tensorflow.python.keras.utils import tf_contextlib
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import object_identity
|
|
||||||
|
|
||||||
|
|
||||||
def is_tensor_or_tensor_list(v):
|
def is_tensor_or_tensor_list(v):
|
||||||
|
@ -22,6 +22,7 @@ import weakref
|
|||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
|
|
||||||
|
|
||||||
|
# LINT.IfChange
|
||||||
class _ObjectIdentityWrapper(object):
|
class _ObjectIdentityWrapper(object):
|
||||||
"""Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
|
"""Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
|
||||||
|
|
||||||
@ -244,3 +245,4 @@ class ObjectIdentityWeakSet(ObjectIdentitySet):
|
|||||||
self.discard(key)
|
self.discard(key)
|
||||||
else:
|
else:
|
||||||
yield unwrapped
|
yield unwrapped
|
||||||
|
# LINT.ThenChange(//tensorflow/python/keras/utils/object_identity.py)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user