Handle case when delement is None in PopBackGrad.

Fixes https://github.com/tensorflow/tensorflow/issues/37230

PiperOrigin-RevId: 303845628
Change-Id: Ia0159cb2dfbc70112f822f17e88182e414a83494
This commit is contained in:
Saurabh Saxena 2020-03-30 16:18:46 -07:00 committed by TensorFlower Gardener
parent 76ce154bb5
commit cf09044d9e
3 changed files with 23 additions and 2 deletions

View File

@ -410,7 +410,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string &op_name) {
static std::array<OpIndexInfo, 460> a = {{
static std::array<OpIndexInfo, 459> a = {{
{"Abs"},
{"AccumulateNV2"},
{"Acos"},
@ -833,7 +833,6 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
{"TensorListGather"},
{"TensorListGetItem"},
{"TensorListLength"},
{"TensorListPopBack", 1, {1}},
{"TensorListPushBack"},
{"TensorListPushBackBatch"},
{"TensorListResize"},

View File

@ -1665,6 +1665,26 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(f(), [b"A", b"B", b"C"])
def testPopBackGrad(self):
# https://github.com/tensorflow/tensorflow/issues/37230
@def_function.function
def g(x):
x_prod = constant_op.constant([1.])
for unused_i in math_ops.range(3):
x_prod = x_prod * x
return x_prod
x = constant_op.constant(1.)
with backprop.GradientTape() as t:
t.watch(x)
with backprop.GradientTape() as tt:
tt.watch(x)
loss = g(x)
jac = tt.gradient(loss, x)
hess = t.gradient(jac, x)
self.assertAllEqual(hess, 6.)
if __name__ == "__main__":
test.main()

View File

@ -186,6 +186,8 @@ def _PopBackGrad(op, dlist, delement):
element_dtype=delement.dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
if delement is None:
delement = array_ops.zeros_like(op.outputs[1])
return gen_list_ops.tensor_list_push_back(dlist, delement), None