diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index d77b2696e07..80223306063 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -652,31 +652,44 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): if captured_tensor is not None: return captured_tensor - # 'tensor' is an uncaptured intermediate in the forward graph. We wrap it in - # an optional in the forward graph and capture the optional normally. We - # then unwrap the captured optional value in the gradient graph to get the - # raw intermediate value. + # 'tensor' is an uncaptured intermediate in the forward graph. + # If it is not a resource, we wrap it in an optional in the forward graph + # and capture the optional normally. We then unwrap the captured optional + # value in the gradient graph to get the raw intermediate value. + # If it is a resource, we trace the resource upto the input in the forward + # graph and capture that. - if tensor not in self._wrapped_intermediates: - # If the gradient has already been computed for this If op, 'tensor' may - # already be wrapped. - for consumer in tensor.consumers(): - if (consumer.type == "OptionalFromValue" - and consumer.outputs[0] in self._forward_graph.outputs): - optional = consumer.outputs[0] - break - else: - # 'tensor' hasn't been wrapped, do it now. - with self._forward_graph.as_default(): - optional = gen_dataset_ops.optional_from_value([tensor]) - self.if_op_needs_rewrite = True + if tensor.dtype == dtypes.resource: + # Index of the forward graph input corresponding to the resource tensor. + index = util.resource_input_index( + tensor.name, [t.name for t in self._forward_graph.inputs], + {op.name: op.node_def for op in self._forward_graph.get_operations()}, + self._forward_graph._functions) + # This gets mapped to the corresponding If op input in + # `_resolve_grad_inputs`. + captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( + self._forward_graph.inputs[index], name) + else: + if tensor not in self._wrapped_intermediates: + # If the gradient has already been computed for this If op, 'tensor' may + # already be wrapped. + for consumer in tensor.consumers(): + if (consumer.type == "OptionalFromValue" and + consumer.outputs[0] in self._forward_graph.outputs): + optional = consumer.outputs[0] + break + else: + # 'tensor' hasn't been wrapped, do it now. + with self._forward_graph.as_default(): + optional = gen_dataset_ops.optional_from_value([tensor]) + self.if_op_needs_rewrite = True + self._wrapped_intermediates[tensor] = optional - self._wrapped_intermediates[tensor] = optional + optional = self._wrapped_intermediates[tensor] + captured_optional = super(_CondGradFuncGraph, + self)._capture_helper(optional, name) + captured_tensor = gen_dataset_ops.optional_get_value( + captured_optional, [tensor.dtype], [tensor.shape])[0] - optional = self._wrapped_intermediates[tensor] - captured_optional = super(_CondGradFuncGraph, self)._capture_helper( - optional, name) - captured_tensor = gen_dataset_ops.optional_get_value( - captured_optional, [tensor.dtype], [tensor.shape])[0] self._indirect_captures[tensor] = captured_tensor return captured_tensor diff --git a/tensorflow/python/ops/control_flow_util_v2.py b/tensorflow/python/ops/control_flow_util_v2.py index cd37419906b..38915e272a8 100644 --- a/tensorflow/python/ops/control_flow_util_v2.py +++ b/tensorflow/python/ops/control_flow_util_v2.py @@ -138,3 +138,70 @@ def maybe_propagate_compile_time_consts_in_xla(op): op._set_attr("_xla_propagate_compile_time_consts", attr_value_pb2.AttrValue(b=True)) # pylint: enable=protected-access + + +def resource_input_index(tensor_name, input_names, node_defs, functions): + """Returns the index of the input corresponding to `tensor_name`. + + This method is used to find the corresponding index of an arbitrary resource + tensor in a function (the function could be a loop body). We assume that + resource handles are never created in functions, so that every resource + tensor can be traced back to a function input. + + The awkward signature of this method is to make it work with both FuncGraphs + and FunctionDefs. This is so we can recurse on function call ops without + building the corresponding FuncGraph (note that even if a FuncGraph for a + FunctionDef already exists, the input/output/node names may have been + changed when the FuncGraph was serialized to the FunctionDef, which makes it + unusable with this algorithm). + + Args: + tensor_name: the name of the resource tensor to be resolved to an input. + input_names: a list of the names of all inputs to the function. + node_defs: a dict mapping op name -> NodeDef for every op in the function. + functions: a dict mapping function name -> _EagerDefinedFunction. + + Returns: + The index into input_names corresponding to `tensor_name`. + """ + while tensor_name not in input_names: + # FunctionDefs and graphs use different tensor naming conventions. + parts = tensor_name.split(":") + if len(parts) == 3: + op_name, _, output_idx = parts + elif len(parts) == 2: + op_name, output_idx = parts + else: + assert len(parts) == 1 + op_name = parts[0] + output_idx = 0 + output_idx = int(output_idx) + node_def = node_defs[op_name] + + if node_def.op == "While": + # Captured resources occur at the same index in the lists of inputs and + # outputs of a while op. So we lookup the input of `tensor.op` at the + # same index as the index of `tensor` in the `tensor.op.outputs`. + tensor_name = node_def.input[output_idx] + elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): + # Functions output any captured resource tensors used by their + # gradients. `tensor_name` is one of these outputs from a nested + # function call, so recursively find the corresponding input in the + # nested FunctionDef. + func_name = node_def.attr["f"].func.name + fdef = functions[func_name].definition + output_arg_name = fdef.signature.output_arg[output_idx].name + output_tensor_name = fdef.ret[output_arg_name] + input_index = resource_input_index( + output_tensor_name, [arg.name for arg in fdef.signature.input_arg], + {ndef.name: ndef for ndef in fdef.node_def}, functions) + tensor_name = node_def.input[input_index] + else: + # We assume there are no other ops types that will "forward" resource + # handles like this, so all other handles must have been created by the + # op. (Note that cond_v2 wraps resource handle outputs in optionals, + # which we'll end up accumulating). + raise ValueError("Taking gradient of a while loop which creates " + "a resource in its body is not supported: %s" % op_name) + + return input_names.index(tensor_name) diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 122f275b98e..f6c70b406e0 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -810,9 +810,8 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): """ assert tensor.dtype == dtypes.resource - index = self._resource_input_index( - tensor.name, - [t.name for t in self._forward_graph.inputs], + index = util.resource_input_index( + tensor.name, [t.name for t in self._forward_graph.inputs], {op.name: op.node_def for op in self._forward_graph.get_operations()}, self._forward_graph._functions) @@ -830,76 +829,6 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): tensor_in_outer_graph, whitelisted=True) return self._indirect_captures[tensor] - def _resource_input_index(self, tensor_name, input_names, node_defs, - functions): - """Returns the index of the input corresponding to `tensor_name`. - - This method is used to find the corresponding index of an arbitrary resource - tensor in a function (the function could be a loop body). We assume that - resource handles are never created in functions, so that every resource - tensor can be traced back to a function input. - - The awkward signature of this method is to make it work with both FuncGraphs - and FunctionDefs. This is so we can recurse on function call ops without - building the corresponding FuncGraph (note that even if a FuncGraph for a - FunctionDef already exists, the input/output/node names may have been - changed when the FuncGraph was serialized to the FunctionDef, which makes it - unusable with this algorithm). - - Args: - tensor_name: the name of the resource tensor to be resolved to an input. - input_names: a list of the names of all inputs to the function. - node_defs: a dict mapping op name -> NodeDef for every op in the function. - functions: a dict mapping function name -> _EagerDefinedFunction. - - Returns: - The index into input_names corresponding to `tensor_name`. - """ - while tensor_name not in input_names: - # FunctionDefs and graphs use different tensor naming conventions. - parts = tensor_name.split(":") - if len(parts) == 3: - op_name, _, output_idx = parts - elif len(parts) == 2: - op_name, output_idx = parts - else: - assert len(parts) == 1 - op_name = parts[0] - output_idx = 0 - output_idx = int(output_idx) - node_def = node_defs[op_name] - - if node_def.op == "While": - # Captured resources occur at the same index in the lists of inputs and - # outputs of a while op. So we lookup the input of `tensor.op` at the - # same index as the index of `tensor` in the `tensor.op.outputs`. - tensor_name = node_def.input[output_idx] - elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): - # Functions output any captured resource tensors used by their - # gradients. `tensor_name` is one of these outputs from a nested - # function call, so recursively find the corresponding input in the - # nested FunctionDef. - func_name = node_def.attr["f"].func.name - fdef = functions[func_name].definition - output_arg_name = fdef.signature.output_arg[output_idx].name - output_tensor_name = fdef.ret[output_arg_name] - input_index = self._resource_input_index( - output_tensor_name, - [arg.name for arg in fdef.signature.input_arg], - {ndef.name: ndef for ndef in fdef.node_def}, - functions) - tensor_name = node_def.input[input_index] - else: - # We assume there are no other ops types that will "forward" resource - # handles like this, so all other handles must have been created by the - # op. (Note that cond_v2 wraps resource handle outputs in optionals, - # which we'll end up accumulating). - raise ValueError( - "Taking gradient of a while loop which creates " - "a resource in its body is not supported: %s" % op_name) - - return input_names.index(tensor_name) - def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): for (t, shape, input_t) in zip(output_tensors, shape_invariants,