From d85bb53d3e10036af9f0dd48b8293485e4c4448e Mon Sep 17 00:00:00 2001
From: Scott Zhu <scottzhu@google.com>
Date: Thu, 11 Jun 2020 10:52:16 -0700
Subject: [PATCH] Update the export_output_test to not rely on Keras metrics.

The API interface of export_outputs_for_mode made it clear that the input metric need to be a tuple of metric value and update_op, rather than metric instance itself.

PiperOrigin-RevId: 315932763
Change-Id: Id5c032e90c678313e9b42933dbbaada82b7f7d08
---
 .../python/saved_model/model_utils/BUILD      |  2 -
 .../model_utils/export_output_test.py         | 69 +++++++++----------
 2 files changed, 31 insertions(+), 40 deletions(-)

diff --git a/tensorflow/python/saved_model/model_utils/BUILD b/tensorflow/python/saved_model/model_utils/BUILD
index 8aab121d043..70cc89b1946 100644
--- a/tensorflow/python/saved_model/model_utils/BUILD
+++ b/tensorflow/python/saved_model/model_utils/BUILD
@@ -62,7 +62,6 @@ py_test(
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:sparse_tensor",
-        "//tensorflow/python/keras",
         "//tensorflow/python/saved_model:signature_constants",
     ],
 )
@@ -97,7 +96,6 @@ py_test(
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:sparse_tensor",
-        "//tensorflow/python/keras",
         "//tensorflow/python/saved_model:signature_constants",
         "//tensorflow/python/saved_model:signature_def_utils",
     ],
diff --git a/tensorflow/python/saved_model/model_utils/export_output_test.py b/tensorflow/python/saved_model/model_utils/export_output_test.py
index 5262e9fa1e9..13bbeec38b5 100644
--- a/tensorflow/python/saved_model/model_utils/export_output_test.py
+++ b/tensorflow/python/saved_model/model_utils/export_output_test.py
@@ -26,9 +26,9 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.keras import metrics as metrics_module
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_module
 from tensorflow.python.platform import test
 from tensorflow.python.saved_model import signature_constants
 from tensorflow.python.saved_model.model_utils import export_output as export_output_lib
@@ -244,10 +244,9 @@ class SupervisedOutputTest(test.TestCase):
     with context.graph_mode():
       loss = {'my_loss': constant_op.constant([0])}
       predictions = {u'output1': constant_op.constant(['foo'])}
-      metric_obj = metrics_module.Mean()
-      metric_obj.update_state(constant_op.constant([0]))
+      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
       metrics = {
-          'metrics': metric_obj,
+          'metrics': (mean, update_op),
           'metrics2': (constant_op.constant([0]), constant_op.constant([10]))
       }
 
@@ -256,7 +255,7 @@ class SupervisedOutputTest(test.TestCase):
       self.assertEqual(
           outputter.predictions['predictions/output1'], predictions['output1'])
       self.assertEqual(outputter.metrics['metrics/update_op'].name,
-                       'metric_op_wrapper:0')
+                       'mean/update_op:0')
       self.assertEqual(
           outputter.metrics['metrics2/update_op'], metrics['metrics2'][1])
 
@@ -267,14 +266,14 @@ class SupervisedOutputTest(test.TestCase):
       self.assertEqual(
           outputter.predictions, {'predictions': predictions['output1']})
       self.assertEqual(outputter.metrics['metrics/update_op'].name,
-                       'metric_op_wrapper_1:0')
+                       'mean/update_op:0')
 
   def test_supervised_outputs_none(self):
     outputter = MockSupervisedOutput(
         constant_op.constant([0]), None, None)
-    self.assertEqual(len(outputter.loss), 1)
-    self.assertEqual(outputter.predictions, None)
-    self.assertEqual(outputter.metrics, None)
+    self.assertLen(outputter.loss, 1)
+    self.assertIsNone(outputter.predictions)
+    self.assertIsNone(outputter.metrics)
 
   def test_supervised_outputs_invalid(self):
     with self.assertRaisesRegexp(ValueError, 'predictions output value must'):
@@ -291,11 +290,9 @@ class SupervisedOutputTest(test.TestCase):
     with context.graph_mode():
       loss = {('my', 'loss'): constant_op.constant([0])}
       predictions = {(u'output1', '2'): constant_op.constant(['foo'])}
-      metric_obj = metrics_module.Mean()
-      metric_obj.update_state(constant_op.constant([0]))
+      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
       metrics = {
-          ('metrics', '1'):
-              metric_obj,
+          ('metrics', '1'): (mean, update_op),
           ('metrics', '2'): (constant_op.constant([0]),
                              constant_op.constant([10]))
       }
