Internal cleanup: Move the bulk of the source code transformation infrastructure into the generic pyct module.
PiperOrigin-RevId: 305135067 Change-Id: Ifb84546c35a603942fd864769e7320a7ae95da3b
This commit is contained in:
		
							parent
							
								
									fc94412b39
								
							
						
					
					
						commit
						ff551c9f20
					
				| @ -19,7 +19,6 @@ filegroup( | ||||
| py_library( | ||||
|     name = "converters", | ||||
|     srcs = [ | ||||
|         "arg_defaults.py", | ||||
|         "asserts.py", | ||||
|         "break_statements.py", | ||||
|         "call_trees.py", | ||||
| @ -48,18 +47,6 @@ py_library( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "arg_defaults_test", | ||||
|     srcs = ["arg_defaults_test.py"], | ||||
|     python_version = "PY3", | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":converters", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|         "//tensorflow/python/autograph/core:test_lib", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "asserts_test", | ||||
|     srcs = ["asserts_test.py"], | ||||
|  | ||||
| @ -1,105 +0,0 @@ | ||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Modifies the signature to allow resolving the value of default arguments. | ||||
| 
 | ||||
| Normally, function symbols are captured either in a function's globals or | ||||
| closure. This is not true for default arguments, which are evaluated when the | ||||
| function is defined: | ||||
| 
 | ||||
|     b = 1 | ||||
|     c = 2 | ||||
|     def f(a=b + 1): | ||||
|       return a + c | ||||
| 
 | ||||
| In the above example, the namespace of the function would include `c = 2` but | ||||
| not `b`. | ||||
| 
 | ||||
| If we were to naively generate a new function: | ||||
| 
 | ||||
|     def new_f(a=b + 1): | ||||
|       return a + c | ||||
| 
 | ||||
| The generated code would fail to load unless we exposed a symbol `b`. Capturing | ||||
| the closure of such an expression is difficult. However, we can capture the | ||||
| default value of argument `a` with relative ease. | ||||
| 
 | ||||
| This converter replaces all default argument expressions with a constant so | ||||
| that they don't cause loading to fail. This requires that the default values | ||||
| are reset after loading the transformed function: | ||||
| 
 | ||||
|     def new_f(a=None): | ||||
|       return a + c | ||||
| 
 | ||||
|     # ... later, after new_f was loaded ... | ||||
|     new_f.__defaults__ = f.__defaults__ | ||||
| 
 | ||||
| """ | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.autograph.core import converter | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| 
 | ||||
| 
 | ||||
| class _Function(object): | ||||
|   pass | ||||
| 
 | ||||
| 
 | ||||
| class ArgDefaultsTransformer(converter.Base): | ||||
|   """Transforms top level argument defaults.""" | ||||
| 
 | ||||
|   def visit_Lambda(self, node): | ||||
|     self.state[_Function].enter() | ||||
|     node.args = self.visit(node.args) | ||||
|     # Only the top level function is modified - no need to visit the children. | ||||
|     self.state[_Function].exit() | ||||
|     return node | ||||
| 
 | ||||
|   def visit_FunctionDef(self, node): | ||||
|     self.state[_Function].enter() | ||||
|     node.args = self.visit(node.args) | ||||
|     # Only the top level function is modified - no need to visit the children. | ||||
|     self.state[_Function].exit() | ||||
|     return node | ||||
| 
 | ||||
|   def visit_arguments(self, node): | ||||
|     if self.state[_Function].level > 2: | ||||
|       return node | ||||
| 
 | ||||
|     for i in range(len(node.defaults)): | ||||
|       node.defaults[i] = parser.parse_expression('None') | ||||
| 
 | ||||
|     for i, d in enumerate(node.kw_defaults): | ||||
|       if d is not None: | ||||
|         node.kw_defaults[i] = parser.parse_expression('None') | ||||
| 
 | ||||
|     # Only the top level function is modified - no need to visit the children. | ||||
|     return node | ||||
| 
 | ||||
| 
 | ||||
| def transform(node, ctx): | ||||
|   """Transform function call to the compiled counterparts. | ||||
| 
 | ||||
|   Args: | ||||
|     node: AST | ||||
|     ctx: EntityContext | ||||
|   Returns: | ||||
|     A tuple (node, new_names): | ||||
|         node: The transformed AST | ||||
|         new_names: set(string), containing any newly-generated names | ||||
|   """ | ||||
|   return ArgDefaultsTransformer(ctx).visit(node) | ||||
| @ -1,108 +0,0 @@ | ||||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for arg_defaults module.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.autograph.converters import arg_defaults | ||||
| from tensorflow.python.autograph.core import converter_testing | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| 
 | ||||
| class ArgDefaultsTransformerTest(converter_testing.TestCase): | ||||
| 
 | ||||
|   def assertTransformedFirstLineIs(self, node, expected): | ||||
|     self.assertEqual( | ||||
|         parser.unparse(node, include_encoding_marker=False).split('\n')[0], | ||||
|         expected) | ||||
| 
 | ||||
|   def test_no_args(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       pass | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     self.assertTransformedFirstLineIs(node, 'def test_fn():') | ||||
| 
 | ||||
|   def test_no_defaults(self): | ||||
| 
 | ||||
|     def test_fn(a, b, *c, **e): | ||||
|       return a, b, c, e | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     self.assertTransformedFirstLineIs(node, 'def test_fn(a, b, *c, **e):') | ||||
| 
 | ||||
|   # TODO(mdan): Add kwonly-arg tests when PY2 is no longer supported. | ||||
| 
 | ||||
|   def test_arg_defaults(self): | ||||
| 
 | ||||
|     def test_fn(a, b=1, c=2): | ||||
|       return a, b, c | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, c=None):') | ||||
| 
 | ||||
|   def test_arg_defaults_with_vararg(self): | ||||
| 
 | ||||
|     def test_fn(a, b=1, *c):  # pylint: disable=keyword-arg-before-vararg | ||||
|       return a, b, c | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, *c):') | ||||
| 
 | ||||
|   def test_arg_defaults_ignores_inner_lambda(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       return (lambda x=7: x)() | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     with self.converted(test_fn, arg_defaults, {}) as result: | ||||
|       self.assertEqual(test_fn(), result.test_fn()) | ||||
| 
 | ||||
|   def test_arg_defaults_ignores_inner_function(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       def inner_fn(a=3): | ||||
|         return a | ||||
|       return inner_fn() | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     with self.converted(test_fn, arg_defaults, {}) as result: | ||||
|       self.assertEqual(test_fn(), result.test_fn()) | ||||
| 
 | ||||
|   def test_arg_defaults_ignores_inner_function_returned(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       def inner_fn(a=3): | ||||
|         return a | ||||
|       return inner_fn | ||||
| 
 | ||||
|     node, ctx = self.prepare(test_fn, {}) | ||||
|     node = arg_defaults.transform(node, ctx) | ||||
|     with self.converted(test_fn, arg_defaults, {}) as result: | ||||
|       self.assertEqual(test_fn()(), result.test_fn()()) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   test.main() | ||||
| @ -191,7 +191,7 @@ class CallTreeTransformer(converter.Base): | ||||
|       return node | ||||
| 
 | ||||
|     if (full_name == 'print' and | ||||
|         not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): | ||||
|         not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): | ||||
|       return node | ||||
| 
 | ||||
|     template = """ | ||||
|  | ||||
| @ -54,8 +54,8 @@ class FunctionTransformer(converter.Base): | ||||
|     # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See | ||||
|     # function_wrappers.py. | ||||
|     if fn_scope.level == 2: | ||||
|       return self.ctx.program.options | ||||
|     return self.ctx.program.options.call_options() | ||||
|       return self.ctx.user.options | ||||
|     return self.ctx.user.options.call_options() | ||||
| 
 | ||||
|   def visit_Lambda(self, node): | ||||
|     with self.state[_Function] as fn_scope: | ||||
|  | ||||
| @ -53,7 +53,7 @@ class LogicalExpressionTransformer(converter.Base): | ||||
|     op_type = type(operator) | ||||
|     if op_type in LOGICAL_OPERATORS: | ||||
|       return LOGICAL_OPERATORS[op_type] | ||||
|     if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS): | ||||
|     if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS): | ||||
|       if op_type in EQUALITY_OPERATORS: | ||||
|         return EQUALITY_OPERATORS[op_type] | ||||
|     return None | ||||
| @ -83,7 +83,7 @@ class LogicalExpressionTransformer(converter.Base): | ||||
|   def visit_Compare(self, node): | ||||
|     node = self.generic_visit(node) | ||||
| 
 | ||||
|     if (not self.ctx.program.options.uses( | ||||
|     if (not self.ctx.user.options.uses( | ||||
|         converter.Feature.EQUALITY_OPERATORS)): | ||||
|       return node | ||||
| 
 | ||||
|  | ||||
| @ -253,25 +253,6 @@ class ProgramContext( | ||||
|   pass | ||||
| 
 | ||||
| 
 | ||||
| class EntityContext(transformer.Context): | ||||
|   """Tracks the conversion of a single entity. | ||||
| 
 | ||||
|   This object is mutable, and is updated during conversion. Not thread safe. | ||||
| 
 | ||||
|   Attributes: | ||||
|     namer: Namer | ||||
|     info: transformer.EntityInfo | ||||
|     program: ProgramContext, | ||||
|     targe_name: Text | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, namer, entity_info, program_ctx, target_name=None): | ||||
|     super(EntityContext, self).__init__(entity_info) | ||||
|     self.namer = namer | ||||
|     self.program = program_ctx | ||||
|     self.target_name = target_name | ||||
| 
 | ||||
| 
 | ||||
| class Base(transformer.Base): | ||||
|   """All converters should inherit from this class. | ||||
| 
 | ||||
|  | ||||
| @ -168,12 +168,12 @@ class TestCase(test.TestCase): | ||||
|         options=converter.ConversionOptions(recursive=recursive), | ||||
|         autograph_module=None) | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         name=test_fn.__name__, | ||||
|         source_code=source, | ||||
|         source_file='<fragment>', | ||||
|         future_features=future_features, | ||||
|         namespace=namespace) | ||||
|     ctx = converter.EntityContext( | ||||
|         namer, entity_info, program_ctx, 'test_fn') | ||||
|     ctx = transformer.Context(entity_info, namer, program_ctx) | ||||
|     origin_info.resolve_entity(node, source, test_fn) | ||||
|     node = converter.standard_analysis(node, ctx, is_initial=True) | ||||
|     return node, ctx | ||||
|  | ||||
| @ -29,10 +29,7 @@ import sys | ||||
| import textwrap | ||||
| import traceback | ||||
| 
 | ||||
| # pylint:disable=g-bad-import-order | ||||
| 
 | ||||
| import six | ||||
| # pylint:enable=g-bad-import-order | ||||
| 
 | ||||
| from tensorflow.python.autograph.core import ag_ctx | ||||
| from tensorflow.python.autograph.core import converter | ||||
| @ -668,7 +665,7 @@ def to_graph(entity, recursive=True, experimental_optional_features=None): | ||||
|             user_requested=True, | ||||
|             optional_features=experimental_optional_features), | ||||
|         autograph_module=tf_inspect.getmodule(to_graph)) | ||||
|     return conversion.convert(entity, program_ctx) | ||||
|     return autograph_artifact(conversion.convert(entity, program_ctx)) | ||||
|   except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: | ||||
|     logging.error(1, 'Error converting %s', entity, exc_info=True) | ||||
|     raise ConversionError('converting {}: {}: {}'.format( | ||||
|  | ||||
| @ -18,21 +18,12 @@ from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import collections | ||||
| import functools | ||||
| import imp | ||||
| import inspect | ||||
| import sys | ||||
| import threading | ||||
| import types | ||||
| import unittest | ||||
| import weakref | ||||
| 
 | ||||
| import gast | ||||
| 
 | ||||
| from tensorflow.python.autograph import operators | ||||
| from tensorflow.python.autograph import utils | ||||
| from tensorflow.python.autograph.converters import arg_defaults | ||||
| from tensorflow.python.autograph.converters import asserts | ||||
| from tensorflow.python.autograph.converters import break_statements | ||||
| from tensorflow.python.autograph.converters import call_trees | ||||
| @ -50,304 +41,70 @@ from tensorflow.python.autograph.core import converter | ||||
| from tensorflow.python.autograph.core import function_wrappers | ||||
| from tensorflow.python.autograph.core import unsupported_features_checker | ||||
| from tensorflow.python.autograph.lang import special_functions | ||||
| from tensorflow.python.autograph.pyct import ast_util | ||||
| from tensorflow.python.autograph.pyct import cache | ||||
| from tensorflow.python.autograph.pyct import inspect_utils | ||||
| from tensorflow.python.autograph.pyct import loader | ||||
| from tensorflow.python.autograph.pyct import naming | ||||
| from tensorflow.python.autograph.pyct import origin_info | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.autograph.pyct import pretty_printer | ||||
| from tensorflow.python.autograph.pyct import templates | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| from tensorflow.python.autograph.pyct import transpiler | ||||
| from tensorflow.python.autograph.utils import ag_logging as logging | ||||
| from tensorflow.python.eager import function | ||||
| from tensorflow.python.util import tf_inspect | ||||
| 
 | ||||
| 
 | ||||
| class _ConvertedEntityFactoryInfo( | ||||
|     collections.namedtuple( | ||||
|         '_ConvertedEntityFactoryInfo', | ||||
|         ('module_name', 'converted_name', 'factory_factory_name', 'source_map')) | ||||
| ): | ||||
|   """Holds metadata about a converted entity stored as a dynamic factory. | ||||
| class AutoGraphTranspiler(transpiler.FunctionTranspiler): | ||||
| 
 | ||||
|   The dynamic factory is assumed to be created by _wrap_into_dynamic_factory, | ||||
|   be named `factory_factory_name` and located inside the module named as | ||||
|   `module_name`. | ||||
|   def get_transformed_name(self, node): | ||||
|     return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node) | ||||
| 
 | ||||
|   Attributes: | ||||
|     module_name: Text, the name of the module containing the entity. | ||||
|     converted_name: Text, the name of the converted entity. | ||||
|     factory_factory_name: Text, the name of the dynamic factory. | ||||
|     source_map: Dict. | ||||
|   """ | ||||
|   def transform_ast(self, node, ctx): | ||||
|     # TODO(mdan): Insert list_comprehensions somewhere. | ||||
|     unsupported_features_checker.verify(node) | ||||
| 
 | ||||
|   def __str__(self): | ||||
|     return '_ConvertedEntityFactoryInfo({} in {})'.format( | ||||
|         self.converted_name, self.module_name) | ||||
| 
 | ||||
|   def get_module(self): | ||||
|     return sys.modules[self.module_name] | ||||
| 
 | ||||
|   def get_factory(self): | ||||
|     assert self.module_name in sys.modules | ||||
|     factory_factory = getattr(sys.modules[self.module_name], | ||||
|                               self.factory_factory_name) | ||||
|     return factory_factory() | ||||
|     node = converter.standard_analysis(node, ctx, is_initial=True) | ||||
|     node = converter.apply_(node, ctx, functions) | ||||
|     node = converter.apply_(node, ctx, directives) | ||||
|     node = converter.apply_(node, ctx, break_statements) | ||||
|     if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): | ||||
|       node = converter.apply_(node, ctx, asserts) | ||||
|     # Note: sequencing continue canonicalization before for loop one avoids | ||||
|     # dealing with the extra loop increment operation that the for | ||||
|     # canonicalization creates. | ||||
|     node = converter.apply_(node, ctx, continue_statements) | ||||
|     node = converter.apply_(node, ctx, return_statements) | ||||
|     if ctx.user.options.uses(converter.Feature.LISTS): | ||||
|       node = converter.apply_(node, ctx, lists) | ||||
|       node = converter.apply_(node, ctx, slices) | ||||
|     node = converter.apply_(node, ctx, call_trees) | ||||
|     node = converter.apply_(node, ctx, control_flow) | ||||
|     node = converter.apply_(node, ctx, conditional_expressions) | ||||
|     node = converter.apply_(node, ctx, logical_expressions) | ||||
|     return node | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): Add a garbage collection hook for cleaning up modules. | ||||
| class _FunctionCache(object): | ||||
|   """A hierarchical cache that uses the converted entity as weak key. | ||||
| 
 | ||||
|   The keys soft references (i.e. they are discarded when the key is | ||||
|   destroyed). The subkeys are normal hashable values. | ||||
| 
 | ||||
|   This class is generic - see the call site for how the keys and values are | ||||
|   defined. | ||||
|   """ | ||||
| 
 | ||||
|   __slots__ = ('_cache',) | ||||
| 
 | ||||
|   def __init__(self): | ||||
|     self._cache = weakref.WeakKeyDictionary() | ||||
| 
 | ||||
|   def _get_key(self, entity): | ||||
|     raise NotImplementedError('subclasses will override') | ||||
| 
 | ||||
|   def has(self, entity, subkey): | ||||
|     key = self._get_key(entity) | ||||
|     if key not in self._cache: | ||||
|       return False | ||||
|     return subkey in self._cache[key] | ||||
| 
 | ||||
|   def __getitem__(self, entity): | ||||
|     key = self._get_key(entity) | ||||
|     if key not in self._cache: | ||||
|       # The bucket needs to be initialized to support this usage: | ||||
|       #   cache[key][subkey] = value | ||||
|       self._cache[key] = {} | ||||
|     return self._cache[key] | ||||
| 
 | ||||
|   def __len__(self): | ||||
|     return len(self._cache) | ||||
| _TRANSPILER = AutoGraphTranspiler() | ||||
| _WHITELIST_CACHE = cache.UnboundInstanceCache() | ||||
| 
 | ||||
| 
 | ||||
| class _CodeObjectCache(_FunctionCache): | ||||
|   """A function cache based on code objects (i.e., the source code). | ||||
| 
 | ||||
|   Multiple functions may share the same code object, but they may share the | ||||
|   cache because we know they have the exact source code. This properly handles | ||||
|   functions defined in a loop, bound methods, etc. | ||||
| 
 | ||||
|   Falls back to the function object, if it doesn't have a code object. | ||||
|   """ | ||||
| 
 | ||||
|   def _get_key(self, entity): | ||||
|     if hasattr(entity, '__code__'): | ||||
|       return entity.__code__ | ||||
|     else: | ||||
|       return entity | ||||
| 
 | ||||
| 
 | ||||
| class _UnboundInstanceCache(_FunctionCache): | ||||
|   """A function cache based on unbound function objects. | ||||
| 
 | ||||
|   Unlike the _CodeObjectCache, this discriminates between different functions | ||||
|   even if they have the same code. This properly handles decorators that may | ||||
|   masquerade as various functions. Bound functions are not discriminated by | ||||
|   the object they're bound to. | ||||
|   """ | ||||
| 
 | ||||
|   def _get_key(self, entity): | ||||
|     if inspect.ismethod(entity): | ||||
|       return entity.__func__ | ||||
|     return entity | ||||
| 
 | ||||
| 
 | ||||
| # Using a re-entrant lock to guard against the unlikely possibility that the | ||||
| # conversion process triggers additional code execution. | ||||
| _CACHE_LOCK = threading.RLock() | ||||
| 
 | ||||
| 
 | ||||
| _CACHE = _CodeObjectCache() | ||||
| _WHITELIST_CACHE = _UnboundInstanceCache() | ||||
| 
 | ||||
| 
 | ||||
| # Note: strictly speaking, a simple factory might have been sufficient for | ||||
| # functions. But the double factory approach allows us to control the closure | ||||
| # and globals of the converted code in a cleaner fashion. | ||||
| # TODO(mdan): A simple factory may be sufficient. | ||||
| def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name, | ||||
|                                factory_name, closure_vars, future_features): | ||||
|   """Wraps an AST into the body of a dynamic factory. | ||||
| 
 | ||||
|   This uses the dynamic factory (factory of factory) pattern to achieve the | ||||
|   following: | ||||
| 
 | ||||
|    1. The inner factory, dynamically creates the entity represented by nodes. | ||||
|    2. The entity is parametrized by `ag__`, the internal AutoGraph module. | ||||
|    3. The outer factory creates the inner factory with a lexical scope | ||||
|       in which `closure_vars` are bound local variables. This in turn allows the | ||||
|       caller to control the exact closure (i.e. non-global free variables) for | ||||
|       the inner factory. | ||||
| 
 | ||||
|   The AST is expected to define some symbol named by `entity_name`. | ||||
| 
 | ||||
|   Args: | ||||
|     nodes: ast.AST | ||||
|     entity_name: Union[Text, ast.AST] | ||||
|     factory_factory_name: Text | ||||
|     factory_name: Text | ||||
|     closure_vars: Iterable[Text] | ||||
|     future_features: Iterable[Text], see EntityInfo.future_features. | ||||
| 
 | ||||
|   Returns: | ||||
|     ast.AST | ||||
|   """ | ||||
|   if not isinstance(nodes, (list, tuple)): | ||||
|     nodes = (nodes,) | ||||
| 
 | ||||
