Require statement directives like set_loop_options to be the first statement in their respective block.
PiperOrigin-RevId: 285177965 Change-Id: I0f817602d6144016db5e17b7a70361c0e1eb248f
This commit is contained in:
parent
27f3043b89
commit
b963d6a436
tensorflow/python/autograph/converters
@ -40,11 +40,18 @@ from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
ENCLOSING_LOOP = 'enclosing_loop'
|
||||
|
||||
STATIC_VALUE = 'static_value'
|
||||
"""Used for AST annotations, see visit_Name."""
|
||||
|
||||
|
||||
class _LoopScope(object):
|
||||
|
||||
def __init__(self):
|
||||
self.ast_node = None
|
||||
self.statements_visited = 0
|
||||
|
||||
|
||||
def _map_args(call_node, function):
|
||||
"""Maps AST call nodes to the actual function's arguments.
|
||||
|
||||
@ -94,10 +101,14 @@ class DirectivesTransformer(converter.Base):
|
||||
return call_node
|
||||
|
||||
def _process_statement_directive(self, call_node, directive):
|
||||
if self.local_scope_level < 2:
|
||||
if self.state[_LoopScope].statements_visited > 1:
|
||||
raise ValueError(
|
||||
'"%s" must be the first statement in the loop block' % (
|
||||
directive.__name__))
|
||||
if self.state[_LoopScope].level < 2:
|
||||
raise ValueError(
|
||||
'"%s" must be used inside a statement' % directive.__name__)
|
||||
target = self.get_local(ENCLOSING_LOOP)
|
||||
target = self.state[_LoopScope].ast_node
|
||||
node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {})
|
||||
node_anno[directive] = _map_args(call_node, directive)
|
||||
anno.setanno(target, anno.Basic.DIRECTIVES, node_anno)
|
||||
@ -120,7 +131,16 @@ class DirectivesTransformer(converter.Base):
|
||||
anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr))
|
||||
return node
|
||||
|
||||
def visit_Assign(self, node):
|
||||
self.state[_LoopScope].statements_visited += 1
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_AugAssign(self, node):
|
||||
self.state[_LoopScope].statements_visited += 1
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.state[_LoopScope].statements_visited += 1
|
||||
node = self.generic_visit(node)
|
||||
if isinstance(node.value, gast.Call):
|
||||
call_node = node.value
|
||||
@ -141,10 +161,10 @@ class DirectivesTransformer(converter.Base):
|
||||
# That means that if we ever have a directive that affects things other than
|
||||
# loops, we'll need support for parallel scopes, or have multiple converters.
|
||||
def _track_and_visit_loop(self, node):
|
||||
self.enter_local_scope()
|
||||
self.set_local(ENCLOSING_LOOP, node)
|
||||
self.state[_LoopScope].enter()
|
||||
self.state[_LoopScope].ast_node = node
|
||||
node = self.generic_visit(node)
|
||||
self.exit_local_scope()
|
||||
self.state[_LoopScope].exit()
|
||||
return node
|
||||
|
||||
def visit_While(self, node):
|
||||
|
@ -73,7 +73,7 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
self.assertEqual(d['back_prop'].id, 'a')
|
||||
self.assertNotIn('swap_memory', d)
|
||||
|
||||
def test_loop_target_with_no_loop(self):
|
||||
def test_loop_target_no_loop(self):
|
||||
|
||||
def test_fn():
|
||||
directives.set_loop_options()
|
||||
@ -82,6 +82,18 @@ class DirectivesTest(converter_testing.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'must be used inside a statement'):
|
||||
node = directives_converter.transform(node, ctx)
|
||||
|
||||
def test_loop_target_not_first(self):
|
||||
|
||||
def test_fn():
|
||||
a = 1
|
||||
while True:
|
||||
a = 2
|
||||
directives.set_loop_options(parallel_iterations=10, back_prop=a)
|
||||
|
||||
node, ctx = self.prepare(test_fn, {'directives': directives})
|
||||
with self.assertRaisesRegexp(ValueError, 'must be the first statement'):
|
||||
node = directives_converter.transform(node, ctx)
|
||||
|
||||
def test_invalid_default(self):
|
||||
|
||||
def invalid_directive(valid_arg, invalid_default=object()):
|
||||
|
Loading…
Reference in New Issue
Block a user