Make a couple of remaining tests in math_ops_test.py run in eager mode.
PiperOrigin-RevId: 317756434 Change-Id: I14911d5f460131f4b9bdf6031cc7011ddd6a1d53
This commit is contained in:
parent
23fc134a17
commit
a26044ac2f
@ -397,7 +397,6 @@ class AddNTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(x[0] * num_inputs,
|
self.assertAllEqual(x[0] * num_inputs,
|
||||||
math_ops.add_n([tf_x[0]] * num_inputs))
|
math_ops.add_n([tf_x[0]] * num_inputs))
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
|
||||||
def testGrad(self):
|
def testGrad(self):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
for num_inputs in range(1, 10):
|
for num_inputs in range(1, 10):
|
||||||
@ -406,9 +405,16 @@ class AddNTest(test_util.TensorFlowTestCase):
|
|||||||
variables.Variable(10.0 * np.random.random())
|
variables.Variable(10.0 * np.random.random())
|
||||||
for _ in range(0, num_inputs)
|
for _ in range(0, num_inputs)
|
||||||
]
|
]
|
||||||
addn = math_ops.add_n(input_vars)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
add_n_grad = gradients.gradients(addn, input_vars)
|
if context.executing_eagerly():
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(input_vars)
|
||||||
|
addn = math_ops.add_n(input_vars)
|
||||||
|
add_n_grad = tape.gradient(addn, input_vars)
|
||||||
|
else:
|
||||||
|
addn = math_ops.add_n(input_vars)
|
||||||
|
add_n_grad = gradients.gradients(addn, input_vars)
|
||||||
|
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1
|
||||||
[self.evaluate(g) for g in add_n_grad])
|
[self.evaluate(g) for g in add_n_grad])
|
||||||
@ -515,18 +521,32 @@ class DivAndModTest(test_util.TensorFlowTestCase):
|
|||||||
_ = math_ops.divide(foo, 1.)
|
_ = math_ops.divide(foo, 1.)
|
||||||
_ = math_ops.div(foo, 2.)
|
_ = math_ops.div(foo, 2.)
|
||||||
|
|
||||||
@test_util.deprecated_graph_mode_only
|
|
||||||
def testFloorDivGrad(self):
|
def testFloorDivGrad(self):
|
||||||
a = variables.Variable(2.)
|
a = variables.Variable(2.)
|
||||||
b = variables.Variable(4.)
|
b = variables.Variable(4.)
|
||||||
|
input_vars = [a, b]
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
c_grad = gradients.gradients(math_ops.divide(a, b), [a, b])
|
if context.executing_eagerly():
|
||||||
self.assertAllEqual([self.evaluate(x) for x in c_grad], [.25, -.125])
|
# TDOO(rmlarsen): Is there a more compact way of
|
||||||
c_grad = gradients.gradients(math_ops.div(a, b), [a, b])
|
# writing this for multiple expressions?
|
||||||
self.assertAllEqual([self.evaluate(x) for x in c_grad], [.25, -.125])
|
with backprop.GradientTape() as tape:
|
||||||
c_grad = gradients.gradients(math_ops.floordiv(a, b), [a, b])
|
tape.watch(input_vars)
|
||||||
|
c_grad0 = tape.gradient(math_ops.divide(a, b), input_vars)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(input_vars)
|
||||||
|
c_grad1 = tape.gradient(math_ops.div(a, b), input_vars)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
tape.watch(input_vars)
|
||||||
|
c_grad2 = tape.gradient(math_ops.floordiv(a, b), input_vars)
|
||||||
|
else:
|
||||||
|
c_grad0 = gradients.gradients(math_ops.divide(a, b), input_vars)
|
||||||
|
c_grad1 = gradients.gradients(math_ops.div(a, b), input_vars)
|
||||||
|
c_grad2 = gradients.gradients(math_ops.floordiv(a, b), input_vars)
|
||||||
|
self.assertAllEqual([self.evaluate(x) for x in c_grad0], [.25, -.125])
|
||||||
|
self.assertAllEqual([self.evaluate(x) for x in c_grad1], [.25, -.125])
|
||||||
self.assertAllEqual(
|
self.assertAllEqual(
|
||||||
[None if x is None else self.evaluate(x) for x in c_grad], [None, None])
|
[None if x is None else self.evaluate(x) for x in c_grad2],
|
||||||
|
[None, None])
|
||||||
|
|
||||||
def testConsistent(self):
|
def testConsistent(self):
|
||||||
nums, divs = self.intTestData()
|
nums, divs = self.intTestData()
|
||||||
|
Loading…
Reference in New Issue
Block a user