tape.batch_jacobian: don't make zeros with the wrong dtype if gradients are disconnected

Fixes #43043.

PiperOrigin-RevId: 347425143
Change-Id: I6d965b49c64319d48b3baffd0821450b1155de62
This commit is contained in:
A. Unique TensorFlower 2020-12-14 11:06:49 -08:00 committed by TensorFlower Gardener
parent 94dfcdb38d
commit 3c81a30606
2 changed files with 3 additions and 29 deletions

View File

@ -1344,13 +1344,9 @@ class GradientTape(object):
parallel_iterations=parallel_iterations)
new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
if output is None:
if not experimental_use_pfor and target_row_size == 0:
# Since we can't actually run the loop function in this case, we don't
# know whether gradients are unconnected or not. We'll return a numeric
# tensor (with zero elements).
output = array_ops.zeros(new_shape, target.dtype)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
output = array_ops.zeros(new_shape)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
return output
else:
output = array_ops.reshape(output,

View File

@ -1961,28 +1961,6 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
f = def_function.function(f)
self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
@parameterized.parameters((True,), (False))
def test_respects_disconnected_gradients(self, use_pfor):
@def_function.function
def f(x):
del x
return constant_op.constant([[1.]], dtype=dtypes.float64)
with backprop.GradientTape(persistent=True) as tape:
x = constant_op.constant([[2.]], dtype=dtypes.float64)
tape.watch(x)
y = f(x)
self.assertIsNone(tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor))
with backprop.GradientTape(persistent=True) as tape:
x = constant_op.constant([[2.]], dtype=dtypes.float64)
tape.watch(x)
y = f(x)
jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
experimental_use_pfor=use_pfor)
self.assertEqual(dtypes.float64, jac.dtype)
self.assertAllClose([[[0.]]], jac)
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):