Fixes flaky test in dnn_linear_combined_test.
PiperOrigin-RevId: 157622951
This commit is contained in:
parent
c9cc388dc2
commit
fd6c3c4f1b
@ -1042,9 +1042,15 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
|
|||||||
dnn_hidden_units=[3, 3],
|
dnn_hidden_units=[3, 3],
|
||||||
fix_global_step_increment_bug=False)
|
fix_global_step_increment_bug=False)
|
||||||
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
|
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
|
||||||
|
global_step = classifier.get_variable_value('global_step')
|
||||||
|
|
||||||
# Expected is 100, but because of the global step increment bug, this is 50.
|
if global_step == 100:
|
||||||
|
# Expected is 100, but because of the global step increment bug, is 50.
|
||||||
self.assertEqual(50, step_counter.steps)
|
self.assertEqual(50, step_counter.steps)
|
||||||
|
else:
|
||||||
|
# Occasionally, training stops when global_step == 101, due to a race
|
||||||
|
# condition.
|
||||||
|
self.assertEqual(51, step_counter.steps)
|
||||||
|
|
||||||
def testGlobalStepDNNLinearCombinedBugFixed(self):
|
def testGlobalStepDNNLinearCombinedBugFixed(self):
|
||||||
"""Tests global step update for dnn-linear combined model."""
|
"""Tests global step update for dnn-linear combined model."""
|
||||||
|
Loading…
Reference in New Issue
Block a user