tape.batch_jacobian: don't make zeros with the wrong dtype if gradients are disconnected
It'd be nice to retrun `None` if gradients are disconnected, but users are relying on the default returning zeros here, so I'm not planning to change that. Fixes #43043. PiperOrigin-RevId: 347524869 Change-Id: I44c5d807fff3bad9e117cd868319ead9efd9d6f2
This commit is contained in:
parent
ba28d6de31
commit
9a02380d42
@ -1344,7 +1344,10 @@ class GradientTape(object):
|
||||
parallel_iterations=parallel_iterations)
|
||||
new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
|
||||
if output is None:
|
||||
output = array_ops.zeros(new_shape)
|
||||
# Note that this block is returning zeros when it could use `None` to
|
||||
# represent unconnected gradients. This is to maintain compatibility with
|
||||
# the previous behavior, which ignored `unconnected_gradients`.
|
||||
output = array_ops.zeros(new_shape, target.dtype)
|
||||
if rewrap_as_ndarray:
|
||||
output = np_arrays.tensor_to_ndarray(output)
|
||||
return output
|
||||
|
@ -1961,6 +1961,31 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
f = def_function.function(f)
|
||||
self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
|
||||
|
||||
@parameterized.parameters((True,), (False))
|
||||
def test_zeros_type_correct(self, use_pfor):
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
@def_function.function
|
||||
def f(x):
|
||||
del x
|
||||
return constant_op.constant([[1.]], dtype=dtype) # pylint: disable=cell-var-from-loop
|
||||
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = constant_op.constant([[2.]], dtype=dtype)
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
jac = tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
|
||||
self.assertEqual(dtype, jac.dtype)
|
||||
self.assertAllClose([[[0.]]], jac)
|
||||
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = constant_op.constant([[2.]], dtype=dtype)
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
|
||||
experimental_use_pfor=use_pfor)
|
||||
self.assertEqual(dtype, jac.dtype)
|
||||
self.assertAllClose([[[0.]]], jac)
|
||||
|
||||
|
||||
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user