This CL adds extra tests for contrib.eager.metrics that check eager metrics combined with while loops.

PiperOrigin-RevId: 211479604
This commit is contained in:
A. Unique TensorFlower 2018-09-04 10:26:32 -07:00 committed by TensorFlower Gardener
parent 102e0de242
commit 53bb944808

View File

@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@ -244,6 +247,48 @@ class MetricsTest(test.TestCase):
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
@test_util.run_in_graph_and_eager_modes
def testGraphAndEagerTensorGlobalVariables(self):
m = metrics.Mean(use_global_variables=True)
inputs = ops.convert_to_tensor([1.0, 2.0])
accumulate = m(inputs)
result = m.result()
self.evaluate(m.init_variables())
self.evaluate(accumulate)
self.assertEqual(self.evaluate(result), 1.5)
# Second init resets all the variables.
self.evaluate(m.init_variables())
inputs = ops.convert_to_tensor([2.0, 3.0])
self.evaluate(m(inputs))
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
@test_util.run_in_graph_and_eager_modes
def testGraphAndEagerTensorWhileLoopDoubleCall(self):
m = metrics.Mean()
init_value = constant_op.constant(1)
cond = lambda i: math_ops.less(i, 3)
def body(x):
with ops.control_dependencies([m(x)]):
return math_ops.add(x, 1)
accumulate = control_flow_ops.while_loop(cond, body, [init_value])
result = m.result()
self.evaluate(m.init_variables())
self.evaluate(accumulate)
self.assertEqual(self.evaluate(result), 1.5)
# Second init resets all the variables.
self.evaluate(m.init_variables())
inputs = ops.convert_to_tensor([2.0, 3.0])
self.evaluate(m(inputs))
if ops.context.executing_eagerly():
self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
else:
# Reuse the loop operators in graph mode
self.evaluate(accumulate)
value = m.value()
self.assertEqual(self.evaluate(value), 2.0)
def testTwoMeansGraph(self):
# Verify two metrics with the same name in the same graph raises a
# ValueError.