Preserve the directives annotation while lowering break statements.
PiperOrigin-RevId: 295780462 Change-Id: I48fa59628c110aafe250ba20b7b6cdf2cae73e26
This commit is contained in:
parent
149f584de1
commit
26bf35aec5
@ -71,6 +71,7 @@ class BreakTransformer(converter.Base):
|
||||
return nodes, break_used
|
||||
|
||||
def visit_While(self, node):
|
||||
original_node = node
|
||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
||||
|
||||
@ -98,9 +99,13 @@ class BreakTransformer(converter.Base):
|
||||
body=node.body,
|
||||
orelse=guarded_orelse)
|
||||
|
||||
new_while_node = node[1]
|
||||
anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
|
||||
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
original_node = node
|
||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
||||
|
||||
@ -137,7 +142,9 @@ class BreakTransformer(converter.Base):
|
||||
body=node.body,
|
||||
orelse=guarded_orelse)
|
||||
|
||||
anno.setanno(node[1], 'extra_test', extra_test)
|
||||
new_for_node = node[1]
|
||||
anno.setanno(new_for_node, 'extra_test', extra_test)
|
||||
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
|
||||
|
||||
return node
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -46,6 +47,21 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
self.assertTransformedEquivalent(test_fn, 1)
|
||||
self.assertTransformedEquivalent(test_fn, 4)
|
||||
|
||||
def test_while_loop_preserves_directives(self):
|
||||
|
||||
def test_fn(x):
|
||||
while x > 0:
|
||||
x -= 1
|
||||
if x % 2 == 0:
|
||||
break
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
fake_annotation = object()
|
||||
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||
node = break_statements.transform(node, ctx)
|
||||
self.assertIs(
|
||||
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
||||
|
||||
def test_for_loop(self):
|
||||
|
||||
def test_fn(a):
|
||||
@ -63,6 +79,20 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
||||
# but the section following the break will be skipped.
|
||||
self.assertEqual([3], result.test_fn([5, 4]))
|
||||
|
||||
def test_for_loop_preserves_directives(self):
|
||||
|
||||
def test_fn(a):
|
||||
for x in a:
|
||||
if x % 2 == 0:
|
||||
break
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
fake_annotation = object()
|
||||
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||
node = break_statements.transform(node, ctx)
|
||||
self.assertIs(
|
||||
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
||||
|
||||
def test_nested(self):
|
||||
|
||||
def test_fn(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user