Preserve the directives annotation while lowering break statements.
PiperOrigin-RevId: 295780462 Change-Id: I48fa59628c110aafe250ba20b7b6cdf2cae73e26
This commit is contained in:
parent
149f584de1
commit
26bf35aec5
@ -71,6 +71,7 @@ class BreakTransformer(converter.Base):
|
|||||||
return nodes, break_used
|
return nodes, break_used
|
||||||
|
|
||||||
def visit_While(self, node):
|
def visit_While(self, node):
|
||||||
|
original_node = node
|
||||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||||
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
||||||
|
|
||||||
@ -98,9 +99,13 @@ class BreakTransformer(converter.Base):
|
|||||||
body=node.body,
|
body=node.body,
|
||||||
orelse=guarded_orelse)
|
orelse=guarded_orelse)
|
||||||
|
|
||||||
|
new_while_node = node[1]
|
||||||
|
anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_For(self, node):
|
def visit_For(self, node):
|
||||||
|
original_node = node
|
||||||
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
|
||||||
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
|
||||||
|
|
||||||
@ -137,7 +142,9 @@ class BreakTransformer(converter.Base):
|
|||||||
body=node.body,
|
body=node.body,
|
||||||
orelse=guarded_orelse)
|
orelse=guarded_orelse)
|
||||||
|
|
||||||
anno.setanno(node[1], 'extra_test', extra_test)
|
new_for_node = node[1]
|
||||||
|
anno.setanno(new_for_node, 'extra_test', extra_test)
|
||||||
|
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
|
||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.converters import break_statements
|
from tensorflow.python.autograph.converters import break_statements
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
from tensorflow.python.autograph.core import converter_testing
|
||||||
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -46,6 +47,21 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
self.assertTransformedEquivalent(test_fn, 1)
|
self.assertTransformedEquivalent(test_fn, 1)
|
||||||
self.assertTransformedEquivalent(test_fn, 4)
|
self.assertTransformedEquivalent(test_fn, 4)
|
||||||
|
|
||||||
|
def test_while_loop_preserves_directives(self):
|
||||||
|
|
||||||
|
def test_fn(x):
|
||||||
|
while x > 0:
|
||||||
|
x -= 1
|
||||||
|
if x % 2 == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
node, ctx = self.prepare(test_fn, {})
|
||||||
|
fake_annotation = object()
|
||||||
|
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||||
|
node = break_statements.transform(node, ctx)
|
||||||
|
self.assertIs(
|
||||||
|
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
||||||
|
|
||||||
def test_for_loop(self):
|
def test_for_loop(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
@ -63,6 +79,20 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
|
|||||||
# but the section following the break will be skipped.
|
# but the section following the break will be skipped.
|
||||||
self.assertEqual([3], result.test_fn([5, 4]))
|
self.assertEqual([3], result.test_fn([5, 4]))
|
||||||
|
|
||||||
|
def test_for_loop_preserves_directives(self):
|
||||||
|
|
||||||
|
def test_fn(a):
|
||||||
|
for x in a:
|
||||||
|
if x % 2 == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
node, ctx = self.prepare(test_fn, {})
|
||||||
|
fake_annotation = object()
|
||||||
|
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
|
||||||
|
node = break_statements.transform(node, ctx)
|
||||||
|
self.assertIs(
|
||||||
|
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
|
||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user