This CL adds extra tests for contrib.eager.metrics
that check eager metrics combined with while loops.
PiperOrigin-RevId: 211479604
This commit is contained in:
parent
102e0de242
commit
53bb944808
@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.summary import summary_test_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||
@ -244,6 +247,48 @@ class MetricsTest(test.TestCase):
|
||||
value = m.value()
|
||||
self.assertEqual(self.evaluate(value), 2.5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGraphAndEagerTensorGlobalVariables(self):
|
||||
m = metrics.Mean(use_global_variables=True)
|
||||
inputs = ops.convert_to_tensor([1.0, 2.0])
|
||||
accumulate = m(inputs)
|
||||
result = m.result()
|
||||
self.evaluate(m.init_variables())
|
||||
self.evaluate(accumulate)
|
||||
self.assertEqual(self.evaluate(result), 1.5)
|
||||
# Second init resets all the variables.
|
||||
self.evaluate(m.init_variables())
|
||||
inputs = ops.convert_to_tensor([2.0, 3.0])
|
||||
self.evaluate(m(inputs))
|
||||
value = m.value()
|
||||
self.assertEqual(self.evaluate(value), 2.5)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGraphAndEagerTensorWhileLoopDoubleCall(self):
|
||||
m = metrics.Mean()
|
||||
init_value = constant_op.constant(1)
|
||||
cond = lambda i: math_ops.less(i, 3)
|
||||
def body(x):
|
||||
with ops.control_dependencies([m(x)]):
|
||||
return math_ops.add(x, 1)
|
||||
accumulate = control_flow_ops.while_loop(cond, body, [init_value])
|
||||
|
||||
result = m.result()
|
||||
self.evaluate(m.init_variables())
|
||||
self.evaluate(accumulate)
|
||||
self.assertEqual(self.evaluate(result), 1.5)
|
||||
# Second init resets all the variables.
|
||||
self.evaluate(m.init_variables())
|
||||
inputs = ops.convert_to_tensor([2.0, 3.0])
|
||||
self.evaluate(m(inputs))
|
||||
if ops.context.executing_eagerly():
|
||||
self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
|
||||
else:
|
||||
# Reuse the loop operators in graph mode
|
||||
self.evaluate(accumulate)
|
||||
value = m.value()
|
||||
self.assertEqual(self.evaluate(value), 2.0)
|
||||
|
||||
def testTwoMeansGraph(self):
|
||||
# Verify two metrics with the same name in the same graph raises a
|
||||
# ValueError.
|
||||
|
Loading…
x
Reference in New Issue
Block a user