Merge pull request #39711 from abhichou4:tests/tangent-batch
PiperOrigin-RevId: 313833540 Change-Id: Iaa0470254e9fd32fd3bfff1fd1882a6763da9ad4
This commit is contained in:
commit
7860a9ccbc
@ -90,6 +90,26 @@ def _jacfwd(f, primals):
|
|||||||
return nest.pack_sequence_as(primals, jac_flat)
|
return nest.pack_sequence_as(primals, jac_flat)
|
||||||
|
|
||||||
|
|
||||||
|
def _jvp_batch(f, primal, tangents):
|
||||||
|
tf_function = def_function.function(f)
|
||||||
|
|
||||||
|
return control_flow_ops.vectorized_map(
|
||||||
|
functools.partial(_jvp, tf_function, primal), tangents)
|
||||||
|
|
||||||
|
|
||||||
|
def _jvp_batch_matmul(f, primals, tangent_batch):
|
||||||
|
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
|
||||||
|
jac_fwd = _jacfwd(f, primals)
|
||||||
|
|
||||||
|
def jac_mul(tangent):
|
||||||
|
flat_tangent = array_ops.reshape(tangent, shape=[-1])
|
||||||
|
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
|
||||||
|
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
|
||||||
|
return array_ops.reshape(jvp_vector, tangent.shape)
|
||||||
|
|
||||||
|
return control_flow_ops.vectorized_map(jac_mul, tangent_batch)
|
||||||
|
|
||||||
|
|
||||||
def _grad(f, argnums=0):
|
def _grad(f, argnums=0):
|
||||||
"""Return a function which computes the gradient of `f`."""
|
"""Return a function which computes the gradient of `f`."""
|
||||||
|
|
||||||
@ -941,6 +961,18 @@ class HessianTests(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(hess_value, hessian_pfor)
|
self.assertAllClose(hess_value, hessian_pfor)
|
||||||
|
|
||||||
|
|
||||||
|
class JacobianTests(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@parameterized.parameters([(math_ops.sin, (2, 3), 5),
|
||||||
|
(math_ops.sin, (2, 3, 4), 10)])
|
||||||
|
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
|
||||||
|
primals = [random_ops.random_uniform(primal_shape)]
|
||||||
|
tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])]
|
||||||
|
self.assertAllClose(
|
||||||
|
_jvp_batch(f, primals, tangent_batch)[1],
|
||||||
|
_jvp_batch_matmul(f, primals, *tangent_batch))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# TODO(allenl): Also test with 1.x-style graph mode.
|
# TODO(allenl): Also test with 1.x-style graph mode.
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user