|   dummy_closure_defs = [] | ||||
|   for var_name in closure_vars: | ||||
|     template = """ | ||||
|       var_name = None | ||||
|     """ | ||||
|     dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) | ||||
| 
 | ||||
|   if future_features: | ||||
|     future_imports = gast.ImportFrom( | ||||
|         module='__future__', | ||||
|         names=[gast.alias(name=name, asname=None) for name in future_features], | ||||
|         level=0) | ||||
|   else: | ||||
|     future_imports = [] | ||||
| 
 | ||||
|   # These dummy symbol declarations create local fariables in a function scope, | ||||
|   # so that the Python parser correctly marks them as free non-global variables | ||||
|   # upon load (that is, it creates cell slots for each symbol). Their values are | ||||
|   # not used, as the cells are swapped with the original entity's cells after | ||||
|   # the code has been loaded. | ||||
|   template = """ | ||||
|     future_imports | ||||
|     def factory_factory_name(): | ||||
|       dummy_closure_defs | ||||
|       def factory_name(ag__, ag_source_map__, ag_module__): | ||||
|         entity_defs | ||||
|         entity_name.ag_source_map = ag_source_map__ | ||||
|         entity_name.ag_module = ag_module__ | ||||
|         entity_name = ag__.autograph_artifact(entity_name) | ||||
|         return entity_name | ||||
|       return factory_name | ||||
|   """ | ||||
|   return templates.replace( | ||||
|       template, | ||||
|       future_imports=future_imports, | ||||
|       factory_factory_name=factory_factory_name, | ||||
|       factory_name=factory_name, | ||||
|       dummy_closure_defs=dummy_closure_defs, | ||||
|       entity_defs=nodes, | ||||
|       entity_name=entity_name) | ||||
| 
 | ||||
