Use ObjectIdentitySet to store tensors

PiperOrigin-RevId: 260997853
This commit is contained in:
Gaurav Jain 2019-07-31 14:23:51 -07:00 committed by TensorFlower Gardener
parent ebca088d52
commit 7e297bab3f
2 changed files with 5 additions and 1 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
@ -243,7 +244,7 @@ class CriticalSection(object):
# captured_resources is a list of resources that are directly
# accessed only by ops created during fn(), not by any
# ancestors of those ops in the graph.
captured_resources = set([
captured_resources = object_identity.ObjectIdentitySet([
input_ for op in created_ops
for input_ in op.inputs
if input_.dtype == dtypes.resource

View File

@ -131,6 +131,9 @@ class ObjectIdentitySet(collections_abc.MutableSet):
def update(self, items):
self._storage.update([self._wrap_key(item) for item in items])
def intersection(self, items):
return self._storage.intersection([self._wrap_key(item) for item in items])
def __len__(self):
return len(self._storage)