diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 850dd356c12..57e70e169ca 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -1042,9 +1042,15 @@ class DNNLinearCombinedClassifierTest(test.TestCase): dnn_hidden_units=[3, 3], fix_global_step_increment_bug=False) classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter]) + global_step = classifier.get_variable_value('global_step') - # Expected is 100, but because of the global step increment bug, this is 50. - self.assertEqual(50, step_counter.steps) + if global_step == 100: + # Expected is 100, but because of the global step increment bug, is 50. + self.assertEqual(50, step_counter.steps) + else: + # Occasionally, training stops when global_step == 101, due to a race + # condition. + self.assertEqual(51, step_counter.steps) def testGlobalStepDNNLinearCombinedBugFixed(self): """Tests global step update for dnn-linear combined model."""