@@ -316,10 +313,9 @@ class SupervisedOutputTest(test.TestCase):
     with context.graph_mode():
       loss = {'loss': constant_op.constant([0])}
       predictions = {u'predictions': constant_op.constant(['foo'])}
-      metric_obj = metrics_module.Mean()
-      metric_obj.update_state(constant_op.constant([0]))
+      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
       metrics = {
-          'metrics_1': metric_obj,
+          'metrics_1': (mean, update_op),
           'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
       }
 
@@ -337,10 +333,9 @@ class SupervisedOutputTest(test.TestCase):
     with context.graph_mode():
       loss = {'my_loss': constant_op.constant([0])}
       predictions = {u'output1': constant_op.constant(['foo'])}
-      metric_obj = metrics_module.Mean()
-      metric_obj.update_state(constant_op.constant([0]))
+      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
       metrics = {
-          'metrics_1': metric_obj,
+          'metrics_1': (mean, update_op),
           'metrics_2': (constant_op.constant([0]), constant_op.constant([10]))
       }
 
@@ -350,11 +345,11 @@ class SupervisedOutputTest(test.TestCase):
                   'labels': constant_op.constant(100, shape=(100, 1))}
       sig_def = outputter.as_signature_def(receiver)
 
-      self.assertTrue('loss/my_loss' in sig_def.outputs)
-      self.assertTrue('metrics_1/value' in sig_def.outputs)
-      self.assertTrue('metrics_2/value' in sig_def.outputs)
-      self.assertTrue('predictions/output1' in sig_def.outputs)
-      self.assertTrue('features' in sig_def.inputs)
+      self.assertIn('loss/my_loss', sig_def.outputs)
+      self.assertIn('metrics_1/value', sig_def.outputs)
+      self.assertIn('metrics_2/value', sig_def.outputs)
+      self.assertIn('predictions/output1', sig_def.outputs)
+      self.assertIn('features', sig_def.inputs)
 
   def test_eval_signature_def(self):
     with context.graph_mode():
@@ -367,38 +362,36 @@ class SupervisedOutputTest(test.TestCase):
                   'labels': constant_op.constant(100, shape=(100, 1))}
       sig_def = outputter.as_signature_def(receiver)
 
-      self.assertTrue('loss/my_loss' in sig_def.outputs)
-      self.assertFalse('metrics/value' in sig_def.outputs)
-      self.assertTrue('predictions/output1' in sig_def.outputs)
-      self.assertTrue('features' in sig_def.inputs)
+      self.assertIn('loss/my_loss', sig_def.outputs)
+      self.assertNotIn('metrics/value', sig_def.outputs)
+      self.assertIn('predictions/output1', sig_def.outputs)
+      self.assertIn('features', sig_def.inputs)
 
   def test_metric_op_is_tensor(self):
     """Tests that ops.Operation is wrapped by a tensor for metric_ops."""
     with context.graph_mode():
       loss = {'my_loss': constant_op.constant([0])}
       predictions = {u'output1': constant_op.constant(['foo'])}
-      metric_obj = metrics_module.Mean()
-      metric_obj.update_state(constant_op.constant([0]))
+      mean, update_op = metrics_module.mean_tensor(constant_op.constant([0]))
       metrics = {
-          'metrics_1': metric_obj,
+          'metrics_1': (mean, update_op),
           'metrics_2': (constant_op.constant([0]), control_flow_ops.no_op())
       }
 
       outputter = MockSupervisedOutput(loss, predictions, metrics)
 
       self.assertTrue(outputter.metrics['metrics_1/update_op'].name.startswith(
-          'metric_op_wrapper'))
-      self.assertTrue(
-          isinstance(outputter.metrics['metrics_1/update_op'], ops.Tensor))
-      self.assertTrue(
-          isinstance(outputter.metrics['metrics_1/value'], ops.Tensor))
+          'mean/update_op'))
+      self.assertIsInstance(
+          outputter.metrics['metrics_1/update_op'], ops.Tensor)
+      self.assertIsInstance(outputter.metrics['metrics_1/value'], ops.Tensor)
 
       self.assertEqual(outputter.metrics['metrics_2/value'],
                        metrics['metrics_2'][0])
       self.assertTrue(outputter.metrics['metrics_2/update_op'].name.startswith(
           'metric_op_wrapper'))
-      self.assertTrue(
-          isinstance(outputter.metrics['metrics_2/update_op'], ops.Tensor))
+      self.assertIsInstance(
+          outputter.metrics['metrics_2/update_op'], ops.Tensor)
 
 
 if __name__ == '__main__':