Clean up distributed variable capturing code
PiperOrigin-RevId: 260624136
This commit is contained in:
parent
8f498f0ad3
commit
ace39d8e0a
@ -559,17 +559,6 @@ class FuncGraph(ops.Graph):
|
|||||||
# makes sure that any tensor needed by a custom_gradient is correctly
|
# makes sure that any tensor needed by a custom_gradient is correctly
|
||||||
# captured.
|
# captured.
|
||||||
|
|
||||||
# TODO(b/134097853): figure out a better way to check distributed variables
|
|
||||||
if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"):
|
|
||||||
# This checks if the 'tensor' is a DistributedVariable. When it is a
|
|
||||||
# DistributedVariable, we do not want to check its "graph" attr as the
|
|
||||||
# following if branch does, because "graph" is not an attr for the
|
|
||||||
# container DistributedVariable object, and the underlying components may
|
|
||||||
# not have been initialized yet.
|
|
||||||
# The reason we do not use isinstance() is due to cyclic dependency issue.
|
|
||||||
if name is None:
|
|
||||||
name = str("distributed_variable")
|
|
||||||
return self._capture_helper(tensor, name)
|
|
||||||
if (getattr(tensor, "graph", None) is not self and
|
if (getattr(tensor, "graph", None) is not self and
|
||||||
hasattr(self, "_forward_func_graph") and
|
hasattr(self, "_forward_func_graph") and
|
||||||
isinstance(self._forward_func_graph, FuncGraph)):
|
isinstance(self._forward_func_graph, FuncGraph)):
|
||||||
@ -605,6 +594,12 @@ class FuncGraph(ops.Graph):
|
|||||||
lambda x: [x])
|
lambda x: [x])
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
|
||||||
|
def capture_distributed_variable(self, variable, placeholder):
|
||||||
|
"""Add given distributed variable to captures with given placeholder."""
|
||||||
|
self.captures[variable] = placeholder
|
||||||
|
tape.record_operation("captured_value", [placeholder], [variable],
|
||||||
|
lambda x: [x])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def external_captures(self):
|
def external_captures(self):
|
||||||
"""External tensors captured by this function."""
|
"""External tensors captured by this function."""
|
||||||
|
@ -176,6 +176,10 @@ class Loader(object):
|
|||||||
if bound_inputs:
|
if bound_inputs:
|
||||||
for bound_input, internal_capture in zip(
|
for bound_input, internal_capture in zip(
|
||||||
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
||||||
|
if ds_values.is_distributed_variable(bound_input):
|
||||||
|
concrete_function.graph.capture_distributed_variable(
|
||||||
|
bound_input, internal_capture)
|
||||||
|
else:
|
||||||
concrete_function.graph.captures[bound_input] = internal_capture
|
concrete_function.graph.captures[bound_input] = internal_capture
|
||||||
if internal_capture.dtype == dtypes.resource:
|
if internal_capture.dtype == dtypes.resource:
|
||||||
if resource_variable_ops.is_resource_variable(bound_input):
|
if resource_variable_ops.is_resource_variable(bound_input):
|
||||||
|
Loading…
Reference in New Issue
Block a user