From 6bee0d45f8228f2498f53bd6dec0a691f53b3c7b Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 11 Jun 2020 10:47:26 -0700 Subject: [PATCH] Update evaluation_test to not rely on Keras metrics. PiperOrigin-RevId: 315931625 Change-Id: I15e7759c439445dabcfacafe4499cbad94540e13 --- tensorflow/python/training/evaluation_test.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/training/evaluation_test.py b/tensorflow/python/training/evaluation_test.py index 690c97e3db1..7f28c552a6a 100644 --- a/tensorflow/python/training/evaluation_test.py +++ b/tensorflow/python/training/evaluation_test.py @@ -26,10 +26,10 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed -from tensorflow.python.keras import metrics as metrics_module from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_module from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses @@ -117,19 +117,18 @@ class EvaluateOnceTest(test.TestCase): logits = logistic_classifier(inputs) predictions = math_ops.round(logits) - accuracy = metrics_module.Accuracy() - update_op = accuracy.update_state(labels, predictions) + accuracy, update_op = metrics_module.accuracy(labels, predictions) checkpoint_path = saver.latest_checkpoint(checkpoint_dir) final_ops_values = evaluation._evaluate_once( checkpoint_path=checkpoint_path, eval_ops=update_op, - final_ops={'accuracy': (accuracy.result(), update_op)}, + final_ops={'accuracy': (accuracy, update_op)}, hooks=[ evaluation._StopAfterNEvalsHook(1), ]) - self.assertTrue(final_ops_values['accuracy'] > .99) + self.assertGreater(final_ops_values['accuracy'], .99) def testEvaluateWithFiniteInputs(self): checkpoint_dir = os.path.join(self.get_temp_dir(), @@ -150,8 +149,7 @@ class EvaluateOnceTest(test.TestCase): logits = logistic_classifier(inputs) predictions = math_ops.round(logits) - accuracy = metrics_module.Accuracy() - update_op = accuracy.update_state(labels, predictions) + accuracy, update_op = metrics_module.accuracy(labels, predictions) checkpoint_path = saver.latest_checkpoint(checkpoint_dir) @@ -159,7 +157,7 @@ class EvaluateOnceTest(test.TestCase): checkpoint_path=checkpoint_path, eval_ops=update_op, final_ops={ - 'accuracy': (accuracy.result(), update_op), + 'accuracy': (accuracy, update_op), 'eval_steps': evaluation._get_or_create_eval_step() }, hooks=[