Add use of ObjectIdentitySet for tensor equality
PiperOrigin-RevId: 261232429
This commit is contained in:
parent
38c098a463
commit
166c742245
@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -150,7 +151,8 @@ def _lift_unlifted_variables(graph, variable_holder):
|
||||
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
local_collection_variables = ops.get_collection(
|
||||
ops.GraphKeys.LOCAL_VARIABLES)
|
||||
existing_captures = set(graph.internal_captures)
|
||||
existing_captures = object_identity.ObjectIdentitySet(
|
||||
graph.internal_captures)
|
||||
lifted_variables = {}
|
||||
|
||||
def _should_lift_variable(v):
|
||||
@ -249,7 +251,8 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
|
||||
# Ignoring all feeds that are captures allows prune to be called
|
||||
# using wrapped_func.inputs even when it uses variables
|
||||
internal_captures = self.graph.internal_captures
|
||||
internal_captures = object_identity.ObjectIdentitySet(
|
||||
self.graph.internal_captures)
|
||||
flat_feeds = [f for f in flat_feeds if f not in internal_captures]
|
||||
|
||||
operation_fetches = []
|
||||
@ -302,7 +305,7 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
lift_map = lift_to_graph.lift_to_graph(
|
||||
operation_fetches + tensor_fetches,
|
||||
pruned_graph,
|
||||
sources=flat_feeds + internal_captures)
|
||||
sources=flat_feeds + self.graph.internal_captures)
|
||||
|
||||
# Note that we add the component tensors of any composite tensors to the
|
||||
# returned function's outputs list; the list must contain these component
|
||||
|
Loading…
Reference in New Issue
Block a user