Update evaluation_test to not rely on Keras metrics.
PiperOrigin-RevId: 315931625 Change-Id: I15e7759c439445dabcfacafe4499cbad94540e13
This commit is contained in:
parent
863ceb97de
commit
6bee0d45f8
@ -26,10 +26,10 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.keras import metrics as metrics_module
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics as metrics_module
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import losses
|
||||
@ -117,19 +117,18 @@ class EvaluateOnceTest(test.TestCase):
|
||||
logits = logistic_classifier(inputs)
|
||||
predictions = math_ops.round(logits)
|
||||
|
||||
accuracy = metrics_module.Accuracy()
|
||||
update_op = accuracy.update_state(labels, predictions)
|
||||
accuracy, update_op = metrics_module.accuracy(labels, predictions)
|
||||
|
||||
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
||||
|
||||
final_ops_values = evaluation._evaluate_once(
|
||||
checkpoint_path=checkpoint_path,
|
||||
eval_ops=update_op,
|
||||
final_ops={'accuracy': (accuracy.result(), update_op)},
|
||||
final_ops={'accuracy': (accuracy, update_op)},
|
||||
hooks=[
|
||||
evaluation._StopAfterNEvalsHook(1),
|
||||
])
|
||||
self.assertTrue(final_ops_values['accuracy'] > .99)
|
||||
self.assertGreater(final_ops_values['accuracy'], .99)
|
||||
|
||||
def testEvaluateWithFiniteInputs(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(),
|
||||
@ -150,8 +149,7 @@ class EvaluateOnceTest(test.TestCase):
|
||||
logits = logistic_classifier(inputs)
|
||||
predictions = math_ops.round(logits)
|
||||
|
||||
accuracy = metrics_module.Accuracy()
|
||||
update_op = accuracy.update_state(labels, predictions)
|
||||
accuracy, update_op = metrics_module.accuracy(labels, predictions)
|
||||
|
||||
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
|
||||
|
||||
@ -159,7 +157,7 @@ class EvaluateOnceTest(test.TestCase):
|
||||
checkpoint_path=checkpoint_path,
|
||||
eval_ops=update_op,
|
||||
final_ops={
|
||||
'accuracy': (accuracy.result(), update_op),
|
||||
'accuracy': (accuracy, update_op),
|
||||
'eval_steps': evaluation._get_or_create_eval_step()
|
||||
},
|
||||
hooks=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user