Fix flaky test -- when testing the metric value, make sure that the updated value is reflected by re-running the value instead of update_op
PiperOrigin-RevId: 220538414
This commit is contained in:
parent
3146bc0a22
commit
665bd7a2ce
@ -104,11 +104,7 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_windows",
|
"no_windows",
|
||||||
# TODO(b/119022845): Re-enable this test in TAP.
|
|
||||||
"manual",
|
|
||||||
"notap",
|
|
||||||
"notsan",
|
"notsan",
|
||||||
"no_oss",
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":keras_saved_model",
|
":keras_saved_model",
|
||||||
|
@ -345,21 +345,22 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
inputs, outputs = load_model(sess, output_path,
|
inputs, outputs = load_model(sess, output_path,
|
||||||
model_fn_lib.ModeKeys.EVAL)
|
model_fn_lib.ModeKeys.EVAL)
|
||||||
|
|
||||||
sess.run(outputs['metrics/mae/update_op'], {
|
# First obtain the loss and predictions, and run the metric update op by
|
||||||
inputs[input_name]: input_arr,
|
# feeding in the inputs and targets.
|
||||||
inputs[target_name]: target_arr
|
loss, predictions, _ = sess.run(
|
||||||
})
|
(outputs['loss'], outputs['predictions/' + output_name],
|
||||||
|
outputs['metrics/mae/update_op']),
|
||||||
|
{inputs[input_name]: input_arr, inputs[target_name]: target_arr})
|
||||||
|
|
||||||
eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
|
# The metric value should be run after the update op, to ensure that it
|
||||||
inputs[target_name]: target_arr})
|
# reflects the correct value.
|
||||||
|
metric_value = sess.run(outputs['metrics/mae/value'])
|
||||||
|
|
||||||
self.assertEqual(int(train_before_export),
|
self.assertEqual(int(train_before_export),
|
||||||
sess.run(training_module.get_global_step()))
|
sess.run(training_module.get_global_step()))
|
||||||
self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
|
self.assertAllClose(ref_loss, loss, atol=1e-05)
|
||||||
self.assertAllClose(
|
self.assertAllClose(ref_mae, metric_value, atol=1e-05)
|
||||||
ref_mae, eval_results['metrics/mae/value'], atol=1e-05)
|
self.assertAllClose(ref_predict, predictions, atol=1e-05)
|
||||||
self.assertAllClose(
|
|
||||||
ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
|
|
||||||
|
|
||||||
# Load train graph, and check for the train op, and prediction values
|
# Load train graph, and check for the train op, and prediction values
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user