This CL adds extra tests for contrib.eager.metrics
that check eager metrics combined with while loops.
PiperOrigin-RevId: 211479604
This commit is contained in:
parent
102e0de242
commit
53bb944808
@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
|
|||||||
from tensorflow.contrib.summary import summary_test_util
|
from tensorflow.contrib.summary import summary_test_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
from tensorflow.python.training.checkpointable import util as checkpointable_utils
|
||||||
@ -244,6 +247,48 @@ class MetricsTest(test.TestCase):
|
|||||||
value = m.value()
|
value = m.value()
|
||||||
self.assertEqual(self.evaluate(value), 2.5)
|
self.assertEqual(self.evaluate(value), 2.5)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testGraphAndEagerTensorGlobalVariables(self):
|
||||||
|
m = metrics.Mean(use_global_variables=True)
|
||||||
|
inputs = ops.convert_to_tensor([1.0, 2.0])
|
||||||
|
accumulate = m(inputs)
|
||||||
|
result = m.result()
|
||||||
|
self.evaluate(m.init_variables())
|
||||||
|
self.evaluate(accumulate)
|
||||||
|
self.assertEqual(self.evaluate(result), 1.5)
|
||||||
|
# Second init resets all the variables.
|
||||||
|
self.evaluate(m.init_variables())
|
||||||
|
inputs = ops.convert_to_tensor([2.0, 3.0])
|
||||||
|
self.evaluate(m(inputs))
|
||||||
|
value = m.value()
|
||||||
|
self.assertEqual(self.evaluate(value), 2.5)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testGraphAndEagerTensorWhileLoopDoubleCall(self):
|
||||||
|
m = metrics.Mean()
|
||||||
|
init_value = constant_op.constant(1)
|
||||||
|
cond = lambda i: math_ops.less(i, 3)
|
||||||
|
def body(x):
|
||||||
|
with ops.control_dependencies([m(x)]):
|
||||||
|
return math_ops.add(x, 1)
|
||||||
|
accumulate = control_flow_ops.while_loop(cond, body, [init_value])
|
||||||
|
|
||||||
|
result = m.result()
|
||||||
|
self.evaluate(m.init_variables())
|
||||||
|
self.evaluate(accumulate)
|
||||||
|
self.assertEqual(self.evaluate(result), 1.5)
|
||||||
|
# Second init resets all the variables.
|
||||||
|
self.evaluate(m.init_variables())
|
||||||
|
inputs = ops.convert_to_tensor([2.0, 3.0])
|
||||||
|
self.evaluate(m(inputs))
|
||||||
|
if ops.context.executing_eagerly():
|
||||||
|
self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
|
||||||
|
else:
|
||||||
|
# Reuse the loop operators in graph mode
|
||||||
|
self.evaluate(accumulate)
|
||||||
|
value = m.value()
|
||||||
|
self.assertEqual(self.evaluate(value), 2.0)
|
||||||
|
|
||||||
def testTwoMeansGraph(self):
|
def testTwoMeansGraph(self):
|
||||||
# Verify two metrics with the same name in the same graph raises a
|
# Verify two metrics with the same name in the same graph raises a
|
||||||
# ValueError.
|
# ValueError.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user