Merge pull request #28826 from astropeak:issue-27492-custom-gradient

PiperOrigin-RevId: 249263479
This commit is contained in:
TensorFlower Gardener 2019-05-21 09:34:49 -07:00
commit 269efd353a

View File

@ -210,11 +210,12 @@ def _graph_mode_decorator(f, *args, **kwargs):
logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
"no ResourceVariables were used on the forward pass.")
flat_result = nest.flatten(result)
flat_result_len = len(flat_result)
all_tensors = flat_result + args + variables
def tape_grad_fn(*result_grads):
"""Custom grad fn wrapper."""
result_grads = result_grads[:len(flat_result)]
result_grads = result_grads[:flat_result_len]
if variables:
input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
if len(variable_grads) != len(variables):
@ -228,7 +229,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
# gradients of the inputs of the custom_gradient function with the
# gradients of the outputs as well.
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
return ([None] * flat_result_len) + input_grads + variable_grads
@ops.RegisterGradient(name)
def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable
@ -250,7 +251,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
for ot, t in zip(original_tensors, all_tensors):
copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
structure=result, flat_sequence=all_tensors[:flat_result_len])
def _eager_mode_decorator(f, *args, **kwargs):