Handle degenerate shape in batch_jacobian.
PiperOrigin-RevId: 260035766
This commit is contained in:
parent
b4865ec145
commit
2cae1803c1
@ -1128,7 +1128,7 @@ class GradientTape(object):
|
||||
See [wikipedia article](http://en.wikipedia.org/wiki/jacobian_matrix_and_determinant) for the
|
||||
definition of a Jacobian. This function is essentially an efficient
|
||||
implementation of the following:
|
||||
|
||||
|
||||
`tf.stack([self.jacobian(y[i], x[i]) for i in range(x.shape[0])])`.
|
||||
|
||||
Note that compared to `GradientTape.jacobian` which computes gradient of
|
||||
@ -1146,7 +1146,7 @@ class GradientTape(object):
|
||||
x = tf.constant([[1., 2.], [3., 4.]], dtype=tf.float32)
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
batch_jacobian = g.batch_jacobian(y, x)
|
||||
batch_jacobian = g.batch_jacobian(y, x)
|
||||
# batch_jacobian is [[[2, 0], [0, 4]], [[6, 0], [0, 8]]]
|
||||
```
|
||||
|
||||
@ -1229,10 +1229,11 @@ class GradientTape(object):
|
||||
" with experimental_use_pfor set to False.")
|
||||
output = pfor_ops.for_loop(loop_fn, target.dtype, target_row_size,
|
||||
parallel_iterations=parallel_iterations)
|
||||
if output is None:
|
||||
return None
|
||||
output = array_ops.reshape(output,
|
||||
[target_row_size, batch_size, -1])
|
||||
output = array_ops.transpose(output, [1, 0, 2])
|
||||
new_shape = array_ops.concat([target_shape, source_shape[1:]], axis=0)
|
||||
return array_ops.reshape(output, new_shape)
|
||||
if output is None:
|
||||
return array_ops.zeros(new_shape)
|
||||
else:
|
||||
output = array_ops.reshape(output,
|
||||
[target_row_size, batch_size, -1])
|
||||
output = array_ops.transpose(output, [1, 0, 2])
|
||||
return array_ops.reshape(output, new_shape)
|
||||
|
||||
@ -18,6 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
@ -128,7 +129,7 @@ class BackpropTest(test.TestCase):
|
||||
_ = v + 1.0 # This reads the variable inside the loop context
|
||||
with backprop.GradientTape() as t:
|
||||
result = v * 2
|
||||
self.assertTrue(t.gradient(result, v) is not None)
|
||||
self.assertIsNotNone(t.gradient(result, v))
|
||||
return 1.0
|
||||
|
||||
control_flow_ops.while_loop(lambda i: False, body, [1.0])
|
||||
@ -268,8 +269,8 @@ class BackpropTest(test.TestCase):
|
||||
|
||||
grads = backprop.implicit_grad(f)()
|
||||
ordered_variables = [x[1] for x in grads]
|
||||
self.assertTrue(ordered_variables[0] is v0)
|
||||
self.assertTrue(ordered_variables[1] is v1)
|
||||
self.assertIs(ordered_variables[0], v0)
|
||||
self.assertIs(ordered_variables[1], v1)
|
||||
|
||||
def testTapeNoOpGradient(self):
|
||||
x = constant_op.constant(3.0)
|
||||
@ -1482,7 +1483,7 @@ class JacobianTest(test.TestCase):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BatchJacobianTest(test.TestCase):
|
||||
class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _batch_jacobian(self, experimental_use_pfor):
|
||||
persistent = context.executing_eagerly and not experimental_use_pfor
|
||||
@ -1583,6 +1584,23 @@ class BatchJacobianTest(test.TestCase):
|
||||
self.assertAllClose(g.batch_jacobian(y, x, parallel_iterations=2),
|
||||
g.batch_jacobian(y, x, parallel_iterations=3))
|
||||
|
||||
@parameterized.parameters(
|
||||
(True, True),
|
||||
(True, False),
|
||||
(False, True),
|
||||
(False, False))
|
||||
def test_degenerate_shape(self, use_function, use_pfor):
|
||||
|
||||
def f(x):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
tape.watch(x)
|
||||
y = x**2
|
||||
return tape.batch_jacobian(y, x, experimental_use_pfor=use_pfor)
|
||||
|
||||
if use_function:
|
||||
f = def_function.function(f)
|
||||
self.assertAllEqual([1, 0, 0], array_ops.shape(f(array_ops.zeros([1, 0]))))
|
||||
|
||||
|
||||
class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
||||
@ -99,7 +99,12 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
|
||||
|
||||
output = [None if is_none else ta.concat()
|
||||
for ta, is_none in zip(ta_list, is_none_list)]
|
||||
return nest.pack_sequence_as(loop_fn_dtypes, output)
|
||||
assert len(output) in (0, len(flat_loop_fn_dtypes))
|
||||
if not output:
|
||||
# This may happen for the case where iters == 0.
|
||||
return None
|
||||
else:
|
||||
return nest.pack_sequence_as(loop_fn_dtypes, output)
|
||||
|
||||
|
||||
def _flatten_first_two_dims(x):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user