| 
 | ||||
| def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): | ||||
|   """Returns a (possibly cached) factory for the converted result of entity.""" | ||||
|   # The cache subkey encompasses any conversion options on which the generated | ||||
|   # code may depend. | ||||
|   # The cached factory includes the necessary definitions to distinguish | ||||
|   # between the global and non-global free variables. For this reason, the | ||||
|   # cache subkey includes the names of the free non-globals. | ||||
|   subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) | ||||
| 
 | ||||
|   with _CACHE_LOCK: | ||||
|     # The cache values are _ConvertedEntityFactoryInfo objects. | ||||
|     if _CACHE.has(entity, subkey): | ||||
|       # TODO(mdan): Check whether the module is still loaded. | ||||
|       converted_entity_info = _CACHE[entity][subkey] | ||||
|       logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey, | ||||
|                   converted_entity_info) | ||||
|       return converted_entity_info | ||||
| 
 | ||||
|     logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey) | ||||
| 
 | ||||
|     nodes, converted_name, entity_info = convert_entity_to_ast( | ||||
|         entity, program_ctx) | ||||
| 
 | ||||
|     namer = naming.Namer(entity_info.namespace) | ||||
|     factory_factory_name = namer.new_symbol('create_converted_entity_factory', | ||||
|                                             ()) | ||||
|     factory_name = namer.new_symbol('create_converted_entity', ()) | ||||
|     nodes = _wrap_into_dynamic_factory(nodes, converted_name, | ||||
|                                        factory_factory_name, factory_name, | ||||
|                                        free_nonglobal_var_names, | ||||
|                                        entity_info.future_features) | ||||
| 
 | ||||
|     module, _, source_map = loader.load_ast(nodes, include_source_map=True) | ||||
|     module_name = module.__name__ | ||||
| 
 | ||||
|     converted_entity_info = _ConvertedEntityFactoryInfo( | ||||
|         module_name=module_name, | ||||
|         converted_name=converted_name, | ||||
|         factory_factory_name=factory_factory_name, | ||||
|         source_map=source_map) | ||||
|     _CACHE[entity][subkey] = converted_entity_info | ||||
|     return converted_entity_info | ||||
| 
 | ||||
| 
 | ||||
| def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): | ||||
|   """Creates a converted instance and binds it to match original entity.""" | ||||
|   factory = converted_entity_info.get_factory() | ||||
| 
 | ||||
|   entity_globals = entity.__globals__ | ||||
|   entity_closure = entity.__closure__ or () | ||||
|   assert len(entity_closure) == len(free_nonglobal_var_names) | ||||
| 
 | ||||
|   # Fit the original entity's cells to match the order of factory's cells. | ||||
|   original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure)) | ||||
|   new_factory_cells = tuple( | ||||
|       original_names_and_cells[name] for name in factory.__code__.co_freevars) | ||||
| 
 | ||||
|   bound_factory = types.FunctionType( | ||||
|       code=factory.__code__, | ||||
|       globals=entity_globals, | ||||
|       name=factory.__name__, | ||||
|       argdefs=(), | ||||
|       closure=new_factory_cells) | ||||
| 
 | ||||
|   # Two other free vars: the internal "ag__" module and the source | ||||
|   # map. These are wired via the parameters of the factory. | ||||
|   converted_entity = bound_factory(  # pylint:disable=not-callable | ||||
|       ag_internal, converted_entity_info.source_map, | ||||
|       converted_entity_info.get_module()) | ||||
| 
 | ||||
|   # Attach the default argument to the converted function. | ||||
|   converted_entity.__defaults__ = entity.__defaults__ | ||||
|   if hasattr(entity, '__kwdefaults__'): | ||||
|     converted_entity.__kwdefaults__ = entity.__kwdefaults__ | ||||
| 
 | ||||
|   return converted_entity | ||||
| custom_vars = None | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): Superfluous function, remove. | ||||
| # TODO(mdan): Put these extra fields inside __autograph_info__. | ||||
| def convert(entity, program_ctx): | ||||
|   """Converts an entity into an equivalent entity.""" | ||||
|   """Applies AutoGraph to entity.""" | ||||
| 
 | ||||
|   if not hasattr(entity, '__code__'): | ||||
|     raise ValueError('Cannot apply autograph to a function that doesn\'t ' | ||||
|                      'expose a __code__ object. If this is a @tf.function,' | ||||
|                      ' try passing f.python_function instead.') | ||||
|   free_nonglobal_var_names = entity.__code__.co_freevars | ||||
| 
 | ||||
|   for i, name in enumerate(free_nonglobal_var_names): | ||||
|     if (name == 'ag__' and | ||||
|         entity.__closure__[i].cell_contents is not ag_internal): | ||||
|       raise ValueError('entity {} uses the reserved symbol "{}"'.format( | ||||
|           entity, name)) | ||||
|     # TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered. | ||||
|   _create_custom_vars(program_ctx) | ||||
|   transformed, module, source_map = _TRANSPILER.transform_function( | ||||
|       entity, program_ctx.options, program_ctx, custom_vars) | ||||
| 
 | ||||
|   converted_entity_info = _convert_with_cache(entity, program_ctx, | ||||
|                                               free_nonglobal_var_names) | ||||
| 
 | ||||
|   return _instantiate(entity, converted_entity_info, free_nonglobal_var_names) | ||||
|   assert not hasattr(transformed, 'ag_module') | ||||
|   assert not hasattr(transformed, 'ag_source_map') | ||||
|   transformed.ag_module = module | ||||
|   transformed.ag_source_map = source_map | ||||
|   return transformed | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. | ||||
| @ -472,58 +229,15 @@ def cache_whitelisted(entity, options): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): Rename to convert_*_node to avoid confusion with convert. | ||||
| def convert_entity_to_ast(o, program_ctx): | ||||
|   """Compile a Python entity into equivalent TensorFlow. | ||||
| 
 | ||||
|   Args: | ||||
|     o: A Python entity. | ||||
|     program_ctx: A ProgramContext object. | ||||
| 
 | ||||
|   Returns: | ||||
|     A tuple (ast, new_name, namespace): | ||||
|         * ast: An AST representing an entity with interface equivalent to `o`, | ||||
|             but which when executed it creates TF a graph. | ||||
|         * new_name: The symbol name under which the new entity can be found. | ||||
|         * namespace: A dict mapping all symbols visible to the converted entity, | ||||
|             keyed by their symbol name. | ||||
| 
 | ||||
|   Raises: | ||||
|     NotImplementedError: if entity is of a type that is not yet supported. | ||||
|   """ | ||||
|   logging.log(1, 'Converting %s', o) | ||||
| 
 | ||||
