Allow global and nonlocal keywords in converted code. No transformations are made - the statements retain original semantics.
PiperOrigin-RevId: 250415616
This commit is contained in:
parent
ee65ff1690
commit
4941b4a73f
tensorflow
python/autograph
core
impl
pyct
tools
@ -27,13 +27,6 @@ class UnsupportedFeaturesChecker(gast.NodeTransformer):
|
|||||||
Any features detected will cause AutoGraph to not compile a function.
|
Any features detected will cause AutoGraph to not compile a function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO(b/124103128): Implement support for `global` statements
|
|
||||||
def visit_Global(self, node):
|
|
||||||
raise NotImplementedError('The global keyword is not yet supported.')
|
|
||||||
|
|
||||||
def visit_Nonlocal(self, node):
|
|
||||||
raise NotImplementedError('The nonlocal keyword is not yet supported.')
|
|
||||||
|
|
||||||
# These checks could potentially be replaced with inspect.isgeneratorfunction
|
# These checks could potentially be replaced with inspect.isgeneratorfunction
|
||||||
# to avoid a getsource/parse/ast-walk round trip.
|
# to avoid a getsource/parse/ast-walk round trip.
|
||||||
def visit_Yield(self, node):
|
def visit_Yield(self, node):
|
||||||
|
@ -48,7 +48,7 @@ from tensorflow.python.util import tf_inspect
|
|||||||
tf = utils.fake_tf()
|
tf = utils.fake_tf()
|
||||||
|
|
||||||
|
|
||||||
testing_global_numeric = 2
|
global_n = 2
|
||||||
|
|
||||||
|
|
||||||
class TestResource(object):
|
class TestResource(object):
|
||||||
@ -660,13 +660,14 @@ class ApiTest(test.TestCase):
|
|||||||
def test_to_graph_with_globals(self):
|
def test_to_graph_with_globals(self):
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
global testing_global_numeric
|
global global_n
|
||||||
testing_global_numeric = x + testing_global_numeric
|
global_n = x + global_n
|
||||||
return testing_global_numeric
|
return global_n
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
converted_fn = api.to_graph(test_fn)
|
||||||
NotImplementedError, 'global keyword is not yet supported'):
|
prev_val = global_n
|
||||||
api.to_graph(test_fn)
|
converted_fn(10)
|
||||||
|
self.assertGreater(global_n, prev_val)
|
||||||
|
|
||||||
def test_to_graph_with_kwargs_clashing_converted_call(self):
|
def test_to_graph_with_kwargs_clashing_converted_call(self):
|
||||||
|
|
||||||
|
@ -701,6 +701,12 @@ class AstToCfg(gast.NodeVisitor):
|
|||||||
def visit_Pass(self, node):
|
def visit_Pass(self, node):
|
||||||
self._process_basic_statement(node)
|
self._process_basic_statement(node)
|
||||||
|
|
||||||
|
def visit_Global(self, node):
|
||||||
|
self._process_basic_statement(node)
|
||||||
|
|
||||||
|
def visit_Nonlocal(self, node):
|
||||||
|
self._process_basic_statement(node)
|
||||||
|
|
||||||
def visit_Print(self, node):
|
def visit_Print(self, node):
|
||||||
self._process_basic_statement(node)
|
self._process_basic_statement(node)
|
||||||
|
|
||||||
|
@ -46,8 +46,37 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "activity_test_lib",
|
||||||
|
testonly = True,
|
||||||
|
srcs = ["activity_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
"@gast_archive//:gast",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "activity_py3_test",
|
||||||
|
srcs = ["activity_py3_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
tags = ["no_oss_py2"],
|
||||||
|
deps = [
|
||||||
|
":activity_test_lib",
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
"@gast_archive//:gast",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "liveness_test",
|
name = "liveness_test",
|
||||||
|
testonly = True,
|
||||||
srcs = ["liveness_test.py"],
|
srcs = ["liveness_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
@ -57,6 +86,32 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "liveness_test_lib",
|
||||||
|
srcs = ["liveness_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
"@gast_archive//:gast",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "liveness_py3_test",
|
||||||
|
srcs = ["liveness_py3_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
tags = ["no_oss_py2"],
|
||||||
|
deps = [
|
||||||
|
":liveness_test_lib",
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "reaching_definitions_test",
|
name = "reaching_definitions_test",
|
||||||
srcs = ["reaching_definitions_test.py"],
|
srcs = ["reaching_definitions_test.py"],
|
||||||
@ -67,3 +122,29 @@ py_test(
|
|||||||
"//tensorflow/python/autograph/pyct",
|
"//tensorflow/python/autograph/pyct",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "reaching_definitions_test_lib",
|
||||||
|
srcs = ["reaching_definitions_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
"@gast_archive//:gast",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "reaching_definitions_py3_test",
|
||||||
|
srcs = ["reaching_definitions_py3_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY3",
|
||||||
|
tags = ["no_oss_py2"],
|
||||||
|
deps = [
|
||||||
|
":reaching_definitions_test_lib",
|
||||||
|
":static_analysis",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/autograph/pyct",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -262,12 +262,6 @@ class ActivityAnalyzer(transformer.Base):
|
|||||||
self._exit_scope()
|
self._exit_scope()
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_Nonlocal(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def visit_Global(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def visit_Expr(self, node):
|
def visit_Expr(self, node):
|
||||||
return self._process_statement(node)
|
return self._process_statement(node)
|
||||||
|
|
||||||
|
@ -0,0 +1,51 @@
|
|||||||
|
# python3
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for activity module, that only run in Python 3."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.pyct import anno
|
||||||
|
from tensorflow.python.autograph.pyct.static_analysis import activity_test
|
||||||
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
NodeAnno = annos.NodeAnno
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityAnalyzerTest(activity_test.ActivityAnalyzerTestBase):
|
||||||
|
"""Tests which can only run in Python 3."""
|
||||||
|
|
||||||
|
def test_nonlocal_symbol(self):
|
||||||
|
|
||||||
|
nonlocal_a = 3
|
||||||
|
nonlocal_b = 13
|
||||||
|
|
||||||
|
def test_fn(c):
|
||||||
|
nonlocal nonlocal_a
|
||||||
|
nonlocal nonlocal_b
|
||||||
|
nonlocal_a = nonlocal_b + c
|
||||||
|
|
||||||
|
node, _ = self._parse_and_analyze(test_fn)
|
||||||
|
fn_node = node
|
||||||
|
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||||
|
self.assertScopeIs(body_scope, ('nonlocal_b', 'c'), ('nonlocal_a',))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -25,12 +25,18 @@ from tensorflow.python.autograph.pyct import anno
|
|||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
from tensorflow.python.autograph.pyct import qual_names
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
from tensorflow.python.autograph.pyct.qual_names import QN
|
|
||||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||||
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
QN = qual_names.QN
|
||||||
|
NodeAnno = annos.NodeAnno
|
||||||
|
|
||||||
|
global_a = 7
|
||||||
|
global_b = 17
|
||||||
|
|
||||||
|
|
||||||
class ScopeTest(test.TestCase):
|
class ScopeTest(test.TestCase):
|
||||||
|
|
||||||
def assertMissing(self, qn, scope):
|
def assertMissing(self, qn, scope):
|
||||||
@ -110,7 +116,7 @@ class ScopeTest(test.TestCase):
|
|||||||
self.assertFalse(QN('a') in child.referenced)
|
self.assertFalse(QN('a') in child.referenced)
|
||||||
|
|
||||||
|
|
||||||
class ActivityAnalyzerTest(test.TestCase):
|
class ActivityAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
@ -137,6 +143,9 @@ class ActivityAnalyzerTest(test.TestCase):
|
|||||||
self.assertSymbolSetsAre(used, scope.read, 'read')
|
self.assertSymbolSetsAre(used, scope.read, 'read')
|
||||||
self.assertSymbolSetsAre(modified, scope.modified, 'modified')
|
self.assertSymbolSetsAre(modified, scope.modified, 'modified')
|
||||||
|
|
||||||
|
|
||||||
|
class ActivityAnalyzerTest(ActivityAnalyzerTestBase):
|
||||||
|
|
||||||
def test_print_statement(self):
|
def test_print_statement(self):
|
||||||
|
|
||||||
def test_fn(a):
|
def test_fn(a):
|
||||||
@ -497,6 +506,18 @@ class ActivityAnalyzerTest(test.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertScopeIs(body_scope, ('a', 'b'), ('b',))
|
self.assertScopeIs(body_scope, ('a', 'b'), ('b',))
|
||||||
|
|
||||||
|
def test_global_symbol(self):
|
||||||
|
|
||||||
|
def test_fn(c):
|
||||||
|
global global_a
|
||||||
|
global global_b
|
||||||
|
global_a = global_b + c
|
||||||
|
|
||||||
|
node, _ = self._parse_and_analyze(test_fn)
|
||||||
|
fn_node = node
|
||||||
|
body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
|
||||||
|
self.assertScopeIs(body_scope, ('global_b', 'c'), ('global_a',))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -70,8 +70,9 @@ class Analyzer(cfg.GraphVisitor):
|
|||||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
# Nodes that don't have a scope annotation are assumed not to touch any
|
||||||
# symbols.
|
# symbols.
|
||||||
# This Name node below is a literal name, e.g. False
|
# This Name node below is a literal name, e.g. False
|
||||||
assert isinstance(node.ast_node, (gast.Name, gast.Continue, gast.Break,
|
assert isinstance(node.ast_node,
|
||||||
gast.Pass)), type(node.ast_node)
|
(gast.Name, gast.Continue, gast.Break, gast.Pass,
|
||||||
|
gast.Global, gast.Nonlocal)), type(node.ast_node)
|
||||||
live_out = set()
|
live_out = set()
|
||||||
for n in node.next:
|
for n in node.next:
|
||||||
live_out |= self.in_[n]
|
live_out |= self.in_[n]
|
||||||
@ -145,12 +146,6 @@ class WholeTreeAnalyzer(transformer.Base):
|
|||||||
self.current_analyzer = parent_analyzer
|
self.current_analyzer = parent_analyzer
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_Nonlocal(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def visit_Global(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class Annotator(transformer.Base):
|
class Annotator(transformer.Base):
|
||||||
"""AST visitor that annotates each control flow block with live symbols."""
|
"""AST visitor that annotates each control flow block with live symbols."""
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
# python3
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for liveness module, that only run in Python 3."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||||
|
from tensorflow.python.autograph.pyct.static_analysis import liveness_test
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
NodeAnno = annos.NodeAnno
|
||||||
|
|
||||||
|
|
||||||
|
class LivenessAnalyzerTest(liveness_test.LivenessAnalyzerTestBase):
|
||||||
|
"""Tests which can only run in Python 3."""
|
||||||
|
|
||||||
|
def test_nonlocal_symbol(self):
|
||||||
|
|
||||||
|
nonlocal_a = 3
|
||||||
|
nonlocal_b = 13
|
||||||
|
|
||||||
|
def test_fn(c):
|
||||||
|
nonlocal nonlocal_a
|
||||||
|
nonlocal nonlocal_b
|
||||||
|
if nonlocal_a:
|
||||||
|
nonlocal_b = c
|
||||||
|
else:
|
||||||
|
nonlocal_b = c
|
||||||
|
return nonlocal_b
|
||||||
|
|
||||||
|
node = self._parse_and_analyze(test_fn)
|
||||||
|
fn_body = node.body
|
||||||
|
self.assertHasLiveOut(fn_body[2], ('nonlocal_b',))
|
||||||
|
self.assertHasLiveIn(fn_body[2], ('nonlocal_a', 'c'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -28,7 +28,11 @@ from tensorflow.python.autograph.pyct.static_analysis import liveness
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class LivenessTest(test.TestCase):
|
global_a = 7
|
||||||
|
global_b = 17
|
||||||
|
|
||||||
|
|
||||||
|
class LivenessAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
@ -59,6 +63,9 @@ class LivenessTest(test.TestCase):
|
|||||||
expected = (expected,)
|
expected = (expected,)
|
||||||
self.assertSetEqual(live_in_strs, set(expected))
|
self.assertSetEqual(live_in_strs, set(expected))
|
||||||
|
|
||||||
|
|
||||||
|
class LivenessAnalyzerTest(LivenessAnalyzerTestBase):
|
||||||
|
|
||||||
def test_live_out_try_block(self):
|
def test_live_out_try_block(self):
|
||||||
|
|
||||||
def test_fn(x, a, b, c): # pylint:disable=unused-argument
|
def test_fn(x, a, b, c): # pylint:disable=unused-argument
|
||||||
@ -421,6 +428,22 @@ class LivenessTest(test.TestCase):
|
|||||||
|
|
||||||
self.assertHasLiveIn(fn_body[0], ('y',))
|
self.assertHasLiveIn(fn_body[0], ('y',))
|
||||||
|
|
||||||
|
def test_global_symbol(self):
|
||||||
|
|
||||||
|
def test_fn(c):
|
||||||
|
global global_a
|
||||||
|
global global_b
|
||||||
|
if global_a:
|
||||||
|
global_b = c
|
||||||
|
else:
|
||||||
|
global_b = c
|
||||||
|
return global_b
|
||||||
|
|
||||||
|
node = self._parse_and_analyze(test_fn)
|
||||||
|
fn_body = node.body
|
||||||
|
self.assertHasLiveOut(fn_body[2], ('global_b',))
|
||||||
|
self.assertHasLiveIn(fn_body[2], ('global_a', 'c'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -34,6 +34,7 @@ import gast
|
|||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
from tensorflow.python.autograph.pyct import cfg
|
||||||
|
from tensorflow.python.autograph.pyct import qual_names
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||||
|
|
||||||
@ -147,6 +148,24 @@ class Analyzer(cfg.GraphVisitor):
|
|||||||
kill = node_scope.modified | node_scope.deleted
|
kill = node_scope.modified | node_scope.deleted
|
||||||
defs_out = gen | (defs_in - kill)
|
defs_out = gen | (defs_in - kill)
|
||||||
|
|
||||||
|
elif isinstance(node.ast_node, (gast.Global, gast.Nonlocal)):
|
||||||
|
# Special case for global and nonlocal: they generate a definition,
|
||||||
|
# but are not tracked by activity analysis.
|
||||||
|
if node not in self.gen_map:
|
||||||
|
node_symbols = {}
|
||||||
|
for s in node.ast_node.names:
|
||||||
|
qn = qual_names.QN(s)
|
||||||
|
if qn in defs_in.value:
|
||||||
|
# In Python 2, this is a syntax warning. In Python 3, it's an error.
|
||||||
|
raise ValueError(
|
||||||
|
'"{}" is assigned before global definition'.format(s))
|
||||||
|
def_ = self._definition_factory()
|
||||||
|
node_symbols[qn] = def_
|
||||||
|
self.gen_map[node] = _NodeState(node_symbols)
|
||||||
|
|
||||||
|
gen = self.gen_map[node]
|
||||||
|
defs_out = defs_in | gen
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Nodes that don't have a scope annotation are assumed not to touch any
|
# Nodes that don't have a scope annotation are assumed not to touch any
|
||||||
# symbols.
|
# symbols.
|
||||||
@ -216,12 +235,6 @@ class TreeAnnotator(transformer.Base):
|
|||||||
|
|
||||||
return node
|
return node
|
||||||
|
|
||||||
def visit_Nonlocal(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def visit_Global(self, node):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def visit_Name(self, node):
|
def visit_Name(self, node):
|
||||||
if self.current_analyzer is None:
|
if self.current_analyzer is None:
|
||||||
# Names may appear outside function defs - for example in class
|
# Names may appear outside function defs - for example in class
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
# python3
|
||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for reaching_definitions module, that only run in Python 3."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions_test
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ReachingDefinitionsAnalyzerTest(
|
||||||
|
reaching_definitions_test.ReachingDefinitionsAnalyzerTestBase):
|
||||||
|
"""Tests which can only run in Python 3."""
|
||||||
|
|
||||||
|
def test_nonlocal_symbol(self):
|
||||||
|
|
||||||
|
nonlocal_a = 3
|
||||||
|
nonlocal_b = 13
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
nonlocal nonlocal_a
|
||||||
|
nonlocal nonlocal_b
|
||||||
|
if nonlocal_a:
|
||||||
|
nonlocal_b = []
|
||||||
|
return nonlocal_a, nonlocal_b
|
||||||
|
|
||||||
|
node = self._parse_and_analyze(test_fn)
|
||||||
|
fn_body = node.body
|
||||||
|
|
||||||
|
self.assertHasDefs(fn_body[2].test, 1)
|
||||||
|
self.assertHasDefs(fn_body[2].body[0].targets[0], 1)
|
||||||
|
self.assertHasDefs(fn_body[3].value.elts[0], 1)
|
||||||
|
self.assertHasDefs(fn_body[3].value.elts[1], 2)
|
||||||
|
|
||||||
|
self.assertSameDef(fn_body[2].test, fn_body[3].value.elts[0])
|
||||||
|
|
||||||
|
self.assertHasDefinedIn(fn_body[2], ('nonlocal_a', 'nonlocal_b'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -30,7 +30,11 @@ from tensorflow.python.autograph.pyct.static_analysis import reaching_definition
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class DefinitionInfoTest(test.TestCase):
|
global_a = 7
|
||||||
|
global_b = 17
|
||||||
|
|
||||||
|
|
||||||
|
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
@ -73,6 +77,9 @@ class DefinitionInfoTest(test.TestCase):
|
|||||||
anno.getanno(first, anno.Static.DEFINITIONS)[0],
|
anno.getanno(first, anno.Static.DEFINITIONS)[0],
|
||||||
anno.getanno(second, anno.Static.DEFINITIONS)[0])
|
anno.getanno(second, anno.Static.DEFINITIONS)[0])
|
||||||
|
|
||||||
|
|
||||||
|
class ReachingDefinitionsAnalyzerTest(ReachingDefinitionsAnalyzerTestBase):
|
||||||
|
|
||||||
def test_conditional(self):
|
def test_conditional(self):
|
||||||
|
|
||||||
def test_fn(a, b):
|
def test_fn(a, b):
|
||||||
@ -366,6 +373,45 @@ class DefinitionInfoTest(test.TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertHasDefs(retval, 0)
|
self.assertHasDefs(retval, 0)
|
||||||
|
|
||||||
|
def test_function_definition(self):
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
def a():
|
||||||
|
pass
|
||||||
|
if a: # pylint:disable=using-constant-test
|
||||||
|
a = None
|
||||||
|
return a
|
||||||
|
|
||||||
|
node = self._parse_and_analyze(test_fn)
|
||||||
|
fn_body = node.body
|
||||||
|
|
||||||
|
self.assertHasDefs(fn_body[1].test, 1)
|
||||||
|
self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
|
||||||
|
self.assertHasDefs(fn_body[2].value, 2)
|
||||||
|
|
||||||
|
self.assertHasDefinedIn(fn_body[1], ('a',))
|
||||||
|
|
||||||
|
def test_global(self):
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
global global_a
|
||||||
|
global global_b
|
||||||
|
if global_a:
|
||||||
|
global_b = []
|
||||||
|
return global_a, global_b
|
||||||
|
|
||||||
|
node = self._parse_and_analyze(test_fn)
|
||||||
|
fn_body = node.body
|
||||||
|
|
||||||
|
self.assertHasDefs(fn_body[2].test, 1)
|
||||||
|
self.assertHasDefs(fn_body[2].body[0].targets[0], 1)
|
||||||
|
self.assertHasDefs(fn_body[3].value.elts[0], 1)
|
||||||
|
self.assertHasDefs(fn_body[3].value.elts[1], 2)
|
||||||
|
|
||||||
|
self.assertSameDef(fn_body[2].test, fn_body[3].value.elts[0])
|
||||||
|
|
||||||
|
self.assertHasDefinedIn(fn_body[2], ('global_a', 'global_b'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -105,7 +105,8 @@ do_pylint() {
|
|||||||
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
"^tensorflow/python/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
|
||||||
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
"^tensorflow/python/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
|
||||||
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned "\
|
||||||
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable"
|
"^tensorflow/python/keras/utils/data_utils.py.*\[E1102.*not-callable "\
|
||||||
|
"^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "
|
||||||
|
|
||||||
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
|
||||||
|
|
||||||
|
@ -159,7 +159,9 @@ def main():
|
|||||||
|
|
||||||
missing_dependencies = []
|
missing_dependencies = []
|
||||||
# File extensions and endings to ignore
|
# File extensions and endings to ignore
|
||||||
ignore_extensions = ["_test", "_test.py", "_test_gpu", "_test_gpu.py"]
|
ignore_extensions = [
|
||||||
|
"_test", "_test.py", "_test_gpu", "_test_gpu.py", "_test_lib"
|
||||||
|
]
|
||||||
|
|
||||||
ignored_files_count = 0
|
ignored_files_count = 0
|
||||||
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST)
|
||||||
|
Loading…
Reference in New Issue
Block a user