Merge pull request #28826 from astropeak:issue-27492-custom-gradient
PiperOrigin-RevId: 249263479
This commit is contained in:
commit
269efd353a
@ -210,11 +210,12 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
||||
logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
|
||||
"no ResourceVariables were used on the forward pass.")
|
||||
flat_result = nest.flatten(result)
|
||||
flat_result_len = len(flat_result)
|
||||
all_tensors = flat_result + args + variables
|
||||
|
||||
def tape_grad_fn(*result_grads):
|
||||
"""Custom grad fn wrapper."""
|
||||
result_grads = result_grads[:len(flat_result)]
|
||||
result_grads = result_grads[:flat_result_len]
|
||||
if variables:
|
||||
input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
|
||||
if len(variable_grads) != len(variables):
|
||||
@ -228,7 +229,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
||||
# gradients of the inputs of the custom_gradient function with the
|
||||
# gradients of the outputs as well.
|
||||
input_grads = nest.flatten(input_grads)
|
||||
return ([None] * len(flat_result)) + input_grads + variable_grads
|
||||
return ([None] * flat_result_len) + input_grads + variable_grads
|
||||
|
||||
@ops.RegisterGradient(name)
|
||||
def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable
|
||||
@ -250,7 +251,7 @@ def _graph_mode_decorator(f, *args, **kwargs):
|
||||
for ot, t in zip(original_tensors, all_tensors):
|
||||
copy_handle_data(ot, t)
|
||||
return nest.pack_sequence_as(
|
||||
structure=result, flat_sequence=all_tensors[:len(flat_result)])
|
||||
structure=result, flat_sequence=all_tensors[:flat_result_len])
|
||||
|
||||
|
||||
def _eager_mode_decorator(f, *args, **kwargs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user