diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index 278f3640856..882c8097a0f 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -410,7 +410,7 @@ absl::optional> OpGradientUnusedInputIndices( absl::optional> OpGradientUnusedOutputIndices( const tensorflow::string &op_name) { - static std::array a = {{ + static std::array a = {{ {"Abs"}, {"AccumulateNV2"}, {"Acos"}, @@ -833,7 +833,6 @@ absl::optional> OpGradientUnusedOutputIndices( {"TensorListGather"}, {"TensorListGetItem"}, {"TensorListLength"}, - {"TensorListPopBack", 1, {1}}, {"TensorListPushBack"}, {"TensorListPushBackBatch"}, {"TensorListResize"}, diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index e618e21ed9d..53ebdd3ab88 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -1665,6 +1665,26 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): self.assertAllEqual(f(), [b"A", b"B", b"C"]) + def testPopBackGrad(self): + # https://github.com/tensorflow/tensorflow/issues/37230 + + @def_function.function + def g(x): + x_prod = constant_op.constant([1.]) + for unused_i in math_ops.range(3): + x_prod = x_prod * x + return x_prod + + x = constant_op.constant(1.) + with backprop.GradientTape() as t: + t.watch(x) + with backprop.GradientTape() as tt: + tt.watch(x) + loss = g(x) + jac = tt.gradient(loss, x) + hess = t.gradient(jac, x) + self.assertAllEqual(hess, 6.) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py index ee01ff7cf68..3e7c116ec97 100644 --- a/tensorflow/python/ops/list_ops.py +++ b/tensorflow/python/ops/list_ops.py @@ -186,6 +186,8 @@ def _PopBackGrad(op, dlist, delement): element_dtype=delement.dtype, element_shape=gen_list_ops.tensor_list_element_shape( op.outputs[0], shape_type=dtypes.int32)) + if delement is None: + delement = array_ops.zeros_like(op.outputs[1]) return gen_list_ops.tensor_list_push_back(dlist, delement), None