Verify the early stopping condition in for loops before consuming the target, rather than just at the beginning of the iteration.
PiperOrigin-RevId: 286311186 Change-Id: Ie2afdd2a9f26d5045baf4517141517c0d28e3879
This commit is contained in:
parent
afd8b4dbcf
commit
16740c1266
@ -356,12 +356,19 @@ def for_stmt(iter_,
|
|||||||
def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
|
def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
|
||||||
"""Overload of for_stmt that executes a Python for loop."""
|
"""Overload of for_stmt that executes a Python for loop."""
|
||||||
del get_state, set_state
|
del get_state, set_state
|
||||||
|
|
||||||
state = init_vars
|
state = init_vars
|
||||||
for target in iter_:
|
|
||||||
if extra_test is not None and not extra_test(*state):
|
if extra_test is not None:
|
||||||
break
|
if extra_test(*state):
|
||||||
state = body(target, *state)
|
for target in iter_:
|
||||||
|
state = body(target, *state)
|
||||||
|
if not extra_test(*state):
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
for target in iter_:
|
||||||
|
state = body(target, *state)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,6 +166,53 @@ class ForLoopTest(test.TestCase):
|
|||||||
opts={})
|
opts={})
|
||||||
self.assertEqual(s, (1234,))
|
self.assertEqual(s, (1234,))
|
||||||
|
|
||||||
|
def test_python_generator_with_early_stopping(self):
|
||||||
|
def new_generator():
|
||||||
|
for i in range(1, 5):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
gen = new_generator()
|
||||||
|
def run_loop():
|
||||||
|
return control_flow.for_stmt(
|
||||||
|
gen,
|
||||||
|
extra_test=lambda s, c: c == 0, # Break after first iteration
|
||||||
|
body=lambda i, s, c: (s * 10 + i, c + 1),
|
||||||
|
get_state=None,
|
||||||
|
set_state=None,
|
||||||
|
init_vars=(0, 0),
|
||||||
|
basic_symbol_names=('s', 'c'),
|
||||||
|
composite_symbol_names=(),
|
||||||
|
opts={})
|
||||||
|
|
||||||
|
self.assertEqual(run_loop(), (1, 1))
|
||||||
|
self.assertEqual(run_loop(), (2, 1))
|
||||||
|
self.assertEqual(run_loop(), (3, 1))
|
||||||
|
|
||||||
|
self.assertEqual(next(gen), 4)
|
||||||
|
|
||||||
|
def test_python_generator_with_early_stopping_before_loop(self):
|
||||||
|
def new_generator():
|
||||||
|
for i in range(5):
|
||||||
|
yield i
|
||||||
|
|
||||||
|
gen = new_generator()
|
||||||
|
def run_loop():
|
||||||
|
return control_flow.for_stmt(
|
||||||
|
gen,
|
||||||
|
extra_test=lambda s: False, # Break before loop
|
||||||
|
body=lambda i, s: (s * 10 + i,),
|
||||||
|
get_state=None,
|
||||||
|
set_state=None,
|
||||||
|
init_vars=(0,),
|
||||||
|
basic_symbol_names=('s',),
|
||||||
|
composite_symbol_names=(),
|
||||||
|
opts={})
|
||||||
|
|
||||||
|
self.assertEqual(run_loop(), (0,))
|
||||||
|
self.assertEqual(run_loop(), (0,))
|
||||||
|
|
||||||
|
self.assertEqual(next(gen), 0)
|
||||||
|
|
||||||
def test_tf_dataset(self):
|
def test_tf_dataset(self):
|
||||||
s = control_flow.for_stmt(
|
s = control_flow.for_stmt(
|
||||||
dataset_ops.Dataset.range(5),
|
dataset_ops.Dataset.range(5),
|
||||||
|
Loading…
Reference in New Issue
Block a user