Verify the early stopping condition in for loops before consuming the target, rather than just at the beginning of the iteration.

PiperOrigin-RevId: 286311186
Change-Id: Ie2afdd2a9f26d5045baf4517141517c0d28e3879
This commit is contained in:
Dan Moldovan 2019-12-18 19:17:59 -08:00 committed by TensorFlower Gardener
parent afd8b4dbcf
commit 16740c1266
2 changed files with 59 additions and 5 deletions

View File

@ -356,12 +356,19 @@ def for_stmt(iter_,
def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
"""Overload of for_stmt that executes a Python for loop."""
del get_state, set_state
state = init_vars
for target in iter_:
if extra_test is not None and not extra_test(*state):
break
state = body(target, *state)
if extra_test is not None:
if extra_test(*state):
for target in iter_:
state = body(target, *state)
if not extra_test(*state):
break
else:
for target in iter_:
state = body(target, *state)
return state

View File

@ -166,6 +166,53 @@ class ForLoopTest(test.TestCase):
opts={})
self.assertEqual(s, (1234,))
def test_python_generator_with_early_stopping(self):
def new_generator():
for i in range(1, 5):
yield i
gen = new_generator()
def run_loop():
return control_flow.for_stmt(
gen,
extra_test=lambda s, c: c == 0, # Break after first iteration
body=lambda i, s, c: (s * 10 + i, c + 1),
get_state=None,
set_state=None,
init_vars=(0, 0),
basic_symbol_names=('s', 'c'),
composite_symbol_names=(),
opts={})
self.assertEqual(run_loop(), (1, 1))
self.assertEqual(run_loop(), (2, 1))
self.assertEqual(run_loop(), (3, 1))
self.assertEqual(next(gen), 4)
def test_python_generator_with_early_stopping_before_loop(self):
def new_generator():
for i in range(5):
yield i
gen = new_generator()
def run_loop():
return control_flow.for_stmt(
gen,
extra_test=lambda s: False, # Break before loop
body=lambda i, s: (s * 10 + i,),
get_state=None,
set_state=None,
init_vars=(0,),
basic_symbol_names=('s',),
composite_symbol_names=(),
opts={})
self.assertEqual(run_loop(), (0,))
self.assertEqual(run_loop(), (0,))
self.assertEqual(next(gen), 0)
def test_tf_dataset(self):
s = control_flow.for_stmt(
dataset_ops.Dataset.range(5),