Add use of ObjectIdentitySet for tensor equality
PiperOrigin-RevId: 261232429
This commit is contained in:
parent
38c098a463
commit
166c742245
@ -36,6 +36,7 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training.tracking import data_structures
|
from tensorflow.python.training.tracking import data_structures
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util import object_identity
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -150,7 +151,8 @@ def _lift_unlifted_variables(graph, variable_holder):
|
|||||||
ops.GraphKeys.GLOBAL_VARIABLES)
|
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
local_collection_variables = ops.get_collection(
|
local_collection_variables = ops.get_collection(
|
||||||
ops.GraphKeys.LOCAL_VARIABLES)
|
ops.GraphKeys.LOCAL_VARIABLES)
|
||||||
existing_captures = set(graph.internal_captures)
|
existing_captures = object_identity.ObjectIdentitySet(
|
||||||
|
graph.internal_captures)
|
||||||
lifted_variables = {}
|
lifted_variables = {}
|
||||||
|
|
||||||
def _should_lift_variable(v):
|
def _should_lift_variable(v):
|
||||||
@ -249,7 +251,8 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
|
|
||||||
# Ignoring all feeds that are captures allows prune to be called
|
# Ignoring all feeds that are captures allows prune to be called
|
||||||
# using wrapped_func.inputs even when it uses variables
|
# using wrapped_func.inputs even when it uses variables
|
||||||
internal_captures = self.graph.internal_captures
|
internal_captures = object_identity.ObjectIdentitySet(
|
||||||
|
self.graph.internal_captures)
|
||||||
flat_feeds = [f for f in flat_feeds if f not in internal_captures]
|
flat_feeds = [f for f in flat_feeds if f not in internal_captures]
|
||||||
|
|
||||||
operation_fetches = []
|
operation_fetches = []
|
||||||
@ -302,7 +305,7 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
lift_map = lift_to_graph.lift_to_graph(
|
lift_map = lift_to_graph.lift_to_graph(
|
||||||
operation_fetches + tensor_fetches,
|
operation_fetches + tensor_fetches,
|
||||||
pruned_graph,
|
pruned_graph,
|
||||||
sources=flat_feeds + internal_captures)
|
sources=flat_feeds + self.graph.internal_captures)
|
||||||
|
|
||||||
# Note that we add the component tensors of any composite tensors to the
|
# Note that we add the component tensors of any composite tensors to the
|
||||||
# returned function's outputs list; the list must contain these component
|
# returned function's outputs list; the list must contain these component
|
||||||
|
Loading…
Reference in New Issue
Block a user