From 166c74224501a5443d6fe04f90b117e0a778a73b Mon Sep 17 00:00:00 2001 From: Gaurav Jain Date: Thu, 1 Aug 2019 17:32:15 -0700 Subject: [PATCH] Add use of ObjectIdentitySet for tensor equality PiperOrigin-RevId: 261232429 --- tensorflow/python/eager/wrap_function.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 269ec344b75..96c463ceecb 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.tracking import data_structures from tensorflow.python.util import nest +from tensorflow.python.util import object_identity from tensorflow.python.util.tf_export import tf_export @@ -150,7 +151,8 @@ def _lift_unlifted_variables(graph, variable_holder): ops.GraphKeys.GLOBAL_VARIABLES) local_collection_variables = ops.get_collection( ops.GraphKeys.LOCAL_VARIABLES) - existing_captures = set(graph.internal_captures) + existing_captures = object_identity.ObjectIdentitySet( + graph.internal_captures) lifted_variables = {} def _should_lift_variable(v): @@ -249,7 +251,8 @@ class WrappedFunction(function.ConcreteFunction): # Ignoring all feeds that are captures allows prune to be called # using wrapped_func.inputs even when it uses variables - internal_captures = self.graph.internal_captures + internal_captures = object_identity.ObjectIdentitySet( + self.graph.internal_captures) flat_feeds = [f for f in flat_feeds if f not in internal_captures] operation_fetches = [] @@ -302,7 +305,7 @@ class WrappedFunction(function.ConcreteFunction): lift_map = lift_to_graph.lift_to_graph( operation_fetches + tensor_fetches, pruned_graph, - sources=flat_feeds + internal_captures) + sources=flat_feeds + self.graph.internal_captures) # Note that we add the component tensors of any composite tensors to the # returned function's outputs list; the list must contain these component