Fixes flaky test in dnn_linear_combined_test.

PiperOrigin-RevId: 157622951
This commit is contained in:
A. Unique TensorFlower 2017-05-31 13:00:07 -07:00 committed by TensorFlower Gardener
parent c9cc388dc2
commit fd6c3c4f1b

View File

@ -1042,9 +1042,15 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
dnn_hidden_units=[3, 3],
fix_global_step_increment_bug=False)
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
global_step = classifier.get_variable_value('global_step')
# Expected is 100, but because of the global step increment bug, this is 50.
self.assertEqual(50, step_counter.steps)
if global_step == 100:
# Expected is 100, but because of the global step increment bug, is 50.
self.assertEqual(50, step_counter.steps)
else:
# Occasionally, training stops when global_step == 101, due to a race
# condition.
self.assertEqual(51, step_counter.steps)
def testGlobalStepDNNLinearCombinedBugFixed(self):
"""Tests global step update for dnn-linear combined model."""