Fix flaky test -- when testing the metric value, make sure that the updated value is reflected by re-running the value instead of update_op

PiperOrigin-RevId: 220538414
This commit is contained in:
Katherine Wu 2018-11-07 15:17:32 -08:00 committed by TensorFlower Gardener
parent 3146bc0a22
commit 665bd7a2ce
2 changed files with 12 additions and 15 deletions

View File

@ -104,11 +104,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
# TODO(b/119022845): Re-enable this test in TAP.
"manual",
"notap",
"notsan",
"no_oss",
],
deps = [
":keras_saved_model",

View File

@ -345,21 +345,22 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
inputs, outputs = load_model(sess, output_path,
model_fn_lib.ModeKeys.EVAL)
sess.run(outputs['metrics/mae/update_op'], {
inputs[input_name]: input_arr,
inputs[target_name]: target_arr
})
# First obtain the loss and predictions, and run the metric update op by
# feeding in the inputs and targets.
loss, predictions, _ = sess.run(
(outputs['loss'], outputs['predictions/' + output_name],
outputs['metrics/mae/update_op']),
{inputs[input_name]: input_arr, inputs[target_name]: target_arr})
eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
inputs[target_name]: target_arr})
# The metric value should be run after the update op, to ensure that it
# reflects the correct value.
metric_value = sess.run(outputs['metrics/mae/value'])
self.assertEqual(int(train_before_export),
sess.run(training_module.get_global_step()))
self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
self.assertAllClose(
ref_mae, eval_results['metrics/mae/value'], atol=1e-05)
self.assertAllClose(
ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
self.assertAllClose(ref_loss, loss, atol=1e-05)
self.assertAllClose(ref_mae, metric_value, atol=1e-05)
self.assertAllClose(ref_predict, predictions, atol=1e-05)
# Load train graph, and check for the train op, and prediction values
with session.Session(graph=ops.Graph()) as sess: