execute eagerly

This commit is contained in:
Abhineet Choudhary 2020-05-25 22:44:01 +05:30
parent 488533ca3c
commit 86db9c0581

View File

@ -91,9 +91,10 @@ def _jacfwd(f, primals):
def _jvp_batch(f, primal, tangents):
tf_function = def_function.function(f)
return control_flow_ops.vectorized_map(
functools.partial(_jvp, f, primal),
functools.partial(_jvp, tf_function, primal),
tangents
)
@ -106,10 +107,10 @@ def _jvp_batch_matmul(f, primals, tangent_batch):
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
return array_ops.reshape(jvp_vector, tangent.shape)
return control_flow_ops.vectorized_map(
jac_mul,
tangent_batch)
tangent_batch
)
def _grad(f, argnums=0):
@ -965,7 +966,7 @@ class JacobianTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([
(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10),
(math_ops.sin, (2, 3, 4), 10)
])
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
primals = [random_ops.random_uniform(primal_shape)]