Update evaluation_test to not rely on Keras metrics.

PiperOrigin-RevId: 315931625
Change-Id: I15e7759c439445dabcfacafe4499cbad94540e13
This commit is contained in:
Scott Zhu 2020-06-11 10:47:26 -07:00 committed by TensorFlower Gardener
parent 863ceb97de
commit 6bee0d45f8

View File

@ -26,10 +26,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.layers import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
@ -117,19 +117,18 @@ class EvaluateOnceTest(test.TestCase):
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
accuracy = metrics_module.Accuracy()
update_op = accuracy.update_state(labels, predictions)
accuracy, update_op = metrics_module.accuracy(labels, predictions)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=update_op,
final_ops={'accuracy': (accuracy.result(), update_op)},
final_ops={'accuracy': (accuracy, update_op)},
hooks=[
evaluation._StopAfterNEvalsHook(1),
])
self.assertTrue(final_ops_values['accuracy'] > .99)
self.assertGreater(final_ops_values['accuracy'], .99)
def testEvaluateWithFiniteInputs(self):
checkpoint_dir = os.path.join(self.get_temp_dir(),
@ -150,8 +149,7 @@ class EvaluateOnceTest(test.TestCase):
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
accuracy = metrics_module.Accuracy()
update_op = accuracy.update_state(labels, predictions)
accuracy, update_op = metrics_module.accuracy(labels, predictions)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
@ -159,7 +157,7 @@ class EvaluateOnceTest(test.TestCase):
checkpoint_path=checkpoint_path,
eval_ops=update_op,
final_ops={
'accuracy': (accuracy.result(), update_op),
'accuracy': (accuracy, update_op),
'eval_steps': evaluation._get_or_create_eval_step()
},
hooks=[