execute eagerly

This commit is contained in:
Abhineet Choudhary 2020-05-25 22:44:01 +05:30
parent 488533ca3c
commit 86db9c0581

View File

@ -91,25 +91,26 @@ def _jacfwd(f, primals):
def _jvp_batch(f, primal, tangents):
tf_function = def_function.function(f)
return control_flow_ops.vectorized_map(
functools.partial(_jvp, f, primal),
functools.partial(_jvp, tf_function, primal),
tangents
)
)
def _jvp_batch_matmul(f, primals, tangent_batch):
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
jac_fwd = _jacfwd(f, primals)
def jac_mul(tangent):
flat_tangent = array_ops.reshape(tangent, shape=[-1])
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
return array_ops.reshape(jvp_vector, tangent.shape)
return control_flow_ops.vectorized_map(
jac_mul,
tangent_batch)
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
jac_fwd = _jacfwd(f, primals)
def jac_mul(tangent):
flat_tangent = array_ops.reshape(tangent, shape=[-1])
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
return array_ops.reshape(jvp_vector, tangent.shape)
return control_flow_ops.vectorized_map(
jac_mul,
tangent_batch
)
def _grad(f, argnums=0):
@ -962,10 +963,10 @@ class HessianTests(test.TestCase, parameterized.TestCase):
class JacobianTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([
(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10),
(math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10)
])
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
primals = [random_ops.random_uniform(primal_shape)]