|   nodes, name, entity_info = convert_func_to_ast(o, program_ctx) | ||||
| 
 | ||||
|   if logging.has_verbosity(2): | ||||
|     logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) | ||||
|   if logging.has_verbosity(4): | ||||
|     for n in nodes: | ||||
|       logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, | ||||
|                   pretty_printer.fmt(n, color=False)) | ||||
| 
 | ||||
|   return nodes, name, entity_info | ||||
| 
 | ||||
| 
 | ||||
| def _add_reserved_symbol(namespace, name, entity): | ||||
|   if name not in namespace: | ||||
|     namespace[name] = entity | ||||
|   elif namespace[name] != entity: | ||||
|     raise ValueError('The name "%s" is reserved and may not be used.' % name) | ||||
| 
 | ||||
| 
 | ||||
| ag_internal = None | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): Move into core or replace with an actual importable module. | ||||
| def _add_self_references(namespace, autograph_module): | ||||
| def _create_custom_vars(program_ctx): | ||||
|   """Adds namespace references to the module that exposes the api itself.""" | ||||
|   global ag_internal | ||||
|   if ag_internal is None: | ||||
|   global custom_vars | ||||
|   if custom_vars is None: | ||||
|     # Craft a module that exposes parts of the external API as well as certain | ||||
|     # internal modules. | ||||
|     ag_internal = imp.new_module('autograph') | ||||
|     ag_internal.__dict__.update(autograph_module.__dict__) | ||||
|     ag_internal.__dict__.update(program_ctx.autograph_module.__dict__) | ||||
|     ag_internal.ConversionOptions = converter.ConversionOptions | ||||
|     ag_internal.STD = converter.STANDARD_OPTIONS | ||||
|     ag_internal.Feature = converter.Feature | ||||
| @ -536,102 +250,4 @@ def _add_self_references(namespace, autograph_module): | ||||
|     ag_internal.__dict__.update(special_functions.__dict__) | ||||
|     ag_internal.__dict__.update(operators.__dict__) | ||||
| 
 | ||||
|   _add_reserved_symbol(namespace, 'ag__', ag_internal) | ||||
| 
 | ||||
| 
 | ||||
| def convert_func_to_ast(f, program_ctx, do_rename=True): | ||||
|   """Specialization of `convert_entity_to_ast` for callable functions.""" | ||||
| 
 | ||||
|   future_features = inspect_utils.getfutureimports(f) | ||||
|   node, source = parser.parse_entity(f, future_features=future_features) | ||||
|   logging.log(3, 'Source code of %s:\n\n%s\n', f, source) | ||||
|   # Parsed AST should contain future imports and one function def node. | ||||
| 
 | ||||
|   # In general, the output of inspect.getsource is inexact for lambdas because | ||||
|   # it uses regex matching to adjust the exact location around the line number | ||||
|   # that CPython records. Then, the entire containing line is returned, which | ||||
|   # we may have trouble disambiguating. For example: | ||||
|   # x, y = lambda: 1, lambda: 2 | ||||
|   if f.__name__ == '<lambda>': | ||||
|     nodes = ast_util.find_matching_definitions(node, f) | ||||
|     if len(nodes) != 1: | ||||
|       raise ValueError( | ||||
|           'Unable to identify source code of lambda function {}. It was' | ||||
|           ' defined on this line: {}, which must contain a single lambda with' | ||||
|           ' matching signature. To avoid ambiguity, define each lambda' | ||||
|           ' in a separate expression.'.format(f, source)) | ||||
|     node, = nodes | ||||
| 
 | ||||
|   # TODO(znado): Place inside standard_analysis. | ||||
|   origin_info.resolve_entity(node, source, f) | ||||
| 
 | ||||
|   namespace = inspect_utils.getnamespace(f) | ||||
|   _add_self_references(namespace, program_ctx.autograph_module) | ||||
|   namer = naming.Namer(namespace) | ||||
| 
 | ||||
|   if isinstance(node, gast.Lambda): | ||||
|     new_name = namer.new_symbol('tf__lambda', ()) | ||||
|   elif do_rename: | ||||
|     new_name = namer.new_symbol('tf__' + f.__name__, ()) | ||||
|   else: | ||||
|     new_name = f.__name__ | ||||
| 
 | ||||
|   entity_info = transformer.EntityInfo( | ||||
|       source_code=source, | ||||
|       source_file='<fragment>', | ||||
|       future_features=future_features, | ||||
|       namespace=namespace) | ||||
|   context = converter.EntityContext(namer, entity_info, program_ctx, new_name) | ||||
|   node = node_to_graph(node, context) | ||||
| 
 | ||||
|   if isinstance(node, gast.Lambda): | ||||
|     node = gast.Assign( | ||||
|         targets=[ | ||||
|             gast.Name( | ||||
|                 new_name, ctx=gast.Store(), annotation=None, type_comment=None) | ||||
|         ], | ||||
|         value=node) | ||||
|   elif do_rename: | ||||
|     node.name = new_name | ||||
|   else: | ||||
|     assert node.name == new_name | ||||
| 
 | ||||
|   return (node,), new_name, entity_info | ||||
| 
 | ||||
| 
 | ||||
| def node_to_graph(node, context): | ||||
|   """Convert Python code to equivalent TF graph mode code. | ||||
| 
 | ||||
|   Args: | ||||
|     node: AST, the code to convert. | ||||
|     context: converter.EntityContext | ||||
| 
 | ||||
|   Returns: | ||||
|     A tuple (node, deps): | ||||
|         * node: A Python ast node, representing the converted code. | ||||
|         * deps: A set of strings, the fully qualified names of entity | ||||
|             dependencies that this node has. | ||||
|   """ | ||||
|   # TODO(mdan): Insert list_comprehensions somewhere. | ||||
|   unsupported_features_checker.verify(node) | ||||
| 
 | ||||
|   node = converter.standard_analysis(node, context, is_initial=True) | ||||
|   node = converter.apply_(node, context, functions) | ||||
|   node = converter.apply_(node, context, arg_defaults) | ||||
|   node = converter.apply_(node, context, directives) | ||||
|   node = converter.apply_(node, context, break_statements) | ||||
|   if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS): | ||||
|     node = converter.apply_(node, context, asserts) | ||||
|   # Note: sequencing continue canonicalization before for loop one avoids | ||||
|   # dealing with the extra loop increment operation that the for | ||||
|   # canonicalization creates. | ||||
|   node = converter.apply_(node, context, continue_statements) | ||||
|   node = converter.apply_(node, context, return_statements) | ||||
|   if context.program.options.uses(converter.Feature.LISTS): | ||||
|     node = converter.apply_(node, context, lists) | ||||
|     node = converter.apply_(node, context, slices) | ||||
|   node = converter.apply_(node, context, call_trees) | ||||
|   node = converter.apply_(node, context, control_flow) | ||||
|   node = converter.apply_(node, context, conditional_expressions) | ||||
|   node = converter.apply_(node, context, logical_expressions) | ||||
|   return node | ||||
|     custom_vars = {'ag__': ag_internal} | ||||
|  | ||||
| @ -20,11 +20,9 @@ from __future__ import print_function | ||||
| 
 | ||||
| import imp | ||||
| import sys | ||||
| import threading | ||||
| import types | ||||
| import weakref | ||||
| 
 | ||||
| import gast | ||||
| import six | ||||
| 
 | ||||
| from tensorflow.python.autograph import utils | ||||
| @ -33,7 +31,6 @@ from tensorflow.python.autograph.core import converter | ||||
| from tensorflow.python.autograph.impl import api | ||||
| from tensorflow.python.autograph.impl import conversion | ||||
| from tensorflow.python.autograph.impl.testing import pybind_for_testing | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.eager import function | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.platform import test | ||||
| @ -126,156 +123,6 @@ class ConversionTest(test.TestCase): | ||||
|       # Note: currently, native bindings are whitelisted by a separate check. | ||||
|       self.assertFalse(conversion.is_whitelisted(test_object.method)) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_callable(self): | ||||
|     b = 2 | ||||
| 
 | ||||
|     def f(a): | ||||
|       return a + b | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx) | ||||
|     fn_node, = nodes | ||||
|     self.assertIsInstance(fn_node, gast.FunctionDef) | ||||
|     self.assertEqual('tf__f', name) | ||||
|     self.assertIs(info.namespace['b'], b) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_function_with_defaults(self): | ||||
|     b = 2 | ||||
|     c = 1 | ||||
| 
 | ||||
|     def f(a, d=c + 1): | ||||
|       return a + b + d | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx) | ||||
|     fn_node, = nodes | ||||
|     self.assertIsInstance(fn_node, gast.FunctionDef) | ||||
|     self.assertEqual('tf__f', name) | ||||
|     self.assertEqual( | ||||
|         parser.unparse(fn_node.args.defaults[0], | ||||
|                        include_encoding_marker=False).strip(), 'None') | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_call_tree(self): | ||||
| 
 | ||||
|     def g(a): | ||||
|       return a | ||||
| 
 | ||||
