Do not wrap Resource tensors inside optionals in cond_v2.

This makes it consistent with while_v2 and avoids weird issues with copying Resources wrapped in Optionals across devices.

PiperOrigin-RevId: 241830574
This commit is contained in:
Saurabh Saxena 2019-04-03 16:35:43 -07:00 committed by TensorFlower Gardener
parent cc686693f3
commit 34fc260230
3 changed files with 105 additions and 96 deletions

View File

@ -652,17 +652,30 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
if captured_tensor is not None: if captured_tensor is not None:
return captured_tensor return captured_tensor
# 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in # 'tensor' is an uncaptured intermediate in the forward graph.
# an optional in the forward graph and capture the optional normally. We # If it is not a resource, we wrap it in an optional in the forward graph
# then unwrap the captured optional value in the gradient graph to get the # and capture the optional normally. We then unwrap the captured optional
# raw intermediate value. # value in the gradient graph to get the raw intermediate value.
# If it is a resource, we trace the resource upto the input in the forward
# graph and capture that.
if tensor.dtype == dtypes.resource:
# Index of the forward graph input corresponding to the resource tensor.
index = util.resource_input_index(
tensor.name, [t.name for t in self._forward_graph.inputs],
{op.name: op.node_def for op in self._forward_graph.get_operations()},
self._forward_graph._functions)
# This gets mapped to the corresponding If op input in
# `_resolve_grad_inputs`.
captured_tensor = super(_CondGradFuncGraph, self)._capture_helper(
self._forward_graph.inputs[index], name)
else:
if tensor not in self._wrapped_intermediates: if tensor not in self._wrapped_intermediates:
# If the gradient has already been computed for this If op, 'tensor' may # If the gradient has already been computed for this If op, 'tensor' may
# already be wrapped. # already be wrapped.
for consumer in tensor.consumers(): for consumer in tensor.consumers():
if (consumer.type == "OptionalFromValue" if (consumer.type == "OptionalFromValue" and
and consumer.outputs[0] in self._forward_graph.outputs): consumer.outputs[0] in self._forward_graph.outputs):
optional = consumer.outputs[0] optional = consumer.outputs[0]
break break
else: else:
@ -670,13 +683,13 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
with self._forward_graph.as_default(): with self._forward_graph.as_default():
optional = gen_dataset_ops.optional_from_value([tensor]) optional = gen_dataset_ops.optional_from_value([tensor])
self.if_op_needs_rewrite = True self.if_op_needs_rewrite = True
self._wrapped_intermediates[tensor] = optional self._wrapped_intermediates[tensor] = optional
optional = self._wrapped_intermediates[tensor] optional = self._wrapped_intermediates[tensor]
captured_optional = super(_CondGradFuncGraph, self)._capture_helper( captured_optional = super(_CondGradFuncGraph,
optional, name) self)._capture_helper(optional, name)
captured_tensor = gen_dataset_ops.optional_get_value( captured_tensor = gen_dataset_ops.optional_get_value(
captured_optional, [tensor.dtype], [tensor.shape])[0] captured_optional, [tensor.dtype], [tensor.shape])[0]
self._indirect_captures[tensor] = captured_tensor self._indirect_captures[tensor] = captured_tensor
return captured_tensor return captured_tensor

View File

