diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD index 8e29fa1961e..9c1d5a38707 100644 --- a/tensorflow/python/autograph/converters/BUILD +++ b/tensorflow/python/autograph/converters/BUILD @@ -19,7 +19,6 @@ filegroup( py_library( name = "converters", srcs = [ - "arg_defaults.py", "asserts.py", "break_statements.py", "call_trees.py", @@ -48,18 +47,6 @@ py_library( ], ) -py_test( - name = "arg_defaults_test", - srcs = ["arg_defaults_test.py"], - python_version = "PY3", - srcs_version = "PY2AND3", - deps = [ - ":converters", - "//tensorflow/python:client_testlib", - "//tensorflow/python/autograph/core:test_lib", - ], -) - py_test( name = "asserts_test", srcs = ["asserts_test.py"], diff --git a/tensorflow/python/autograph/converters/arg_defaults.py b/tensorflow/python/autograph/converters/arg_defaults.py deleted file mode 100644 index 2f8865ffcb0..00000000000 --- a/tensorflow/python/autograph/converters/arg_defaults.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Modifies the signature to allow resolving the value of default arguments. - -Normally, function symbols are captured either in a function's globals or -closure. This is not true for default arguments, which are evaluated when the -function is defined: - - b = 1 - c = 2 - def f(a=b + 1): - return a + c - -In the above example, the namespace of the function would include `c = 2` but -not `b`. - -If we were to naively generate a new function: - - def new_f(a=b + 1): - return a + c - -The generated code would fail to load unless we exposed a symbol `b`. Capturing -the closure of such an expression is difficult. However, we can capture the -default value of argument `a` with relative ease. - -This converter replaces all default argument expressions with a constant so -that they don't cause loading to fail. This requires that the default values -are reset after loading the transformed function: - - def new_f(a=None): - return a + c - - # ... later, after new_f was loaded ... - new_f.__defaults__ = f.__defaults__ - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.pyct import parser - - -class _Function(object): - pass - - -class ArgDefaultsTransformer(converter.Base): - """Transforms top level argument defaults.""" - - def visit_Lambda(self, node): - self.state[_Function].enter() - node.args = self.visit(node.args) - # Only the top level function is modified - no need to visit the children. - self.state[_Function].exit() - return node - - def visit_FunctionDef(self, node): - self.state[_Function].enter() - node.args = self.visit(node.args) - # Only the top level function is modified - no need to visit the children. - self.state[_Function].exit() - return node - - def visit_arguments(self, node): - if self.state[_Function].level > 2: - return node - - for i in range(len(node.defaults)): - node.defaults[i] = parser.parse_expression('None') - - for i, d in enumerate(node.kw_defaults): - if d is not None: - node.kw_defaults[i] = parser.parse_expression('None') - - # Only the top level function is modified - no need to visit the children. - return node - - -def transform(node, ctx): - """Transform function call to the compiled counterparts. - - Args: - node: AST - ctx: EntityContext - Returns: - A tuple (node, new_names): - node: The transformed AST - new_names: set(string), containing any newly-generated names - """ - return ArgDefaultsTransformer(ctx).visit(node) diff --git a/tensorflow/python/autograph/converters/arg_defaults_test.py b/tensorflow/python/autograph/converters/arg_defaults_test.py deleted file mode 100644 index 6448f3124db..00000000000 --- a/tensorflow/python/autograph/converters/arg_defaults_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for arg_defaults module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.autograph.converters import arg_defaults -from tensorflow.python.autograph.core import converter_testing -from tensorflow.python.autograph.pyct import parser -from tensorflow.python.platform import test - - -class ArgDefaultsTransformerTest(converter_testing.TestCase): - - def assertTransformedFirstLineIs(self, node, expected): - self.assertEqual( - parser.unparse(node, include_encoding_marker=False).split('\n')[0], - expected) - - def test_no_args(self): - - def test_fn(): - pass - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - self.assertTransformedFirstLineIs(node, 'def test_fn():') - - def test_no_defaults(self): - - def test_fn(a, b, *c, **e): - return a, b, c, e - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - self.assertTransformedFirstLineIs(node, 'def test_fn(a, b, *c, **e):') - - # TODO(mdan): Add kwonly-arg tests when PY2 is no longer supported. - - def test_arg_defaults(self): - - def test_fn(a, b=1, c=2): - return a, b, c - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, c=None):') - - def test_arg_defaults_with_vararg(self): - - def test_fn(a, b=1, *c): # pylint: disable=keyword-arg-before-vararg - return a, b, c - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, *c):') - - def test_arg_defaults_ignores_inner_lambda(self): - - def test_fn(): - return (lambda x=7: x)() - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - with self.converted(test_fn, arg_defaults, {}) as result: - self.assertEqual(test_fn(), result.test_fn()) - - def test_arg_defaults_ignores_inner_function(self): - - def test_fn(): - def inner_fn(a=3): - return a - return inner_fn() - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - with self.converted(test_fn, arg_defaults, {}) as result: - self.assertEqual(test_fn(), result.test_fn()) - - def test_arg_defaults_ignores_inner_function_returned(self): - - def test_fn(): - def inner_fn(a=3): - return a - return inner_fn - - node, ctx = self.prepare(test_fn, {}) - node = arg_defaults.transform(node, ctx) - with self.converted(test_fn, arg_defaults, {}) as result: - self.assertEqual(test_fn()(), result.test_fn()()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 54804fcef3d..6d59c4bc761 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -191,7 +191,7 @@ class CallTreeTransformer(converter.Base): return node if (full_name == 'print' and - not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): + not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): return node template = """ diff --git a/tensorflow/python/autograph/converters/functions.py b/tensorflow/python/autograph/converters/functions.py index c1003badc1d..5ddbb277d10 100644 --- a/tensorflow/python/autograph/converters/functions.py +++ b/tensorflow/python/autograph/converters/functions.py @@ -54,8 +54,8 @@ class FunctionTransformer(converter.Base): # ControlStatusCtx(autograph=ENABLED) when user_requested is True. See # function_wrappers.py. if fn_scope.level == 2: - return self.ctx.program.options - return self.ctx.program.options.call_options() + return self.ctx.user.options + return self.ctx.user.options.call_options() def visit_Lambda(self, node): with self.state[_Function] as fn_scope: diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index 615dc21052f..92b0ca0718e 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -53,7 +53,7 @@ class LogicalExpressionTransformer(converter.Base): op_type = type(operator) if op_type in LOGICAL_OPERATORS: return LOGICAL_OPERATORS[op_type] - if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS): + if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS): if op_type in EQUALITY_OPERATORS: return EQUALITY_OPERATORS[op_type] return None @@ -83,7 +83,7 @@ class LogicalExpressionTransformer(converter.Base): def visit_Compare(self, node): node = self.generic_visit(node) - if (not self.ctx.program.options.uses( + if (not self.ctx.user.options.uses( converter.Feature.EQUALITY_OPERATORS)): return node diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index bc79d5fe506..77559fd2040 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -253,25 +253,6 @@ class ProgramContext( pass -class EntityContext(transformer.Context): - """Tracks the conversion of a single entity. - - This object is mutable, and is updated during conversion. Not thread safe. - - Attributes: - namer: Namer - info: transformer.EntityInfo - program: ProgramContext, - targe_name: Text - """ - - def __init__(self, namer, entity_info, program_ctx, target_name=None): - super(EntityContext, self).__init__(entity_info) - self.namer = namer - self.program = program_ctx - self.target_name = target_name - - class Base(transformer.Base): """All converters should inherit from this class. diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 8afcbdfb6bd..1b37fb4131c 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -168,12 +168,12 @@ class TestCase(test.TestCase): options=converter.ConversionOptions(recursive=recursive), autograph_module=None) entity_info = transformer.EntityInfo( + name=test_fn.__name__, source_code=source, source_file='<fragment>', future_features=future_features, namespace=namespace) - ctx = converter.EntityContext( - namer, entity_info, program_ctx, 'test_fn') + ctx = transformer.Context(entity_info, namer, program_ctx) origin_info.resolve_entity(node, source, test_fn) node = converter.standard_analysis(node, ctx, is_initial=True) return node, ctx diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 146d4b6ec2c..f444af77cbc 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -29,10 +29,7 @@ import sys import textwrap import traceback -# pylint:disable=g-bad-import-order - import six -# pylint:enable=g-bad-import-order from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter @@ -668,7 +665,7 @@ def to_graph(entity, recursive=True, experimental_optional_features=None): user_requested=True, optional_features=experimental_optional_features), autograph_module=tf_inspect.getmodule(to_graph)) - return conversion.convert(entity, program_ctx) + return autograph_artifact(conversion.convert(entity, program_ctx)) except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: logging.error(1, 'Error converting %s', entity, exc_info=True) raise ConversionError('converting {}: {}: {}'.format( diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index d4706879b0a..fa90a4fa42c 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -18,21 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import functools import imp -import inspect -import sys -import threading -import types import unittest -import weakref - -import gast from tensorflow.python.autograph import operators from tensorflow.python.autograph import utils -from tensorflow.python.autograph.converters import arg_defaults from tensorflow.python.autograph.converters import asserts from tensorflow.python.autograph.converters import break_statements from tensorflow.python.autograph.converters import call_trees @@ -50,304 +41,70 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import function_wrappers from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.lang import special_functions -from tensorflow.python.autograph.pyct import ast_util +from tensorflow.python.autograph.pyct import cache from tensorflow.python.autograph.pyct import inspect_utils -from tensorflow.python.autograph.pyct import loader -from tensorflow.python.autograph.pyct import naming -from tensorflow.python.autograph.pyct import origin_info -from tensorflow.python.autograph.pyct import parser -from tensorflow.python.autograph.pyct import pretty_printer -from tensorflow.python.autograph.pyct import templates -from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler from tensorflow.python.autograph.utils import ag_logging as logging from tensorflow.python.eager import function from tensorflow.python.util import tf_inspect -class _ConvertedEntityFactoryInfo( - collections.namedtuple( - '_ConvertedEntityFactoryInfo', - ('module_name', 'converted_name', 'factory_factory_name', 'source_map')) -): - """Holds metadata about a converted entity stored as a dynamic factory. +class AutoGraphTranspiler(transpiler.FunctionTranspiler): - The dynamic factory is assumed to be created by _wrap_into_dynamic_factory, - be named `factory_factory_name` and located inside the module named as - `module_name`. + def get_transformed_name(self, node): + return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node) - Attributes: - module_name: Text, the name of the module containing the entity. - converted_name: Text, the name of the converted entity. - factory_factory_name: Text, the name of the dynamic factory. - source_map: Dict. - """ + def transform_ast(self, node, ctx): + # TODO(mdan): Insert list_comprehensions somewhere. + unsupported_features_checker.verify(node) - def __str__(self): - return '_ConvertedEntityFactoryInfo({} in {})'.format( - self.converted_name, self.module_name) - - def get_module(self): - return sys.modules[self.module_name] - - def get_factory(self): - assert self.module_name in sys.modules - factory_factory = getattr(sys.modules[self.module_name], - self.factory_factory_name) - return factory_factory() + node = converter.standard_analysis(node, ctx, is_initial=True) + node = converter.apply_(node, ctx, functions) + node = converter.apply_(node, ctx, directives) + node = converter.apply_(node, ctx, break_statements) + if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): + node = converter.apply_(node, ctx, asserts) + # Note: sequencing continue canonicalization before for loop one avoids + # dealing with the extra loop increment operation that the for + # canonicalization creates. + node = converter.apply_(node, ctx, continue_statements) + node = converter.apply_(node, ctx, return_statements) + if ctx.user.options.uses(converter.Feature.LISTS): + node = converter.apply_(node, ctx, lists) + node = converter.apply_(node, ctx, slices) + node = converter.apply_(node, ctx, call_trees) + node = converter.apply_(node, ctx, control_flow) + node = converter.apply_(node, ctx, conditional_expressions) + node = converter.apply_(node, ctx, logical_expressions) + return node -# TODO(mdan): Add a garbage collection hook for cleaning up modules. -class _FunctionCache(object): - """A hierarchical cache that uses the converted entity as weak key. - - The keys soft references (i.e. they are discarded when the key is - destroyed). The subkeys are normal hashable values. - - This class is generic - see the call site for how the keys and values are - defined. - """ - - __slots__ = ('_cache',) - - def __init__(self): - self._cache = weakref.WeakKeyDictionary() - - def _get_key(self, entity): - raise NotImplementedError('subclasses will override') - - def has(self, entity, subkey): - key = self._get_key(entity) - if key not in self._cache: - return False - return subkey in self._cache[key] - - def __getitem__(self, entity): - key = self._get_key(entity) - if key not in self._cache: - # The bucket needs to be initialized to support this usage: - # cache[key][subkey] = value - self._cache[key] = {} - return self._cache[key] - - def __len__(self): - return len(self._cache) +_TRANSPILER = AutoGraphTranspiler() +_WHITELIST_CACHE = cache.UnboundInstanceCache() -class _CodeObjectCache(_FunctionCache): - """A function cache based on code objects (i.e., the source code). - - Multiple functions may share the same code object, but they may share the - cache because we know they have the exact source code. This properly handles - functions defined in a loop, bound methods, etc. - - Falls back to the function object, if it doesn't have a code object. - """ - - def _get_key(self, entity): - if hasattr(entity, '__code__'): - return entity.__code__ - else: - return entity - - -class _UnboundInstanceCache(_FunctionCache): - """A function cache based on unbound function objects. - - Unlike the _CodeObjectCache, this discriminates between different functions - even if they have the same code. This properly handles decorators that may - masquerade as various functions. Bound functions are not discriminated by - the object they're bound to. - """ - - def _get_key(self, entity): - if inspect.ismethod(entity): - return entity.__func__ - return entity - - -# Using a re-entrant lock to guard against the unlikely possibility that the -# conversion process triggers additional code execution. -_CACHE_LOCK = threading.RLock() - - -_CACHE = _CodeObjectCache() -_WHITELIST_CACHE = _UnboundInstanceCache() - - -# Note: strictly speaking, a simple factory might have been sufficient for -# functions. But the double factory approach allows us to control the closure -# and globals of the converted code in a cleaner fashion. -# TODO(mdan): A simple factory may be sufficient. -def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name, - factory_name, closure_vars, future_features): - """Wraps an AST into the body of a dynamic factory. - - This uses the dynamic factory (factory of factory) pattern to achieve the - following: - - 1. The inner factory, dynamically creates the entity represented by nodes. - 2. The entity is parametrized by `ag__`, the internal AutoGraph module. - 3. The outer factory creates the inner factory with a lexical scope - in which `closure_vars` are bound local variables. This in turn allows the - caller to control the exact closure (i.e. non-global free variables) for - the inner factory. - - The AST is expected to define some symbol named by `entity_name`. - - Args: - nodes: ast.AST - entity_name: Union[Text, ast.AST] - factory_factory_name: Text - factory_name: Text - closure_vars: Iterable[Text] - future_features: Iterable[Text], see EntityInfo.future_features. - - Returns: - ast.AST - """ - if not isinstance(nodes, (list, tuple)): - nodes = (nodes,) - - dummy_closure_defs = [] - for var_name in closure_vars: - template = """ - var_name = None - """ - dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) - - if future_features: - future_imports = gast.ImportFrom( - module='__future__', - names=[gast.alias(name=name, asname=None) for name in future_features], - level=0) - else: - future_imports = [] - - # These dummy symbol declarations create local fariables in a function scope, - # so that the Python parser correctly marks them as free non-global variables - # upon load (that is, it creates cell slots for each symbol). Their values are - # not used, as the cells are swapped with the original entity's cells after - # the code has been loaded. - template = """ - future_imports - def factory_factory_name(): - dummy_closure_defs - def factory_name(ag__, ag_source_map__, ag_module__): - entity_defs - entity_name.ag_source_map = ag_source_map__ - entity_name.ag_module = ag_module__ - entity_name = ag__.autograph_artifact(entity_name) - return entity_name - return factory_name - """ - return templates.replace( - template, - future_imports=future_imports, - factory_factory_name=factory_factory_name, - factory_name=factory_name, - dummy_closure_defs=dummy_closure_defs, - entity_defs=nodes, - entity_name=entity_name) - - -def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names): - """Returns a (possibly cached) factory for the converted result of entity.""" - # The cache subkey encompasses any conversion options on which the generated - # code may depend. - # The cached factory includes the necessary definitions to distinguish - # between the global and non-global free variables. For this reason, the - # cache subkey includes the names of the free non-globals. - subkey = (program_ctx.options, frozenset(free_nonglobal_var_names)) - - with _CACHE_LOCK: - # The cache values are _ConvertedEntityFactoryInfo objects. - if _CACHE.has(entity, subkey): - # TODO(mdan): Check whether the module is still loaded. - converted_entity_info = _CACHE[entity][subkey] - logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey, - converted_entity_info) - return converted_entity_info - - logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey) - - nodes, converted_name, entity_info = convert_entity_to_ast( - entity, program_ctx) - - namer = naming.Namer(entity_info.namespace) - factory_factory_name = namer.new_symbol('create_converted_entity_factory', - ()) - factory_name = namer.new_symbol('create_converted_entity', ()) - nodes = _wrap_into_dynamic_factory(nodes, converted_name, - factory_factory_name, factory_name, - free_nonglobal_var_names, - entity_info.future_features) - - module, _, source_map = loader.load_ast(nodes, include_source_map=True) - module_name = module.__name__ - - converted_entity_info = _ConvertedEntityFactoryInfo( - module_name=module_name, - converted_name=converted_name, - factory_factory_name=factory_factory_name, - source_map=source_map) - _CACHE[entity][subkey] = converted_entity_info - return converted_entity_info - - -def _instantiate(entity, converted_entity_info, free_nonglobal_var_names): - """Creates a converted instance and binds it to match original entity.""" - factory = converted_entity_info.get_factory() - - entity_globals = entity.__globals__ - entity_closure = entity.__closure__ or () - assert len(entity_closure) == len(free_nonglobal_var_names) - - # Fit the original entity's cells to match the order of factory's cells. - original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure)) - new_factory_cells = tuple( - original_names_and_cells[name] for name in factory.__code__.co_freevars) - - bound_factory = types.FunctionType( - code=factory.__code__, - globals=entity_globals, - name=factory.__name__, - argdefs=(), - closure=new_factory_cells) - - # Two other free vars: the internal "ag__" module and the source - # map. These are wired via the parameters of the factory. - converted_entity = bound_factory( # pylint:disable=not-callable - ag_internal, converted_entity_info.source_map, - converted_entity_info.get_module()) - - # Attach the default argument to the converted function. - converted_entity.__defaults__ = entity.__defaults__ - if hasattr(entity, '__kwdefaults__'): - converted_entity.__kwdefaults__ = entity.__kwdefaults__ - - return converted_entity +custom_vars = None +# TODO(mdan): Superfluous function, remove. +# TODO(mdan): Put these extra fields inside __autograph_info__. def convert(entity, program_ctx): - """Converts an entity into an equivalent entity.""" + """Applies AutoGraph to entity.""" if not hasattr(entity, '__code__'): raise ValueError('Cannot apply autograph to a function that doesn\'t ' 'expose a __code__ object. If this is a @tf.function,' ' try passing f.python_function instead.') - free_nonglobal_var_names = entity.__code__.co_freevars - for i, name in enumerate(free_nonglobal_var_names): - if (name == 'ag__' and - entity.__closure__[i].cell_contents is not ag_internal): - raise ValueError('entity {} uses the reserved symbol "{}"'.format( - entity, name)) - # TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered. + _create_custom_vars(program_ctx) + transformed, module, source_map = _TRANSPILER.transform_function( + entity, program_ctx.options, program_ctx, custom_vars) - converted_entity_info = _convert_with_cache(entity, program_ctx, - free_nonglobal_var_names) - - return _instantiate(entity, converted_entity_info, free_nonglobal_var_names) + assert not hasattr(transformed, 'ag_module') + assert not hasattr(transformed, 'ag_source_map') + transformed.ag_module = module + transformed.ag_source_map = source_map + return transformed # TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. @@ -472,58 +229,15 @@ def cache_whitelisted(entity, options): pass -# TODO(mdan): Rename to convert_*_node to avoid confusion with convert. -def convert_entity_to_ast(o, program_ctx): - """Compile a Python entity into equivalent TensorFlow. - - Args: - o: A Python entity. - program_ctx: A ProgramContext object. - - Returns: - A tuple (ast, new_name, namespace): - * ast: An AST representing an entity with interface equivalent to `o`, - but which when executed it creates TF a graph. - * new_name: The symbol name under which the new entity can be found. - * namespace: A dict mapping all symbols visible to the converted entity, - keyed by their symbol name. - - Raises: - NotImplementedError: if entity is of a type that is not yet supported. - """ - logging.log(1, 'Converting %s', o) - - nodes, name, entity_info = convert_func_to_ast(o, program_ctx) - - if logging.has_verbosity(2): - logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes)) - if logging.has_verbosity(4): - for n in nodes: - logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, - pretty_printer.fmt(n, color=False)) - - return nodes, name, entity_info - - -def _add_reserved_symbol(namespace, name, entity): - if name not in namespace: - namespace[name] = entity - elif namespace[name] != entity: - raise ValueError('The name "%s" is reserved and may not be used.' % name) - - -ag_internal = None - - # TODO(mdan): Move into core or replace with an actual importable module. -def _add_self_references(namespace, autograph_module): +def _create_custom_vars(program_ctx): """Adds namespace references to the module that exposes the api itself.""" - global ag_internal - if ag_internal is None: + global custom_vars + if custom_vars is None: # Craft a module that exposes parts of the external API as well as certain # internal modules. ag_internal = imp.new_module('autograph') - ag_internal.__dict__.update(autograph_module.__dict__) + ag_internal.__dict__.update(program_ctx.autograph_module.__dict__) ag_internal.ConversionOptions = converter.ConversionOptions ag_internal.STD = converter.STANDARD_OPTIONS ag_internal.Feature = converter.Feature @@ -536,102 +250,4 @@ def _add_self_references(namespace, autograph_module): ag_internal.__dict__.update(special_functions.__dict__) ag_internal.__dict__.update(operators.__dict__) - _add_reserved_symbol(namespace, 'ag__', ag_internal) - - -def convert_func_to_ast(f, program_ctx, do_rename=True): - """Specialization of `convert_entity_to_ast` for callable functions.""" - - future_features = inspect_utils.getfutureimports(f) - node, source = parser.parse_entity(f, future_features=future_features) - logging.log(3, 'Source code of %s:\n\n%s\n', f, source) - # Parsed AST should contain future imports and one function def node. - - # In general, the output of inspect.getsource is inexact for lambdas because - # it uses regex matching to adjust the exact location around the line number - # that CPython records. Then, the entire containing line is returned, which - # we may have trouble disambiguating. For example: - # x, y = lambda: 1, lambda: 2 - if f.__name__ == '<lambda>': - nodes = ast_util.find_matching_definitions(node, f) - if len(nodes) != 1: - raise ValueError( - 'Unable to identify source code of lambda function {}. It was' - ' defined on this line: {}, which must contain a single lambda with' - ' matching signature. To avoid ambiguity, define each lambda' - ' in a separate expression.'.format(f, source)) - node, = nodes - - # TODO(znado): Place inside standard_analysis. - origin_info.resolve_entity(node, source, f) - - namespace = inspect_utils.getnamespace(f) - _add_self_references(namespace, program_ctx.autograph_module) - namer = naming.Namer(namespace) - - if isinstance(node, gast.Lambda): - new_name = namer.new_symbol('tf__lambda', ()) - elif do_rename: - new_name = namer.new_symbol('tf__' + f.__name__, ()) - else: - new_name = f.__name__ - - entity_info = transformer.EntityInfo( - source_code=source, - source_file='<fragment>', - future_features=future_features, - namespace=namespace) - context = converter.EntityContext(namer, entity_info, program_ctx, new_name) - node = node_to_graph(node, context) - - if isinstance(node, gast.Lambda): - node = gast.Assign( - targets=[ - gast.Name( - new_name, ctx=gast.Store(), annotation=None, type_comment=None) - ], - value=node) - elif do_rename: - node.name = new_name - else: - assert node.name == new_name - - return (node,), new_name, entity_info - - -def node_to_graph(node, context): - """Convert Python code to equivalent TF graph mode code. - - Args: - node: AST, the code to convert. - context: converter.EntityContext - - Returns: - A tuple (node, deps): - * node: A Python ast node, representing the converted code. - * deps: A set of strings, the fully qualified names of entity - dependencies that this node has. - """ - # TODO(mdan): Insert list_comprehensions somewhere. - unsupported_features_checker.verify(node) - - node = converter.standard_analysis(node, context, is_initial=True) - node = converter.apply_(node, context, functions) - node = converter.apply_(node, context, arg_defaults) - node = converter.apply_(node, context, directives) - node = converter.apply_(node, context, break_statements) - if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS): - node = converter.apply_(node, context, asserts) - # Note: sequencing continue canonicalization before for loop one avoids - # dealing with the extra loop increment operation that the for - # canonicalization creates. - node = converter.apply_(node, context, continue_statements) - node = converter.apply_(node, context, return_statements) - if context.program.options.uses(converter.Feature.LISTS): - node = converter.apply_(node, context, lists) - node = converter.apply_(node, context, slices) - node = converter.apply_(node, context, call_trees) - node = converter.apply_(node, context, control_flow) - node = converter.apply_(node, context, conditional_expressions) - node = converter.apply_(node, context, logical_expressions) - return node + custom_vars = {'ag__': ag_internal} diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index b0c1e45cc45..24d93e18b24 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import imp import sys -import threading import types import weakref -import gast import six from tensorflow.python.autograph import utils @@ -33,7 +31,6 @@ from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.impl import api from tensorflow.python.autograph.impl import conversion from tensorflow.python.autograph.impl.testing import pybind_for_testing -from tensorflow.python.autograph.pyct import parser from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.platform import test @@ -126,156 +123,6 @@ class ConversionTest(test.TestCase): # Note: currently, native bindings are whitelisted by a separate check. self.assertFalse(conversion.is_whitelisted(test_object.method)) - def test_convert_entity_to_ast_callable(self): - b = 2 - - def f(a): - return a + b - - program_ctx = self._simple_program_ctx() - nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx) - fn_node, = nodes - self.assertIsInstance(fn_node, gast.FunctionDef) - self.assertEqual('tf__f', name) - self.assertIs(info.namespace['b'], b) - - def test_convert_entity_to_ast_function_with_defaults(self): - b = 2 - c = 1 - - def f(a, d=c + 1): - return a + b + d - - program_ctx = self._simple_program_ctx() - nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx) - fn_node, = nodes - self.assertIsInstance(fn_node, gast.FunctionDef) - self.assertEqual('tf__f', name) - self.assertEqual( - parser.unparse(fn_node.args.defaults[0], - include_encoding_marker=False).strip(), 'None') - - def test_convert_entity_to_ast_call_tree(self): - - def g(a): - return a - - def f(a): - return g(a) - - program_ctx = self._simple_program_ctx() - nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx) - f_node, = nodes - self.assertEqual('tf__f', f_node.name) - - def test_convert_entity_to_ast_lambda(self): - b = 2 - f = lambda x: b * x if x > 0 else -x - - program_ctx = self._simple_program_ctx() - (fn_node,), name, entity_info = conversion.convert_entity_to_ast( - f, program_ctx) - self.assertIsInstance(fn_node, gast.Assign) - self.assertIsInstance(fn_node.value, gast.Lambda) - self.assertEqual('tf__lambda', name) - self.assertIs(entity_info.namespace['b'], b) - - def test_convert_entity_to_ast_multiple_lambdas(self): - a, b = 1, 2 - f, _ = (lambda x: a * x, lambda y: b * y) - - program_ctx = self._simple_program_ctx() - (fn_node,), name, entity_info = conversion.convert_entity_to_ast( - f, program_ctx) - self.assertIsInstance(fn_node, gast.Assign) - self.assertIsInstance(fn_node.value, gast.Lambda) - self.assertEqual('tf__lambda', name) - self.assertIs(entity_info.namespace['a'], a) - - def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self): - a, b = 1, 2 - f, _ = (lambda x: a * x, lambda x: b * x) - - program_ctx = self._simple_program_ctx() - with self.assertRaises(ValueError): - conversion.convert_entity_to_ast(f, program_ctx) - - def test_convert_entity_to_ast_lambda_code_with_garbage(self): - # pylint:disable=g-long-lambda - f = ( # intentional wrap - lambda x: ( - x # intentional wrap - + 1),)[0] - # pylint:enable=g-long-lambda - - program_ctx = self._simple_program_ctx() - (fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx) - self.assertIsInstance(fn_node, gast.Assign) - self.assertIsInstance(fn_node.value, gast.Lambda) - self.assertEqual('tf__lambda', name) - - def test_convert_entity_to_ast_nested_functions(self): - b = 2 - - def f(x): - - def g(x): - return b * x - - return g(x) - - program_ctx = self._simple_program_ctx() - (fn_node,), name, entity_info = conversion.convert_entity_to_ast( - f, program_ctx) - self.assertIsInstance(fn_node, gast.FunctionDef) - self.assertEqual(fn_node.name, 'tf__f') - self.assertEqual('tf__f', name) - self.assertIs(entity_info.namespace['b'], b) - - def test_convert_concurrency(self): - - def test_fn(): - pass - - generated_file_names = [] - - def conversion_thread(): - new_f = conversion.convert(test_fn, self._simple_program_ctx()) - generated_file_names.append(new_f.__code__.co_filename) - - threads = tuple( - threading.Thread(target=conversion_thread) for _ in range(10)) - for t in threads: - t.start() - for t in threads: - t.join() - - # Races would potentially create multiple files (non-deterministically, - # but with high likelihood). - self.assertEqual(len(set(generated_file_names)), 1) - - def test_convert_reentrance(self): - - def test_fn(): - pass - - # There are no known ways to cause convert to re-enter. So we instrument - # an internal function to do that instead. - old_node_to_graph = conversion.node_to_graph - self.num_conversions = 0 - def node_to_graph_wrapper(node, context): - self.num_conversions += 1 - if self.num_conversions < 2: - conversion.convert(test_fn, self._simple_program_ctx()) - return old_node_to_graph(node, context) - - try: - conversion.node_to_graph = node_to_graph_wrapper - new_f = conversion.convert(test_fn, self._simple_program_ctx()) - self.assertIsNotNone(new_f) - finally: - conversion.node_to_graph = old_node_to_graph - if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD index 735d504f18f..93bd9228f36 100644 --- a/tensorflow/python/autograph/pyct/BUILD +++ b/tensorflow/python/autograph/pyct/BUILD @@ -39,6 +39,7 @@ py_library( "qual_names.py", "templates.py", "transformer.py", + "transpiler.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -230,3 +231,14 @@ py_test( "@gast_archive//:gast", ], ) + +py_test( + name = "transpiler_test", + srcs = ["transpiler_test.py"], + python_version = "PY3", + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index ced2ee3a975..d8d13fefb0f 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -50,8 +50,12 @@ class AnfTestBase(test.TestCase): def _simple_context(self): entity_info = transformer.EntityInfo( - source_code=None, source_file=None, future_features=(), namespace=None) - return transformer.Context(entity_info) + name='test_fn', + source_code=None, + source_file=None, + future_features=(), + namespace=None) + return transformer.Context(entity_info, None, None) def assert_same_ast(self, expected_node, node, msg=None): expected_source = parser.unparse(expected_node, indentation=' ') diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index 1cf0e18097f..e4a93dbc91d 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -22,6 +22,7 @@ import gast import six from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import transformer @@ -113,11 +114,17 @@ class ScopeTest(test.TestCase): class ActivityAnalyzerTestBase(test.TestCase): def _parse_and_analyze(self, test_fn): + # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( - source_code=source, source_file=None, future_features=(), namespace={}) + name=test_fn.__name__, + source_code=source, + source_file=None, + future_features=(), + namespace={}) node = qual_names.resolve(node) - ctx = transformer.Context(entity_info) + namer = naming.Namer({}) + ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) return node, entity_info diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py index b4bd1651b21..90bcc67301a 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import transformer @@ -35,11 +36,17 @@ global_b = 17 class LivenessAnalyzerTestBase(test.TestCase): def _parse_and_analyze(self, test_fn): + # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( - source_code=source, source_file=None, future_features=(), namespace={}) + name=test_fn.__name__, + source_code=source, + source_file=None, + future_features=(), + namespace={}) node = qual_names.resolve(node) - ctx = transformer.Context(entity_info) + namer = naming.Namer({}) + ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) graphs = cfg.build(node) liveness.resolve(node, ctx, graphs) diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index 684e6e1e38f..8b00b5c00ee 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -22,6 +22,7 @@ import six from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cfg +from tensorflow.python.autograph.pyct import naming from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import qual_names from tensorflow.python.autograph.pyct import transformer @@ -37,11 +38,17 @@ global_b = 17 class ReachingDefinitionsAnalyzerTestBase(test.TestCase): def _parse_and_analyze(self, test_fn): + # TODO(mdan): Use a custom FunctionTransformer here. node, source = parser.parse_entity(test_fn, future_features=()) entity_info = transformer.EntityInfo( - source_code=source, source_file=None, future_features=(), namespace={}) + name=test_fn.__name__, + source_code=source, + source_file=None, + future_features=(), + namespace={}) node = qual_names.resolve(node) - ctx = transformer.Context(entity_info) + namer = naming.Namer({}) + ctx = transformer.Context(entity_info, namer, None) node = activity.resolve(node, ctx) graphs = cfg.build(node) node = reaching_definitions.resolve(node, ctx, graphs, diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py index 07dcde7fdc7..33700501cf8 100644 --- a/tensorflow/python/autograph/pyct/transformer.py +++ b/tensorflow/python/autograph/pyct/transformer.py @@ -36,20 +36,26 @@ class Context(object): Attributes: info: EntityInfo, immutable. + namer: naming.Namer. current_origin: origin_info.OriginInfo, holds the OriginInfo of the last AST node to be processed successfully. Useful for error handling. + user: An user-supplied context object. The object is opaque to the + infrastructure, but will pe passed through to all custom transformations. """ - def __init__(self, info): + def __init__(self, info, namer, user_context): self.info = info + self.namer = namer self.current_origin = None + self.user = user_context # TODO(mdan): Move to a standalone file. class EntityInfo( collections.namedtuple( 'EntityInfo', - ('source_code', 'source_file', 'future_features', 'namespace'))): + ('name', 'source_code', 'source_file', 'future_features', 'namespace')) +): """Contains information about a Python entity. Immutable. @@ -57,6 +63,7 @@ class EntityInfo( Examples of entities include functions and classes. Attributes: + name: The name that identifies this entity. source_code: The entity's source code. source_file: The entity's source file. future_features: Tuple[Text], the future features that this entity was diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py index 4408395f813..30284ba5634 100644 --- a/tensorflow/python/autograph/pyct/transformer_test.py +++ b/tensorflow/python/autograph/pyct/transformer_test.py @@ -31,8 +31,12 @@ class TransformerTest(test.TestCase): def _simple_context(self): entity_info = transformer.EntityInfo( - source_code=None, source_file=None, future_features=(), namespace=None) - return transformer.Context(entity_info) + name='Test_fn', + source_code=None, + source_file=None, + future_features=(), + namespace=None) + return transformer.Context(entity_info, None, None) def assertSameAnno(self, first, second, key): self.assertIs(anno.getanno(first, key), anno.getanno(second, key)) @@ -299,8 +303,12 @@ class CodeGeneratorTest(test.TestCase): def _simple_context(self): entity_info = transformer.EntityInfo( - source_code=None, source_file=None, future_features=(), namespace=None) - return transformer.Context(entity_info) + name='test_fn', + source_code=None, + source_file=None, + future_features=(), + namespace=None) + return transformer.Context(entity_info, None, None) def test_basic_codegen(self): diff --git a/tensorflow/python/autograph/pyct/transpiler.py b/tensorflow/python/autograph/pyct/transpiler.py new file mode 100644 index 00000000000..a024685ba81 --- /dev/null +++ b/tensorflow/python/autograph/pyct/transpiler.py @@ -0,0 +1,419 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generic source code transformation infrastructure.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import types + +import gast + +from tensorflow.python.autograph.pyct import ast_util +from tensorflow.python.autograph.pyct import cache +from tensorflow.python.autograph.pyct import inspect_utils +from tensorflow.python.autograph.pyct import loader +from tensorflow.python.autograph.pyct import naming +from tensorflow.python.autograph.pyct import origin_info +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.autograph.pyct import templates +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.utils import ag_logging as logging + + +def _wrap_into_factory(nodes, entity_name, inner_factory_name, + outer_factory_name, closure_vars, factory_args, + future_features): + """Wraps an AST into the body of a factory with consistent lexical context. + + The AST is expected to define some symbol with a name given by `entity_name`. + + This mechanism ensures that the resulting transformed entity has lexical + scoping identical to that of the source entity, while allowing extra + parametrization. + + Two nested factories achieve the following: + + 1. The inner factory dynamically creates the entity represented by `nodes`. + 2. The inner factory is parametrized by a custom set of arguments. + 3. The inner factory has a closure identical to that of the transformed + entity. + 4. The inner factory has local variables named like `args`, which `nodes` may + use as additional parameters. + 5. The inner factory returns the variables given by `entity_name`. + 6. The outer factory is niladic. + 7. The outer factory has no closure. + 8. The outer factory creates the necessary lexical scope for the inner + factory, so that the loaded code has the given configuration for + closure/globals. + 9. The outer factory returns the inner factory. + + Roughly speaking, the following code is generated: + + from __future__ import future_feature_1 + from __future__ import future_feature_2 + ... + + def outer_factory(): + closure_var_1 = None + closure_var_2 = None + ... + + def inner_factory(arg_1, arg_2, ...): + <<nodes>> + return entity + + return inner_factory + + The lexical scoping is created using dummy symbol declarations which create + local fariables in the body of the outer factory, so that the Python parser + correctly marks them as free non-global variables upon load (that is, it + creates cell slots for each symbol. Thes symbols are initialized with None, + but their values are not expected to be used; instead, the caller is expected + to replace them with the cells of the source entity. For more details, see: + https://docs.python.org/3/reference/executionmodel.html#binding-of-names + + Args: + nodes: Tuple[ast.AST], the source code to wrap. + entity_name: Union[Text, ast.AST], the name of the principal entity that + `nodes` define. + inner_factory_name: Text, the name of the inner factory. + outer_factory_name: Text, the name of the outer factory. + closure_vars: Iterable[Text], names of the closure variables for the inner + factory. + factory_args: Iterable[Text], names of additional arguments for the + inner factory. Useful to configure variables that the converted code can + use. Typically, these are modules. + future_features: Iterable[Text], names of future statements to associate the + code with. + + Returns: + ast.AST + """ + dummy_closure_defs = [] + for var_name in closure_vars: + template = """ + var_name = None + """ + dummy_closure_defs.extend(templates.replace(template, var_name=var_name)) + + if future_features: + future_imports = gast.ImportFrom( + module='__future__', + names=[gast.alias(name=name, asname=None) for name in future_features], + level=0) + else: + future_imports = [] + + factory_args = [ + gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None) + for name in factory_args + ] + + template = """ + future_imports + def outer_factory_name(): + dummy_closure_defs + def inner_factory_name(factory_args): + entity_defs + return entity_name + return inner_factory_name + """ + return templates.replace( + template, + dummy_closure_defs=dummy_closure_defs, + entity_defs=nodes, + entity_name=entity_name, + factory_args=factory_args, + future_imports=future_imports, + inner_factory_name=inner_factory_name, + outer_factory_name=outer_factory_name) + + +class _TransformedFnFactory(object): + """Helper object that wraps a transformed function factory.""" + + def __init__(self, name, freevars, extra_locals): + """Creates a new factory for a transformed function. + + Args: + name: The function name. + freevars: The list of non-global free variables for the function. + extra_locals: Dict[Text, Any], names and values for custom variables that + are accessible to the generated code as local variables. + """ + self._name = name + self._freevars = freevars + self._extra_locals = extra_locals + + self._unbound_factory = None + self.module = None + self.source_map = None + + def create(self, + nodes, + namer, + inner_factory_name='inner_factory', + outer_factory_name='outer_factory', + future_features=()): + """Initializes a transformed function.""" + if self._unbound_factory is not None: + raise ValueError('double initialization; create a new object instead') + + inner_factory_name = namer.new_symbol(inner_factory_name, ()) + outer_factory_name = namer.new_symbol(outer_factory_name, ()) + nodes = _wrap_into_factory(nodes, self._name, inner_factory_name, + outer_factory_name, self._freevars, + self._extra_locals.keys(), future_features) + + module, _, source_map = loader.load_ast( + nodes, include_source_map=True) + outer_factory = getattr(module, outer_factory_name) + self._unbound_factory = outer_factory() + self.module = module + self.source_map = source_map + + def instantiate(self, + globals_, + closure, + defaults=None, + kwdefaults=None): + """Creates a new instance of the transformed function.""" + if self._unbound_factory is None: + raise ValueError('call create first') + + factory_code = self._unbound_factory.__code__ + factory_freevars = factory_code.co_freevars + closure_map = dict(zip(self._freevars, closure)) + factory_closure = tuple( + closure_map[name] for name in factory_code.co_freevars) + if len(factory_closure) != len(closure): + raise ValueError( + 'closure mismatch, requested {}, but source function had {}'.format( + self._freevars, factory_freevars)) + + bound_factory = types.FunctionType( + code=factory_code, + globals=globals_, + name=self._name, + argdefs=(), + closure=factory_closure) + + # The lint override is a false positive. + transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable + + if defaults: + transformed_entity.__defaults__ = defaults + if kwdefaults: + transformed_entity.__kwdefaults__ = kwdefaults + + return transformed_entity + + +class FunctionTranspiler(object): + """A generic source-to-source transpiler for Python functions. + + Its interface `transform_function` API offers a function-in, function-out + interface. Internally, it takes care of parsing, caching and variable binding. + + Users typically subclass this, customizing the transform_ast method. + + Usually, instances of this class are singletons, since each instance manages + its own cache. The caching subkey allows managing multiple types of + transformation. + + Example: + + class MyTransformer(FunctionTranspiler): + + def transform_ast(self, node, ctx): + node = <<transform node, usually using ast.NodeTransformer classes>> + return node + + transformer = MyTransfomer() + + new_f, module, source_map = transformer.transform_function(f, ...) + # new_f is a function with signature identical to f + + The transformed function has access to the same namespace as the original + function. To allow access to internal APIs, users may inject additional + symbols though the `extra_locals` argument of `transform_function`. + """ + + def __init__(self): + self._cache_lock = threading.RLock() + self._cache = cache.CodeObjectCache() + + def transform_ast(self, node, user_context): + """Performs an actual transformation of a function's AST. + + Subclasses must implement this method. They must not call it. + + The method receives the original AST and generates code according to the + AST that the method returns. For functions, the returned AST is expected to + contain a function with the exact same arguments and closure. The resulting + function will receive the globals, closure and argument defaults of the + input function. + + Args: + node: One or more ast.AST nodes representing the AST to be transformed. + user_context: The same value that the caller passed to + `transform_function`. + """ + raise NotImplementedError('subclasses must override this') + + def get_transformed_name(self, node): + """Returns a name for the output function. Subclasses may override this.""" + if isinstance(node, gast.Lambda): + return 'lam' + elif isinstance(node, gast.FunctionDef): + # Note that we need to rename the function, to avoid any namespace + # clashes. + return node.name + else: + raise ValueError('Unknown node type {}'.format(node)) + + def _erase_arg_defaults(self, node): + """Erase argde fault expressions, which would otherwise be unbound.""" + args = node.args + for i in range(len(args.defaults)): + args.defaults[i] = parser.parse_expression('None') + for i, d in enumerate(args.kw_defaults): + if d is not None: + args.kw_defaults[i] = parser.parse_expression('None') + return node + + def _transform_function(self, fn, user_context): + """Performs source code transformation on a function.""" + future_features = inspect_utils.getfutureimports(fn) + node, source = parser.parse_entity(fn, future_features=future_features) + logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) + + # In general, the output of inspect.getsource is inexact for lambdas + # because it uses regex matching to adjust the exact location around + # the line number that CPython records. Then, the entire containing line + # is returned, which we may have trouble disambiguating. + # For example: + # x, y = lambda: 1, lambda: 2 + is_lambda = fn.__name__ == '<lambda>' + if is_lambda: + nodes = ast_util.find_matching_definitions(node, fn) + if len(nodes) != 1: + raise ValueError( + 'Unable to identify source code of lambda function {}.' + ' It was defined in this code:\n' + '{}\n' + 'This code must contain a single distinguishable lambda.' + ' To avoid this problem, define each lambda in a separate' + ' expression.'.format(fn, source)) + node, = nodes + + origin_info.resolve_entity(node, source, fn) + + namespace = inspect_utils.getnamespace(fn) + namer = naming.Namer(namespace) + new_name = namer.new_symbol(self.get_transformed_name(node), ()) + entity_info = transformer.EntityInfo( + name=new_name, + source_code=source, + source_file='<fragment>', + future_features=future_features, + namespace=namespace) + context = transformer.Context(entity_info, namer, user_context) + + node = self._erase_arg_defaults(node) + node = self.transform_ast(node, context) + + if is_lambda: + node = gast.Assign( + targets=[ + gast.Name( + new_name, + ctx=gast.Store(), + annotation=None, + type_comment=None) + ], + value=node) + else: + node.name = new_name + + return node, context + + def _cached_factory(self, fn, cache_subkey): + cached_factory = self._cache[fn][cache_subkey] + logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, + cached_factory) + return cached_factory + + def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals): + """Returns the transformed function factory for a given input.""" + if self._cache.has(fn, cache_subkey): + return self._cached_factory(fn, cache_subkey) + + with self._cache_lock: + # Check again under lock. + if self._cache.has(fn, cache_subkey): + return self._cached_factory(fn, cache_subkey) + + logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) + nodes, ctx = self._transform_function(fn, user_context) + + if logging.has_verbosity(2): + logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) + + factory = _TransformedFnFactory( + ctx.info.name, fn.__code__.co_freevars, extra_locals) + factory.create(nodes, ctx.namer, future_features=ctx.info.future_features) + self._cache[fn][cache_subkey] = factory + return factory + + def transform_function(self, fn, caching_subkey, user_context, extra_locals): + """Transforms a function. + + The `caching_subkey` argument allows mapping each function to multiple + outputs in the cache. This is useful for instance when transformers + can generate multiple variants of output code, typically as a result of + different transformation flags. + + Args: + fn: A function or lambda. + caching_subkey: Used for caching. Calls made for functions with the same + code object and caching_subkey will return a cached instance on + subsequent invocations. Using a constant will create unique per-function + entries. + user_context: An opaque object (may be none) that is forwarded to + transform_ast. + extra_locals: A Dict[Text, Any] containing additional variables to make + available to the transformed code. These will be visible as local + variables. + Returns: + A tuple: + * A function or lambda with the same signature and closure as `fn` + * The temporary module into which the transformed function was loaded + * The source map as a + Dict[origin_info.LineLocation, origin_info.OriginInfo] + + """ + factory = self._transformed_factory(fn, caching_subkey, user_context, + extra_locals) + + transformed_fn = factory.instantiate( + globals_=fn.__globals__, + closure=fn.__closure__ or (), + defaults=fn.__defaults__, + kwdefaults=getattr(fn, '__kwdefaults__', None)) + return transformed_fn, factory.module, factory.source_map diff --git a/tensorflow/python/autograph/pyct/transpiler_test.py b/tensorflow/python/autograph/pyct/transpiler_test.py new file mode 100644 index 00000000000..b1d44dff265 --- /dev/null +++ b/tensorflow/python/autograph/pyct/transpiler_test.py @@ -0,0 +1,249 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for transpiler module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import gast + +from tensorflow.python.autograph.pyct import transformer +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.platform import test + + +class FlipSignTransformer(transformer.Base): + + def visit_BinOp(self, node): + if isinstance(node.op, gast.Add): + node.op = gast.Sub() + return self.generic_visit(node) + + +class TestTranspiler(transpiler.FunctionTranspiler): + + def transform_ast(self, node, ctx): + return FlipSignTransformer(ctx).visit(node) + + +global_var_for_test_global = 1 +global_var_for_test_namespace_collisions = object() + + +class FunctionTranspilerTest(test.TestCase): + + def test_basic(self): + def f(a): + return a + 1 + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 0) + + def test_closure(self): + b = 1 + + def f(a): + return a + b + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 0) + b = 2 + self.assertEqual(f(1), -1) + + def test_global(self): + def f(a): + return a + global_var_for_test_global + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + global global_var_for_test_global + global_var_for_test_global = 1 + self.assertEqual(f(1), 0) + global_var_for_test_global = 2 + self.assertEqual(f(1), -1) + + def test_defaults(self): + b = 2 + c = 1 + + def f(a, d=c + 1): + return a + b + d + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 1 - 2 - 2) + c = 0 + self.assertEqual(f(1), 1 - 2 - 2) # Defaults are evaluated at definition. + b = 1 + self.assertEqual(f(1), 1 - 2 - 1) + + def test_call_tree(self): + + def g(a): + return a + 1 + + def f(a): + return g(a) + 1 + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted. + + def test_lambda(self): + b = 2 + f = lambda x: (b + (x if x > 0 else -x)) + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 2 - 1) + self.assertEqual(f(-1), 2 - 1) + + b = 3 + + self.assertEqual(f(1), 3 - 1) + self.assertEqual(f(-1), 3 - 1) + + def test_multiple_lambdas(self): + a, b = 1, 2 + # This can be disambiguated by the argument names. + f, _ = (lambda x: a + x, lambda y: b * y) + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 1 - 1) + + def test_multiple_lambdas_indistinguishable_definitions(self): + a, b = 1, 2 + f, _ = (lambda x: a * x, lambda x: b * x) + + tr = TestTranspiler() + with self.assertRaises(ValueError): + tr.transform_function(f, object(), None, {}) + + def test_lambda_code_with_removable_garbage(self): + # pylint:disable=g-long-lambda + f = ( # intentional wrap + lambda x: ( + x # intentional wrap + + 1),)[0] + # pylint:enable=g-long-lambda + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 1 - 1) + + def test_nested_functions(self): + b = 2 + + def f(x): + + def g(x): + return b + x + + return g(x) + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 2 - 1) + + def test_nested_lambda(self): + b = 2 + + def f(x): + g = lambda x: b + x + return g(x) + + tr = TestTranspiler() + f, _, _ = tr.transform_function(f, object(), None, {}) + + self.assertEqual(f(1), 2 - 1) + + def test_concurrency(self): + + def f(): + pass + + outputs = [] + + tr = TestTranspiler() + cache_key = object() + def conversion_thread(): + _, mod, _ = tr.transform_function(f, cache_key, None, {}) + outputs.append(mod.__name__) + + threads = tuple( + threading.Thread(target=conversion_thread) for _ in range(10)) + for t in threads: + t.start() + for t in threads: + t.join() + + # Races would potentially create multiple functions / modules + # (non-deterministically, but with high likelihood). + self.assertEqual(len(set(outputs)), 1) + + def test_reentrance(self): + + def test_fn(): + return 1 + 1 + + class ReentrantTranspiler(transpiler.FunctionTranspiler): + + def __init__(self): + super(ReentrantTranspiler, self).__init__() + self._recursion_depth = 0 + + def transform_ast(self, node, ctx): + self._recursion_depth += 1 + if self._recursion_depth < 2: + self.transform_function(test_fn, object(), None, {}) + return FlipSignTransformer(ctx).visit(node) + + tr = ReentrantTranspiler() + + f, _, _ = tr.transform_function(test_fn, object(), None, {}) + self.assertEqual(f(), 0) + + def test_namespace_collisions_avoided(self): + + class TestClass(object): + + def global_var_for_test_namespace_collisions(self): + return global_var_for_test_namespace_collisions + + tr = TestTranspiler() + obj = TestClass() + + f, _, _ = tr.transform_function( + obj.global_var_for_test_namespace_collisions, object(), None, {}) + self.assertIs(f(obj), global_var_for_test_namespace_collisions) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py index 01c198e6278..984e4da70f2 100644 --- a/tensorflow/python/autograph/utils/misc.py +++ b/tensorflow/python/autograph/utils/misc.py @@ -52,13 +52,6 @@ def alias_tensors(*args): raise ValueError('at least one argument required') -def capitalize_initial(s): - """Capitalizes the initial of a string only.""" - if s: - return s[0].upper() + s[1:] - return s - - def get_range_len(start, limit, delta): dist = ops.convert_to_tensor(limit - start) unadjusted_len = dist // delta diff --git a/tensorflow/python/autograph/utils/misc_test.py b/tensorflow/python/autograph/utils/misc_test.py index 67c1b827228..8cbbb0e65b3 100644 --- a/tensorflow/python/autograph/utils/misc_test.py +++ b/tensorflow/python/autograph/utils/misc_test.py @@ -29,15 +29,6 @@ from tensorflow.python.platform import test class MiscTest(test.TestCase): - def test_capitalize_initial(self): - self.assertEqual('', misc.capitalize_initial('')) - self.assertEqual('A', misc.capitalize_initial('A')) - self.assertEqual('Ab', misc.capitalize_initial('Ab')) - self.assertEqual('AbC', misc.capitalize_initial('AbC')) - self.assertEqual('A', misc.capitalize_initial('a')) - self.assertEqual('Ab', misc.capitalize_initial('ab')) - self.assertEqual('AbC', misc.capitalize_initial('abC')) - @test_util.run_deprecated_v1 def test_alias_single_tensor(self): a = constant(1) diff --git a/tensorflow/python/eager/gradient_input_output_exclusions.py b/tensorflow/python/eager/gradient_input_output_exclusions.py index 2340ad41715..983f10551ba 100644 --- a/tensorflow/python/eager/gradient_input_output_exclusions.py +++ b/tensorflow/python/eager/gradient_input_output_exclusions.py @@ -198,11 +198,12 @@ def _live_tensors(f, attr_name="inputs"): """ node, _ = parser.parse_entity(f, ()) entity_info = transformer.EntityInfo( + name=f.__name__, source_code=None, source_file=None, future_features=(), namespace=sys.modules[f.__module__].__dict__) - ctx = transformer.Context(entity_info) + ctx = transformer.Context(entity_info, None, None) graphs = cfg.build(node) node = qual_names.resolve(node)