tape.batch_jacobian: don't make zeros with the wrong dtype if gradients are disconnected

It'd be nice to retrun `None` if gradients are disconnected, but users are relying on the default returning zeros here, so I'm not planning to change that.

Fixes #43043.

PiperOrigin-RevId: 347524869
Change-Id: I44c5d807fff3bad9e117cd868319ead9efd9d6f2
This commit is contained in:
Allen Lavoie 2020-12-14 20:05:57 -08:00 committed by TensorFlower Gardener
parent ba28d6de31
commit 9a02380d42
2 changed files with 29 additions and 1 deletions

View File

@ -1344,7 +1344,10 @@ class GradientTape(object):
parallel_iterations=parallel_iterations)
new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
if output is None:
output = array_ops.zeros(new_shape)
# Note that this block is returning zeros when it could use `None` to
# represent unconnected gradients. This is to maintain compatibility with
# the previous behavior, which ignored `unconnected_gradients`.
output = array_ops.zeros(new_shape, target.dtype)
if rewrap_as_ndarray:
output = np_arrays.tensor_to_ndarray(output)
return output

View File

@ -1961,6 +1961,31 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
f = def_function.function(f)
self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
@parameterized.parameters((True,), (False))
def test_zeros_type_correct(self, use_pfor):
for dtype in [dtypes.float32, dtypes.float64]:
@def_function.function
def f(x):
del x
return constant_op.constant([[1.]], dtype=dtype) # pylint: disable=cell-var-from-loop
with backprop.GradientTape(persistent=True) as tape:
x = constant_op.constant([[2.]], dtype=dtype)
tape.watch(x)
y = f(x)
jac = tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
self.assertEqual(dtype, jac.dtype)
self.assertAllClose([[[0.]]], jac)
with backprop.GradientTape(persistent=True) as tape:
x = constant_op.constant([[2.]], dtype=dtype)
tape.watch(x)
y = f(x)
jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
experimental_use_pfor=use_pfor)
self.assertEqual(dtype, jac.dtype)
self.assertAllClose([[[0.]]], jac)
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):