Internal cleanup: consolidate the function transformations in a single file, and name it accordingly.

PiperOrigin-RevId: 304887346
Change-Id: I8049bf59d22e8677dc832d22c222b7e994e50c67
This commit is contained in:
Dan Moldovan 2020-04-05 06:51:30 -07:00 committed by TensorFlower Gardener
parent 64daf94a9b
commit 5b630e8624
7 changed files with 72 additions and 77 deletions

View File

@ -28,7 +28,7 @@ py_library(
"control_flow.py",
"control_flow_deprecated_py2.py",
"directives.py",
"function_scopes.py",
"functions.py",
"list_comprehensions.py",
"lists.py",
"logical_expressions.py",
@ -154,8 +154,8 @@ py_test(
)
py_test(
name = "function_scopes_test",
srcs = ["function_scopes_test.py"],
name = "functions_test",
srcs = ["functions_test.py"],
python_version = "PY3",
deps = [
":converters",

View File

@ -19,7 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
@ -36,7 +36,7 @@ class AssertsTest(converter_testing.TestCase):
return a
with ops.Graph().as_default():
with self.converted(test_fn, (function_scopes, asserts), {}) as result:
with self.converted(test_fn, (functions, asserts), {}) as result:
op = result.test_fn(constant_op.constant(False))
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):

View File

@ -96,47 +96,31 @@ class CallTreeTransformer(converter.Base):
"""Transforms the call tree by renaming transformed symbols."""
def visit_Lambda(self, node):
if anno.hasanno(node, 'function_context_name'):
if not anno.hasanno(node, 'function_context_name'):
# Lambda functions created during the conversion process have no
# context manager.
self.state[_Function].enter()
self.state[_Function].context_name = anno.getanno(
node, 'function_context_name')
node = self.generic_visit(node)
self.state[_Function].exit()
else:
node = self.generic_visit(node)
return node
return self.generic_visit(node)
with self.state[_Function] as fn_scope:
fn_scope.context_name = anno.getanno(node, 'function_context_name')
return self.generic_visit(node)
def visit_FunctionDef(self, node):
self.state[_Function].enter()
# Note: if the conversion process ever creates helper functions, this
# assumption will no longer hold.
assert anno.hasanno(node, 'function_context_name'), (
'The function_scopes converter always creates a scope for functions.')
self.state[_Function].context_name = anno.getanno(
node, 'function_context_name')
node.args = self.visit(node.args)
node.body = self.visit_block(node.body)
if self.state[_Function].level < 2:
# Top-level functions lose their decorator because the conversion is
# always just-in-time and by the time it happens the decorators are
# already set to be applied.
node.decorator_list = []
else:
# TODO(mdan): Fix the tests so that we can always add this decorator.
# Inner functions are converted already, so we insert a decorator to
# prevent double conversion. Double conversion would work too, but this
# saves the overhead.
node.decorator_list.append(
parser.parse_expression('ag__.autograph_artifact'))
if node.returns:
node.returns = self.visit(node.returns)
self.state[_Function].exit()
return node
# Decorators and arg defaults are part of the outer scope.
node.decorator_list = self.visit_block(node.decorator_list)
node.args.defaults = self.visit_block(node.args.defaults)
for i, d in enumerate(node.args.kw_defaults):
if d is not None:
node.args.kw_defaults[i] = self.visit(d)
with self.state[_Function] as fn_scope:
# Note: if the conversion process ever creates helper functions, this
# assumption will no longer hold.
assert anno.hasanno(node, 'function_context_name'), (
'The function_scopes converter always creates a scope for functions.')
fn_scope.context_name = anno.getanno(node, 'function_context_name')
node.body = self.visit_block(node.body)
if node.returns:
node.returns = self.visit(node.returns)
return node
def visit_With(self, node):
# Context manager calls (in node.items) are not converted.

View File

@ -22,7 +22,7 @@ from __future__ import print_function
import imp
from tensorflow.python.autograph.converters import call_trees
from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.platform import test
@ -34,7 +34,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f):
return f() + 20
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda: 1), 21)
self.assertListEqual(self.dynamic_calls, [((), None)])
@ -43,7 +43,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, g):
return f(g() + 20) + 4000
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
self.assertListEqual(self.dynamic_calls, [
((), None),
@ -55,7 +55,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, g):
return f(g()) + 300
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
self.assertListEqual(self.dynamic_calls, [
((), None),
@ -70,7 +70,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn():
return get_one().__add__(20)
with self.converted(test_fn, (function_scopes, call_trees),
with self.converted(test_fn, (functions, call_trees),
{'get_one': get_one}, ()) as result:
self.assertEqual(result.test_fn(), 21)
@ -85,7 +85,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, a):
return f(a) + 20
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a: a, 1), 21)
self.assertListEqual(self.dynamic_calls, [((1,), None)])
@ -94,7 +94,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, a, b):
return f(a, b) + 300
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1, 20), None)])
@ -103,7 +103,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, a, b):
return f(a, c=b) + 300
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
@ -112,7 +112,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, a, *args, **kwargs):
return f(a, *args, **kwargs) + 5
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
'b': 4,
@ -129,7 +129,7 @@ class CallTreesTest(converter_testing.TestCase):
args = [1, 20, 300]
return f(*args) + 4000
with self.converted(test_fn, (function_scopes, call_trees),
with self.converted(test_fn, (functions, call_trees),
{'f': f}) as result:
self.assertEqual(result.test_fn(), 4321)
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
@ -145,7 +145,7 @@ class CallTreesTest(converter_testing.TestCase):
# args2 = [3]
# return f(*args1, 2, *args2, 4)
#
# with self.converted(test_fn, (function_scopes, call_trees),
# with self.converted(test_fn, (functions, call_trees),
# {'f': f}) as result:
# self.assertEqual(result.test_fn(), 1234)
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
@ -155,7 +155,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, a, b, **kwargs):
return f(a, b=b, **kwargs) + 5
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
@ -166,7 +166,7 @@ class CallTreesTest(converter_testing.TestCase):
# def test_fn(f, a, b, c, kwargs1, kwargs2):
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
#
# with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
# with self.converted(test_fn, (functions, call_trees), {}) as result:
# self.assertEqual(
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
# {'e': 5}), 12)
@ -188,7 +188,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn(f, g, a, *args):
return f(lambda x: g(x, *args), a)
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
with self.converted(test_fn, (functions, call_trees), {}) as result:
self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321)
def test_debugger_set_trace(self):
@ -201,7 +201,7 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn():
return pdb.set_trace()
with self.converted(test_fn, (function_scopes, call_trees),
with self.converted(test_fn, (functions, call_trees),
{'pdb': pdb}) as result:
result.test_fn()
self.assertListEqual(tracking_list, [1])
@ -217,7 +217,7 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300
tc = TestClass()
with self.converted(TestClass.test_method, (function_scopes, call_trees),
with self.converted(TestClass.test_method, (functions, call_trees),
{}) as result:
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])
@ -233,7 +233,7 @@ class CallTreesTest(converter_testing.TestCase):
return self.other_method(a) + 300
tc = TestClass()
with self.converted(tc.test_method, (function_scopes, call_trees),
with self.converted(tc.test_method, (functions, call_trees),
{}) as result:
self.assertEqual(321, result.test_method(tc, 1))
self.assertListEqual(self.dynamic_calls, [((1,), None)])

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps the body of a converted function with auxiliary constructs."""
"""Converts function definitions and lambdas by adding necessary boilerplate."""
from __future__ import absolute_import
from __future__ import division
@ -22,6 +22,7 @@ import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis import annos
@ -32,7 +33,7 @@ class _Function(object):
self.context_name = None
class FunctionBodyTransformer(converter.Base):
class FunctionTransformer(converter.Base):
"""Wraps function bodies around autograph-specific boilerplate."""
def visit_Return(self, node):
@ -60,10 +61,7 @@ class FunctionBodyTransformer(converter.Base):
with self.state[_Function] as fn_scope:
node = self.generic_visit(node)
# Only wrap the top-level function. Theoretically, we can and should wrap
# everything, but that can lead to excessive boilerplate when lambdas are
# nested.
# TODO(mdan): Looks more closely for use cases that actually require this.
# TODO(mdan): Fix the tests so that we can always add this decorator.
if fn_scope.level > 2:
return templates.replace_as_expression(
'ag__.autograph_artifact(l)', l=node)
@ -98,6 +96,19 @@ class FunctionBodyTransformer(converter.Base):
node = self.generic_visit(node)
if fn_scope.level <= 2:
# Top-level functions lose their decorator because the conversion is
# always just-in-time and by the time it happens the decorators are
# already set to be applied.
node.decorator_list = []
else:
# TODO(mdan): Fix the tests so that we can always add this decorator.
# Inner functions are converted already, so we insert a decorator to
# prevent double conversion. Double conversion would work too, but this
# saves the overhead.
node.decorator_list.append(
parser.parse_expression('ag__.autograph_artifact'))
docstring_node = None
if node.body:
first_statement = node.body[0]
@ -128,4 +139,4 @@ class FunctionBodyTransformer(converter.Base):
def transform(node, ctx):
return FunctionBodyTransformer(ctx).visit(node)
return FunctionTransformer(ctx).visit(node)

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for function_scopes module."""
"""Tests for functions module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import converter_testing
@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
class FunctionBodyTransformerTest(converter_testing.TestCase):
class FunctionTransformer(converter_testing.TestCase):
@test_util.run_deprecated_v1
def test_basic(self):
@ -39,7 +39,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
l += a
return l
with self.converted(test_fn, function_scopes, {}) as result:
with self.converted(test_fn, functions, {}) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
self.assertEqual('Docstring.', result.test_fn.__doc__)
@ -56,7 +56,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
"""
return tf.constant(1)
with self.converted(test_fn, function_scopes, {},
with self.converted(test_fn, functions, {},
(constant_op.constant,)) as result:
result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
@ -74,7 +74,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
with self.converted(test_fn, function_scopes, {},
with self.converted(test_fn, functions, {},
(ops.name_scope,)) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
@ -100,7 +100,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
'ag_ctx': ag_ctx,
'converter': converter
}
with self.converted(test_fn, function_scopes, ns) as result:
with self.converted(test_fn, functions, ns) as result:
result.test_fn()
@test_util.run_deprecated_v1
@ -118,7 +118,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns)
node = function_scopes.transform(node, ctx)
node = functions.transform(node, ctx)
with self.compiled(node, {}, (ops.name_scope,)) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
@ -131,7 +131,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
def test_fn():
return lambda x: x + 1
with self.converted(test_fn, function_scopes, {}) as result:
with self.converted(test_fn, functions, {}) as result:
result_l = result.test_fn()
self.assertTrue(result_l.fake_autograph_artifact)

View File

@ -40,7 +40,7 @@ from tensorflow.python.autograph.converters import conditional_expressions
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
from tensorflow.python.autograph.converters import return_statements
@ -616,7 +616,7 @@ def node_to_graph(node, context):
unsupported_features_checker.verify(node)
node = converter.standard_analysis(node, context, is_initial=True)
node = converter.apply_(node, context, function_scopes)
node = converter.apply_(node, context, functions)
node = converter.apply_(node, context, arg_defaults)
node = converter.apply_(node, context, directives)
node = converter.apply_(node, context, break_statements)