diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 929b7aeeec8..c6efc853b1b 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import array_ops +from tensorflow.python.util import object_identity from tensorflow.python.training.saver import export_meta_graph @@ -177,10 +178,12 @@ def _get_tensor_data(func): Dict """ tensor_data = {} - map_index_to_variable = { - func.captured_inputs.index(var.handle): var - for var in func.graph.variables - } + map_index_to_variable = {} + for var in func.graph.variables: + for idx, captured_input in enumerate(func.captured_inputs): + if var.handle is captured_input: # pylint: disable=protected-access + map_index_to_variable[idx] = var + break # Iterates through all captures which are represented as Placeholders. for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures): @@ -353,9 +356,10 @@ def _construct_concrete_function(func, output_graph_def, """ # Create a ConcreteFunction from the new GraphDef. input_tensors = func.graph.internal_captures - converted_inputs = set( + converted_inputs = object_identity.ObjectIdentitySet( [input_tensors[index] for index in converted_input_indices]) - not_converted_inputs = set(func.inputs).difference(converted_inputs) + not_converted_inputs = object_identity.ObjectIdentitySet( + func.inputs).difference(converted_inputs) not_converted_inputs_map = { tensor.name: tensor for tensor in not_converted_inputs } diff --git a/tensorflow/python/util/object_identity.py b/tensorflow/python/util/object_identity.py index ba134965752..2f913ddad87 100644 --- a/tensorflow/python/util/object_identity.py +++ b/tensorflow/python/util/object_identity.py @@ -156,6 +156,12 @@ class ObjectIdentitySet(collections_abc.MutableSet): def __init__(self, *args): self._storage = set([self._wrap_key(obj) for obj in list(*args)]) + @staticmethod + def _from_storage(storage): + result = ObjectIdentitySet() + result._storage = storage # pylint: disable=protected-access + return result + def _wrap_key(self, key): return _ObjectIdentityWrapper(key) @@ -174,6 +180,10 @@ class ObjectIdentitySet(collections_abc.MutableSet): def intersection(self, items): return self._storage.intersection([self._wrap_key(item) for item in items]) + def difference(self, items): + return ObjectIdentitySet._from_storage( + self._storage.difference([self._wrap_key(item) for item in items])) + def __len__(self): return len(self._storage) diff --git a/tensorflow/python/util/object_identity_test.py b/tensorflow/python/util/object_identity_test.py index 8290473be2d..5dc8be1a25d 100644 --- a/tensorflow/python/util/object_identity_test.py +++ b/tensorflow/python/util/object_identity_test.py @@ -30,5 +30,23 @@ class ObjectIdentityWrapperTest(test.TestCase): self.assertNotEqual(object_identity._ObjectIdentityWrapper(o), o) +class ObjectIdentitySetTest(test.TestCase): + + def testDifference(self): + + class Element(object): + pass + + a = Element() + b = Element() + c = Element() + set1 = object_identity.ObjectIdentitySet([a, b]) + set2 = object_identity.ObjectIdentitySet([b, c]) + diff_set = set1.difference(set2) + self.assertIn(a, diff_set) + self.assertNotIn(b, diff_set) + self.assertNotIn(c, diff_set) + + if __name__ == '__main__': test.main()