Clean up distributed variable capturing code

PiperOrigin-RevId: 260624136
This commit is contained in:
Gaurav Jain 2019-07-29 18:30:45 -07:00 committed by TensorFlower Gardener
parent 8f498f0ad3
commit ace39d8e0a
2 changed files with 25 additions and 26 deletions

View File

@ -559,17 +559,6 @@ class FuncGraph(ops.Graph):
# makes sure that any tensor needed by a custom_gradient is correctly # makes sure that any tensor needed by a custom_gradient is correctly
# captured. # captured.
# TODO(b/134097853): figure out a better way to check distributed variables
if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"):
# This checks if the 'tensor' is a DistributedVariable. When it is a
# DistributedVariable, we do not want to check its "graph" attr as the
# following if branch does, because "graph" is not an attr for the
# container DistributedVariable object, and the underlying components may
# not have been initialized yet.
# The reason we do not use isinstance() is due to cyclic dependency issue.
if name is None:
name = str("distributed_variable")
return self._capture_helper(tensor, name)
if (getattr(tensor, "graph", None) is not self and if (getattr(tensor, "graph", None) is not self and
hasattr(self, "_forward_func_graph") and hasattr(self, "_forward_func_graph") and
isinstance(self._forward_func_graph, FuncGraph)): isinstance(self._forward_func_graph, FuncGraph)):
@ -605,6 +594,12 @@ class FuncGraph(ops.Graph):
lambda x: [x]) lambda x: [x])
return captured_tensor return captured_tensor
def capture_distributed_variable(self, variable, placeholder):
"""Add given distributed variable to captures with given placeholder."""
self.captures[variable] = placeholder
tape.record_operation("captured_value", [placeholder], [variable],
lambda x: [x])
@property @property
def external_captures(self): def external_captures(self):
"""External tensors captured by this function.""" """External tensors captured by this function."""

View File

@ -176,22 +176,26 @@ class Loader(object):
if bound_inputs: if bound_inputs:
for bound_input, internal_capture in zip( for bound_input, internal_capture in zip(
bound_inputs, concrete_function.inputs[-len(bound_inputs):]): bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
concrete_function.graph.captures[bound_input] = internal_capture if ds_values.is_distributed_variable(bound_input):
if internal_capture.dtype == dtypes.resource: concrete_function.graph.capture_distributed_variable(
if resource_variable_ops.is_resource_variable(bound_input): bound_input, internal_capture)
try: else:
handle = bound_input.handle concrete_function.graph.captures[bound_input] = internal_capture
except ValueError: if internal_capture.dtype == dtypes.resource:
# For mirrored variables we'll copy handle data for components if resource_variable_ops.is_resource_variable(bound_input):
# as they get captured. try:
pass handle = bound_input.handle
except ValueError:
# For mirrored variables we'll copy handle data for components
# as they get captured.
pass
else:
custom_gradient.copy_handle_data(handle, internal_capture)
else: else:
custom_gradient.copy_handle_data(handle, internal_capture) custom_gradient.copy_handle_data(bound_input, internal_capture)
else: # Setting "captures" first means "capture" won't create a new
custom_gradient.copy_handle_data(bound_input, internal_capture) # placeholder for this input.
# Setting "captures" first means "capture" won't create a new concrete_function.graph.capture(bound_input)
# placeholder for this input.
concrete_function.graph.capture(bound_input)
def _get_tensor_from_node(self, node_id): def _get_tensor_from_node(self, node_id):
"""Resolves a node id into a tensor to be captured for a function.""" """Resolves a node id into a tensor to be captured for a function."""