Do not wrap Resource tensors inside optionals in cond_v2.
This makes it consistent with while_v2 and avoids weird issues with copying Resources wrapped in Optionals across devices. PiperOrigin-RevId: 241830574
This commit is contained in:
parent
cc686693f3
commit
34fc260230
@ -652,31 +652,44 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
if captured_tensor is not None:
|
if captured_tensor is not None:
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
|
||||||
# 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
|
# 'tensor' is an uncaptured intermediate in the forward graph.
|
||||||
# an optional in the forward graph and capture the optional normally. We
|
# If it is not a resource, we wrap it in an optional in the forward graph
|
||||||
# then unwrap the captured optional value in the gradient graph to get the
|
# and capture the optional normally. We then unwrap the captured optional
|
||||||
# raw intermediate value.
|
# value in the gradient graph to get the raw intermediate value.
|
||||||
|
# If it is a resource, we trace the resource upto the input in the forward
|
||||||
|
# graph and capture that.
|
||||||
|
|
||||||
if tensor not in self._wrapped_intermediates:
|
if tensor.dtype == dtypes.resource:
|
||||||
# If the gradient has already been computed for this If op, 'tensor' may
|
# Index of the forward graph input corresponding to the resource tensor.
|
||||||
# already be wrapped.
|
index = util.resource_input_index(
|
||||||
for consumer in tensor.consumers():
|
tensor.name, [t.name for t in self._forward_graph.inputs],
|
||||||
if (consumer.type == "OptionalFromValue"
|
{op.name: op.node_def for op in self._forward_graph.get_operations()},
|
||||||
and consumer.outputs[0] in self._forward_graph.outputs):
|
self._forward_graph._functions)
|
||||||
optional = consumer.outputs[0]
|
# This gets mapped to the corresponding If op input in
|
||||||
break
|
# `_resolve_grad_inputs`.
|
||||||
else:
|
captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
|
||||||
# 'tensor' hasn't been wrapped, do it now.
|
self._forward_graph.inputs[index], name)
|
||||||
with self._forward_graph.as_default():
|
else:
|
||||||
optional = gen_dataset_ops.optional_from_value([tensor])
|
if tensor not in self._wrapped_intermediates:
|
||||||
self.if_op_needs_rewrite = True
|
# If the gradient has already been computed for this If op, 'tensor' may
|
||||||
|
# already be wrapped.
|
||||||
|
for consumer in tensor.consumers():
|
||||||
|
if (consumer.type == "OptionalFromValue" and
|
||||||
|
consumer.outputs[0] in self._forward_graph.outputs):
|
||||||
|
optional = consumer.outputs[0]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 'tensor' hasn't been wrapped, do it now.
|
||||||
|
with self._forward_graph.as_default():
|
||||||
|
optional = gen_dataset_ops.optional_from_value([tensor])
|
||||||
|
self.if_op_needs_rewrite = True
|
||||||
|
self._wrapped_intermediates[tensor] = optional
|
||||||
|
|
||||||
self._wrapped_intermediates[tensor] = optional
|
optional = self._wrapped_intermediates[tensor]
|
||||||
|
captured_optional = super(_CondGradFuncGraph,
|
||||||
|
self)._capture_helper(optional, name)
|
||||||
|
captured_tensor = gen_dataset_ops.optional_get_value(
|
||||||
|
captured_optional, [tensor.dtype], [tensor.shape])[0]
|
||||||
|
|
||||||
optional = self._wrapped_intermediates[tensor]
|
|
||||||
captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
|
|
||||||
optional, name)
|
|
||||||
captured_tensor = gen_dataset_ops.optional_get_value(
|
|
||||||
captured_optional, [tensor.dtype], [tensor.shape])[0]
|
|
||||||
self._indirect_captures[tensor] = captured_tensor
|
self._indirect_captures[tensor] = captured_tensor
|
||||||
return captured_tensor
|
return captured_tensor
|
||||||
|
@ -138,3 +138,70 @@ def maybe_propagate_compile_time_consts_in_xla(op):
|
|||||||
op._set_attr("_xla_propagate_compile_time_consts",
|
op._set_attr("_xla_propagate_compile_time_consts",
|
||||||
attr_value_pb2.AttrValue(b=True))
|
attr_value_pb2.AttrValue(b=True))
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
def resource_input_index(tensor_name, input_names, node_defs, functions):
|
||||||
|
"""Returns the index of the input corresponding to `tensor_name`.
|
||||||
|
|
||||||
|
This method is used to find the corresponding index of an arbitrary resource
|
||||||
|
tensor in a function (the function could be a loop body). We assume that
|
||||||
|
resource handles are never created in functions, so that every resource
|
||||||
|
tensor can be traced back to a function input.
|
||||||
|
|
||||||
|
The awkward signature of this method is to make it work with both FuncGraphs
|
||||||
|
and FunctionDefs. This is so we can recurse on function call ops without
|
||||||
|
building the corresponding FuncGraph (note that even if a FuncGraph for a
|
||||||
|
FunctionDef already exists, the input/output/node names may have been
|
||||||
|
changed when the FuncGraph was serialized to the FunctionDef, which makes it
|
||||||
|
unusable with this algorithm).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_name: the name of the resource tensor to be resolved to an input.
|
||||||
|
input_names: a list of the names of all inputs to the function.
|
||||||
|
node_defs: a dict mapping op name -> NodeDef for every op in the function.
|
||||||
|
functions: a dict mapping function name -> _EagerDefinedFunction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The index into input_names corresponding to `tensor_name`.
|
||||||
|
"""
|
||||||
|
while tensor_name not in input_names:
|
||||||
|
# FunctionDefs and graphs use different tensor naming conventions.
|
||||||
|
parts = tensor_name.split(":")
|
||||||
|
if len(parts) == 3:
|
||||||
|
op_name, _, output_idx = parts
|
||||||
|
elif len(parts) == 2:
|
||||||
|
op_name, output_idx = parts
|
||||||
|
else:
|
||||||
|
assert len(parts) == 1
|
||||||
|
op_name = parts[0]
|
||||||
|
output_idx = 0
|
||||||
|
output_idx = int(output_idx)
|
||||||
|
node_def = node_defs[op_name]
|
||||||
|
|
||||||
|
if node_def.op == "While":
|
||||||
|
# Captured resources occur at the same index in the lists of inputs and
|
||||||
|
# outputs of a while op. So we lookup the input of `tensor.op` at the
|
||||||
|
# same index as the index of `tensor` in the `tensor.op.outputs`.
|
||||||
|
tensor_name = node_def.input[output_idx]
|
||||||
|
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
|
||||||
|
# Functions output any captured resource tensors used by their
|
||||||
|
# gradients. `tensor_name` is one of these outputs from a nested
|
||||||
|
# function call, so recursively find the corresponding input in the
|
||||||
|
# nested FunctionDef.
|
||||||
|
func_name = node_def.attr["f"].func.name
|
||||||
|
fdef = functions[func_name].definition
|
||||||
|
output_arg_name = fdef.signature.output_arg[output_idx].name
|
||||||
|
output_tensor_name = fdef.ret[output_arg_name]
|
||||||
|
input_index = resource_input_index(
|
||||||
|
output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
|
||||||
|
{ndef.name: ndef for ndef in fdef.node_def}, functions)
|
||||||
|
tensor_name = node_def.input[input_index]
|
||||||
|
else:
|
||||||
|
# We assume there are no other ops types that will "forward" resource
|
||||||
|
# handles like this, so all other handles must have been created by the
|
||||||
|
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
|
||||||
|
# which we'll end up accumulating).
|
||||||
|
raise ValueError("Taking gradient of a while loop which creates "
|
||||||
|
"a resource in its body is not supported: %s" % op_name)
|
||||||
|
|
||||||
|
return input_names.index(tensor_name)
|
||||||
|
@ -810,9 +810,8 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
"""
|
"""
|
||||||
assert tensor.dtype == dtypes.resource
|
assert tensor.dtype == dtypes.resource
|
||||||
|
|
||||||
index = self._resource_input_index(
|
index = util.resource_input_index(
|
||||||
tensor.name,
|
tensor.name, [t.name for t in self._forward_graph.inputs],
|
||||||
[t.name for t in self._forward_graph.inputs],
|
|
||||||
{op.name: op.node_def for op in self._forward_graph.get_operations()},
|
{op.name: op.node_def for op in self._forward_graph.get_operations()},
|
||||||
self._forward_graph._functions)
|
self._forward_graph._functions)
|
||||||
|
|
||||||
@ -830,76 +829,6 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
|||||||
tensor_in_outer_graph, whitelisted=True)
|
tensor_in_outer_graph, whitelisted=True)
|
||||||
return self._indirect_captures[tensor]
|
return self._indirect_captures[tensor]
|
||||||
|
|
||||||
def _resource_input_index(self, tensor_name, input_names, node_defs,
|
|
||||||
functions):
|
|
||||||
"""Returns the index of the input corresponding to `tensor_name`.
|
|
||||||
|
|
||||||
This method is used to find the corresponding index of an arbitrary resource
|
|
||||||
tensor in a function (the function could be a loop body). We assume that
|
|
||||||
resource handles are never created in functions, so that every resource
|
|
||||||
tensor can be traced back to a function input.
|
|
||||||
|
|
||||||
The awkward signature of this method is to make it work with both FuncGraphs
|
|
||||||
and FunctionDefs. This is so we can recurse on function call ops without
|
|
||||||
building the corresponding FuncGraph (note that even if a FuncGraph for a
|
|
||||||
FunctionDef already exists, the input/output/node names may have been
|
|
||||||
changed when the FuncGraph was serialized to the FunctionDef, which makes it
|
|
||||||
unusable with this algorithm).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor_name: the name of the resource tensor to be resolved to an input.
|
|
||||||
input_names: a list of the names of all inputs to the function.
|
|
||||||
node_defs: a dict mapping op name -> NodeDef for every op in the function.
|
|
||||||
functions: a dict mapping function name -> _EagerDefinedFunction.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The index into input_names corresponding to `tensor_name`.
|
|
||||||
"""
|
|
||||||
while tensor_name not in input_names:
|
|
||||||
# FunctionDefs and graphs use different tensor naming conventions.
|
|
||||||
parts = tensor_name.split(":")
|
|
||||||
if len(parts) == 3:
|
|
||||||
op_name, _, output_idx = parts
|
|
||||||
elif len(parts) == 2:
|
|
||||||
op_name, output_idx = parts
|
|
||||||
else:
|
|
||||||
assert len(parts) == 1
|
|
||||||
op_name = parts[0]
|
|
||||||
output_idx = 0
|
|
||||||
output_idx = int(output_idx)
|
|
||||||
node_def = node_defs[op_name]
|
|
||||||
|
|
||||||
if node_def.op == "While":
|
|
||||||
# Captured resources occur at the same index in the lists of inputs and
|
|
||||||
# outputs of a while op. So we lookup the input of `tensor.op` at the
|
|
||||||
# same index as the index of `tensor` in the `tensor.op.outputs`.
|
|
||||||
tensor_name = node_def.input[output_idx]
|
|
||||||
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
|
|
||||||
# Functions output any captured resource tensors used by their
|
|
||||||
# gradients. `tensor_name` is one of these outputs from a nested
|
|
||||||
# function call, so recursively find the corresponding input in the
|
|
||||||
# nested FunctionDef.
|
|
||||||
func_name = node_def.attr["f"].func.name
|
|
||||||
fdef = functions[func_name].definition
|
|
||||||
output_arg_name = fdef.signature.output_arg[output_idx].name
|
|
||||||
output_tensor_name = fdef.ret[output_arg_name]
|
|
||||||
input_index = self._resource_input_index(
|
|
||||||
output_tensor_name,
|
|
||||||
[arg.name for arg in fdef.signature.input_arg],
|
|
||||||
{ndef.name: ndef for ndef in fdef.node_def},
|
|
||||||
functions)
|
|
||||||
tensor_name = node_def.input[input_index]
|
|
||||||
else:
|
|
||||||
# We assume there are no other ops types that will "forward" resource
|
|
||||||
# handles like this, so all other handles must have been created by the
|
|
||||||
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
|
|
||||||
# which we'll end up accumulating).
|
|
||||||
raise ValueError(
|
|
||||||
"Taking gradient of a while loop which creates "
|
|
||||||
"a resource in its body is not supported: %s" % op_name)
|
|
||||||
|
|
||||||
return input_names.index(tensor_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
|
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
|
||||||
for (t, shape, input_t) in zip(output_tensors, shape_invariants,
|
for (t, shape, input_t) in zip(output_tensors, shape_invariants,
|
||||||
|
Loading…
Reference in New Issue
Block a user