Add use of ObjectIdentitySet for tensor equality

PiperOrigin-RevId: 261232429
This commit is contained in:
Gaurav Jain 2019-08-01 17:32:15 -07:00 committed by TensorFlower Gardener
parent 38c098a463
commit 166c742245

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import data_structures from tensorflow.python.training.tracking import data_structures
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -150,7 +151,8 @@ def _lift_unlifted_variables(graph, variable_holder):
ops.GraphKeys.GLOBAL_VARIABLES) ops.GraphKeys.GLOBAL_VARIABLES)
local_collection_variables = ops.get_collection( local_collection_variables = ops.get_collection(
ops.GraphKeys.LOCAL_VARIABLES) ops.GraphKeys.LOCAL_VARIABLES)
existing_captures = set(graph.internal_captures) existing_captures = object_identity.ObjectIdentitySet(
graph.internal_captures)
lifted_variables = {} lifted_variables = {}
def _should_lift_variable(v): def _should_lift_variable(v):
@ -249,7 +251,8 @@ class WrappedFunction(function.ConcreteFunction):
# Ignoring all feeds that are captures allows prune to be called # Ignoring all feeds that are captures allows prune to be called
# using wrapped_func.inputs even when it uses variables # using wrapped_func.inputs even when it uses variables
internal_captures = self.graph.internal_captures internal_captures = object_identity.ObjectIdentitySet(
self.graph.internal_captures)
flat_feeds = [f for f in flat_feeds if f not in internal_captures] flat_feeds = [f for f in flat_feeds if f not in internal_captures]
operation_fetches = [] operation_fetches = []
@ -302,7 +305,7 @@ class WrappedFunction(function.ConcreteFunction):
lift_map = lift_to_graph.lift_to_graph( lift_map = lift_to_graph.lift_to_graph(
operation_fetches + tensor_fetches, operation_fetches + tensor_fetches,
pruned_graph, pruned_graph,
sources=flat_feeds + internal_captures) sources=flat_feeds + self.graph.internal_captures)
# Note that we add the component tensors of any composite tensors to the # Note that we add the component tensors of any composite tensors to the
# returned function's outputs list; the list must contain these component # returned function's outputs list; the list must contain these component