Allow global and nonlocal keywords in converted code. No transformations are made - the statements retain original semantics.

PiperOrigin-RevId: 250415616
This commit is contained in:
Dan Moldovan 2019-05-28 20:36:34 -07:00 committed by TensorFlower Gardener
parent ee65ff1690
commit 4941b4a73f
15 changed files with 378 additions and 41 deletions

View File

@ -27,13 +27,6 @@ class UnsupportedFeaturesChecker(gast.NodeTransformer):
Any features detected will cause AutoGraph to not compile a function.
"""
# TODO(b/124103128): Implement support for `global` statements
def visit_Global(self, node):
raise NotImplementedError('The global keyword is not yet supported.')
def visit_Nonlocal(self, node):
raise NotImplementedError('The nonlocal keyword is not yet supported.')
# These checks could potentially be replaced with inspect.isgeneratorfunction
# to avoid a getsource/parse/ast-walk round trip.
def visit_Yield(self, node):

View File

@ -48,7 +48,7 @@ from tensorflow.python.util import tf_inspect
tf = utils.fake_tf()
testing_global_numeric = 2
global_n = 2
class TestResource(object):
@ -660,13 +660,14 @@ class ApiTest(test.TestCase):
def test_to_graph_with_globals(self):
def test_fn(x):
global testing_global_numeric
testing_global_numeric = x + testing_global_numeric
return testing_global_numeric
global global_n
global_n = x + global_n
return global_n
with self.assertRaisesRegex(
NotImplementedError, 'global keyword is not yet supported'):
api.to_graph(test_fn)
converted_fn = api.to_graph(test_fn)
prev_val = global_n
converted_fn(10)
self.assertGreater(global_n, prev_val)
def test_to_graph_with_kwargs_clashing_converted_call(self):

View File

@ -701,6 +701,12 @@ class AstToCfg(gast.NodeVisitor):
def visit_Pass(self, node):
self._process_basic_statement(node)
def visit_Global(self, node):
self._process_basic_statement(node)
def visit_Nonlocal(self, node):
self._process_basic_statement(node)
def visit_Print(self, node):
self._process_basic_statement(node)

View File

@ -46,8 +46,37 @@ py_test(
],
)
py_library(
name = "activity_test_lib",
testonly = True,
srcs = ["activity_test.py"],
srcs_version = "PY2AND3",
deps = [
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
py_test(
name = "activity_py3_test",
srcs = ["activity_py3_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = ["no_oss_py2"],
deps = [
":activity_test_lib",
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
py_test(
name = "liveness_test",
testonly = True,
srcs = ["liveness_test.py"],
srcs_version = "PY2AND3",
deps = [
@ -57,6 +86,32 @@ py_test(
],
)
py_library(
name = "liveness_test_lib",
srcs = ["liveness_test.py"],
srcs_version = "PY2AND3",
deps = [
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
py_test(
name = "liveness_py3_test",
srcs = ["liveness_py3_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = ["no_oss_py2"],
deps = [
":liveness_test_lib",
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
],
)
py_test(
name = "reaching_definitions_test",
srcs = ["reaching_definitions_test.py"],
@ -67,3 +122,29 @@ py_test(
"//tensorflow/python/autograph/pyct",
],
)
py_library(
name = "reaching_definitions_test_lib",
srcs = ["reaching_definitions_test.py"],
srcs_version = "PY2AND3",
deps = [
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
"@gast_archive//:gast",
],
)
py_test(
name = "reaching_definitions_py3_test",
srcs = ["reaching_definitions_py3_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = ["no_oss_py2"],
deps = [
":reaching_definitions_test_lib",
":static_analysis",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/pyct",
],
)

View File

@ -262,12 +262,6 @@ class ActivityAnalyzer(transformer.Base):
self._exit_scope()
return node
def visit_Nonlocal(self, node):
raise NotImplementedError()
def visit_Global(self, node):
raise NotImplementedError()
def visit_Expr(self, node):
return self._process_statement(node)

View File

@ -0,0 +1,51 @@
# python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for activity module, that only run in Python 3."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct.static_analysis import activity_test
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.platform import test
NodeAnno = annos.NodeAnno
class ActivityAnalyzerTest(activity_test.ActivityAnalyzerTestBase):
"""Tests which can only run in Python 3."""
def test_nonlocal_symbol(self):
nonlocal_a = 3
nonlocal_b = 13
def test_fn(c):
nonlocal nonlocal_a
nonlocal nonlocal_b
nonlocal_a = nonlocal_b + c
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('nonlocal_b', 'c'), ('nonlocal_a',))
if __name__ == '__main__':
test.main()

View File

@ -25,12 +25,18 @@ from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.qual_names import QN
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.platform import test
QN = qual_names.QN
NodeAnno = annos.NodeAnno
global_a = 7
global_b = 17
class ScopeTest(test.TestCase):
def assertMissing(self, qn, scope):
@ -110,7 +116,7 @@ class ScopeTest(test.TestCase):
self.assertFalse(QN('a') in child.referenced)
class ActivityAnalyzerTest(test.TestCase):
class ActivityAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn, future_features=())
@ -137,6 +143,9 @@ class ActivityAnalyzerTest(test.TestCase):
self.assertSymbolSetsAre(used, scope.read, 'read')
self.assertSymbolSetsAre(modified, scope.modified, 'modified')
class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
def test_print_statement(self):
def test_fn(a):
@ -497,6 +506,18 @@ class ActivityAnalyzerTest(test.TestCase):
else:
self.assertScopeIs(body_scope, ('a', 'b'), ('b',))
def test_global_symbol(self):
def test_fn(c):
global global_a
global global_b
global_a = global_b + c
node, _ = self._parse_and_analyze(test_fn)
fn_node = node
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
if __name__ == '__main__':
test.main()

View File

@ -70,8 +70,9 @@ class Analyzer(cfg.GraphVisitor):
# Nodes that don't have a scope annotation are assumed not to touch any
# symbols.
# This Name node below is a literal name, e.g. False
assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break,
gast.Pass)), type(node.ast_node)
assert isinstance(node.ast_node,
(gast.Name, gast.Continue, gast.Break, gast.Pass,
gast.Global, gast.Nonlocal)), type(node.ast_node)
live_out = set()
for n in node.next:
live_out |= self.in_[n]
@ -145,12 +146,6 @@ class WholeTreeAnalyzer(transformer.Base):
self.current_analyzer = parent_analyzer
return node
def visit_Nonlocal(self, node):
raise NotImplementedError()
def visit_Global(self, node):
raise NotImplementedError()
class Annotator(transformer.Base):
"""AST visitor that annotates each control flow block with live symbols."""

View File

@ -0,0 +1,54 @@
# python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for liveness module, that only run in Python 3."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct.static_analysis import annos
from tensorflow.python.autograph.pyct.static_analysis import liveness_test
from tensorflow.python.platform import test
NodeAnno = annos.NodeAnno
class LivenessAnalyzerTest(liveness_test.LivenessAnalyzerTestBase):
"""Tests which can only run in Python 3."""
def test_nonlocal_symbol(self):
nonlocal_a = 3
nonlocal_b = 13
def test_fn(c):
nonlocal nonlocal_a
nonlocal nonlocal_b
if nonlocal_a:
nonlocal_b = c
else:
nonlocal_b = c
return nonlocal_b
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[2], ('nonlocal_b',))
self.assertHasLiveIn(fn_body[2], ('nonlocal_a', 'c'))
if __name__ == '__main__':
test.main()

View File

@ -28,7 +28,11 @@ from tensorflow.python.autograph.pyct.static_analysis import liveness
from tensorflow.python.platform import test
class LivenessTest(test.TestCase):
global_a = 7
global_b = 17
class LivenessAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn, future_features=())
@ -59,6 +63,9 @@ class LivenessTest(test.TestCase):
expected = (expected,)
self.assertSetEqual(live_in_strs, set(expected))
class LivenessAnalyzerTest(LivenessAnalyzerTestBase):
def test_live_out_try_block(self):
def test_fn(x, a, b, c): # pylint:disable=unused-argument
@ -421,6 +428,22 @@ class LivenessTest(test.TestCase):
self.assertHasLiveIn(fn_body[0], ('y',))
def test_global_symbol(self):
def test_fn(c):
global global_a
global global_b
if global_a:
global_b = c
else:
global_b = c
return global_b
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasLiveOut(fn_body[2], ('global_b',))
self.assertHasLiveIn(fn_body[2], ('global_a', 'c'))
if __name__ == '__main__':
test.main()

View File

@ -34,6 +34,7 @@ import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis import annos
@ -147,6 +148,24 @@ class Analyzer(cfg.GraphVisitor):
kill = node_scope.modified | node_scope.deleted
defs_out = gen | (defs_in - kill)
elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
# Special case for global and nonlocal: they generate a definition,
# but are not tracked by activity analysis.
if node not in self.gen_map:
node_symbols = {}
for s in node.ast_node.names:
qn = qual_names.QN(s)
if qn in defs_in.value:
# In Python 2, this is a syntax warning. In Python 3, it's an error.
raise ValueError(
'"{}" is assigned before global definition'.format(s))
def_ = self._definition_factory()
node_symbols[qn] = def_
self.gen_map[node] = _NodeState(node_symbols)
gen = self.gen_map[node]
defs_out = defs_in | gen
else:
# Nodes that don't have a scope annotation are assumed not to touch any
# symbols.
@ -216,12 +235,6 @@ class TreeAnnotator(transformer.Base):
return node
def visit_Nonlocal(self, node):
raise NotImplementedError()
def visit_Global(self, node):
raise NotImplementedError()
def visit_Name(self, node):
if self.current_analyzer is None:
# Names may appear outside function defs - for example in class

View File

@ -0,0 +1,56 @@
# python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for reaching_definitions module, that only run in Python 3."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions_test
from tensorflow.python.platform import test
class ReachingDefinitionsAnalyzerTest(
reaching_definitions_test.ReachingDefinitionsAnalyzerTestBase):
"""Tests which can only run in Python 3."""
def test_nonlocal_symbol(self):
nonlocal_a = 3
nonlocal_b = 13
def test_fn():
nonlocal nonlocal_a
nonlocal nonlocal_b
if nonlocal_a:
nonlocal_b = []
return nonlocal_a, nonlocal_b
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasDefs(fn_body[2].test, 1)
self.assertHasDefs(fn_body[2].body[0].targets[0], 1)
self.assertHasDefs(fn_body[3].value.elts[0], 1)
self.assertHasDefs(fn_body[3].value.elts[1], 2)
self.assertSameDef(fn_body[2].test, fn_body[3].value.elts[0])
self.assertHasDefinedIn(fn_body[2], ('nonlocal_a', 'nonlocal_b'))
if __name__ == '__main__':
test.main()

View File

@ -30,7 +30,11 @@ from tensorflow.python.autograph.pyct.static_analysis import reaching_definition
from tensorflow.python.platform import test
class DefinitionInfoTest(test.TestCase):
global_a = 7
global_b = 17
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
node, source = parser.parse_entity(test_fn, future_features=())
@ -73,6 +77,9 @@ class DefinitionInfoTest(test.TestCase):
anno.getanno(first, anno.Static.DEFINITIONS)[0],
anno.getanno(second, anno.Static.DEFINITIONS)[0])
class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
def test_conditional(self):
def test_fn(a, b):
@ -366,6 +373,45 @@ class DefinitionInfoTest(test.TestCase):
else:
self.assertHasDefs(retval, 0)
def test_function_definition(self):
def test_fn():
def a():
pass
if a: # pylint:disable=using-constant-test
a = None
return a
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasDefs(fn_body[1].test, 1)
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
self.assertHasDefs(fn_body[2].value, 2)
self.assertHasDefinedIn(fn_body[1], ('a',))
def test_global(self):
def test_fn():
global global_a
global global_b
if global_a:
global_b = []
return global_a, global_b
node = self._parse_and_analyze(test_fn)
fn_body = node.body
self.assertHasDefs(fn_body[2].test, 1)
self.assertHasDefs(fn_body[2].body[0].targets[0], 1)
self.assertHasDefs(fn_body[3].value.elts[0], 1)
self.assertHasDefs(fn_body[3].value.elts[1], 2)
self.assertSameDef(fn_body[2].test, fn_body[3].value.elts[0])
self.assertHasDefinedIn(fn_body[2], ('global_a', 'global_b'))
if __name__ == '__main__':
test.main()

View File

@ -105,7 +105,8 @@ do_pylint() {
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable"
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""

View File

@ -159,7 +159,9 @@ def main():
missing_dependencies = []
# File extensions and endings to ignore
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
ignore_extensions = [
"_test", "_test.py", "_test_gpu", "_test_gpu.py", "_test_lib"
]
ignored_files_count = 0
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)