From 6be615c36ec88413556a304af697a89999e6f698 Mon Sep 17 00:00:00 2001 From: Astropeak Date: Sat, 18 May 2019 21:15:11 +0800 Subject: [PATCH] Break the reference between closure tape_grad_fn and flat_result --- tensorflow/python/ops/custom_gradient.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py index 0ef72e1e927..3ea2cad12d2 100644 --- a/tensorflow/python/ops/custom_gradient.py +++ b/tensorflow/python/ops/custom_gradient.py @@ -210,11 +210,12 @@ def _graph_mode_decorator(f, *args, **kwargs): logging.warn("@custom_gradient grad_fn has 'variables' in signature, but " "no ResourceVariables were used on the forward pass.") flat_result = nest.flatten(result) + flat_result_len = len(flat_result) all_tensors = flat_result + args + variables def tape_grad_fn(*result_grads): """Custom grad fn wrapper.""" - result_grads = result_grads[:len(flat_result)] + result_grads = result_grads[:flat_result_len] if variables: input_grads, variable_grads = grad_fn(*result_grads, variables=variables) if len(variable_grads) != len(variables): @@ -228,7 +229,7 @@ def _graph_mode_decorator(f, *args, **kwargs): # gradients of the inputs of the custom_gradient function with the # gradients of the outputs as well. input_grads = nest.flatten(input_grads) - return ([None] * len(flat_result)) + input_grads + variable_grads + return ([None] * flat_result_len) + input_grads + variable_grads @ops.RegisterGradient(name) def internal_grad_fn(unused_op, *result_grads): # pylint: disable=unused-variable @@ -250,7 +251,7 @@ def _graph_mode_decorator(f, *args, **kwargs): for ot, t in zip(original_tensors, all_tensors): copy_handle_data(ot, t) return nest.pack_sequence_as( - structure=result, flat_sequence=all_tensors[:len(flat_result)]) + structure=result, flat_sequence=all_tensors[:flat_result_len]) def _eager_mode_decorator(f, *args, **kwargs):