Fix issue with loops that have multiple continue statements in a certain configuration.

PiperOrigin-RevId: 251038502
This commit is contained in:
A. Unique TensorFlower 2019-06-01 05:49:04 -07:00 committed by TensorFlower Gardener
parent 9628bc838b
commit efdc88cf13
2 changed files with 37 additions and 40 deletions
tensorflow/python/autograph/converters

View File

@ -42,20 +42,17 @@ class _Block(object):
`continue` statements (e.g. `if not continue_:`).
Attributes:
guard_created: bool, whether the guard has been created for the last
continue statement.
create_guard: bool, whether a guard should be created because a continue
statement has just been encountered.
create_guard_current: bool, whether to create a guard for the current
statement.
create_guard_next: bool, whether to create a guard for the next
statement.
is_loop_type: bool, whether this block is the body of a loop.
"""
def __init__(self):
self.is_loop_type = False
self.reset_guard_state()
def reset_guard_state(self):
self.guard_created = False
self.create_guard = False
self.create_guard_current = False
self.create_guard_next = False
class ContinueCanonicalizationTransformer(converter.Base):
@ -64,11 +61,12 @@ class ContinueCanonicalizationTransformer(converter.Base):
def visit_Continue(self, node):
self.state[_Continue].used = True
for block in reversed(self.state[_Block].stack):
block.reset_guard_state()
# See ContinueCanonicalizationTest.test_multiple_continues for an example
# it's necessary to reset the state of all enclosing affected blocks, not
# it's necessary to create guards for all enclosing affected blocks, not
# just that of the current block.
block.create_guard_next = True
if block.is_loop_type:
# continue only affects the innermost loop
break
template = """
var_name = True
@ -77,35 +75,13 @@ class ContinueCanonicalizationTransformer(converter.Base):
template, var_name=self.state[_Continue].control_var_name)
def _postprocess_statement(self, node):
# Example of how the state machine below works:
#
# 1| stmt # State: Continue_.used = False
# | # Action: none
# 2| if cond:
# 3| continue # State: Continue_.used = True,
# | # Continue_.guard_created = False,
# | # Continue_.create_guard = False
# | # Action: Continue_.create_guard = True
# 4| stmt # State: Continue_.used = True,
# | # Continue_.guard_created = False,
# | # Continue_.create_guard = True
# | # Action: create `if not continue_used`,
# | # set Continue_.guard_created = True
# 5| stmt # State: Continue_.used = True,
# | # Continue_.guard_created = True
# | # Action: none (will be wrapped under previously
# | # created if node)
if self.state[_Continue].used:
if self.state[_Block].guard_created:
return node, None
elif not self.state[_Block].create_guard:
self.state[_Block].create_guard = True
return node, None
else:
self.state[_Block].guard_created = True
block = self.state[_Block]
should_wrap_current = block.create_guard_current
# After processing propagate whether to guard the next statement
block.create_guard_current = block.create_guard_next
block.create_guard_next = False
if should_wrap_current:
template = """
if ag__.not_(var_name):
original_node
@ -160,7 +136,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
return node
def visit_If(self, node):
node.body = self.visit_block(node.body)
node.body = self._visit_non_loop_body(node.body)
node.orelse = self._visit_non_loop_body(node.orelse)
return node

View File

@ -203,6 +203,27 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 4)
def test_multiple_guarded_continues_with_side_effects(self):
def test_fn(x):
def track(u, x):
u.append(x)
return x
u = []
v = []
while x > 0:
x -= 1
if track(u, x) > 1:
continue
if track(u, x) > 2:
continue
v.append(x)
return u, v
self.assertTransformedEquivalent(test_fn, 3)
self.assertTransformedEquivalent(test_fn, 2)
if __name__ == '__main__':
test.main()