Clean up distributed variable capturing code
PiperOrigin-RevId: 260624136
This commit is contained in:
parent
8f498f0ad3
commit
ace39d8e0a
@ -559,17 +559,6 @@ class FuncGraph(ops.Graph):
|
||||
# makes sure that any tensor needed by a custom_gradient is correctly
|
||||
# captured.
|
||||
|
||||
# TODO(b/134097853): figure out a better way to check distributed variables
|
||||
if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"):
|
||||
# This checks if the 'tensor' is a DistributedVariable. When it is a
|
||||
# DistributedVariable, we do not want to check its "graph" attr as the
|
||||
# following if branch does, because "graph" is not an attr for the
|
||||
# container DistributedVariable object, and the underlying components may
|
||||
# not have been initialized yet.
|
||||
# The reason we do not use isinstance() is due to cyclic dependency issue.
|
||||
if name is None:
|
||||
name = str("distributed_variable")
|
||||
return self._capture_helper(tensor, name)
|
||||
if (getattr(tensor, "graph", None) is not self and
|
||||
hasattr(self, "_forward_func_graph") and
|
||||
isinstance(self._forward_func_graph, FuncGraph)):
|
||||
@ -605,6 +594,12 @@ class FuncGraph(ops.Graph):
|
||||
lambda x: [x])
|
||||
return captured_tensor
|
||||
|
||||
def capture_distributed_variable(self, variable, placeholder):
|
||||
"""Add given distributed variable to captures with given placeholder."""
|
||||
self.captures[variable] = placeholder
|
||||
tape.record_operation("captured_value", [placeholder], [variable],
|
||||
lambda x: [x])
|
||||
|
||||
@property
|
||||
def external_captures(self):
|
||||
"""External tensors captured by this function."""
|
||||
|
@ -176,22 +176,26 @@ class Loader(object):
|
||||
if bound_inputs:
|
||||
for bound_input, internal_capture in zip(
|
||||
bound_inputs, concrete_function.inputs[-len(bound_inputs):]):
|
||||
concrete_function.graph.captures[bound_input] = internal_capture
|
||||
if internal_capture.dtype == dtypes.resource:
|
||||
if resource_variable_ops.is_resource_variable(bound_input):
|
||||
try:
|
||||
handle = bound_input.handle
|
||||
except ValueError:
|
||||
# For mirrored variables we'll copy handle data for components
|
||||
# as they get captured.
|
||||
pass
|
||||
if ds_values.is_distributed_variable(bound_input):
|
||||
concrete_function.graph.capture_distributed_variable(
|
||||
bound_input, internal_capture)
|
||||
else:
|
||||
concrete_function.graph.captures[bound_input] = internal_capture
|
||||
if internal_capture.dtype == dtypes.resource:
|
||||
if resource_variable_ops.is_resource_variable(bound_input):
|
||||
try:
|
||||
handle = bound_input.handle
|
||||
except ValueError:
|
||||
# For mirrored variables we'll copy handle data for components
|
||||
# as they get captured.
|
||||
pass
|
||||
else:
|
||||
custom_gradient.copy_handle_data(handle, internal_capture)
|
||||
else:
|
||||
custom_gradient.copy_handle_data(handle, internal_capture)
|
||||
else:
|
||||
custom_gradient.copy_handle_data(bound_input, internal_capture)
|
||||
# Setting "captures" first means "capture" won't create a new
|
||||
# placeholder for this input.
|
||||
concrete_function.graph.capture(bound_input)
|
||||
custom_gradient.copy_handle_data(bound_input, internal_capture)
|
||||
# Setting "captures" first means "capture" won't create a new
|
||||
# placeholder for this input.
|
||||
concrete_function.graph.capture(bound_input)
|
||||
|
||||
def _get_tensor_from_node(self, node_id):
|
||||
"""Resolves a node id into a tensor to be captured for a function."""
|
||||
|
Loading…
Reference in New Issue
Block a user