Add test case for checking multiple metric instances with default parameters.

PiperOrigin-RevId: 246208421
This commit is contained in:
Pavithra Vijay 2019-05-01 15:14:50 -07:00 committed by TensorFlower Gardener
parent 96f5824e31
commit 34f34bf7ed
2 changed files with 49 additions and 2 deletions
tensorflow/python/keras

View File

@ -158,8 +158,8 @@ class KerasSumTest(test.TestCase):
self.assertEqual(600., self.evaluate(restore_sum.result())) self.assertEqual(600., self.evaluate(restore_sum.result()))
@test_util.run_all_in_graph_and_eager_modes @keras_parameterized.run_all_keras_modes
class KerasMeanTest(test.TestCase): class KerasMeanTest(keras_parameterized.TestCase):
# TODO(b/120949004): Re-enable garbage collection check # TODO(b/120949004): Re-enable garbage collection check
# @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) # @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@ -294,6 +294,43 @@ class KerasMeanTest(test.TestCase):
self.assertEqual(200., self.evaluate(restore_mean.result())) self.assertEqual(200., self.evaluate(restore_mean.result()))
self.assertEqual(3, self.evaluate(restore_mean.count)) self.assertEqual(3, self.evaluate(restore_mean.count))
def test_multiple_instances(self):
m = metrics.Mean()
m2 = metrics.Mean()
self.assertEqual(m.name, 'mean')
self.assertEqual(m2.name, 'mean')
self.assertEqual([v.name for v in m.variables],
testing_utils.get_expected_metric_variable_names(
['total', 'count']))
self.assertEqual([v.name for v in m2.variables],
testing_utils.get_expected_metric_variable_names(
['total', 'count'], name_suffix='_1'))
self.evaluate(variables.variables_initializer(m.variables))
self.evaluate(variables.variables_initializer(m2.variables))
# check initial state
self.assertEqual(self.evaluate(m.total), 0)
self.assertEqual(self.evaluate(m.count), 0)
self.assertEqual(self.evaluate(m2.total), 0)
self.assertEqual(self.evaluate(m2.count), 0)
# check __call__()
self.assertEqual(self.evaluate(m(100)), 100)
self.assertEqual(self.evaluate(m.total), 100)
self.assertEqual(self.evaluate(m.count), 1)
self.assertEqual(self.evaluate(m2.total), 0)
self.assertEqual(self.evaluate(m2.count), 0)
self.assertEqual(self.evaluate(m2([63, 10])), 36.5)
self.assertEqual(self.evaluate(m2.total), 73)
self.assertEqual(self.evaluate(m2.count), 2)
self.assertEqual(self.evaluate(m.result()), 100)
self.assertEqual(self.evaluate(m.total), 100)
self.assertEqual(self.evaluate(m.count), 1)
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class KerasAccuracyTest(test.TestCase): class KerasAccuracyTest(test.TestCase):

View File

@ -23,6 +23,7 @@ import threading
import numpy as np import numpy as np
from tensorflow.python import keras from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -682,3 +683,12 @@ def get_v2_optimizer(name, **kwargs):
raise ValueError( raise ValueError(
'Could not find requested v2 optimizer: {}\nValid choices: {}'.format( 'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
name, list(_V2_OPTIMIZER_MAP.keys()))) name, list(_V2_OPTIMIZER_MAP.keys())))
def get_expected_metric_variable_names(var_names, name_suffix=''):
"""Returns expected metric variable names given names and prefix/suffix."""
if tf2.enabled() or context.executing_eagerly():
# In V1 eager mode and V2 variable names are not made unique.
return [n + ':0' for n in var_names]
# In V1 graph mode variable names are made unique using a suffix.
return [n + name_suffix + ':0' for n in var_names]