Handle case when delement
is None in PopBackGrad.
Fixes https://github.com/tensorflow/tensorflow/issues/37230 PiperOrigin-RevId: 303845628 Change-Id: Ia0159cb2dfbc70112f822f17e88182e414a83494
This commit is contained in:
parent
76ce154bb5
commit
cf09044d9e
@ -410,7 +410,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string &op_name) {
|
||||
static std::array<OpIndexInfo, 460> a = {{
|
||||
static std::array<OpIndexInfo, 459> a = {{
|
||||
{"Abs"},
|
||||
{"AccumulateNV2"},
|
||||
{"Acos"},
|
||||
@ -833,7 +833,6 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
{"TensorListGather"},
|
||||
{"TensorListGetItem"},
|
||||
{"TensorListLength"},
|
||||
{"TensorListPopBack", 1, {1}},
|
||||
{"TensorListPushBack"},
|
||||
{"TensorListPushBackBatch"},
|
||||
{"TensorListResize"},
|
||||
|
@ -1665,6 +1665,26 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
self.assertAllEqual(f(), [b"A", b"B", b"C"])
|
||||
|
||||
def testPopBackGrad(self):
|
||||
# https://github.com/tensorflow/tensorflow/issues/37230
|
||||
|
||||
@def_function.function
|
||||
def g(x):
|
||||
x_prod = constant_op.constant([1.])
|
||||
for unused_i in math_ops.range(3):
|
||||
x_prod = x_prod * x
|
||||
return x_prod
|
||||
|
||||
x = constant_op.constant(1.)
|
||||
with backprop.GradientTape() as t:
|
||||
t.watch(x)
|
||||
with backprop.GradientTape() as tt:
|
||||
tt.watch(x)
|
||||
loss = g(x)
|
||||
jac = tt.gradient(loss, x)
|
||||
hess = t.gradient(jac, x)
|
||||
self.assertAllEqual(hess, 6.)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -186,6 +186,8 @@ def _PopBackGrad(op, dlist, delement):
|
||||
element_dtype=delement.dtype,
|
||||
element_shape=gen_list_ops.tensor_list_element_shape(
|
||||
op.outputs[0], shape_type=dtypes.int32))
|
||||
if delement is None:
|
||||
delement = array_ops.zeros_like(op.outputs[1])
|
||||
return gen_list_ops.tensor_list_push_back(dlist, delement), None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user