execute eagerly
This commit is contained in:
parent
488533ca3c
commit
86db9c0581
@ -91,9 +91,10 @@ def _jacfwd(f, primals):
|
|||||||
|
|
||||||
|
|
||||||
def _jvp_batch(f, primal, tangents):
|
def _jvp_batch(f, primal, tangents):
|
||||||
|
tf_function = def_function.function(f)
|
||||||
|
|
||||||
return control_flow_ops.vectorized_map(
|
return control_flow_ops.vectorized_map(
|
||||||
functools.partial(_jvp, f, primal),
|
functools.partial(_jvp, tf_function, primal),
|
||||||
tangents
|
tangents
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,10 +107,10 @@ def _jvp_batch_matmul(f, primals, tangent_batch):
|
|||||||
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
|
tangent_vector = array_ops.expand_dims(flat_tangent, 1)
|
||||||
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
|
jvp_vector = math_ops.matmul(jac_fwd, tangent_vector)
|
||||||
return array_ops.reshape(jvp_vector, tangent.shape)
|
return array_ops.reshape(jvp_vector, tangent.shape)
|
||||||
|
|
||||||
return control_flow_ops.vectorized_map(
|
return control_flow_ops.vectorized_map(
|
||||||
jac_mul,
|
jac_mul,
|
||||||
tangent_batch)
|
tangent_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _grad(f, argnums=0):
|
def _grad(f, argnums=0):
|
||||||
@ -965,7 +966,7 @@ class JacobianTests(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
(math_ops.sin, (2, 3), 5),
|
(math_ops.sin, (2, 3), 5),
|
||||||
(math_ops.sin, (2, 3, 4), 10),
|
(math_ops.sin, (2, 3, 4), 10)
|
||||||
])
|
])
|
||||||
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
|
def testJVPBatchCorrectness(self, f, primal_shape, batch_size):
|
||||||
primals = [random_ops.random_uniform(primal_shape)]
|
primals = [random_ops.random_uniform(primal_shape)]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user