tape.batch_jacobian: don't make zeros with the wrong dtype if gradients are disconnected
Fixes #43043. PiperOrigin-RevId: 347425143 Change-Id: I6d965b49c64319d48b3baffd0821450b1155de62
This commit is contained in:
parent
94dfcdb38d
commit
3c81a30606
@ -1344,13 +1344,9 @@ class GradientTape(object):
|
||||
parallel_iterations=parallel_iterations)
|
||||
new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
|
||||
if output is None:
|
||||
if not experimental_use_pfor and target_row_size == 0:
|
||||
# Since we can't actually run the loop function in this case, we don't
|
||||
# know whether gradients are unconnected or not. We'll return a numeric
|
||||
# tensor (with zero elements).
|
||||
output = array_ops.zeros(new_shape, target.dtype)
|
||||
if rewrap_as_ndarray:
|
||||
output = np_arrays.tensor_to_ndarray(output)
|
||||
output = array_ops.zeros(new_shape)
|
||||
if rewrap_as_ndarray:
|
||||
output = np_arrays.tensor_to_ndarray(output)
|
||||
return output
|
||||
else:
|
||||
output = array_ops.reshape(output,
|
||||
|
@ -1961,28 +1961,6 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
f = def_function.function(f)
|
||||
self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
|
||||
|
||||
@parameterized.parameters((True,), (False))
|
||||
def test_respects_disconnected_gradients(self, use_pfor):
|
||||
@def_function.function
|
||||
def f(x):
|
||||
del x
|
||||
return constant_op.constant([[1.]], dtype=dtypes.float64)
|
||||
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = constant_op.constant([[2.]], dtype=dtypes.float64)
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
self.assertIsNone(tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor))
|
||||
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
x = constant_op.constant([[2.]], dtype=dtypes.float64)
|
||||
tape.watch(x)
|
||||
y = f(x)
|
||||
jac = tape.batch_jacobian(y, x, unconnected_gradients='zero',
|
||||
experimental_use_pfor=use_pfor)
|
||||
self.assertEqual(dtypes.float64, jac.dtype)
|
||||
self.assertAllClose([[[0.]]], jac)
|
||||
|
||||
|
||||
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user