execute eagerly

This commit is contained in:
Abhineet Choudhary 2020-05-25 22:44:01 +05:30
parent 488533ca3c
commit 86db9c0581

View File

@ -91,25 +91,26 @@ def _jacfwd(f, primals):
def _jvp_batch(f, primal, tangents): def _jvp_batch(f, primal, tangents):
tf_function = def_function.function(f)
return control_flow_ops.vectorized_map( return control_flow_ops.vectorized_map(
functools.partial(_jvp, f, primal), functools.partial(_jvp, tf_function, primal),
tangents tangents
) )
def _jvp_batch_matmul(f, primals, tangent_batch): def _jvp_batch_matmul(f, primals, tangent_batch):
"""Compute the jacobian of `f` at `primals` multiplied by `tangents`.""" """Compute the jacobian of `f` at `primals` multiplied by `tangents`."""
jac_fwd = _jacfwd(f, primals) jac_fwd = _jacfwd(f, primals)
def jac_mul(tangent): def jac_mul(tangent):
flat_tangent = array_ops.reshape(tangent, shape=[-1]) flat_tangent = array_ops.reshape(tangent, shape=[-1])
tangent_vector = array_ops.expand_dims(flat_tangent, 1) tangent_vector = array_ops.expand_dims(flat_tangent, 1)
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector) jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
return array_ops.reshape(jvp_vector, tangent.shape) return array_ops.reshape(jvp_vector, tangent.shape)
return control_flow_ops.vectorized_map(
return control_flow_ops.vectorized_map( jac_mul,
jac_mul, tangent_batch
tangent_batch) )
def _grad(f, argnums=0): def _grad(f, argnums=0):
@ -964,8 +965,8 @@ class HessianTests(test.TestCase, parameterized.TestCase):
class JacobianTests(test.TestCase, parameterized.TestCase): class JacobianTests(test.TestCase, parameterized.TestCase):
@parameterized.parameters([ @parameterized.parameters([
(math_ops.sin, (2, 3), 5), (math_ops.sin, (2, 3), 5),
(math_ops.sin, (2, 3, 4), 10), (math_ops.sin, (2, 3, 4), 10)
]) ])
def testJVPBatchCorrectness(self, f, primal_shape, batch_size): def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
primals = [random_ops.random_uniform(primal_shape)] primals = [random_ops.random_uniform(primal_shape)]