Internal cleanup: consolidate the function transformations in a single file, and name it accordingly.
PiperOrigin-RevId: 304887346 Change-Id: I8049bf59d22e8677dc832d22c222b7e994e50c67
This commit is contained in:
parent
64daf94a9b
commit
5b630e8624
@ -28,7 +28,7 @@ py_library(
|
||||
"control_flow.py",
|
||||
"control_flow_deprecated_py2.py",
|
||||
"directives.py",
|
||||
"function_scopes.py",
|
||||
"functions.py",
|
||||
"list_comprehensions.py",
|
||||
"lists.py",
|
||||
"logical_expressions.py",
|
||||
@ -154,8 +154,8 @@ py_test(
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "function_scopes_test",
|
||||
srcs = ["function_scopes_test.py"],
|
||||
name = "functions_test",
|
||||
srcs = ["functions_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":converters",
|
||||
|
@ -19,7 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import function_scopes
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -36,7 +36,7 @@ class AssertsTest(converter_testing.TestCase):
|
||||
return a
|
||||
|
||||
with ops.Graph().as_default():
|
||||
with self.converted(test_fn, (function_scopes, asserts), {}) as result:
|
||||
with self.converted(test_fn, (functions, asserts), {}) as result:
|
||||
op = result.test_fn(constant_op.constant(False))
|
||||
|
||||
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, 'testmsg'):
|
||||
|
@ -96,47 +96,31 @@ class CallTreeTransformer(converter.Base):
|
||||
"""Transforms the call tree by renaming transformed symbols."""
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
if anno.hasanno(node, 'function_context_name'):
|
||||
if not anno.hasanno(node, 'function_context_name'):
|
||||
# Lambda functions created during the conversion process have no
|
||||
# context manager.
|
||||
self.state[_Function].enter()
|
||||
self.state[_Function].context_name = anno.getanno(
|
||||
node, 'function_context_name')
|
||||
node = self.generic_visit(node)
|
||||
self.state[_Function].exit()
|
||||
else:
|
||||
node = self.generic_visit(node)
|
||||
return node
|
||||
return self.generic_visit(node)
|
||||
with self.state[_Function] as fn_scope:
|
||||
fn_scope.context_name = anno.getanno(node, 'function_context_name')
|
||||
return self.generic_visit(node)
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.state[_Function].enter()
|
||||
# Note: if the conversion process ever creates helper functions, this
|
||||
# assumption will no longer hold.
|
||||
assert anno.hasanno(node, 'function_context_name'), (
|
||||
'The function_scopes converter always creates a scope for functions.')
|
||||
self.state[_Function].context_name = anno.getanno(
|
||||
node, 'function_context_name')
|
||||
node.args = self.visit(node.args)
|
||||
node.body = self.visit_block(node.body)
|
||||
|
||||
if self.state[_Function].level < 2:
|
||||
# Top-level functions lose their decorator because the conversion is
|
||||
# always just-in-time and by the time it happens the decorators are
|
||||
# already set to be applied.
|
||||
node.decorator_list = []
|
||||
else:
|
||||
# TODO(mdan): Fix the tests so that we can always add this decorator.
|
||||
# Inner functions are converted already, so we insert a decorator to
|
||||
# prevent double conversion. Double conversion would work too, but this
|
||||
# saves the overhead.
|
||||
node.decorator_list.append(
|
||||
parser.parse_expression('ag__.autograph_artifact'))
|
||||
|
||||
if node.returns:
|
||||
node.returns = self.visit(node.returns)
|
||||
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
# Decorators and arg defaults are part of the outer scope.
|
||||
node.decorator_list = self.visit_block(node.decorator_list)
|
||||
node.args.defaults = self.visit_block(node.args.defaults)
|
||||
for i, d in enumerate(node.args.kw_defaults):
|
||||
if d is not None:
|
||||
node.args.kw_defaults[i] = self.visit(d)
|
||||
with self.state[_Function] as fn_scope:
|
||||
# Note: if the conversion process ever creates helper functions, this
|
||||
# assumption will no longer hold.
|
||||
assert anno.hasanno(node, 'function_context_name'), (
|
||||
'The function_scopes converter always creates a scope for functions.')
|
||||
fn_scope.context_name = anno.getanno(node, 'function_context_name')
|
||||
node.body = self.visit_block(node.body)
|
||||
if node.returns:
|
||||
node.returns = self.visit(node.returns)
|
||||
return node
|
||||
|
||||
def visit_With(self, node):
|
||||
# Context manager calls (in node.items) are not converted.
|
||||
|
@ -22,7 +22,7 @@ from __future__ import print_function
|
||||
import imp
|
||||
|
||||
from tensorflow.python.autograph.converters import call_trees
|
||||
from tensorflow.python.autograph.converters import function_scopes
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -34,7 +34,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f):
|
||||
return f() + 20
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda: 1), 21)
|
||||
self.assertListEqual(self.dynamic_calls, [((), None)])
|
||||
|
||||
@ -43,7 +43,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, g):
|
||||
return f(g() + 20) + 4000
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda x: x + 300, lambda: 1), 4321)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), None),
|
||||
@ -55,7 +55,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, g):
|
||||
return f(g()) + 300
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda x: x + 20, lambda: 1), 321)
|
||||
self.assertListEqual(self.dynamic_calls, [
|
||||
((), None),
|
||||
@ -70,7 +70,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn():
|
||||
return get_one().__add__(20)
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees),
|
||||
with self.converted(test_fn, (functions, call_trees),
|
||||
{'get_one': get_one}, ()) as result:
|
||||
|
||||
self.assertEqual(result.test_fn(), 21)
|
||||
@ -85,7 +85,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, a):
|
||||
return f(a) + 20
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda a: a, 1), 21)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
||||
@ -94,7 +94,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, a, b):
|
||||
return f(a, b) + 300
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda a, b: a + b, 1, 20), 321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 20), None)])
|
||||
|
||||
@ -103,7 +103,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, a, b):
|
||||
return f(a, c=b) + 300
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(lambda a, c: a + c, 1, 20), 321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'c': 20})])
|
||||
|
||||
@ -112,7 +112,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, a, *args, **kwargs):
|
||||
return f(a, *args, **kwargs) + 5
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(lambda *args, **kwargs: 7, 1, *[2, 3], **{
|
||||
'b': 4,
|
||||
@ -129,7 +129,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
args = [1, 20, 300]
|
||||
return f(*args) + 4000
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees),
|
||||
with self.converted(test_fn, (functions, call_trees),
|
||||
{'f': f}) as result:
|
||||
self.assertEqual(result.test_fn(), 4321)
|
||||
self.assertListEqual(self.dynamic_calls, [((1, 20, 300), None)])
|
||||
@ -145,7 +145,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
# args2 = [3]
|
||||
# return f(*args1, 2, *args2, 4)
|
||||
#
|
||||
# with self.converted(test_fn, (function_scopes, call_trees),
|
||||
# with self.converted(test_fn, (functions, call_trees),
|
||||
# {'f': f}) as result:
|
||||
# self.assertEqual(result.test_fn(), 1234)
|
||||
# self.assertListEqual(self.dynamic_calls, [((1, 2, 3, 4), None)])
|
||||
@ -155,7 +155,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, a, b, **kwargs):
|
||||
return f(a, b=b, **kwargs) + 5
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(
|
||||
result.test_fn(lambda *args, **kwargs: 7, 1, 2, **{'c': 3}), 12)
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), {'b': 2, 'c': 3})])
|
||||
@ -166,7 +166,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
# def test_fn(f, a, b, c, kwargs1, kwargs2):
|
||||
# return f(a, b=b, **kwargs1, c=c, **kwargs2) + 5
|
||||
#
|
||||
# with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
# with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
# self.assertEqual(
|
||||
# result.test_fn(lambda *args, **kwargs: 7, 1, 2, 3, {'d': 4},
|
||||
# {'e': 5}), 12)
|
||||
@ -188,7 +188,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn(f, g, a, *args):
|
||||
return f(lambda x: g(x, *args), a)
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees), {}) as result:
|
||||
with self.converted(test_fn, (functions, call_trees), {}) as result:
|
||||
self.assertEqual(result.test_fn(f, g, 1, *(20, 300)), 4321)
|
||||
|
||||
def test_debugger_set_trace(self):
|
||||
@ -201,7 +201,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
def test_fn():
|
||||
return pdb.set_trace()
|
||||
|
||||
with self.converted(test_fn, (function_scopes, call_trees),
|
||||
with self.converted(test_fn, (functions, call_trees),
|
||||
{'pdb': pdb}) as result:
|
||||
result.test_fn()
|
||||
self.assertListEqual(tracking_list, [1])
|
||||
@ -217,7 +217,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
return self.other_method(a) + 300
|
||||
|
||||
tc = TestClass()
|
||||
with self.converted(TestClass.test_method, (function_scopes, call_trees),
|
||||
with self.converted(TestClass.test_method, (functions, call_trees),
|
||||
{}) as result:
|
||||
self.assertEqual(321, result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
@ -233,7 +233,7 @@ class CallTreesTest(converter_testing.TestCase):
|
||||
return self.other_method(a) + 300
|
||||
|
||||
tc = TestClass()
|
||||
with self.converted(tc.test_method, (function_scopes, call_trees),
|
||||
with self.converted(tc.test_method, (functions, call_trees),
|
||||
{}) as result:
|
||||
self.assertEqual(321, result.test_method(tc, 1))
|
||||
self.assertListEqual(self.dynamic_calls, [((1,), None)])
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Wraps the body of a converted function with auxiliary constructs."""
|
||||
"""Converts function definitions and lambdas by adding necessary boilerplate."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -22,6 +22,7 @@ import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct.static_analysis import annos
|
||||
|
||||
@ -32,7 +33,7 @@ class _Function(object):
|
||||
self.context_name = None
|
||||
|
||||
|
||||
class FunctionBodyTransformer(converter.Base):
|
||||
class FunctionTransformer(converter.Base):
|
||||
"""Wraps function bodies around autograph-specific boilerplate."""
|
||||
|
||||
def visit_Return(self, node):
|
||||
@ -60,10 +61,7 @@ class FunctionBodyTransformer(converter.Base):
|
||||
with self.state[_Function] as fn_scope:
|
||||
node = self.generic_visit(node)
|
||||
|
||||
# Only wrap the top-level function. Theoretically, we can and should wrap
|
||||
# everything, but that can lead to excessive boilerplate when lambdas are
|
||||
# nested.
|
||||
# TODO(mdan): Looks more closely for use cases that actually require this.
|
||||
# TODO(mdan): Fix the tests so that we can always add this decorator.
|
||||
if fn_scope.level > 2:
|
||||
return templates.replace_as_expression(
|
||||
'ag__.autograph_artifact(l)', l=node)
|
||||
@ -98,6 +96,19 @@ class FunctionBodyTransformer(converter.Base):
|
||||
|
||||
node = self.generic_visit(node)
|
||||
|
||||
if fn_scope.level <= 2:
|
||||
# Top-level functions lose their decorator because the conversion is
|
||||
# always just-in-time and by the time it happens the decorators are
|
||||
# already set to be applied.
|
||||
node.decorator_list = []
|
||||
else:
|
||||
# TODO(mdan): Fix the tests so that we can always add this decorator.
|
||||
# Inner functions are converted already, so we insert a decorator to
|
||||
# prevent double conversion. Double conversion would work too, but this
|
||||
# saves the overhead.
|
||||
node.decorator_list.append(
|
||||
parser.parse_expression('ag__.autograph_artifact'))
|
||||
|
||||
docstring_node = None
|
||||
if node.body:
|
||||
first_statement = node.body[0]
|
||||
@ -128,4 +139,4 @@ class FunctionBodyTransformer(converter.Base):
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
return FunctionBodyTransformer(ctx).visit(node)
|
||||
return FunctionTransformer(ctx).visit(node)
|
@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for function_scopes module."""
|
||||
"""Tests for functions module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import function_scopes
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
@ -28,7 +28,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
class FunctionTransformer(converter_testing.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_basic(self):
|
||||
@ -39,7 +39,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
l += a
|
||||
return l
|
||||
|
||||
with self.converted(test_fn, function_scopes, {}) as result:
|
||||
with self.converted(test_fn, functions, {}) as result:
|
||||
result_op = result.test_fn(constant_op.constant(1))
|
||||
self.assertIn('test_fn/', result_op.op.name)
|
||||
self.assertEqual('Docstring.', result.test_fn.__doc__)
|
||||
@ -56,7 +56,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
"""
|
||||
return tf.constant(1)
|
||||
|
||||
with self.converted(test_fn, function_scopes, {},
|
||||
with self.converted(test_fn, functions, {},
|
||||
(constant_op.constant,)) as result:
|
||||
result_op = result.test_fn()
|
||||
self.assertIn('test_fn/', result_op.op.name)
|
||||
@ -74,7 +74,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
l += 1
|
||||
return l, inner_fn(l)
|
||||
|
||||
with self.converted(test_fn, function_scopes, {},
|
||||
with self.converted(test_fn, functions, {},
|
||||
(ops.name_scope,)) as result:
|
||||
first, second = result.test_fn(constant_op.constant(1))
|
||||
self.assertIn('test_fn/', first.op.name)
|
||||
@ -100,7 +100,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
'ag_ctx': ag_ctx,
|
||||
'converter': converter
|
||||
}
|
||||
with self.converted(test_fn, function_scopes, ns) as result:
|
||||
with self.converted(test_fn, functions, ns) as result:
|
||||
result.test_fn()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@ -118,7 +118,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
|
||||
ns = {'TestClass': TestClass}
|
||||
node, ctx = self.prepare(TestClass, ns)
|
||||
node = function_scopes.transform(node, ctx)
|
||||
node = functions.transform(node, ctx)
|
||||
|
||||
with self.compiled(node, {}, (ops.name_scope,)) as result:
|
||||
first, second = result.TestClass().test_fn(constant_op.constant(1))
|
||||
@ -131,7 +131,7 @@ class FunctionBodyTransformerTest(converter_testing.TestCase):
|
||||
def test_fn():
|
||||
return lambda x: x + 1
|
||||
|
||||
with self.converted(test_fn, function_scopes, {}) as result:
|
||||
with self.converted(test_fn, functions, {}) as result:
|
||||
result_l = result.test_fn()
|
||||
self.assertTrue(result_l.fake_autograph_artifact)
|
||||
|
@ -40,7 +40,7 @@ from tensorflow.python.autograph.converters import conditional_expressions
|
||||
from tensorflow.python.autograph.converters import continue_statements
|
||||
from tensorflow.python.autograph.converters import control_flow
|
||||
from tensorflow.python.autograph.converters import directives
|
||||
from tensorflow.python.autograph.converters import function_scopes
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import lists
|
||||
from tensorflow.python.autograph.converters import logical_expressions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
@ -616,7 +616,7 @@ def node_to_graph(node, context):
|
||||
unsupported_features_checker.verify(node)
|
||||
|
||||
node = converter.standard_analysis(node, context, is_initial=True)
|
||||
node = converter.apply_(node, context, function_scopes)
|
||||
node = converter.apply_(node, context, functions)
|
||||
node = converter.apply_(node, context, arg_defaults)
|
||||
node = converter.apply_(node, context, directives)
|
||||
node = converter.apply_(node, context, break_statements)
|
||||
|
Loading…
Reference in New Issue
Block a user