Clean up distributed variable capturing code

PiperOrigin-RevId: 260624136
This commit is contained in:
Gaurav Jain 2019-07-29 18:30:45 -07:00 committed by TensorFlower Gardener
parent 8f498f0ad3
commit ace39d8e0a
2 changed files with 25 additions and 26 deletions

View File

@ -559,17 +559,6 @@ class FuncGraph(ops.Graph):
# makes sure that any tensor needed by a custom_gradient is correctly # makes sure that any tensor needed by a custom_gradient is correctly
# captured. # captured.
# TODO(b/134097853): figure out a better way to check distributed variables
if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"):
# This checks if the 'tensor' is a DistributedVariable. When it is a
# DistributedVariable, we do not want to check its "graph" attr as the
# following if branch does, because "graph" is not an attr for the
# container DistributedVariable object, and the underlying components may
# not have been initialized yet.
# The reason we do not use isinstance() is due to cyclic dependency issue.
if name is None:
name = str("distributed_variable")
return self._capture_helper(tensor, name)
if (getattr(tensor, "graph", None) is not self and if (getattr(tensor, "graph", None) is not self and
hasattr(self, "_forward_func_graph") and hasattr(self, "_forward_func_graph") and
isinstance(self._forward_func_graph, FuncGraph)): isinstance(self._forward_func_graph, FuncGraph)):
@ -605,6 +594,12 @@ class FuncGraph(ops.Graph):
lambda x: [x]) lambda x: [x])
return captured_tensor return captured_tensor
def capture_distributed_variable(self, variable, placeholder):
"""Add given distributed variable to captures with given placeholder."""
self.captures[variable] = placeholder
tape.record_operation("captured_value", [placeholder], [variable],
lambda x: [x])
@property @property
def external_captures(self): def external_captures(self):
"""External tensors captured by this function.""" """External tensors captured by this function."""

View File

@ -176,6 +176,10 @@ class Loader(object):
if bound_inputs: if bound_inputs:
for bound_input, internal_capture in zip( for bound_input, internal_capture in zip(
bound_inputs, concrete_function.inputs[-len(bound_inputs):]): bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
if ds_values.is_distributed_variable(bound_input):
concrete_function.graph.capture_distributed_variable(
bound_input, internal_capture)
else:
concrete_function.graph.captures[bound_input] = internal_capture concrete_function.graph.captures[bound_input] = internal_capture
if internal_capture.dtype == dtypes.resource: if internal_capture.dtype == dtypes.resource:
if resource_variable_ops.is_resource_variable(bound_input): if resource_variable_ops.is_resource_variable(bound_input):