Require statement directives like set_loop_options to be the first statement in their respective block.

PiperOrigin-RevId: 285177965
Change-Id: I0f817602d6144016db5e17b7a70361c0e1eb248f
This commit is contained in:
Dan Moldovan 2019-12-12 06:16:41 -08:00 committed by TensorFlower Gardener
parent 27f3043b89
commit b963d6a436
2 changed files with 39 additions and 7 deletions
tensorflow/python/autograph/converters

View File

@ -40,11 +40,18 @@ from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.util import tf_inspect
ENCLOSING_LOOP = 'enclosing_loop'
STATIC_VALUE = 'static_value'
"""Used for AST annotations, see visit_Name."""
class _LoopScope(object):
def __init__(self):
self.ast_node = None
self.statements_visited = 0
def _map_args(call_node, function):
"""Maps AST call nodes to the actual function's arguments.
@ -94,10 +101,14 @@ class DirectivesTransformer(converter.Base):
return call_node
def _process_statement_directive(self, call_node, directive):
if self.local_scope_level < 2:
if self.state[_LoopScope].statements_visited > 1:
raise ValueError(
'"%s" must be the first statement in the loop block' % (
directive.__name__))
if self.state[_LoopScope].level < 2:
raise ValueError(
'"%s" must be used inside a statement' % directive.__name__)
target = self.get_local(ENCLOSING_LOOP)
target = self.state[_LoopScope].ast_node
node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
node_anno[directive] = _map_args(call_node, directive)
anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
@ -120,7 +131,16 @@ class DirectivesTransformer(converter.Base):
anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
return node
def visit_Assign(self, node):
self.state[_LoopScope].statements_visited += 1
return self.generic_visit(node)
def visit_AugAssign(self, node):
self.state[_LoopScope].statements_visited += 1
return self.generic_visit(node)
def visit_Expr(self, node):
self.state[_LoopScope].statements_visited += 1
node = self.generic_visit(node)
if isinstance(node.value, gast.Call):
call_node = node.value
@ -141,10 +161,10 @@ class DirectivesTransformer(converter.Base):
# That means that if we ever have a directive that affects things other than
# loops, we'll need support for parallel scopes, or have multiple converters.
def _track_and_visit_loop(self, node):
self.enter_local_scope()
self.set_local(ENCLOSING_LOOP, node)
self.state[_LoopScope].enter()
self.state[_LoopScope].ast_node = node
node = self.generic_visit(node)
self.exit_local_scope()
self.state[_LoopScope].exit()
return node
def visit_While(self, node):

View File

@ -73,7 +73,7 @@ class DirectivesTest(converter_testing.TestCase):
self.assertEqual(d['back_prop'].id, 'a')
self.assertNotIn('swap_memory', d)
def test_loop_target_with_no_loop(self):
def test_loop_target_no_loop(self):
def test_fn():
directives.set_loop_options()
@ -82,6 +82,18 @@ class DirectivesTest(converter_testing.TestCase):
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
node = directives_converter.transform(node, ctx)
def test_loop_target_not_first(self):
def test_fn():
a = 1
while True:
a = 2
directives.set_loop_options(parallel_iterations=10, back_prop=a)
node, ctx = self.prepare(test_fn, {'directives': directives})
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
node = directives_converter.transform(node, ctx)
def test_invalid_default(self):
def invalid_directive(valid_arg, invalid_default=object()):