From 16740c1266bb193b1e048b994faa277b63aa56df Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 18 Dec 2019 19:17:59 -0800 Subject: [PATCH] Verify the early stopping condition in for loops before consuming the target, rather than just at the beginning of the iteration. PiperOrigin-RevId: 286311186 Change-Id: Ie2afdd2a9f26d5045baf4517141517c0d28e3879 --- .../autograph/operators/control_flow.py | 17 +++++-- .../autograph/operators/control_flow_test.py | 47 +++++++++++++++++++ 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index f9b2ff9338e..63f9c0233a8 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -356,12 +356,19 @@ def for_stmt(iter_, def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars): """Overload of for_stmt that executes a Python for loop.""" del get_state, set_state - state = init_vars - for target in iter_: - if extra_test is not None and not extra_test(*state): - break - state = body(target, *state) + + if extra_test is not None: + if extra_test(*state): + for target in iter_: + state = body(target, *state) + if not extra_test(*state): + break + + else: + for target in iter_: + state = body(target, *state) + return state diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py index a85d74246a1..ee5b85e7c0e 100644 --- a/tensorflow/python/autograph/operators/control_flow_test.py +++ b/tensorflow/python/autograph/operators/control_flow_test.py @@ -166,6 +166,53 @@ class ForLoopTest(test.TestCase): opts={}) self.assertEqual(s, (1234,)) + def test_python_generator_with_early_stopping(self): + def new_generator(): + for i in range(1, 5): + yield i + + gen = new_generator() + def run_loop(): + return control_flow.for_stmt( + gen, + extra_test=lambda s, c: c == 0, # Break after first iteration + body=lambda i, s, c: (s * 10 + i, c + 1), + get_state=None, + set_state=None, + init_vars=(0, 0), + basic_symbol_names=('s', 'c'), + composite_symbol_names=(), + opts={}) + + self.assertEqual(run_loop(), (1, 1)) + self.assertEqual(run_loop(), (2, 1)) + self.assertEqual(run_loop(), (3, 1)) + + self.assertEqual(next(gen), 4) + + def test_python_generator_with_early_stopping_before_loop(self): + def new_generator(): + for i in range(5): + yield i + + gen = new_generator() + def run_loop(): + return control_flow.for_stmt( + gen, + extra_test=lambda s: False, # Break before loop + body=lambda i, s: (s * 10 + i,), + get_state=None, + set_state=None, + init_vars=(0,), + basic_symbol_names=('s',), + composite_symbol_names=(), + opts={}) + + self.assertEqual(run_loop(), (0,)) + self.assertEqual(run_loop(), (0,)) + + self.assertEqual(next(gen), 0) + def test_tf_dataset(self): s = control_flow.for_stmt( dataset_ops.Dataset.range(5),