diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index dd0bad30cb8..4a1156f534f 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -90,6 +90,26 @@ def _jacfwd(f, primals): return nest.pack_sequence_as(primals, jac_flat) +def _jvp_batch(f, primal, tangents): + tf_function = def_function.function(f) + + return control_flow_ops.vectorized_map( + functools.partial(_jvp, tf_function, primal), tangents) + + +def _jvp_batch_matmul(f, primals, tangent_batch): + """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" + jac_fwd = _jacfwd(f, primals) + + def jac_mul(tangent): + flat_tangent = array_ops.reshape(tangent, shape=[-1]) + tangent_vector = array_ops.expand_dims(flat_tangent, 1) + jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) + return array_ops.reshape(jvp_vector, tangent.shape) + + return control_flow_ops.vectorized_map(jac_mul, tangent_batch) + + def _grad(f, argnums=0): """Return a function which computes the gradient of `f`.""" @@ -941,6 +961,18 @@ class HessianTests(test.TestCase, parameterized.TestCase): self.assertAllClose(hess_value, hessian_pfor) +class JacobianTests(test.TestCase, parameterized.TestCase): + + @parameterized.parameters([(math_ops.sin, (2, 3), 5), + (math_ops.sin, (2, 3, 4), 10)]) + def testJVPBatchCorrectness(self, f, primal_shape, batch_size): + primals = [random_ops.random_uniform(primal_shape)] + tangent_batch = [random_ops.random_uniform([batch_size, *primal_shape])] + self.assertAllClose( + _jvp_batch(f, primals, tangent_batch)[1], + _jvp_batch_matmul(f, primals, *tangent_batch)) + + if __name__ == "__main__": # TODO(allenl): Also test with 1.x-style graph mode. ops.enable_eager_execution()