Fix issue with loops that have multiple continue statements in a certain configuration.
PiperOrigin-RevId: 251038502
This commit is contained in:
parent
9628bc838b
commit
efdc88cf13
tensorflow/python/autograph/converters
@ -42,20 +42,17 @@ class _Block(object):
|
||||
`continue` statements (e.g. `if not continue_:`).
|
||||
|
||||
Attributes:
|
||||
guard_created: bool, whether the guard has been created for the last
|
||||
continue statement.
|
||||
create_guard: bool, whether a guard should be created because a continue
|
||||
statement has just been encountered.
|
||||
create_guard_current: bool, whether to create a guard for the current
|
||||
statement.
|
||||
create_guard_next: bool, whether to create a guard for the next
|
||||
statement.
|
||||
is_loop_type: bool, whether this block is the body of a loop.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_loop_type = False
|
||||
self.reset_guard_state()
|
||||
|
||||
def reset_guard_state(self):
|
||||
self.guard_created = False
|
||||
self.create_guard = False
|
||||
self.create_guard_current = False
|
||||
self.create_guard_next = False
|
||||
|
||||
|
||||
class ContinueCanonicalizationTransformer(converter.Base):
|
||||
@ -64,11 +61,12 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
def visit_Continue(self, node):
|
||||
self.state[_Continue].used = True
|
||||
for block in reversed(self.state[_Block].stack):
|
||||
block.reset_guard_state()
|
||||
# See ContinueCanonicalizationTest.test_multiple_continues for an example
|
||||
# it's necessary to reset the state of all enclosing affected blocks, not
|
||||
# it's necessary to create guards for all enclosing affected blocks, not
|
||||
# just that of the current block.
|
||||
block.create_guard_next = True
|
||||
if block.is_loop_type:
|
||||
# continue only affects the innermost loop
|
||||
break
|
||||
template = """
|
||||
var_name = True
|
||||
@ -77,35 +75,13 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
template, var_name=self.state[_Continue].control_var_name)
|
||||
|
||||
def _postprocess_statement(self, node):
|
||||
# Example of how the state machine below works:
|
||||
#
|
||||
# 1| stmt # State: Continue_.used = False
|
||||
# | # Action: none
|
||||
# 2| if cond:
|
||||
# 3| continue # State: Continue_.used = True,
|
||||
# | # Continue_.guard_created = False,
|
||||
# | # Continue_.create_guard = False
|
||||
# | # Action: Continue_.create_guard = True
|
||||
# 4| stmt # State: Continue_.used = True,
|
||||
# | # Continue_.guard_created = False,
|
||||
# | # Continue_.create_guard = True
|
||||
# | # Action: create `if not continue_used`,
|
||||
# | # set Continue_.guard_created = True
|
||||
# 5| stmt # State: Continue_.used = True,
|
||||
# | # Continue_.guard_created = True
|
||||
# | # Action: none (will be wrapped under previously
|
||||
# | # created if node)
|
||||
|
||||
if self.state[_Continue].used:
|
||||
if self.state[_Block].guard_created:
|
||||
return node, None
|
||||
|
||||
elif not self.state[_Block].create_guard:
|
||||
self.state[_Block].create_guard = True
|
||||
return node, None
|
||||
|
||||
else:
|
||||
self.state[_Block].guard_created = True
|
||||
block = self.state[_Block]
|
||||
should_wrap_current = block.create_guard_current
|
||||
# After processing propagate whether to guard the next statement
|
||||
block.create_guard_current = block.create_guard_next
|
||||
block.create_guard_next = False
|
||||
if should_wrap_current:
|
||||
template = """
|
||||
if ag__.not_(var_name):
|
||||
original_node
|
||||
@ -160,7 +136,7 @@ class ContinueCanonicalizationTransformer(converter.Base):
|
||||
return node
|
||||
|
||||
def visit_If(self, node):
|
||||
node.body = self.visit_block(node.body)
|
||||
node.body = self._visit_non_loop_body(node.body)
|
||||
node.orelse = self._visit_non_loop_body(node.orelse)
|
||||
return node
|
||||
|
||||
|
@ -203,6 +203,27 @@ class ContinueCanonicalizationTest(converter_testing.TestCase):
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
def test_multiple_guarded_continues_with_side_effects(self):
|
||||
|
||||
def test_fn(x):
|
||||
def track(u, x):
|
||||
u.append(x)
|
||||
return x
|
||||
|
||||
u = []
|
||||
v = []
|
||||
while x > 0:
|
||||
x -= 1
|
||||
if track(u, x) > 1:
|
||||
continue
|
||||
if track(u, x) > 2:
|
||||
continue
|
||||
v.append(x)
|
||||
return u, v
|
||||
|
||||
self.assertTransformedEquivalent(test_fn, 3)
|
||||
self.assertTransformedEquivalent(test_fn, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user