|     def f(a): | ||||
|       return g(a) | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx) | ||||
|     f_node, = nodes | ||||
|     self.assertEqual('tf__f', f_node.name) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_lambda(self): | ||||
|     b = 2 | ||||
|     f = lambda x: b * x if x > 0 else -x | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     (fn_node,), name, entity_info = conversion.convert_entity_to_ast( | ||||
|         f, program_ctx) | ||||
|     self.assertIsInstance(fn_node, gast.Assign) | ||||
|     self.assertIsInstance(fn_node.value, gast.Lambda) | ||||
|     self.assertEqual('tf__lambda', name) | ||||
|     self.assertIs(entity_info.namespace['b'], b) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_multiple_lambdas(self): | ||||
|     a, b = 1, 2 | ||||
|     f, _ = (lambda x: a * x, lambda y: b * y) | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     (fn_node,), name, entity_info = conversion.convert_entity_to_ast( | ||||
|         f, program_ctx) | ||||
|     self.assertIsInstance(fn_node, gast.Assign) | ||||
|     self.assertIsInstance(fn_node.value, gast.Lambda) | ||||
|     self.assertEqual('tf__lambda', name) | ||||
|     self.assertIs(entity_info.namespace['a'], a) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self): | ||||
|     a, b = 1, 2 | ||||
|     f, _ = (lambda x: a * x, lambda x: b * x) | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     with self.assertRaises(ValueError): | ||||
|       conversion.convert_entity_to_ast(f, program_ctx) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_lambda_code_with_garbage(self): | ||||
|     # pylint:disable=g-long-lambda | ||||
|     f = (  # intentional wrap | ||||
|         lambda x: ( | ||||
|             x  # intentional wrap | ||||
|             + 1),)[0] | ||||
|     # pylint:enable=g-long-lambda | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     (fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx) | ||||
|     self.assertIsInstance(fn_node, gast.Assign) | ||||
|     self.assertIsInstance(fn_node.value, gast.Lambda) | ||||
|     self.assertEqual('tf__lambda', name) | ||||
| 
 | ||||
|   def test_convert_entity_to_ast_nested_functions(self): | ||||
|     b = 2 | ||||
| 
 | ||||
|     def f(x): | ||||
| 
 | ||||
|       def g(x): | ||||
|         return b * x | ||||
| 
 | ||||
|       return g(x) | ||||
| 
 | ||||
|     program_ctx = self._simple_program_ctx() | ||||
|     (fn_node,), name, entity_info = conversion.convert_entity_to_ast( | ||||
|         f, program_ctx) | ||||
|     self.assertIsInstance(fn_node, gast.FunctionDef) | ||||
|     self.assertEqual(fn_node.name, 'tf__f') | ||||
|     self.assertEqual('tf__f', name) | ||||
|     self.assertIs(entity_info.namespace['b'], b) | ||||
| 
 | ||||
|   def test_convert_concurrency(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       pass | ||||
| 
 | ||||
|     generated_file_names = [] | ||||
| 
 | ||||
|     def conversion_thread(): | ||||
|       new_f = conversion.convert(test_fn, self._simple_program_ctx()) | ||||
|       generated_file_names.append(new_f.__code__.co_filename) | ||||
| 
 | ||||
|     threads = tuple( | ||||
|         threading.Thread(target=conversion_thread) for _ in range(10)) | ||||
|     for t in threads: | ||||
|       t.start() | ||||
|     for t in threads: | ||||
|       t.join() | ||||
| 
 | ||||
|     # Races would potentially create multiple files (non-deterministically, | ||||
|     # but with high likelihood). | ||||
|     self.assertEqual(len(set(generated_file_names)), 1) | ||||
| 
 | ||||
|   def test_convert_reentrance(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       pass | ||||
| 
 | ||||
|     # There are no known ways to cause convert to re-enter. So we instrument | ||||
|     # an internal function to do that instead. | ||||
|     old_node_to_graph = conversion.node_to_graph | ||||
|     self.num_conversions = 0 | ||||
|     def node_to_graph_wrapper(node, context): | ||||
|       self.num_conversions += 1 | ||||
|       if self.num_conversions < 2: | ||||
|         conversion.convert(test_fn, self._simple_program_ctx()) | ||||
|       return old_node_to_graph(node, context) | ||||
| 
 | ||||
|     try: | ||||
|       conversion.node_to_graph = node_to_graph_wrapper | ||||
|       new_f = conversion.convert(test_fn, self._simple_program_ctx()) | ||||
|       self.assertIsNotNone(new_f) | ||||
|     finally: | ||||
|       conversion.node_to_graph = old_node_to_graph | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   test.main() | ||||
|  | ||||
| @ -39,6 +39,7 @@ py_library( | ||||
|         "qual_names.py", | ||||
|         "templates.py", | ||||
|         "transformer.py", | ||||
|         "transpiler.py", | ||||
|     ], | ||||
|     srcs_version = "PY2AND3", | ||||
|     visibility = ["//visibility:public"], | ||||
| @ -230,3 +231,14 @@ py_test( | ||||
|         "@gast_archive//:gast", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_test( | ||||
|     name = "transpiler_test", | ||||
|     srcs = ["transpiler_test.py"], | ||||
|     python_version = "PY3", | ||||
|     srcs_version = "PY2AND3", | ||||
|     deps = [ | ||||
|         ":pyct", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|     ], | ||||
| ) | ||||
|  | ||||
| @ -50,8 +50,12 @@ class AnfTestBase(test.TestCase): | ||||
| 
 | ||||
|   def _simple_context(self): | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=None, source_file=None, future_features=(), namespace=None) | ||||
|     return transformer.Context(entity_info) | ||||
|         name='test_fn', | ||||
|         source_code=None, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace=None) | ||||
|     return transformer.Context(entity_info, None, None) | ||||
| 
 | ||||
|   def assert_same_ast(self, expected_node, node, msg=None): | ||||
|     expected_source = parser.unparse(expected_node, indentation='  ') | ||||
|  | ||||
| @ -22,6 +22,7 @@ import gast | ||||
| import six | ||||
| 
 | ||||
| from tensorflow.python.autograph.pyct import anno | ||||
| from tensorflow.python.autograph.pyct import naming | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.autograph.pyct import qual_names | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| @ -113,11 +114,17 @@ class ScopeTest(test.TestCase): | ||||
| class ActivityAnalyzerTestBase(test.TestCase): | ||||
| 
 | ||||
|   def _parse_and_analyze(self, test_fn): | ||||
|     # TODO(mdan): Use a custom FunctionTransformer here. | ||||
|     node, source = parser.parse_entity(test_fn, future_features=()) | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=source, source_file=None, future_features=(), namespace={}) | ||||
|         name=test_fn.__name__, | ||||
|         source_code=source, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace={}) | ||||
|     node = qual_names.resolve(node) | ||||
|     ctx = transformer.Context(entity_info) | ||||
|     namer = naming.Namer({}) | ||||
|     ctx = transformer.Context(entity_info, namer, None) | ||||
|     node = activity.resolve(node, ctx) | ||||
|     return node, entity_info | ||||
| 
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.autograph.pyct import anno | ||||
| from tensorflow.python.autograph.pyct import cfg | ||||
| from tensorflow.python.autograph.pyct import naming | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.autograph.pyct import qual_names | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| @ -35,11 +36,17 @@ global_b = 17 | ||||
| class LivenessAnalyzerTestBase(test.TestCase): | ||||
| 
 | ||||
|   def _parse_and_analyze(self, test_fn): | ||||
|     # TODO(mdan): Use a custom FunctionTransformer here. | ||||
|     node, source = parser.parse_entity(test_fn, future_features=()) | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=source, source_file=None, future_features=(), namespace={}) | ||||
|         name=test_fn.__name__, | ||||
|         source_code=source, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace={}) | ||||
|     node = qual_names.resolve(node) | ||||
|     ctx = transformer.Context(entity_info) | ||||
|     namer = naming.Namer({}) | ||||
|     ctx = transformer.Context(entity_info, namer, None) | ||||
|     node = activity.resolve(node, ctx) | ||||
|     graphs = cfg.build(node) | ||||
|     liveness.resolve(node, ctx, graphs) | ||||
|  | ||||
| @ -22,6 +22,7 @@ import six | ||||
| 
 | ||||
| from tensorflow.python.autograph.pyct import anno | ||||
| from tensorflow.python.autograph.pyct import cfg | ||||
| from tensorflow.python.autograph.pyct import naming | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.autograph.pyct import qual_names | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| @ -37,11 +38,17 @@ global_b = 17 | ||||
| class ReachingDefinitionsAnalyzerTestBase(test.TestCase): | ||||
| 
 | ||||
|   def _parse_and_analyze(self, test_fn): | ||||
|     # TODO(mdan): Use a custom FunctionTransformer here. | ||||
|     node, source = parser.parse_entity(test_fn, future_features=()) | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=source, source_file=None, future_features=(), namespace={}) | ||||
|         name=test_fn.__name__, | ||||
|         source_code=source, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace={}) | ||||
|     node = qual_names.resolve(node) | ||||
|     ctx = transformer.Context(entity_info) | ||||
|     namer = naming.Namer({}) | ||||
|     ctx = transformer.Context(entity_info, namer, None) | ||||
|     node = activity.resolve(node, ctx) | ||||
|     graphs = cfg.build(node) | ||||
|     node = reaching_definitions.resolve(node, ctx, graphs, | ||||
|  | ||||
| @ -36,20 +36,26 @@ class Context(object): | ||||
| 
 | ||||
|   Attributes: | ||||
|     info: EntityInfo, immutable. | ||||
|     namer: naming.Namer. | ||||
|     current_origin: origin_info.OriginInfo, holds the OriginInfo of the last | ||||
|       AST node to be processed successfully. Useful for error handling. | ||||
|     user: An user-supplied context object. The object is opaque to the | ||||
|       infrastructure, but will pe passed through to all custom transformations. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, info): | ||||
|   def __init__(self, info, namer, user_context): | ||||
|     self.info = info | ||||
|     self.namer = namer | ||||
|     self.current_origin = None | ||||
|     self.user = user_context | ||||
| 
 | ||||
| 
 | ||||
