From 34f34bf7ed4c094ea505f3adacaed99f2c1d62df Mon Sep 17 00:00:00 2001
From: Pavithra Vijay <psv@google.com>
Date: Wed, 1 May 2019 15:14:50 -0700
Subject: [PATCH] Add test case for checking multiple metric instances with
 default parameters.

PiperOrigin-RevId: 246208421
---
 tensorflow/python/keras/metrics_test.py  | 41 ++++++++++++++++++++++--
 tensorflow/python/keras/testing_utils.py | 10 ++++++
 2 files changed, 49 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index a6d56257b17..c8b3a35f4d0 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -158,8 +158,8 @@ class KerasSumTest(test.TestCase):
     self.assertEqual(600., self.evaluate(restore_sum.result()))
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class KerasMeanTest(test.TestCase):
+@keras_parameterized.run_all_keras_modes
+class KerasMeanTest(keras_parameterized.TestCase):
 
   # TODO(b/120949004): Re-enable garbage collection check
   # @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@@ -294,6 +294,43 @@ class KerasMeanTest(test.TestCase):
     self.assertEqual(200., self.evaluate(restore_mean.result()))
     self.assertEqual(3, self.evaluate(restore_mean.count))
 
+  def test_multiple_instances(self):
+    m = metrics.Mean()
+    m2 = metrics.Mean()
+
+    self.assertEqual(m.name, 'mean')
+    self.assertEqual(m2.name, 'mean')
+
+    self.assertEqual([v.name for v in m.variables],
+                     testing_utils.get_expected_metric_variable_names(
+                         ['total', 'count']))
+    self.assertEqual([v.name for v in m2.variables],
+                     testing_utils.get_expected_metric_variable_names(
+                         ['total', 'count'], name_suffix='_1'))
+
+    self.evaluate(variables.variables_initializer(m.variables))
+    self.evaluate(variables.variables_initializer(m2.variables))
+
+    # check initial state
+    self.assertEqual(self.evaluate(m.total), 0)
+    self.assertEqual(self.evaluate(m.count), 0)
+    self.assertEqual(self.evaluate(m2.total), 0)
+    self.assertEqual(self.evaluate(m2.count), 0)
+
+    # check __call__()
+    self.assertEqual(self.evaluate(m(100)), 100)
+    self.assertEqual(self.evaluate(m.total), 100)
+    self.assertEqual(self.evaluate(m.count), 1)
+    self.assertEqual(self.evaluate(m2.total), 0)
+    self.assertEqual(self.evaluate(m2.count), 0)
+
+    self.assertEqual(self.evaluate(m2([63, 10])), 36.5)
+    self.assertEqual(self.evaluate(m2.total), 73)
+    self.assertEqual(self.evaluate(m2.count), 2)
+    self.assertEqual(self.evaluate(m.result()), 100)
+    self.assertEqual(self.evaluate(m.total), 100)
+    self.assertEqual(self.evaluate(m.count), 1)
+
 
 @test_util.run_all_in_graph_and_eager_modes
 class KerasAccuracyTest(test.TestCase):
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index dabfe1a79b3..81a9452f6a8 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -23,6 +23,7 @@ import threading
 import numpy as np
 
 from tensorflow.python import keras
+from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
@@ -682,3 +683,12 @@ def get_v2_optimizer(name, **kwargs):
     raise ValueError(
         'Could not find requested v2 optimizer: {}\nValid choices: {}'.format(
             name, list(_V2_OPTIMIZER_MAP.keys())))
+
+
+def get_expected_metric_variable_names(var_names, name_suffix=''):
+  """Returns expected metric variable names given names and prefix/suffix."""
+  if tf2.enabled() or context.executing_eagerly():
+    # In V1 eager mode and V2 variable names are not made unique.
+    return [n + ':0' for n in var_names]
+  # In V1 graph mode variable names are made unique using a suffix.
+  return [n + name_suffix + ':0' for n in var_names]