Preserve the directives annotation while lowering break statements.

PiperOrigin-RevId: 295780462
Change-Id: I48fa59628c110aafe250ba20b7b6cdf2cae73e26
This commit is contained in:
Dan Moldovan 2020-02-18 11:24:01 -08:00 committed by TensorFlower Gardener
parent 149f584de1
commit 26bf35aec5
2 changed files with 38 additions and 1 deletions

View File

@ -71,6 +71,7 @@ class BreakTransformer(converter.Base):
return nodes, break_used
def visit_While(self, node):
original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
@ -98,9 +99,13 @@ class BreakTransformer(converter.Base):
body=node.body,
orelse=guarded_orelse)
new_while_node = node[1]
anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
return node
def visit_For(self, node):
original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
@ -137,7 +142,9 @@ class BreakTransformer(converter.Base):
body=node.body,
orelse=guarded_orelse)
anno.setanno(node[1], 'extra_test', extra_test)
new_for_node = node[1]
anno.setanno(new_for_node, 'extra_test', extra_test)
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
return node

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
@ -46,6 +47,21 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
def test_while_loop_preserves_directives(self):
def test_fn(x):
while x > 0:
x -= 1
if x % 2 == 0:
break
node, ctx = self.prepare(test_fn, {})
fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx)
self.assertIs(
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
def test_for_loop(self):
def test_fn(a):
@ -63,6 +79,20 @@ class BreakCanonicalizationTest(converter_testing.TestCase):
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
def test_for_loop_preserves_directives(self):
def test_fn(a):
for x in a:
if x % 2 == 0:
break
node, ctx = self.prepare(test_fn, {})
fake_annotation = object()
anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
node = break_statements.transform(node, ctx)
self.assertIs(
anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
def test_nested(self):
def test_fn(x):