| # TODO(mdan): Move to a standalone file. | ||||
| class EntityInfo( | ||||
|     collections.namedtuple( | ||||
|         'EntityInfo', | ||||
|         ('source_code', 'source_file', 'future_features', 'namespace'))): | ||||
|         ('name', 'source_code', 'source_file', 'future_features', 'namespace')) | ||||
| ): | ||||
|   """Contains information about a Python entity. | ||||
| 
 | ||||
|   Immutable. | ||||
| @ -57,6 +63,7 @@ class EntityInfo( | ||||
|   Examples of entities include functions and classes. | ||||
| 
 | ||||
|   Attributes: | ||||
|     name: The name that identifies this entity. | ||||
|     source_code: The entity's source code. | ||||
|     source_file: The entity's source file. | ||||
|     future_features: Tuple[Text], the future features that this entity was | ||||
|  | ||||
| @ -31,8 +31,12 @@ class TransformerTest(test.TestCase): | ||||
| 
 | ||||
|   def _simple_context(self): | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=None, source_file=None, future_features=(), namespace=None) | ||||
|     return transformer.Context(entity_info) | ||||
|         name='Test_fn', | ||||
|         source_code=None, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace=None) | ||||
|     return transformer.Context(entity_info, None, None) | ||||
| 
 | ||||
|   def assertSameAnno(self, first, second, key): | ||||
|     self.assertIs(anno.getanno(first, key), anno.getanno(second, key)) | ||||
| @ -299,8 +303,12 @@ class CodeGeneratorTest(test.TestCase): | ||||
| 
 | ||||
|   def _simple_context(self): | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         source_code=None, source_file=None, future_features=(), namespace=None) | ||||
|     return transformer.Context(entity_info) | ||||
|         name='test_fn', | ||||
|         source_code=None, | ||||
|         source_file=None, | ||||
|         future_features=(), | ||||
|         namespace=None) | ||||
|     return transformer.Context(entity_info, None, None) | ||||
| 
 | ||||
|   def test_basic_codegen(self): | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										419
									
								
								tensorflow/python/autograph/pyct/transpiler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										419
									
								
								tensorflow/python/autograph/pyct/transpiler.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,419 @@ | ||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Generic source code transformation infrastructure.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import threading | ||||
| import types | ||||
| 
 | ||||
| import gast | ||||
| 
 | ||||
| from tensorflow.python.autograph.pyct import ast_util | ||||
| from tensorflow.python.autograph.pyct import cache | ||||
| from tensorflow.python.autograph.pyct import inspect_utils | ||||
| from tensorflow.python.autograph.pyct import loader | ||||
| from tensorflow.python.autograph.pyct import naming | ||||
| from tensorflow.python.autograph.pyct import origin_info | ||||
| from tensorflow.python.autograph.pyct import parser | ||||
| from tensorflow.python.autograph.pyct import templates | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| from tensorflow.python.autograph.utils import ag_logging as logging | ||||
| 
 | ||||
| 
 | ||||
| def _wrap_into_factory(nodes, entity_name, inner_factory_name, | ||||
|                        outer_factory_name, closure_vars, factory_args, | ||||
|                        future_features): | ||||
|   """Wraps an AST into the body of a factory with consistent lexical context. | ||||
| 
 | ||||
|   The AST is expected to define some symbol with a name given by `entity_name`. | ||||
| 
 | ||||
|   This mechanism ensures that the resulting transformed entity has lexical | ||||
|   scoping identical to that of the source entity, while allowing extra | ||||
|   parametrization. | ||||
| 
 | ||||
|   Two nested factories achieve the following: | ||||
| 
 | ||||
|    1. The inner factory dynamically creates the entity represented by `nodes`. | ||||
|    2. The inner factory is parametrized by a custom set of arguments. | ||||
|    3. The inner factory has a closure identical to that of the transformed | ||||
|        entity. | ||||
|    4. The inner factory has local variables named like `args`, which `nodes` may | ||||
|        use as additional parameters. | ||||
|    5. The inner factory returns the variables given by `entity_name`. | ||||
|    6. The outer factory is niladic. | ||||
|    7. The outer factory has no closure. | ||||
|    8. The outer factory creates the necessary lexical scope for the inner | ||||
|        factory, so that the loaded code has the given configuration for | ||||
|        closure/globals. | ||||
|    9. The outer factory returns the inner factory. | ||||
| 
 | ||||
|   Roughly speaking, the following code is generated: | ||||
| 
 | ||||
|       from __future__ import future_feature_1 | ||||
|       from __future__ import future_feature_2 | ||||
|       ... | ||||
| 
 | ||||
|       def outer_factory(): | ||||
|         closure_var_1 = None | ||||
|         closure_var_2 = None | ||||
|         ... | ||||
| 
 | ||||
|         def inner_factory(arg_1, arg_2, ...): | ||||
|           <<nodes>> | ||||
|           return entity | ||||
| 
 | ||||
|         return inner_factory | ||||
| 
 | ||||
|   The lexical scoping is created using dummy symbol declarations which create | ||||
|   local fariables in the body of the outer factory, so that the Python parser | ||||
|   correctly marks them as free non-global variables upon load (that is, it | ||||
|   creates cell slots for each symbol. Thes symbols are initialized with None, | ||||
|   but their values are not expected to be used; instead, the caller is expected | ||||
|   to replace them with the cells of the source entity. For more details, see: | ||||
|   https://docs.python.org/3/reference/executionmodel.html#binding-of-names | ||||
| 
 | ||||
|   Args: | ||||
|     nodes: Tuple[ast.AST], the source code to wrap. | ||||
|     entity_name: Union[Text, ast.AST], the name of the principal entity that | ||||
|       `nodes` define. | ||||
|     inner_factory_name: Text, the name of the inner factory. | ||||
|     outer_factory_name: Text, the name of the outer factory. | ||||
|     closure_vars: Iterable[Text], names of the closure variables for the inner | ||||
|       factory. | ||||
|     factory_args: Iterable[Text], names of additional arguments for the | ||||
|       inner factory. Useful to configure variables that the converted code can | ||||
|       use. Typically, these are modules. | ||||
|     future_features: Iterable[Text], names of future statements to associate the | ||||
|       code with. | ||||
| 
 | ||||
|   Returns: | ||||
|     ast.AST | ||||
|   """ | ||||
|   dummy_closure_defs = [] | ||||
|   for var_name in closure_vars: | ||||
|     template = """ | ||||
|       var_name = None | ||||
|     """ | ||||
|     dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) | ||||
| 
 | ||||
|   if future_features: | ||||
|     future_imports = gast.ImportFrom( | ||||
|         module='__future__', | ||||
|         names=[gast.alias(name=name, asname=None) for name in future_features], | ||||
|         level=0) | ||||
|   else: | ||||
|     future_imports = [] | ||||
| 
 | ||||
|   factory_args = [ | ||||
|       gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) | ||||
|       for name in factory_args | ||||
|   ] | ||||
| 
 | ||||
|   template = """ | ||||
|     future_imports | ||||
|     def outer_factory_name(): | ||||
|       dummy_closure_defs | ||||
|       def inner_factory_name(factory_args): | ||||
|         entity_defs | ||||
|         return entity_name | ||||
|       return inner_factory_name | ||||
|   """ | ||||
|   return templates.replace( | ||||
|       template, | ||||
|       dummy_closure_defs=dummy_closure_defs, | ||||
|       entity_defs=nodes, | ||||
|       entity_name=entity_name, | ||||
|       factory_args=factory_args, | ||||
|       future_imports=future_imports, | ||||
|       inner_factory_name=inner_factory_name, | ||||
|       outer_factory_name=outer_factory_name) | ||||
| 
 | ||||
| 
 | ||||
| class _TransformedFnFactory(object): | ||||
|   """Helper object that wraps a transformed function factory.""" | ||||
| 
 | ||||
|   def __init__(self, name, freevars, extra_locals): | ||||
|     """Creates a new factory for a transformed function. | ||||
| 
 | ||||
|     Args: | ||||
|       name: The function name. | ||||
|       freevars: The list of non-global free variables for the function. | ||||
|       extra_locals: Dict[Text, Any], names and values for custom variables that | ||||
|         are accessible to the generated code as local variables. | ||||
|     """ | ||||
|     self._name = name | ||||
|     self._freevars = freevars | ||||
|     self._extra_locals = extra_locals | ||||
| 
 | ||||
|     self._unbound_factory = None | ||||
|     self.module = None | ||||
|     self.source_map = None | ||||
| 
 | ||||
|   def create(self, | ||||
|              nodes, | ||||
|              namer, | ||||
|              inner_factory_name='inner_factory', | ||||
|              outer_factory_name='outer_factory', | ||||
|              future_features=()): | ||||
|     """Initializes a transformed function.""" | ||||
|     if self._unbound_factory is not None: | ||||
|       raise ValueError('double initialization; create a new object instead') | ||||
| 
 | ||||
|     inner_factory_name = namer.new_symbol(inner_factory_name, ()) | ||||
|     outer_factory_name = namer.new_symbol(outer_factory_name, ()) | ||||
|     nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, | ||||
|                                outer_factory_name, self._freevars, | ||||
|                                self._extra_locals.keys(), future_features) | ||||
| 
 | ||||
|     module, _, source_map = loader.load_ast( | ||||
|         nodes, include_source_map=True) | ||||
|     outer_factory = getattr(module, outer_factory_name) | ||||
|     self._unbound_factory = outer_factory() | ||||
|     self.module = module | ||||
|     self.source_map = source_map | ||||
| 
 | ||||
|   def instantiate(self, | ||||
|                   globals_, | ||||
|                   closure, | ||||
|                   defaults=None, | ||||
|                   kwdefaults=None): | ||||
|     """Creates a new instance of the transformed function.""" | ||||
|     if self._unbound_factory is None: | ||||
|       raise ValueError('call create first') | ||||
| 
 | ||||
|     factory_code = self._unbound_factory.__code__ | ||||
|     factory_freevars = factory_code.co_freevars | ||||
|     closure_map = dict(zip(self._freevars, closure)) | ||||
|     factory_closure = tuple( | ||||
|         closure_map[name] for name in factory_code.co_freevars) | ||||
|     if len(factory_closure) != len(closure): | ||||
|       raise ValueError( | ||||
|           'closure mismatch, requested {}, but source function had {}'.format( | ||||
|               self._freevars, factory_freevars)) | ||||
| 
 | ||||
|     bound_factory = types.FunctionType( | ||||
|         code=factory_code, | ||||
|         globals=globals_, | ||||
|         name=self._name, | ||||
|         argdefs=(), | ||||
|         closure=factory_closure) | ||||
| 
 | ||||
|     # The lint override is a false positive. | ||||
|     transformed_entity = bound_factory(**self._extra_locals)  # pylint:disable=not-callable | ||||
| 
 | ||||
|     if defaults: | ||||
|       transformed_entity.__defaults__ = defaults | ||||
|     if kwdefaults: | ||||
|       transformed_entity.__kwdefaults__ = kwdefaults | ||||
| 
 | ||||
|     return transformed_entity | ||||
| 
 | ||||
| 
 | ||||
| class FunctionTranspiler(object): | ||||
|   """A generic source-to-source transpiler for Python functions. | ||||
| 
 | ||||
|   Its interface `transform_function` API offers a function-in, function-out | ||||
|   interface. Internally, it takes care of parsing, caching and variable binding. | ||||
| 
 | ||||
|   Users typically subclass this, customizing the transform_ast method. | ||||
| 
 | ||||
|   Usually, instances of this class are singletons, since each instance manages | ||||
|   its own cache. The caching subkey allows managing multiple types of | ||||
|   transformation. | ||||
| 
 | ||||
|   Example: | ||||
| 
 | ||||
|       class MyTransformer(FunctionTranspiler): | ||||
| 
 | ||||
|         def transform_ast(self, node, ctx): | ||||
|           node = <<transform node, usually using ast.NodeTransformer classes>> | ||||
|           return node | ||||
| 
 | ||||
|       transformer = MyTransfomer() | ||||
| 
 | ||||
|       new_f, module, source_map = transformer.transform_function(f, ...) | ||||
|       # new_f is a function with signature identical to f | ||||
| 
 | ||||
|   The transformed function has access to the same namespace as the original | ||||
|   function. To allow access to internal APIs, users may inject additional | ||||
|   symbols though the `extra_locals` argument of `transform_function`. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self): | ||||
|     self._cache_lock = threading.RLock() | ||||
|     self._cache = cache.CodeObjectCache() | ||||
| 
 | ||||
