Fix flaky test -- when testing the metric value, make sure that the updated value is reflected by re-running the value instead of update_op
PiperOrigin-RevId: 220538414
This commit is contained in:
parent
3146bc0a22
commit
665bd7a2ce
@ -104,11 +104,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
# TODO(b/119022845): Re-enable this test in TAP.
|
||||
"manual",
|
||||
"notap",
|
||||
"notsan",
|
||||
"no_oss",
|
||||
],
|
||||
deps = [
|
||||
":keras_saved_model",
|
||||
|
@ -345,21 +345,22 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
inputs, outputs = load_model(sess, output_path,
|
||||
model_fn_lib.ModeKeys.EVAL)
|
||||
|
||||
sess.run(outputs['metrics/mae/update_op'], {
|
||||
inputs[input_name]: input_arr,
|
||||
inputs[target_name]: target_arr
|
||||
})
|
||||
# First obtain the loss and predictions, and run the metric update op by
|
||||
# feeding in the inputs and targets.
|
||||
loss, predictions, _ = sess.run(
|
||||
(outputs['loss'], outputs['predictions/' + output_name],
|
||||
outputs['metrics/mae/update_op']),
|
||||
{inputs[input_name]: input_arr, inputs[target_name]: target_arr})
|
||||
|
||||
eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
|
||||
inputs[target_name]: target_arr})
|
||||
# The metric value should be run after the update op, to ensure that it
|
||||
# reflects the correct value.
|
||||
metric_value = sess.run(outputs['metrics/mae/value'])
|
||||
|
||||
self.assertEqual(int(train_before_export),
|
||||
sess.run(training_module.get_global_step()))
|
||||
self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
|
||||
self.assertAllClose(
|
||||
ref_mae, eval_results['metrics/mae/value'], atol=1e-05)
|
||||
self.assertAllClose(
|
||||
ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
|
||||
self.assertAllClose(ref_loss, loss, atol=1e-05)
|
||||
self.assertAllClose(ref_mae, metric_value, atol=1e-05)
|
||||
self.assertAllClose(ref_predict, predictions, atol=1e-05)
|
||||
|
||||
# Load train graph, and check for the train op, and prediction values
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user