From 86db9c0581616666df559ce10afa13d40fca7693 Mon Sep 17 00:00:00 2001 From: Abhineet Choudhary Date: Mon, 25 May 2020 22:44:01 +0530 Subject: [PATCH] execute eagerly --- tensorflow/python/eager/forwardprop_test.py | 33 +++++++++++---------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/eager/forwardprop_test.py b/tensorflow/python/eager/forwardprop_test.py index 2533db8c232..337bf46fbad 100644 --- a/tensorflow/python/eager/forwardprop_test.py +++ b/tensorflow/python/eager/forwardprop_test.py @@ -91,25 +91,26 @@ def _jacfwd(f, primals): def _jvp_batch(f, primal, tangents): + tf_function = def_function.function(f) return control_flow_ops.vectorized_map( - functools.partial(_jvp, f, primal), + functools.partial(_jvp, tf_function, primal), tangents - ) + ) def _jvp_batch_matmul(f, primals, tangent_batch): - """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" - jac_fwd = _jacfwd(f, primals) - def jac_mul(tangent): - flat_tangent = array_ops.reshape(tangent, shape=[-1]) - tangent_vector = array_ops.expand_dims(flat_tangent, 1) - jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) - return array_ops.reshape(jvp_vector, tangent.shape) - - return control_flow_ops.vectorized_map( - jac_mul, - tangent_batch) + """Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" + jac_fwd = _jacfwd(f, primals) + def jac_mul(tangent): + flat_tangent = array_ops.reshape(tangent, shape=[-1]) + tangent_vector = array_ops.expand_dims(flat_tangent, 1) + jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) + return array_ops.reshape(jvp_vector, tangent.shape) + return control_flow_ops.vectorized_map( + jac_mul, + tangent_batch + ) def _grad(f, argnums=0): @@ -962,10 +963,10 @@ class HessianTests(test.TestCase, parameterized.TestCase): class JacobianTests(test.TestCase, parameterized.TestCase): - + @parameterized.parameters([ - (math_ops.sin, (2, 3), 5), - (math_ops.sin, (2, 3, 4), 10), + (math_ops.sin, (2, 3), 5), + (math_ops.sin, (2, 3, 4), 10) ]) def testJVPBatchCorrectness(self, f, primal_shape, batch_size): primals = [random_ops.random_uniform(primal_shape)]