|   def transform_ast(self, node, user_context): | ||||
|     """Performs an actual transformation of a function's AST. | ||||
| 
 | ||||
|     Subclasses must implement this method. They must not call it. | ||||
| 
 | ||||
|     The method receives the original AST and generates code according to the | ||||
|     AST that the method returns. For functions, the returned AST is expected to | ||||
|     contain a function with the exact same arguments and closure. The resulting | ||||
|     function will receive the globals, closure and argument defaults of the | ||||
|     input function. | ||||
| 
 | ||||
|     Args: | ||||
|       node: One or more ast.AST nodes representing the AST to be transformed. | ||||
|       user_context: The same value that the caller passed to | ||||
|         `transform_function`. | ||||
|     """ | ||||
|     raise NotImplementedError('subclasses must override this') | ||||
| 
 | ||||
|   def get_transformed_name(self, node): | ||||
|     """Returns a name for the output function. Subclasses may override this.""" | ||||
|     if isinstance(node, gast.Lambda): | ||||
|       return 'lam' | ||||
|     elif isinstance(node, gast.FunctionDef): | ||||
|       # Note that we need to rename the function, to avoid any namespace | ||||
|       # clashes. | ||||
|       return node.name | ||||
|     else: | ||||
|       raise ValueError('Unknown node type {}'.format(node)) | ||||
| 
 | ||||
|   def _erase_arg_defaults(self, node): | ||||
|     """Erase argde fault expressions, which would otherwise be unbound.""" | ||||
|     args = node.args | ||||
|     for i in range(len(args.defaults)): | ||||
|       args.defaults[i] = parser.parse_expression('None') | ||||
|     for i, d in enumerate(args.kw_defaults): | ||||
|       if d is not None: | ||||
|         args.kw_defaults[i] = parser.parse_expression('None') | ||||
|     return node | ||||
| 
 | ||||
|   def _transform_function(self, fn, user_context): | ||||
|     """Performs source code transformation on a function.""" | ||||
|     future_features = inspect_utils.getfutureimports(fn) | ||||
|     node, source = parser.parse_entity(fn, future_features=future_features) | ||||
|     logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) | ||||
| 
 | ||||
|     # In general, the output of inspect.getsource is inexact for lambdas | ||||
|     # because it uses regex matching to adjust the exact location around | ||||
|     # the line number that CPython records. Then, the entire containing line | ||||
|     # is returned, which we may have trouble disambiguating. | ||||
|     # For example: | ||||
|     #   x, y = lambda: 1, lambda: 2 | ||||
|     is_lambda = fn.__name__ == '<lambda>' | ||||
|     if is_lambda: | ||||
|       nodes = ast_util.find_matching_definitions(node, fn) | ||||
|       if len(nodes) != 1: | ||||
|         raise ValueError( | ||||
|             'Unable to identify source code of lambda function {}.' | ||||
|             ' It was defined in this code:\n' | ||||
|             '{}\n' | ||||
|             'This code must contain a single distinguishable lambda.' | ||||
|             ' To avoid this problem, define each lambda in a separate' | ||||
|             ' expression.'.format(fn, source)) | ||||
|       node, = nodes | ||||
| 
 | ||||
|     origin_info.resolve_entity(node, source, fn) | ||||
| 
 | ||||
|     namespace = inspect_utils.getnamespace(fn) | ||||
|     namer = naming.Namer(namespace) | ||||
|     new_name = namer.new_symbol(self.get_transformed_name(node), ()) | ||||
|     entity_info = transformer.EntityInfo( | ||||
|         name=new_name, | ||||
|         source_code=source, | ||||
|         source_file='<fragment>', | ||||
|         future_features=future_features, | ||||
|         namespace=namespace) | ||||
|     context = transformer.Context(entity_info, namer, user_context) | ||||
| 
 | ||||
|     node = self._erase_arg_defaults(node) | ||||
|     node = self.transform_ast(node, context) | ||||
| 
 | ||||
|     if is_lambda: | ||||
|       node = gast.Assign( | ||||
|           targets=[ | ||||
|               gast.Name( | ||||
|                   new_name, | ||||
|                   ctx=gast.Store(), | ||||
|                   annotation=None, | ||||
|                   type_comment=None) | ||||
|           ], | ||||
|           value=node) | ||||
|     else: | ||||
|       node.name = new_name | ||||
| 
 | ||||
|     return node, context | ||||
| 
 | ||||
|   def _cached_factory(self, fn, cache_subkey): | ||||
|     cached_factory = self._cache[fn][cache_subkey] | ||||
|     logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, | ||||
|                 cached_factory) | ||||
|     return cached_factory | ||||
| 
 | ||||
|   def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals): | ||||
|     """Returns the transformed function factory for a given input.""" | ||||
|     if self._cache.has(fn, cache_subkey): | ||||
|       return self._cached_factory(fn, cache_subkey) | ||||
| 
 | ||||
|     with self._cache_lock: | ||||
|       # Check again under lock. | ||||
|       if self._cache.has(fn, cache_subkey): | ||||
|         return self._cached_factory(fn, cache_subkey) | ||||
| 
 | ||||
|       logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) | ||||
|       nodes, ctx = self._transform_function(fn, user_context) | ||||
| 
 | ||||
|       if logging.has_verbosity(2): | ||||
|         logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) | ||||
| 
 | ||||
|       factory = _TransformedFnFactory( | ||||
|           ctx.info.name, fn.__code__.co_freevars, extra_locals) | ||||
|       factory.create(nodes, ctx.namer, future_features=ctx.info.future_features) | ||||
|       self._cache[fn][cache_subkey] = factory | ||||
|     return factory | ||||
| 
 | ||||
|   def transform_function(self, fn, caching_subkey, user_context, extra_locals): | ||||
|     """Transforms a function. | ||||
| 
 | ||||
|     The `caching_subkey` argument allows mapping each function to multiple | ||||
|     outputs in the cache. This is useful for instance when transformers | ||||
|     can generate multiple variants of output code, typically as a result of | ||||
|     different transformation flags. | ||||
| 
 | ||||
|     Args: | ||||
|       fn: A function or lambda. | ||||
|       caching_subkey: Used for caching. Calls made for functions with the same | ||||
|         code object and caching_subkey will return a cached instance on | ||||
|         subsequent invocations. Using a constant will create unique per-function | ||||
|         entries. | ||||
|       user_context: An opaque object (may be none) that is forwarded to | ||||
|         transform_ast. | ||||
|       extra_locals: A Dict[Text, Any] containing additional variables to make | ||||
|         available to the transformed code. These will be visible as local | ||||
|         variables. | ||||
|     Returns: | ||||
|       A tuple: | ||||
|         * A function or lambda with the same signature and closure as `fn` | ||||
|         * The temporary module into which the transformed function was loaded | ||||
|         * The source map as a | ||||
|             Dict[origin_info.LineLocation, origin_info.OriginInfo] | ||||
| 
 | ||||
|     """ | ||||
|     factory = self._transformed_factory(fn, caching_subkey, user_context, | ||||
|                                         extra_locals) | ||||
| 
 | ||||