@ -138,3 +138,70 @@ def maybe_propagate_compile_time_consts_in_xla(op):
op._set_attr("_xla_propagate_compile_time_consts", op._set_attr("_xla_propagate_compile_time_consts",
attr_value_pb2.AttrValue(b=True)) attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access # pylint: enable=protected-access
def resource_input_index(tensor_name, input_names, node_defs, functions):
"""Returns the index of the input corresponding to `tensor_name`.
This method is used to find the corresponding index of an arbitrary resource
tensor in a function (the function could be a loop body). We assume that
resource handles are never created in functions, so that every resource
tensor can be traced back to a function input.
The awkward signature of this method is to make it work with both FuncGraphs
and FunctionDefs. This is so we can recurse on function call ops without
building the corresponding FuncGraph (note that even if a FuncGraph for a
FunctionDef already exists, the input/output/node names may have been
changed when the FuncGraph was serialized to the FunctionDef, which makes it
unusable with this algorithm).
Args:
tensor_name: the name of the resource tensor to be resolved to an input.
input_names: a list of the names of all inputs to the function.
node_defs: a dict mapping op name -> NodeDef for every op in the function.
functions: a dict mapping function name -> _EagerDefinedFunction.
Returns:
The index into input_names corresponding to `tensor_name`.
"""
while tensor_name not in input_names:
# FunctionDefs and graphs use different tensor naming conventions.
parts = tensor_name.split(":")
if len(parts) == 3:
op_name, _, output_idx = parts
elif len(parts) == 2:
op_name, output_idx = parts
else:
assert len(parts) == 1
op_name = parts[0]
output_idx = 0
output_idx = int(output_idx)
node_def = node_defs[op_name]
if node_def.op == "While":
# Captured resources occur at the same index in the lists of inputs and
# outputs of a while op. So we lookup the input of `tensor.op` at the
# same index as the index of `tensor` in the `tensor.op.outputs`.
tensor_name = node_def.input[output_idx]
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
# Functions output any captured resource tensors used by their
# gradients. `tensor_name` is one of these outputs from a nested
# function call, so recursively find the corresponding input in the
# nested FunctionDef.
func_name = node_def.attr["f"].func.name
fdef = functions[func_name].definition
output_arg_name = fdef.signature.output_arg[output_idx].name
output_tensor_name = fdef.ret[output_arg_name]
input_index = resource_input_index(
output_tensor_name, [arg.name for arg in fdef.signature.input_arg],
{ndef.name: ndef for ndef in fdef.node_def}, functions)
tensor_name = node_def.input[input_index]
else:
# We assume there are no other ops types that will "forward" resource
# handles like this, so all other handles must have been created by the
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
# which we'll end up accumulating).
raise ValueError("Taking gradient of a while loop which creates "
"a resource in its body is not supported: %s" % op_name)
return input_names.index(tensor_name)

View File

@ -810,9 +810,8 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
""" """
assert tensor.dtype == dtypes.resource assert tensor.dtype == dtypes.resource
index = self._resource_input_index( index = util.resource_input_index(
tensor.name, tensor.name, [t.name for t in self._forward_graph.inputs],
[t.name for t in self._forward_graph.inputs],
{op.name: op.node_def for op in self._forward_graph.get_operations()}, {op.name: op.node_def for op in self._forward_graph.get_operations()},
self._forward_graph._functions) self._forward_graph._functions)
@ -830,76 +829,6 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph):
tensor_in_outer_graph, whitelisted=True) tensor_in_outer_graph, whitelisted=True)
return self._indirect_captures[tensor] return self._indirect_captures[tensor]
def _resource_input_index(self, tensor_name, input_names, node_defs,
functions):
"""Returns the index of the input corresponding to `tensor_name`.
This method is used to find the corresponding index of an arbitrary resource
tensor in a function (the function could be a loop body). We assume that
resource handles are never created in functions, so that every resource
tensor can be traced back to a function input.
The awkward signature of this method is to make it work with both FuncGraphs
and FunctionDefs. This is so we can recurse on function call ops without
building the corresponding FuncGraph (note that even if a FuncGraph for a
FunctionDef already exists, the input/output/node names may have been
changed when the FuncGraph was serialized to the FunctionDef, which makes it
unusable with this algorithm).
Args:
tensor_name: the name of the resource tensor to be resolved to an input.
input_names: a list of the names of all inputs to the function.
node_defs: a dict mapping op name -> NodeDef for every op in the function.
functions: a dict mapping function name -> _EagerDefinedFunction.
Returns:
The index into input_names corresponding to `tensor_name`.
"""
while tensor_name not in input_names:
# FunctionDefs and graphs use different tensor naming conventions.
parts = tensor_name.split(":")
if len(parts) == 3:
op_name, _, output_idx = parts
elif len(parts) == 2:
op_name, output_idx = parts
else:
assert len(parts) == 1
op_name = parts[0]
output_idx = 0
output_idx = int(output_idx)
node_def = node_defs[op_name]
if node_def.op == "While":
# Captured resources occur at the same index in the lists of inputs and
# outputs of a while op. So we lookup the input of `tensor.op` at the
# same index as the index of `tensor` in the `tensor.op.outputs`.
tensor_name = node_def.input[output_idx]
elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
# Functions output any captured resource tensors used by their
# gradients. `tensor_name` is one of these outputs from a nested
# function call, so recursively find the corresponding input in the
# nested FunctionDef.
func_name = node_def.attr["f"].func.name
fdef = functions[func_name].definition
output_arg_name = fdef.signature.output_arg[output_idx].name
output_tensor_name = fdef.ret[output_arg_name]
input_index = self._resource_input_index(
output_tensor_name,
[arg.name for arg in fdef.signature.input_arg],
{ndef.name: ndef for ndef in fdef.node_def},
functions)
tensor_name = node_def.input[input_index]
else:
# We assume there are no other ops types that will "forward" resource
# handles like this, so all other handles must have been created by the
# op. (Note that cond_v2 wraps resource handle outputs in optionals,
# which we'll end up accumulating).
raise ValueError(
"Taking gradient of a while loop which creates "
"a resource in its body is not supported: %s" % op_name)
return input_names.index(tensor_name)
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
for (t, shape, input_t) in zip(output_tensors, shape_invariants, for (t, shape, input_t) in zip(output_tensors, shape_invariants,