Verify the early stopping condition in for loops before consuming the target, rather than just at the beginning of the iteration.
PiperOrigin-RevId: 286311186 Change-Id: Ie2afdd2a9f26d5045baf4517141517c0d28e3879
This commit is contained in:
parent
afd8b4dbcf
commit
16740c1266
@ -356,12 +356,19 @@ def for_stmt(iter_,
|
||||
def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
|
||||
"""Overload of for_stmt that executes a Python for loop."""
|
||||
del get_state, set_state
|
||||
|
||||
state = init_vars
|
||||
for target in iter_:
|
||||
if extra_test is not None and not extra_test(*state):
|
||||
break
|
||||
state = body(target, *state)
|
||||
|
||||
if extra_test is not None:
|
||||
if extra_test(*state):
|
||||
for target in iter_:
|
||||
state = body(target, *state)
|
||||
if not extra_test(*state):
|
||||
break
|
||||
|
||||
else:
|
||||
for target in iter_:
|
||||
state = body(target, *state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
|
@ -166,6 +166,53 @@ class ForLoopTest(test.TestCase):
|
||||
opts={})
|
||||
self.assertEqual(s, (1234,))
|
||||
|
||||
def test_python_generator_with_early_stopping(self):
|
||||
def new_generator():
|
||||
for i in range(1, 5):
|
||||
yield i
|
||||
|
||||
gen = new_generator()
|
||||
def run_loop():
|
||||
return control_flow.for_stmt(
|
||||
gen,
|
||||
extra_test=lambda s, c: c == 0, # Break after first iteration
|
||||
body=lambda i, s, c: (s * 10 + i, c + 1),
|
||||
get_state=None,
|
||||
set_state=None,
|
||||
init_vars=(0, 0),
|
||||
basic_symbol_names=('s', 'c'),
|
||||
composite_symbol_names=(),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(run_loop(), (1, 1))
|
||||
self.assertEqual(run_loop(), (2, 1))
|
||||
self.assertEqual(run_loop(), (3, 1))
|
||||
|
||||
self.assertEqual(next(gen), 4)
|
||||
|
||||
def test_python_generator_with_early_stopping_before_loop(self):
|
||||
def new_generator():
|
||||
for i in range(5):
|
||||
yield i
|
||||
|
||||
gen = new_generator()
|
||||
def run_loop():
|
||||
return control_flow.for_stmt(
|
||||
gen,
|
||||
extra_test=lambda s: False, # Break before loop
|
||||
body=lambda i, s: (s * 10 + i,),
|
||||
get_state=None,
|
||||
set_state=None,
|
||||
init_vars=(0,),
|
||||
basic_symbol_names=('s',),
|
||||
composite_symbol_names=(),
|
||||
opts={})
|
||||
|
||||
self.assertEqual(run_loop(), (0,))
|
||||
self.assertEqual(run_loop(), (0,))
|
||||
|
||||
self.assertEqual(next(gen), 0)
|
||||
|
||||
def test_tf_dataset(self):
|
||||
s = control_flow.for_stmt(
|
||||
dataset_ops.Dataset.range(5),
|
||||
|
Loading…
Reference in New Issue
Block a user