|     transformed_fn = factory.instantiate( | ||||
|         globals_=fn.__globals__, | ||||
|         closure=fn.__closure__ or (), | ||||
|         defaults=fn.__defaults__, | ||||
|         kwdefaults=getattr(fn, '__kwdefaults__', None)) | ||||
|     return transformed_fn, factory.module, factory.source_map | ||||
							
								
								
									
										249
									
								
								tensorflow/python/autograph/pyct/transpiler_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										249
									
								
								tensorflow/python/autograph/pyct/transpiler_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,249 @@ | ||||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for transpiler module.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import threading | ||||
| 
 | ||||
| import gast | ||||
| 
 | ||||
| from tensorflow.python.autograph.pyct import transformer | ||||
| from tensorflow.python.autograph.pyct import transpiler | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| 
 | ||||
| class FlipSignTransformer(transformer.Base): | ||||
| 
 | ||||
|   def visit_BinOp(self, node): | ||||
|     if isinstance(node.op, gast.Add): | ||||
|       node.op = gast.Sub() | ||||
|     return self.generic_visit(node) | ||||
| 
 | ||||
| 
 | ||||
| class TestTranspiler(transpiler.FunctionTranspiler): | ||||
| 
 | ||||
|   def transform_ast(self, node, ctx): | ||||
|     return FlipSignTransformer(ctx).visit(node) | ||||
| 
 | ||||
| 
 | ||||
| global_var_for_test_global = 1 | ||||
| global_var_for_test_namespace_collisions = object() | ||||
| 
 | ||||
| 
 | ||||
| class FunctionTranspilerTest(test.TestCase): | ||||
| 
 | ||||
|   def test_basic(self): | ||||
|     def f(a): | ||||
|       return a + 1 | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 0) | ||||
| 
 | ||||
|   def test_closure(self): | ||||
|     b = 1 | ||||
| 
 | ||||
|     def f(a): | ||||
|       return a + b | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 0) | ||||
|     b = 2 | ||||
|     self.assertEqual(f(1), -1) | ||||
| 
 | ||||
|   def test_global(self): | ||||
|     def f(a): | ||||
|       return a + global_var_for_test_global | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     global global_var_for_test_global | ||||
|     global_var_for_test_global = 1 | ||||
|     self.assertEqual(f(1), 0) | ||||
|     global_var_for_test_global = 2 | ||||
|     self.assertEqual(f(1), -1) | ||||
| 
 | ||||
|   def test_defaults(self): | ||||
|     b = 2 | ||||
|     c = 1 | ||||
| 
 | ||||
|     def f(a, d=c + 1): | ||||
|       return a + b + d | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 1 - 2 - 2) | ||||
|     c = 0 | ||||
|     self.assertEqual(f(1), 1 - 2 - 2)  # Defaults are evaluated at definition. | ||||
|     b = 1 | ||||
|     self.assertEqual(f(1), 1 - 2 - 1) | ||||
| 
 | ||||
|   def test_call_tree(self): | ||||
| 
 | ||||
|     def g(a): | ||||
|       return a + 1 | ||||
| 
 | ||||
|     def f(a): | ||||
|       return g(a) + 1 | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 1 - 1 + 1)  # Only f is converted. | ||||
| 
 | ||||
|   def test_lambda(self): | ||||
|     b = 2 | ||||
|     f = lambda x: (b + (x if x > 0 else -x)) | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 2 - 1) | ||||
|     self.assertEqual(f(-1), 2 - 1) | ||||
| 
 | ||||
|     b = 3 | ||||
| 
 | ||||
|     self.assertEqual(f(1), 3 - 1) | ||||
|     self.assertEqual(f(-1), 3 - 1) | ||||
| 
 | ||||
|   def test_multiple_lambdas(self): | ||||
|     a, b = 1, 2 | ||||
|     # This can be disambiguated by the argument names. | ||||
|     f, _ = (lambda x: a + x, lambda y: b * y) | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 1 - 1) | ||||
| 
 | ||||
|   def test_multiple_lambdas_indistinguishable_definitions(self): | ||||
|     a, b = 1, 2 | ||||
|     f, _ = (lambda x: a * x, lambda x: b * x) | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     with self.assertRaises(ValueError): | ||||
|       tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|   def test_lambda_code_with_removable_garbage(self): | ||||
|     # pylint:disable=g-long-lambda | ||||
|     f = (  # intentional wrap | ||||
|         lambda x: ( | ||||
|             x  # intentional wrap | ||||
|             + 1),)[0] | ||||
|     # pylint:enable=g-long-lambda | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 1 - 1) | ||||
| 
 | ||||
|   def test_nested_functions(self): | ||||
|     b = 2 | ||||
| 
 | ||||
|     def f(x): | ||||
| 
 | ||||
|       def g(x): | ||||
|         return b + x | ||||
| 
 | ||||
|       return g(x) | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 2 - 1) | ||||
| 
 | ||||
|   def test_nested_lambda(self): | ||||
|     b = 2 | ||||
| 
 | ||||
|     def f(x): | ||||
|       g = lambda x: b + x | ||||
|       return g(x) | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     f, _, _ = tr.transform_function(f, object(), None, {}) | ||||
| 
 | ||||
|     self.assertEqual(f(1), 2 - 1) | ||||
| 
 | ||||
|   def test_concurrency(self): | ||||
| 
 | ||||
|     def f(): | ||||
|       pass | ||||
| 
 | ||||
|     outputs = [] | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     cache_key = object() | ||||
|     def conversion_thread(): | ||||
|       _, mod, _ = tr.transform_function(f, cache_key, None, {}) | ||||
|       outputs.append(mod.__name__) | ||||
| 
 | ||||
|     threads = tuple( | ||||
|         threading.Thread(target=conversion_thread) for _ in range(10)) | ||||
|     for t in threads: | ||||
|       t.start() | ||||
|     for t in threads: | ||||
|       t.join() | ||||
| 
 | ||||
|     # Races would potentially create multiple functions / modules | ||||
|     # (non-deterministically, but with high likelihood). | ||||
|     self.assertEqual(len(set(outputs)), 1) | ||||
| 
 | ||||
|   def test_reentrance(self): | ||||
| 
 | ||||
|     def test_fn(): | ||||
|       return 1 + 1 | ||||
| 
 | ||||
|     class ReentrantTranspiler(transpiler.FunctionTranspiler): | ||||
| 
 | ||||
|       def __init__(self): | ||||
|         super(ReentrantTranspiler, self).__init__() | ||||
|         self._recursion_depth = 0 | ||||
| 
 | ||||
|       def transform_ast(self, node, ctx): | ||||
|         self._recursion_depth += 1 | ||||
|         if self._recursion_depth < 2: | ||||
|           self.transform_function(test_fn, object(), None, {}) | ||||
|         return FlipSignTransformer(ctx).visit(node) | ||||
| 
 | ||||
|     tr = ReentrantTranspiler() | ||||
| 
 | ||||
|     f, _, _ = tr.transform_function(test_fn, object(), None, {}) | ||||
|     self.assertEqual(f(), 0) | ||||
| 
 | ||||
|   def test_namespace_collisions_avoided(self): | ||||
| 
 | ||||
|     class TestClass(object): | ||||
| 
 | ||||
|       def global_var_for_test_namespace_collisions(self): | ||||
|         return global_var_for_test_namespace_collisions | ||||
| 
 | ||||
|     tr = TestTranspiler() | ||||
|     obj = TestClass() | ||||
| 
 | ||||
|     f, _, _ = tr.transform_function( | ||||
|         obj.global_var_for_test_namespace_collisions, object(), None, {}) | ||||
|     self.assertIs(f(obj), global_var_for_test_namespace_collisions) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|   test.main() | ||||
| @ -52,13 +52,6 @@ def alias_tensors(*args): | ||||
|   raise ValueError('at least one argument required') | ||||
| 
 | ||||
| 
 | ||||
| def capitalize_initial(s): | ||||
|   """Capitalizes the initial of a string only.""" | ||||
|   if s: | ||||
|     return s[0].upper() + s[1:] | ||||
|   return s | ||||
| 
 | ||||
| 
 | ||||
| def get_range_len(start, limit, delta): | ||||
|   dist = ops.convert_to_tensor(limit - start) | ||||
|   unadjusted_len = dist // delta | ||||
|  | ||||
| @ -29,15 +29,6 @@ from tensorflow.python.platform import test | ||||
| 
 | ||||
| class MiscTest(test.TestCase): | ||||
| 
 | ||||
|   def test_capitalize_initial(self): | ||||
|     self.assertEqual('', misc.capitalize_initial('')) | ||||
|     self.assertEqual('A', misc.capitalize_initial('A')) | ||||
|     self.assertEqual('Ab', misc.capitalize_initial('Ab')) | ||||
|     self.assertEqual('AbC', misc.capitalize_initial('AbC')) | ||||
|     self.assertEqual('A', misc.capitalize_initial('a')) | ||||
|     self.assertEqual('Ab', misc.capitalize_initial('ab')) | ||||
|     self.assertEqual('AbC', misc.capitalize_initial('abC')) | ||||
| 
 | ||||
|   @test_util.run_deprecated_v1 | ||||
|   def test_alias_single_tensor(self): | ||||
|     a = constant(1) | ||||
|  | ||||
| @ -198,11 +198,12 @@ def _live_tensors(f, attr_name="inputs"): | ||||
|   """ | ||||
|   node, _ = parser.parse_entity(f, ()) | ||||
|   entity_info = transformer.EntityInfo( | ||||
|       name=f.__name__, | ||||
|       source_code=None, | ||||
|       source_file=None, | ||||
|       future_features=(), | ||||
|       namespace=sys.modules[f.__module__].__dict__) | ||||
|   ctx = transformer.Context(entity_info) | ||||
|   ctx = transformer.Context(entity_info, None, None) | ||||
| 
 | ||||
|   graphs = cfg.build(node) | ||||
|   node = qual_names.resolve(node) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user