Use ObjectIdentitySet to store tensors
PiperOrigin-RevId: 260997853
This commit is contained in:
parent
ebca088d52
commit
7e297bab3f
@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -243,7 +244,7 @@ class CriticalSection(object):
|
||||
# captured_resources is a list of resources that are directly
|
||||
# accessed only by ops created during fn(), not by any
|
||||
# ancestors of those ops in the graph.
|
||||
captured_resources = set([
|
||||
captured_resources = object_identity.ObjectIdentitySet([
|
||||
input_ for op in created_ops
|
||||
for input_ in op.inputs
|
||||
if input_.dtype == dtypes.resource
|
||||
|
@ -131,6 +131,9 @@ class ObjectIdentitySet(collections_abc.MutableSet):
|
||||
def update(self, items):
|
||||
self._storage.update([self._wrap_key(item) for item in items])
|
||||
|
||||
def intersection(self, items):
|
||||
return self._storage.intersection([self._wrap_key(item) for item in items])
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user