Do not wrap Resource tensors inside optionals in cond_v2.
This makes it consistent with while_v2 and avoids weird issues with copying Resources wrapped in Optionals across devices. PiperOrigin-RevId: 241830574
This commit is contained in:
parent
cc686693f3
commit
34fc260230
@ -652,31 +652,44 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
||||
if captured_tensor is not None:
|
||||
return captured_tensor
|
||||
|
||||
# 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in
|
||||
# an optional in the forward graph and capture the optional normally. We
|
||||
# then unwrap the captured optional value in the gradient graph to get the
|
||||
# raw intermediate value.
|
||||
# 'tensor' is an uncaptured intermediate in the forward graph.
|
||||
# If it is not a resource, we wrap it in an optional in the forward graph
|
||||
# and capture the optional normally. We then unwrap the captured optional
|
||||
# value in the gradient graph to get the raw intermediate value.
|
||||
# If it is a resource, we trace the resource upto the input in the forward
|
||||
# graph and capture that.
|
||||
|
||||
if tensor not in self._wrapped_intermediates:
|
||||
# If the gradient has already been computed for this If op, 'tensor' may
|
||||
# already be wrapped.
|
||||
for consumer in tensor.consumers():
|
||||
if (consumer.type == "OptionalFromValue"
|
||||
and consumer.outputs[0] in self._forward_graph.outputs):
|
||||
optional = consumer.outputs[0]
|
||||
break
|
||||
else:
|
||||
# 'tensor' hasn't been wrapped, do it now.
|
||||
with self._forward_graph.as_default():
|
||||
optional = gen_dataset_ops.optional_from_value([tensor])
|
||||
self.if_op_needs_rewrite = True
|
||||
if tensor.dtype == dtypes.resource:
|
||||
# Index of the forward graph input corresponding to the resource tensor.
|
||||
index = util.resource_input_index(
|
||||
tensor.name, [t.name for t in self._forward_graph.inputs],
|
||||
{op.name: op.node_def for op in self._forward_graph.get_operations()},
|
||||
self._forward_graph._functions)
|
||||
# This gets mapped to the corresponding If op input in
|
||||
# `_resolve_grad_inputs`.
|
||||
captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
|
||||
self._forward_graph.inputs[index], name)
|
||||
else:
|
||||
if tensor not in self._wrapped_intermediates:
|
||||
# If the gradient has already been computed for this If op, 'tensor' may
|
||||
# already be wrapped.
|
||||
for consumer in tensor.consumers():
|
||||
if (consumer.type == "OptionalFromValue" and
|
||||
consumer.outputs[0] in self._forward_graph.outputs):
|
||||
optional = consumer.outputs[0]
|
||||
break
|
||||
else:
|
||||
# 'tensor' hasn't been wrapped, do it now.
|
||||
with self._forward_graph.as_default():
|
||||
optional = gen_dataset_ops.optional_from_value([tensor])
|
||||
self.if_op_needs_rewrite = True
|
||||
self._wrapped_intermediates[tensor] = optional
|
||||
|
||||
self._wrapped_intermediates[tensor] = optional
|
||||
optional = self._wrapped_intermediates[tensor]
|
||||
captured_optional = super(_CondGradFuncGraph,
|
||||
self)._capture_helper(optional, name)
|
||||
captured_tensor = gen_dataset_ops.optional_get_value(
|
||||
captured_optional, [tensor.dtype], [tensor.shape])[0]
|
||||
|
||||
optional = self._wrapped_intermediates[tensor]
|
||||
captured_optional = super(_CondGradFuncGraph, self)._capture_helper(
|
||||
optional, name)
|
||||
captured_tensor = gen_dataset_ops.optional_get_value(
|
||||
captured_optional, [tensor.dtype], [tensor.shape])[0]
|
||||
self._indirect_captures[tensor] = captured_tensor
|
||||
return captured_tensor
|
||||
|
@ -138,3 +138,70 @@ def maybe_propagate_compile_time_consts_in_xla(op):
|
||||
op._set_attr("_xla_propagate_compile_time_consts",
|
||||
attr_value_pb2.AttrValue(b=True))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def resource_input_index(tensor_name, input_names, node_defs, functions):
|
||||
"""Returns the index of the input corresponding to `tensor_name`.
|
||||
|
||||
This method is used to find the corresponding index of an arbitrary resource
|
||||
tensor in a function (the function could be a loop body). We assume that
|
||||
resource handles are never created in functions, so that every resource
|
||||
tensor can be traced back to a function input.
|
||||
|
||||
The awkward signature of this method is to make it work with both FuncGraphs
|
||||
and FunctionDefs. This is so we can recurse on function call ops without
|
||||
building the corresponding FuncGraph (note that even if a FuncGraph for a
|
||||
FunctionDef already exists, the input/output/node names may have been
|
||||
changed when the FuncGraph was serialized to the FunctionDef, which makes it
|
||||
unusable with this algorithm).
|
||||
|
||||
Args:
|
||||
tensor_name: the name of the resource tensor to be resolved to an input.
|
||||
input_names: a list of the names of all inputs to the function.
|
||||
node_defs: a dict mapping op name -> NodeDef for every op in the function.
|
||||
functions: a dict mapping function name -> _EagerDefinedFunction.
|
||||
|
||||
Returns:
|
||||
The index into input_names corresponding to `tensor_name`.
|
||||
"""
|
||||
while tensor_name not in input_names:
|
||||
# FunctionDefs and graphs use different tensor naming conventions.
|
||||
parts = tensor_name.split(":")
|
||||
if len(parts) == 3:
|
||||
op_name, _, output_idx = parts
|
||||
elif len(parts) == 2:
|
||||
op_name, output_idx = parts
|
||||
else:
|
||||
assert len(parts) == 1
|
||||
op_name = parts[0]
|
||||
output_idx = 0
|
||||
output_idx = int(output_idx)
|
||||
node_def = node_defs[op_name]
|
||||
|
||||
if node_def.op == "While":
|
||||
# Captured resources occur at the same index in the lists of inputs and
|
||||
# outputs of a while op. So we lookup the input of `tensor.op` at the
|
||||
# same index as the index of `tensor` in the `tensor.op.outputs`.
|
||||
tensor_name = node_def.input[output_idx]
|
||||
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
|
||||
# Functions output any captured resource tensors used by their
|
||||
# gradients. `tensor_name` is one of these outputs from a nested
|
||||
# function call, so recursively find the corresponding input in the
|
||||
# nested FunctionDef.
|
||||
func_name = node_def.attr["f"].func.name
|
||||
fdef = functions[func_name].definition
|
||||
output_arg_name = fdef.signature.output_arg[output_idx].name
|
||||
output_tensor_name = fdef.ret[output_arg_name]
|
||||
input_index = resource_input_index(
|
||||
output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
|
||||
{ndef.name: ndef for ndef in fdef.node_def}, functions)
|
||||
tensor_name = node_def.input[input_index]
|
||||
else:
|
||||
# We assume there are no other ops types that will "forward" resource
|
||||
# handles like this, so all other handles must have been created by the
|
||||
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
|
||||
# which we'll end up accumulating).
|
||||
raise ValueError("Taking gradient of a while loop which creates "
|
||||
"a resource in its body is not supported: %s" % op_name)
|
||||
|
||||
return input_names.index(tensor_name)
|
||||
|
@ -810,9 +810,8 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
"""
|
||||
assert tensor.dtype == dtypes.resource
|
||||
|
||||
index = self._resource_input_index(
|
||||
tensor.name,
|
||||
[t.name for t in self._forward_graph.inputs],
|
||||
index = util.resource_input_index(
|
||||
tensor.name, [t.name for t in self._forward_graph.inputs],
|
||||
{op.name: op.node_def for op in self._forward_graph.get_operations()},
|
||||
self._forward_graph._functions)
|
||||
|
||||
@ -830,76 +829,6 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
|
||||
tensor_in_outer_graph, whitelisted=True)
|
||||
return self._indirect_captures[tensor]
|
||||
|
||||
def _resource_input_index(self, tensor_name, input_names, node_defs,
|
||||
functions):
|
||||
"""Returns the index of the input corresponding to `tensor_name`.
|
||||
|
||||
This method is used to find the corresponding index of an arbitrary resource
|
||||
tensor in a function (the function could be a loop body). We assume that
|
||||
resource handles are never created in functions, so that every resource
|
||||
tensor can be traced back to a function input.
|
||||
|
||||
The awkward signature of this method is to make it work with both FuncGraphs
|
||||
and FunctionDefs. This is so we can recurse on function call ops without
|
||||
building the corresponding FuncGraph (note that even if a FuncGraph for a
|
||||
FunctionDef already exists, the input/output/node names may have been
|
||||
changed when the FuncGraph was serialized to the FunctionDef, which makes it
|
||||
unusable with this algorithm).
|
||||
|
||||
Args:
|
||||
tensor_name: the name of the resource tensor to be resolved to an input.
|
||||
input_names: a list of the names of all inputs to the function.
|
||||
node_defs: a dict mapping op name -> NodeDef for every op in the function.
|
||||
functions: a dict mapping function name -> _EagerDefinedFunction.
|
||||
|
||||
Returns:
|
||||
The index into input_names corresponding to `tensor_name`.
|
||||
"""
|
||||
while tensor_name not in input_names:
|
||||
# FunctionDefs and graphs use different tensor naming conventions.
|
||||
parts = tensor_name.split(":")
|
||||
if len(parts) == 3:
|
||||
op_name, _, output_idx = parts
|
||||
elif len(parts) == 2:
|
||||
op_name, output_idx = parts
|
||||
else:
|
||||
assert len(parts) == 1
|
||||
op_name = parts[0]
|
||||
output_idx = 0
|
||||
output_idx = int(output_idx)
|
||||
node_def = node_defs[op_name]
|
||||
|
||||
if node_def.op == "While":
|
||||
# Captured resources occur at the same index in the lists of inputs and
|
||||
# outputs of a while op. So we lookup the input of `tensor.op` at the
|
||||
# same index as the index of `tensor` in the `tensor.op.outputs`.
|
||||
tensor_name = node_def.input[output_idx]
|
||||
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
|
||||
# Functions output any captured resource tensors used by their
|
||||
# gradients. `tensor_name` is one of these outputs from a nested
|
||||
# function call, so recursively find the corresponding input in the
|
||||
# nested FunctionDef.
|
||||
func_name = node_def.attr["f"].func.name
|
||||
fdef = functions[func_name].definition
|
||||
output_arg_name = fdef.signature.output_arg[output_idx].name
|
||||
output_tensor_name = fdef.ret[output_arg_name]
|
||||
input_index = self._resource_input_index(
|
||||
output_tensor_name,
|
||||
[arg.name for arg in fdef.signature.input_arg],
|
||||
{ndef.name: ndef for ndef in fdef.node_def},
|
||||
functions)
|
||||
tensor_name = node_def.input[input_index]
|
||||
else:
|
||||
# We assume there are no other ops types that will "forward" resource
|
||||
# handles like this, so all other handles must have been created by the
|
||||
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
|
||||
# which we'll end up accumulating).
|
||||
raise ValueError(
|
||||
"Taking gradient of a while loop which creates "
|
||||
"a resource in its body is not supported: %s" % op_name)
|
||||
|
||||
return input_names.index(tensor_name)
|
||||
|
||||
|
||||
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
|
||||
for (t, shape, input_t) in zip(output_tensors, shape_invariants,
|
||||
|
Loading…
Reference in New Issue
Block a user