From 7e297bab3f9cace20d03055db1f419933b190d35 Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Wed, 31 Jul 2019 14:23:51 -0700 Subject: [PATCH] Use ObjectIdentitySet to store tensors PiperOrigin-RevId: 260997853 --- tensorflow/python/ops/critical_section_ops.py | 3 ++- tensorflow/python/util/object_identity.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/critical_section_ops.py b/tensorflow/python/ops/critical_section_ops.py index 85d828cb40c..16419e45bda 100644 --- a/tensorflow/python/ops/critical_section_ops.py +++ b/tensorflow/python/ops/critical_section_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest +from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import tf_export @@ -243,7 +244,7 @@ class CriticalSection(object): # captured_resources is a list of resources that are directly # accessed only by ops created during fn(), not by any # ancestors of those ops in the graph. - captured_resources = set([ + captured_resources = object_identity.ObjectIdentitySet([ input_ for op in created_ops for input_ in op.inputs if input_.dtype == dtypes.resource diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index c86bf89053f..4d756d4aef2 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -131,6 +131,9 @@ class ObjectIdentitySet(collections_abc.MutableSet): def update(self, items): self._storage.update([self._wrap_key(item) for item in items]) + def intersection(self, items): + return self._storage.intersection([self._wrap_key(item) for item in items]) + def __len__(self): return len(self._storage)