Internal cleanup: Move the bulk of the source code transformation infrastructure into the generic pyct module.

PiperOrigin-RevId: 305135067
Change-Id: Ifb84546c35a603942fd864769e7320a7ae95da3b
This commit is contained in:
Dan Moldovan 2020-04-06 15:48:25 -07:00 committed by TensorFlower Gardener
parent fc94412b39
commit ff551c9f20
23 changed files with 790 additions and 870 deletions

View File

@ -19,7 +19,6 @@ filegroup(
py_library(
name = "converters",
srcs = [
"arg_defaults.py",
"asserts.py",
"break_statements.py",
"call_trees.py",
@ -48,18 +47,6 @@ py_library(
],
)
py_test(
name = "arg_defaults_test",
srcs = ["arg_defaults_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":converters",
"//tensorflow/python:client_testlib",
"//tensorflow/python/autograph/core:test_lib",
],
)
py_test(
name = "asserts_test",
srcs = ["asserts_test.py"],

View File

@ -1,105 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Modifies the signature to allow resolving the value of default arguments.
Normally, function symbols are captured either in a function's globals or
closure. This is not true for default arguments, which are evaluated when the
function is defined:
b = 1
c = 2
def f(a=b + 1):
return a + c
In the above example, the namespace of the function would include `c = 2` but
not `b`.
If we were to naively generate a new function:
def new_f(a=b + 1):
return a + c
The generated code would fail to load unless we exposed a symbol `b`. Capturing
the closure of such an expression is difficult. However, we can capture the
default value of argument `a` with relative ease.
This converter replaces all default argument expressions with a constant so
that they don't cause loading to fail. This requires that the default values
are reset after loading the transformed function:
def new_f(a=None):
return a + c
# ... later, after new_f was loaded ...
new_f.__defaults__ = f.__defaults__
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import parser
class _Function(object):
pass
class ArgDefaultsTransformer(converter.Base):
"""Transforms top level argument defaults."""
def visit_Lambda(self, node):
self.state[_Function].enter()
node.args = self.visit(node.args)
# Only the top level function is modified - no need to visit the children.
self.state[_Function].exit()
return node
def visit_FunctionDef(self, node):
self.state[_Function].enter()
node.args = self.visit(node.args)
# Only the top level function is modified - no need to visit the children.
self.state[_Function].exit()
return node
def visit_arguments(self, node):
if self.state[_Function].level > 2:
return node
for i in range(len(node.defaults)):
node.defaults[i] = parser.parse_expression('None')
for i, d in enumerate(node.kw_defaults):
if d is not None:
node.kw_defaults[i] = parser.parse_expression('None')
# Only the top level function is modified - no need to visit the children.
return node
def transform(node, ctx):
"""Transform function call to the compiled counterparts.
Args:
node: AST
ctx: EntityContext
Returns:
A tuple (node, new_names):
node: The transformed AST
new_names: set(string), containing any newly-generated names
"""
return ArgDefaultsTransformer(ctx).visit(node)

View File

@ -1,108 +0,0 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for arg_defaults module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.converters import arg_defaults
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.platform import test
class ArgDefaultsTransformerTest(converter_testing.TestCase):
def assertTransformedFirstLineIs(self, node, expected):
self.assertEqual(
parser.unparse(node, include_encoding_marker=False).split('\n')[0],
expected)
def test_no_args(self):
def test_fn():
pass
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
self.assertTransformedFirstLineIs(node, 'def test_fn():')
def test_no_defaults(self):
def test_fn(a, b, *c, **e):
return a, b, c, e
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b, *c, **e):')
# TODO(mdan): Add kwonly-arg tests when PY2 is no longer supported.
def test_arg_defaults(self):
def test_fn(a, b=1, c=2):
return a, b, c
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, c=None):')
def test_arg_defaults_with_vararg(self):
def test_fn(a, b=1, *c): # pylint: disable=keyword-arg-before-vararg
return a, b, c
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, *c):')
def test_arg_defaults_ignores_inner_lambda(self):
def test_fn():
return (lambda x=7: x)()
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
with self.converted(test_fn, arg_defaults, {}) as result:
self.assertEqual(test_fn(), result.test_fn())
def test_arg_defaults_ignores_inner_function(self):
def test_fn():
def inner_fn(a=3):
return a
return inner_fn()
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
with self.converted(test_fn, arg_defaults, {}) as result:
self.assertEqual(test_fn(), result.test_fn())
def test_arg_defaults_ignores_inner_function_returned(self):
def test_fn():
def inner_fn(a=3):
return a
return inner_fn
node, ctx = self.prepare(test_fn, {})
node = arg_defaults.transform(node, ctx)
with self.converted(test_fn, arg_defaults, {}) as result:
self.assertEqual(test_fn()(), result.test_fn()())
if __name__ == '__main__':
test.main()

View File

@ -191,7 +191,7 @@ class CallTreeTransformer(converter.Base):
return node
if (full_name == 'print' and
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
return node
template = """

View File

@ -54,8 +54,8 @@ class FunctionTransformer(converter.Base):
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
# function_wrappers.py.
if fn_scope.level == 2:
return self.ctx.program.options
return self.ctx.program.options.call_options()
return self.ctx.user.options
return self.ctx.user.options.call_options()
def visit_Lambda(self, node):
with self.state[_Function] as fn_scope:

View File

@ -53,7 +53,7 @@ class LogicalExpressionTransformer(converter.Base):
op_type = type(operator)
if op_type in LOGICAL_OPERATORS:
return LOGICAL_OPERATORS[op_type]
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
if op_type in EQUALITY_OPERATORS:
return EQUALITY_OPERATORS[op_type]
return None
@ -83,7 +83,7 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
if (not self.ctx.program.options.uses(
if (not self.ctx.user.options.uses(
converter.Feature.EQUALITY_OPERATORS)):
return node

View File

@ -253,25 +253,6 @@ class ProgramContext(
pass
class EntityContext(transformer.Context):
"""Tracks the conversion of a single entity.
This object is mutable, and is updated during conversion. Not thread safe.
Attributes:
namer: Namer
info: transformer.EntityInfo
program: ProgramContext,
targe_name: Text
"""
def __init__(self, namer, entity_info, program_ctx, target_name=None):
super(EntityContext, self).__init__(entity_info)
self.namer = namer
self.program = program_ctx
self.target_name = target_name
class Base(transformer.Base):
"""All converters should inherit from this class.

View File

@ -168,12 +168,12 @@ class TestCase(test.TestCase):
options=converter.ConversionOptions(recursive=recursive),
autograph_module=None)
entity_info = transformer.EntityInfo(
name=test_fn.__name__,
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
ctx = converter.EntityContext(
namer, entity_info, program_ctx, 'test_fn')
ctx = transformer.Context(entity_info, namer, program_ctx)
origin_info.resolve_entity(node, source, test_fn)
node = converter.standard_analysis(node, ctx, is_initial=True)
return node, ctx

View File

@ -29,10 +29,7 @@ import sys
import textwrap
import traceback
# pylint:disable=g-bad-import-order
import six
# pylint:enable=g-bad-import-order
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
@ -668,7 +665,7 @@ def to_graph(entity, recursive=True, experimental_optional_features=None):
user_requested=True,
optional_features=experimental_optional_features),
autograph_module=tf_inspect.getmodule(to_graph))
return conversion.convert(entity, program_ctx)
return autograph_artifact(conversion.convert(entity, program_ctx))
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
logging.error(1, 'Error converting %s', entity, exc_info=True)
raise ConversionError('converting {}: {}: {}'.format(

View File

@ -18,21 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import imp
import inspect
import sys
import threading
import types
import unittest
import weakref
import gast
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.converters import arg_defaults
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import call_trees
@ -50,304 +41,70 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct import transpiler
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.eager import function
from tensorflow.python.util import tf_inspect
class _ConvertedEntityFactoryInfo(
collections.namedtuple(
'_ConvertedEntityFactoryInfo',
('module_name', 'converted_name', 'factory_factory_name', 'source_map'))
):
"""Holds metadata about a converted entity stored as a dynamic factory.
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
The dynamic factory is assumed to be created by _wrap_into_dynamic_factory,
be named `factory_factory_name` and located inside the module named as
`module_name`.
def get_transformed_name(self, node):
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
Attributes:
module_name: Text, the name of the module containing the entity.
converted_name: Text, the name of the converted entity.
factory_factory_name: Text, the name of the dynamic factory.
source_map: Dict.
"""
def transform_ast(self, node, ctx):
# TODO(mdan): Insert list_comprehensions somewhere.
unsupported_features_checker.verify(node)
def __str__(self):
return '_ConvertedEntityFactoryInfo({} in {})'.format(
self.converted_name, self.module_name)
def get_module(self):
return sys.modules[self.module_name]
def get_factory(self):
assert self.module_name in sys.modules
factory_factory = getattr(sys.modules[self.module_name],
self.factory_factory_name)
return factory_factory()
node = converter.standard_analysis(node, ctx, is_initial=True)
node = converter.apply_(node, ctx, functions)
node = converter.apply_(node, ctx, directives)
node = converter.apply_(node, ctx, break_statements)
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
node = converter.apply_(node, ctx, asserts)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = converter.apply_(node, ctx, continue_statements)
node = converter.apply_(node, ctx, return_statements)
if ctx.user.options.uses(converter.Feature.LISTS):
node = converter.apply_(node, ctx, lists)
node = converter.apply_(node, ctx, slices)
node = converter.apply_(node, ctx, call_trees)
node = converter.apply_(node, ctx, control_flow)
node = converter.apply_(node, ctx, conditional_expressions)
node = converter.apply_(node, ctx, logical_expressions)
return node
# TODO(mdan): Add a garbage collection hook for cleaning up modules.
class _FunctionCache(object):
"""A hierarchical cache that uses the converted entity as weak key.
The keys soft references (i.e. they are discarded when the key is
destroyed). The subkeys are normal hashable values.
This class is generic - see the call site for how the keys and values are
defined.
"""
__slots__ = ('_cache',)
def __init__(self):
self._cache = weakref.WeakKeyDictionary()
def _get_key(self, entity):
raise NotImplementedError('subclasses will override')
def has(self, entity, subkey):
key = self._get_key(entity)
if key not in self._cache:
return False
return subkey in self._cache[key]
def __getitem__(self, entity):
key = self._get_key(entity)
if key not in self._cache:
# The bucket needs to be initialized to support this usage:
# cache[key][subkey] = value
self._cache[key] = {}
return self._cache[key]
def __len__(self):
return len(self._cache)
_TRANSPILER = AutoGraphTranspiler()
_WHITELIST_CACHE = cache.UnboundInstanceCache()
class _CodeObjectCache(_FunctionCache):
"""A function cache based on code objects (i.e., the source code).
Multiple functions may share the same code object, but they may share the
cache because we know they have the exact source code. This properly handles
functions defined in a loop, bound methods, etc.
Falls back to the function object, if it doesn't have a code object.
"""
def _get_key(self, entity):
if hasattr(entity, '__code__'):
return entity.__code__
else:
return entity
class _UnboundInstanceCache(_FunctionCache):
"""A function cache based on unbound function objects.
Unlike the _CodeObjectCache, this discriminates between different functions
even if they have the same code. This properly handles decorators that may
masquerade as various functions. Bound functions are not discriminated by
the object they're bound to.
"""
def _get_key(self, entity):
if inspect.ismethod(entity):
return entity.__func__
return entity
# Using a re-entrant lock to guard against the unlikely possibility that the
# conversion process triggers additional code execution.
_CACHE_LOCK = threading.RLock()
_CACHE = _CodeObjectCache()
_WHITELIST_CACHE = _UnboundInstanceCache()
# Note: strictly speaking, a simple factory might have been sufficient for
# functions. But the double factory approach allows us to control the closure
# and globals of the converted code in a cleaner fashion.
# TODO(mdan): A simple factory may be sufficient.
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
factory_name, closure_vars, future_features):
"""Wraps an AST into the body of a dynamic factory.
This uses the dynamic factory (factory of factory) pattern to achieve the
following:
1. The inner factory, dynamically creates the entity represented by nodes.
2. The entity is parametrized by `ag__`, the internal AutoGraph module.
3. The outer factory creates the inner factory with a lexical scope
in which `closure_vars` are bound local variables. This in turn allows the
caller to control the exact closure (i.e. non-global free variables) for
the inner factory.
The AST is expected to define some symbol named by `entity_name`.
Args:
nodes: ast.AST
entity_name: Union[Text, ast.AST]
factory_factory_name: Text
factory_name: Text
closure_vars: Iterable[Text]
future_features: Iterable[Text], see EntityInfo.future_features.
Returns:
ast.AST
"""
if not isinstance(nodes, (list, tuple)):
nodes = (nodes,)
dummy_closure_defs = []
for var_name in closure_vars:
template = """
var_name = None
"""
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
if future_features:
future_imports = gast.ImportFrom(
module='__future__',
names=[gast.alias(name=name, asname=None) for name in future_features],
level=0)
else:
future_imports = []
# These dummy symbol declarations create local fariables in a function scope,
# so that the Python parser correctly marks them as free non-global variables
# upon load (that is, it creates cell slots for each symbol). Their values are
# not used, as the cells are swapped with the original entity's cells after
# the code has been loaded.
template = """
future_imports
def factory_factory_name():
dummy_closure_defs
def factory_name(ag__, ag_source_map__, ag_module__):
entity_defs
entity_name.ag_source_map = ag_source_map__
entity_name.ag_module = ag_module__
entity_name = ag__.autograph_artifact(entity_name)
return entity_name
return factory_name
"""
return templates.replace(
template,
future_imports=future_imports,
factory_factory_name=factory_factory_name,
factory_name=factory_name,
dummy_closure_defs=dummy_closure_defs,
entity_defs=nodes,
entity_name=entity_name)
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
"""Returns a (possibly cached) factory for the converted result of entity."""
# The cache subkey encompasses any conversion options on which the generated
# code may depend.
# The cached factory includes the necessary definitions to distinguish
# between the global and non-global free variables. For this reason, the
# cache subkey includes the names of the free non-globals.
subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))
with _CACHE_LOCK:
# The cache values are _ConvertedEntityFactoryInfo objects.
if _CACHE.has(entity, subkey):
# TODO(mdan): Check whether the module is still loaded.
converted_entity_info = _CACHE[entity][subkey]
logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey,
converted_entity_info)
return converted_entity_info
logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey)
nodes, converted_name, entity_info = convert_entity_to_ast(
entity, program_ctx)
namer = naming.Namer(entity_info.namespace)
factory_factory_name = namer.new_symbol('create_converted_entity_factory',
())
factory_name = namer.new_symbol('create_converted_entity', ())
nodes = _wrap_into_dynamic_factory(nodes, converted_name,
factory_factory_name, factory_name,
free_nonglobal_var_names,
entity_info.future_features)
module, _, source_map = loader.load_ast(nodes, include_source_map=True)
module_name = module.__name__
converted_entity_info = _ConvertedEntityFactoryInfo(
module_name=module_name,
converted_name=converted_name,
factory_factory_name=factory_factory_name,
source_map=source_map)
_CACHE[entity][subkey] = converted_entity_info
return converted_entity_info
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
"""Creates a converted instance and binds it to match original entity."""
factory = converted_entity_info.get_factory()
entity_globals = entity.__globals__
entity_closure = entity.__closure__ or ()
assert len(entity_closure) == len(free_nonglobal_var_names)
# Fit the original entity's cells to match the order of factory's cells.
original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure))
new_factory_cells = tuple(
original_names_and_cells[name] for name in factory.__code__.co_freevars)
bound_factory = types.FunctionType(
code=factory.__code__,
globals=entity_globals,
name=factory.__name__,
argdefs=(),
closure=new_factory_cells)
# Two other free vars: the internal "ag__" module and the source
# map. These are wired via the parameters of the factory.
converted_entity = bound_factory( # pylint:disable=not-callable
ag_internal, converted_entity_info.source_map,
converted_entity_info.get_module())
# Attach the default argument to the converted function.
converted_entity.__defaults__ = entity.__defaults__
if hasattr(entity, '__kwdefaults__'):
converted_entity.__kwdefaults__ = entity.__kwdefaults__
return converted_entity
custom_vars = None
# TODO(mdan): Superfluous function, remove.
# TODO(mdan): Put these extra fields inside __autograph_info__.
def convert(entity, program_ctx):
"""Converts an entity into an equivalent entity."""
"""Applies AutoGraph to entity."""
if not hasattr(entity, '__code__'):
raise ValueError('Cannot apply autograph to a function that doesn\'t '
'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.')
free_nonglobal_var_names = entity.__code__.co_freevars
for i, name in enumerate(free_nonglobal_var_names):
if (name == 'ag__' and
entity.__closure__[i].cell_contents is not ag_internal):
raise ValueError('entity {} uses the reserved symbol "{}"'.format(
entity, name))
# TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered.
_create_custom_vars(program_ctx)
transformed, module, source_map = _TRANSPILER.transform_function(
entity, program_ctx.options, program_ctx, custom_vars)
converted_entity_info = _convert_with_cache(entity, program_ctx,
free_nonglobal_var_names)
return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
assert not hasattr(transformed, 'ag_module')
assert not hasattr(transformed, 'ag_source_map')
transformed.ag_module = module
transformed.ag_source_map = source_map
return transformed
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
@ -472,58 +229,15 @@ def cache_whitelisted(entity, options):
pass
# TODO(mdan): Rename to convert_*_node to avoid confusion with convert.
def convert_entity_to_ast(o, program_ctx):
"""Compile a Python entity into equivalent TensorFlow.
Args:
o: A Python entity.
program_ctx: A ProgramContext object.
Returns:
A tuple (ast, new_name, namespace):
* ast: An AST representing an entity with interface equivalent to `o`,
but which when executed it creates TF a graph.
* new_name: The symbol name under which the new entity can be found.
* namespace: A dict mapping all symbols visible to the converted entity,
keyed by their symbol name.
Raises:
NotImplementedError: if entity is of a type that is not yet supported.
"""
logging.log(1, 'Converting %s', o)
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
if logging.has_verbosity(2):
logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
if logging.has_verbosity(4):
for n in nodes:
logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
pretty_printer.fmt(n, color=False))
return nodes, name, entity_info
def _add_reserved_symbol(namespace, name, entity):
if name not in namespace:
namespace[name] = entity
elif namespace[name] != entity:
raise ValueError('The name "%s" is reserved and may not be used.' % name)
ag_internal = None
# TODO(mdan): Move into core or replace with an actual importable module.
def _add_self_references(namespace, autograph_module):
def _create_custom_vars(program_ctx):
"""Adds namespace references to the module that exposes the api itself."""
global ag_internal
if ag_internal is None:
global custom_vars
if custom_vars is None:
# Craft a module that exposes parts of the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.__dict__.update(autograph_module.__dict__)
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
ag_internal.ConversionOptions = converter.ConversionOptions
ag_internal.STD = converter.STANDARD_OPTIONS
ag_internal.Feature = converter.Feature
@ -536,102 +250,4 @@ def _add_self_references(namespace, autograph_module):
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
_add_reserved_symbol(namespace, 'ag__', ag_internal)
def convert_func_to_ast(f, program_ctx, do_rename=True):
"""Specialization of `convert_entity_to_ast` for callable functions."""
future_features = inspect_utils.getfutureimports(f)
node, source = parser.parse_entity(f, future_features=future_features)
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
# Parsed AST should contain future imports and one function def node.
# In general, the output of inspect.getsource is inexact for lambdas because
# it uses regex matching to adjust the exact location around the line number
# that CPython records. Then, the entire containing line is returned, which
# we may have trouble disambiguating. For example:
# x, y = lambda: 1, lambda: 2
if f.__name__ == '<lambda>':
nodes = ast_util.find_matching_definitions(node, f)
if len(nodes) != 1:
raise ValueError(
'Unable to identify source code of lambda function {}. It was'
' defined on this line: {}, which must contain a single lambda with'
' matching signature. To avoid ambiguity, define each lambda'
' in a separate expression.'.format(f, source))
node, = nodes
# TODO(znado): Place inside standard_analysis.
origin_info.resolve_entity(node, source, f)
namespace = inspect_utils.getnamespace(f)
_add_self_references(namespace, program_ctx.autograph_module)
namer = naming.Namer(namespace)
if isinstance(node, gast.Lambda):
new_name = namer.new_symbol('tf__lambda', ())
elif do_rename:
new_name = namer.new_symbol('tf__' + f.__name__, ())
else:
new_name = f.__name__
entity_info = transformer.EntityInfo(
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
node = node_to_graph(node, context)
if isinstance(node, gast.Lambda):
node = gast.Assign(
targets=[
gast.Name(
new_name, ctx=gast.Store(), annotation=None, type_comment=None)
],
value=node)
elif do_rename:
node.name = new_name
else:
assert node.name == new_name
return (node,), new_name, entity_info
def node_to_graph(node, context):
"""Convert Python code to equivalent TF graph mode code.
Args:
node: AST, the code to convert.
context: converter.EntityContext
Returns:
A tuple (node, deps):
* node: A Python ast node, representing the converted code.
* deps: A set of strings, the fully qualified names of entity
dependencies that this node has.
"""
# TODO(mdan): Insert list_comprehensions somewhere.
unsupported_features_checker.verify(node)
node = converter.standard_analysis(node, context, is_initial=True)
node = converter.apply_(node, context, functions)
node = converter.apply_(node, context, arg_defaults)
node = converter.apply_(node, context, directives)
node = converter.apply_(node, context, break_statements)
if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
node = converter.apply_(node, context, asserts)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = converter.apply_(node, context, continue_statements)
node = converter.apply_(node, context, return_statements)
if context.program.options.uses(converter.Feature.LISTS):
node = converter.apply_(node, context, lists)
node = converter.apply_(node, context, slices)
node = converter.apply_(node, context, call_trees)
node = converter.apply_(node, context, control_flow)
node = converter.apply_(node, context, conditional_expressions)
node = converter.apply_(node, context, logical_expressions)
return node
custom_vars = {'ag__': ag_internal}

View File

@ -20,11 +20,9 @@ from __future__ import print_function
import imp
import sys
import threading
import types
import weakref
import gast
import six
from tensorflow.python.autograph import utils
@ -33,7 +31,6 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.impl.testing import pybind_for_testing
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
@ -126,156 +123,6 @@ class ConversionTest(test.TestCase):
# Note: currently, native bindings are whitelisted by a separate check.
self.assertFalse(conversion.is_whitelisted(test_object.method))
def test_convert_entity_to_ast_callable(self):
b = 2
def f(a):
return a + b
program_ctx = self._simple_program_ctx()
nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx)
fn_node, = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertIs(info.namespace['b'], b)
def test_convert_entity_to_ast_function_with_defaults(self):
b = 2
c = 1
def f(a, d=c + 1):
return a + b + d
program_ctx = self._simple_program_ctx()
nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
fn_node, = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertEqual(
parser.unparse(fn_node.args.defaults[0],
include_encoding_marker=False).strip(), 'None')
def test_convert_entity_to_ast_call_tree(self):
def g(a):
return a
def f(a):
return g(a)
program_ctx = self._simple_program_ctx()
nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx)
f_node, = nodes
self.assertEqual('tf__f', f_node.name)
def test_convert_entity_to_ast_lambda(self):
b = 2
f = lambda x: b * x if x > 0 else -x
program_ctx = self._simple_program_ctx()
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
f, program_ctx)
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
self.assertIs(entity_info.namespace['b'], b)
def test_convert_entity_to_ast_multiple_lambdas(self):
a, b = 1, 2
f, _ = (lambda x: a * x, lambda y: b * y)
program_ctx = self._simple_program_ctx()
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
f, program_ctx)
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
self.assertIs(entity_info.namespace['a'], a)
def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self):
a, b = 1, 2
f, _ = (lambda x: a * x, lambda x: b * x)
program_ctx = self._simple_program_ctx()
with self.assertRaises(ValueError):
conversion.convert_entity_to_ast(f, program_ctx)
def test_convert_entity_to_ast_lambda_code_with_garbage(self):
# pylint:disable=g-long-lambda
f = ( # intentional wrap
lambda x: (
x # intentional wrap
+ 1),)[0]
# pylint:enable=g-long-lambda
program_ctx = self._simple_program_ctx()
(fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx)
self.assertIsInstance(fn_node, gast.Assign)
self.assertIsInstance(fn_node.value, gast.Lambda)
self.assertEqual('tf__lambda', name)
def test_convert_entity_to_ast_nested_functions(self):
b = 2
def f(x):
def g(x):
return b * x
return g(x)
program_ctx = self._simple_program_ctx()
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
f, program_ctx)
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual(fn_node.name, 'tf__f')
self.assertEqual('tf__f', name)
self.assertIs(entity_info.namespace['b'], b)
def test_convert_concurrency(self):
def test_fn():
pass
generated_file_names = []
def conversion_thread():
new_f = conversion.convert(test_fn, self._simple_program_ctx())
generated_file_names.append(new_f.__code__.co_filename)
threads = tuple(
threading.Thread(target=conversion_thread) for _ in range(10))
for t in threads:
t.start()
for t in threads:
t.join()
# Races would potentially create multiple files (non-deterministically,
# but with high likelihood).
self.assertEqual(len(set(generated_file_names)), 1)
def test_convert_reentrance(self):
def test_fn():
pass
# There are no known ways to cause convert to re-enter. So we instrument
# an internal function to do that instead.
old_node_to_graph = conversion.node_to_graph
self.num_conversions = 0
def node_to_graph_wrapper(node, context):
self.num_conversions += 1
if self.num_conversions < 2:
conversion.convert(test_fn, self._simple_program_ctx())
return old_node_to_graph(node, context)
try:
conversion.node_to_graph = node_to_graph_wrapper
new_f = conversion.convert(test_fn, self._simple_program_ctx())
self.assertIsNotNone(new_f)
finally:
conversion.node_to_graph = old_node_to_graph
if __name__ == '__main__':
test.main()

View File

@ -39,6 +39,7 @@ py_library(
"qual_names.py",
"templates.py",
"transformer.py",
"transpiler.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@ -230,3 +231,14 @@ py_test(
"@gast_archive//:gast",
],
)
py_test(
name = "transpiler_test",
srcs = ["transpiler_test.py"],
python_version = "PY3",
srcs_version = "PY2AND3",
deps = [
":pyct",
"//tensorflow/python:client_testlib",
],
)

View File

@ -50,8 +50,12 @@ class AnfTestBase(test.TestCase):
def _simple_context(self):
entity_info = transformer.EntityInfo(
source_code=None, source_file=None, future_features=(), namespace=None)
return transformer.Context(entity_info)
name='test_fn',
source_code=None,
source_file=None,
future_features=(),
namespace=None)
return transformer.Context(entity_info, None, None)
def assert_same_ast(self, expected_node, node, msg=None):
expected_source = parser.unparse(expected_node, indentation=' ')

View File

@ -22,6 +22,7 @@ import gast
import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
@ -113,11 +114,17 @@ class ScopeTest(test.TestCase):
class ActivityAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
# TODO(mdan): Use a custom FunctionTransformer here.
node, source = parser.parse_entity(test_fn, future_features=())
entity_info = transformer.EntityInfo(
source_code=source, source_file=None, future_features=(), namespace={})
name=test_fn.__name__,
source_code=source,
source_file=None,
future_features=(),
namespace={})
node = qual_names.resolve(node)
ctx = transformer.Context(entity_info)
namer = naming.Namer({})
ctx = transformer.Context(entity_info, namer, None)
node = activity.resolve(node, ctx)
return node, entity_info

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
@ -35,11 +36,17 @@ global_b = 17
class LivenessAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
# TODO(mdan): Use a custom FunctionTransformer here.
node, source = parser.parse_entity(test_fn, future_features=())
entity_info = transformer.EntityInfo(
source_code=source, source_file=None, future_features=(), namespace={})
name=test_fn.__name__,
source_code=source,
source_file=None,
future_features=(),
namespace={})
node = qual_names.resolve(node)
ctx = transformer.Context(entity_info)
namer = naming.Namer({})
ctx = transformer.Context(entity_info, namer, None)
node = activity.resolve(node, ctx)
graphs = cfg.build(node)
liveness.resolve(node, ctx, graphs)

View File

@ -22,6 +22,7 @@ import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
@ -37,11 +38,17 @@ global_b = 17
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
def _parse_and_analyze(self, test_fn):
# TODO(mdan): Use a custom FunctionTransformer here.
node, source = parser.parse_entity(test_fn, future_features=())
entity_info = transformer.EntityInfo(
source_code=source, source_file=None, future_features=(), namespace={})
name=test_fn.__name__,
source_code=source,
source_file=None,
future_features=(),
namespace={})
node = qual_names.resolve(node)
ctx = transformer.Context(entity_info)
namer = naming.Namer({})
ctx = transformer.Context(entity_info, namer, None)
node = activity.resolve(node, ctx)
graphs = cfg.build(node)
node = reaching_definitions.resolve(node, ctx, graphs,

View File

@ -36,20 +36,26 @@ class Context(object):
Attributes:
info: EntityInfo, immutable.
namer: naming.Namer.
current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
AST node to be processed successfully. Useful for error handling.
user: An user-supplied context object. The object is opaque to the
infrastructure, but will pe passed through to all custom transformations.
"""
def __init__(self, info):
def __init__(self, info, namer, user_context):
self.info = info
self.namer = namer
self.current_origin = None
self.user = user_context
# TODO(mdan): Move to a standalone file.
class EntityInfo(
collections.namedtuple(
'EntityInfo',
('source_code', 'source_file', 'future_features', 'namespace'))):
('name', 'source_code', 'source_file', 'future_features', 'namespace'))
):
"""Contains information about a Python entity.
Immutable.
@ -57,6 +63,7 @@ class EntityInfo(
Examples of entities include functions and classes.
Attributes:
name: The name that identifies this entity.
source_code: The entity's source code.
source_file: The entity's source file.
future_features: Tuple[Text], the future features that this entity was

View File

@ -31,8 +31,12 @@ class TransformerTest(test.TestCase):
def _simple_context(self):
entity_info = transformer.EntityInfo(
source_code=None, source_file=None, future_features=(), namespace=None)
return transformer.Context(entity_info)
name='Test_fn',
source_code=None,
source_file=None,
future_features=(),
namespace=None)
return transformer.Context(entity_info, None, None)
def assertSameAnno(self, first, second, key):
self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
@ -299,8 +303,12 @@ class CodeGeneratorTest(test.TestCase):
def _simple_context(self):
entity_info = transformer.EntityInfo(
source_code=None, source_file=None, future_features=(), namespace=None)
return transformer.Context(entity_info)
name='test_fn',
source_code=None,
source_file=None,
future_features=(),
namespace=None)
return transformer.Context(entity_info, None, None)
def test_basic_codegen(self):

View File

@ -0,0 +1,419 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic source code transformation infrastructure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import types
import gast
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.utils import ag_logging as logging
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
outer_factory_name, closure_vars, factory_args,
future_features):
"""Wraps an AST into the body of a factory with consistent lexical context.
The AST is expected to define some symbol with a name given by `entity_name`.
This mechanism ensures that the resulting transformed entity has lexical
scoping identical to that of the source entity, while allowing extra
parametrization.
Two nested factories achieve the following:
1. The inner factory dynamically creates the entity represented by `nodes`.
2. The inner factory is parametrized by a custom set of arguments.
3. The inner factory has a closure identical to that of the transformed
entity.
4. The inner factory has local variables named like `args`, which `nodes` may
use as additional parameters.
5. The inner factory returns the variables given by `entity_name`.
6. The outer factory is niladic.
7. The outer factory has no closure.
8. The outer factory creates the necessary lexical scope for the inner
factory, so that the loaded code has the given configuration for
closure/globals.
9. The outer factory returns the inner factory.
Roughly speaking, the following code is generated:
from __future__ import future_feature_1
from __future__ import future_feature_2
...
def outer_factory():
closure_var_1 = None
closure_var_2 = None
...
def inner_factory(arg_1, arg_2, ...):
<<nodes>>
return entity
return inner_factory
The lexical scoping is created using dummy symbol declarations which create
local fariables in the body of the outer factory, so that the Python parser
correctly marks them as free non-global variables upon load (that is, it
creates cell slots for each symbol. Thes symbols are initialized with None,
but their values are not expected to be used; instead, the caller is expected
to replace them with the cells of the source entity. For more details, see:
https://docs.python.org/3/reference/executionmodel.html#binding-of-names
Args:
nodes: Tuple[ast.AST], the source code to wrap.
entity_name: Union[Text, ast.AST], the name of the principal entity that
`nodes` define.
inner_factory_name: Text, the name of the inner factory.
outer_factory_name: Text, the name of the outer factory.
closure_vars: Iterable[Text], names of the closure variables for the inner
factory.
factory_args: Iterable[Text], names of additional arguments for the
inner factory. Useful to configure variables that the converted code can
use. Typically, these are modules.
future_features: Iterable[Text], names of future statements to associate the
code with.
Returns:
ast.AST
"""
dummy_closure_defs = []
for var_name in closure_vars:
template = """
var_name = None
"""
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
if future_features:
future_imports = gast.ImportFrom(
module='__future__',
names=[gast.alias(name=name, asname=None) for name in future_features],
level=0)
else:
future_imports = []
factory_args = [
gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
for name in factory_args
]
template = """
future_imports
def outer_factory_name():
dummy_closure_defs
def inner_factory_name(factory_args):
entity_defs
return entity_name
return inner_factory_name
"""
return templates.replace(
template,
dummy_closure_defs=dummy_closure_defs,
entity_defs=nodes,
entity_name=entity_name,
factory_args=factory_args,
future_imports=future_imports,
inner_factory_name=inner_factory_name,
outer_factory_name=outer_factory_name)
class _TransformedFnFactory(object):
"""Helper object that wraps a transformed function factory."""
def __init__(self, name, freevars, extra_locals):
"""Creates a new factory for a transformed function.
Args:
name: The function name.
freevars: The list of non-global free variables for the function.
extra_locals: Dict[Text, Any], names and values for custom variables that
are accessible to the generated code as local variables.
"""
self._name = name
self._freevars = freevars
self._extra_locals = extra_locals
self._unbound_factory = None
self.module = None
self.source_map = None
def create(self,
nodes,
namer,
inner_factory_name='inner_factory',
outer_factory_name='outer_factory',
future_features=()):
"""Initializes a transformed function."""
if self._unbound_factory is not None:
raise ValueError('double initialization; create a new object instead')
inner_factory_name = namer.new_symbol(inner_factory_name, ())
outer_factory_name = namer.new_symbol(outer_factory_name, ())
nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
outer_factory_name, self._freevars,
self._extra_locals.keys(), future_features)
module, _, source_map = loader.load_ast(
nodes, include_source_map=True)
outer_factory = getattr(module, outer_factory_name)
self._unbound_factory = outer_factory()
self.module = module
self.source_map = source_map
def instantiate(self,
globals_,
closure,
defaults=None,
kwdefaults=None):
"""Creates a new instance of the transformed function."""
if self._unbound_factory is None:
raise ValueError('call create first')
factory_code = self._unbound_factory.__code__
factory_freevars = factory_code.co_freevars
closure_map = dict(zip(self._freevars, closure))
factory_closure = tuple(
closure_map[name] for name in factory_code.co_freevars)
if len(factory_closure) != len(closure):
raise ValueError(
'closure mismatch, requested {}, but source function had {}'.format(
self._freevars, factory_freevars))
bound_factory = types.FunctionType(
code=factory_code,
globals=globals_,
name=self._name,
argdefs=(),
closure=factory_closure)
# The lint override is a false positive.
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
if defaults:
transformed_entity.__defaults__ = defaults
if kwdefaults:
transformed_entity.__kwdefaults__ = kwdefaults
return transformed_entity
class FunctionTranspiler(object):
"""A generic source-to-source transpiler for Python functions.
Its interface `transform_function` API offers a function-in, function-out
interface. Internally, it takes care of parsing, caching and variable binding.
Users typically subclass this, customizing the transform_ast method.
Usually, instances of this class are singletons, since each instance manages
its own cache. The caching subkey allows managing multiple types of
transformation.
Example:
class MyTransformer(FunctionTranspiler):
def transform_ast(self, node, ctx):
node = <<transform node, usually using ast.NodeTransformer classes>>
return node
transformer = MyTransfomer()
new_f, module, source_map = transformer.transform_function(f, ...)
# new_f is a function with signature identical to f
The transformed function has access to the same namespace as the original
function. To allow access to internal APIs, users may inject additional
symbols though the `extra_locals` argument of `transform_function`.
"""
def __init__(self):
self._cache_lock = threading.RLock()
self._cache = cache.CodeObjectCache()
def transform_ast(self, node, user_context):
"""Performs an actual transformation of a function's AST.
Subclasses must implement this method. They must not call it.
The method receives the original AST and generates code according to the
AST that the method returns. For functions, the returned AST is expected to
contain a function with the exact same arguments and closure. The resulting
function will receive the globals, closure and argument defaults of the
input function.
Args:
node: One or more ast.AST nodes representing the AST to be transformed.
user_context: The same value that the caller passed to
`transform_function`.
"""
raise NotImplementedError('subclasses must override this')
def get_transformed_name(self, node):
"""Returns a name for the output function. Subclasses may override this."""
if isinstance(node, gast.Lambda):
return 'lam'
elif isinstance(node, gast.FunctionDef):
# Note that we need to rename the function, to avoid any namespace
# clashes.
return node.name
else:
raise ValueError('Unknown node type {}'.format(node))
def _erase_arg_defaults(self, node):
"""Erase argde fault expressions, which would otherwise be unbound."""
args = node.args
for i in range(len(args.defaults)):
args.defaults[i] = parser.parse_expression('None')
for i, d in enumerate(args.kw_defaults):
if d is not None:
args.kw_defaults[i] = parser.parse_expression('None')
return node
def _transform_function(self, fn, user_context):
"""Performs source code transformation on a function."""
future_features = inspect_utils.getfutureimports(fn)
node, source = parser.parse_entity(fn, future_features=future_features)
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
# In general, the output of inspect.getsource is inexact for lambdas
# because it uses regex matching to adjust the exact location around
# the line number that CPython records. Then, the entire containing line
# is returned, which we may have trouble disambiguating.
# For example:
# x, y = lambda: 1, lambda: 2
is_lambda = fn.__name__ == '<lambda>'
if is_lambda:
nodes = ast_util.find_matching_definitions(node, fn)
if len(nodes) != 1:
raise ValueError(
'Unable to identify source code of lambda function {}.'
' It was defined in this code:\n'
'{}\n'
'This code must contain a single distinguishable lambda.'
' To avoid this problem, define each lambda in a separate'
' expression.'.format(fn, source))
node, = nodes
origin_info.resolve_entity(node, source, fn)
namespace = inspect_utils.getnamespace(fn)
namer = naming.Namer(namespace)
new_name = namer.new_symbol(self.get_transformed_name(node), ())
entity_info = transformer.EntityInfo(
name=new_name,
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
context = transformer.Context(entity_info, namer, user_context)
node = self._erase_arg_defaults(node)
node = self.transform_ast(node, context)
if is_lambda:
node = gast.Assign(
targets=[
gast.Name(
new_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=node)
else:
node.name = new_name
return node, context
def _cached_factory(self, fn, cache_subkey):
cached_factory = self._cache[fn][cache_subkey]
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
cached_factory)
return cached_factory
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
"""Returns the transformed function factory for a given input."""
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
with self._cache_lock:
# Check again under lock.
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
nodes, ctx = self._transform_function(fn, user_context)
if logging.has_verbosity(2):
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
factory = _TransformedFnFactory(
ctx.info.name, fn.__code__.co_freevars, extra_locals)
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
self._cache[fn][cache_subkey] = factory
return factory
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
"""Transforms a function.
The `caching_subkey` argument allows mapping each function to multiple
outputs in the cache. This is useful for instance when transformers
can generate multiple variants of output code, typically as a result of
different transformation flags.
Args:
fn: A function or lambda.
caching_subkey: Used for caching. Calls made for functions with the same
code object and caching_subkey will return a cached instance on
subsequent invocations. Using a constant will create unique per-function
entries.
user_context: An opaque object (may be none) that is forwarded to
transform_ast.
extra_locals: A Dict[Text, Any] containing additional variables to make
available to the transformed code. These will be visible as local
variables.
Returns:
A tuple:
* A function or lambda with the same signature and closure as `fn`
* The temporary module into which the transformed function was loaded
* The source map as a
Dict[origin_info.LineLocation, origin_info.OriginInfo]
"""
factory = self._transformed_factory(fn, caching_subkey, user_context,
extra_locals)
transformed_fn = factory.instantiate(
globals_=fn.__globals__,
closure=fn.__closure__ or (),
defaults=fn.__defaults__,
kwdefaults=getattr(fn, '__kwdefaults__', None))
return transformed_fn, factory.module, factory.source_map

View File

@ -0,0 +1,249 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transpiler module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import gast
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct import transpiler
from tensorflow.python.platform import test
class FlipSignTransformer(transformer.Base):
def visit_BinOp(self, node):
if isinstance(node.op, gast.Add):
node.op = gast.Sub()
return self.generic_visit(node)
class TestTranspiler(transpiler.FunctionTranspiler):
def transform_ast(self, node, ctx):
return FlipSignTransformer(ctx).visit(node)
global_var_for_test_global = 1
global_var_for_test_namespace_collisions = object()
class FunctionTranspilerTest(test.TestCase):
def test_basic(self):
def f(a):
return a + 1
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 0)
def test_closure(self):
b = 1
def f(a):
return a + b
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 0)
b = 2
self.assertEqual(f(1), -1)
def test_global(self):
def f(a):
return a + global_var_for_test_global
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
global global_var_for_test_global
global_var_for_test_global = 1
self.assertEqual(f(1), 0)
global_var_for_test_global = 2
self.assertEqual(f(1), -1)
def test_defaults(self):
b = 2
c = 1
def f(a, d=c + 1):
return a + b + d
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 1 - 2 - 2)
c = 0
self.assertEqual(f(1), 1 - 2 - 2) # Defaults are evaluated at definition.
b = 1
self.assertEqual(f(1), 1 - 2 - 1)
def test_call_tree(self):
def g(a):
return a + 1
def f(a):
return g(a) + 1
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted.
def test_lambda(self):
b = 2
f = lambda x: (b + (x if x > 0 else -x))
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 2 - 1)
self.assertEqual(f(-1), 2 - 1)
b = 3
self.assertEqual(f(1), 3 - 1)
self.assertEqual(f(-1), 3 - 1)
def test_multiple_lambdas(self):
a, b = 1, 2
# This can be disambiguated by the argument names.
f, _ = (lambda x: a + x, lambda y: b * y)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 1 - 1)
def test_multiple_lambdas_indistinguishable_definitions(self):
a, b = 1, 2
f, _ = (lambda x: a * x, lambda x: b * x)
tr = TestTranspiler()
with self.assertRaises(ValueError):
tr.transform_function(f, object(), None, {})
def test_lambda_code_with_removable_garbage(self):
# pylint:disable=g-long-lambda
f = ( # intentional wrap
lambda x: (
x # intentional wrap
+ 1),)[0]
# pylint:enable=g-long-lambda
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 1 - 1)
def test_nested_functions(self):
b = 2
def f(x):
def g(x):
return b + x
return g(x)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 2 - 1)
def test_nested_lambda(self):
b = 2
def f(x):
g = lambda x: b + x
return g(x)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
self.assertEqual(f(1), 2 - 1)
def test_concurrency(self):
def f():
pass
outputs = []
tr = TestTranspiler()
cache_key = object()
def conversion_thread():
_, mod, _ = tr.transform_function(f, cache_key, None, {})
outputs.append(mod.__name__)
threads = tuple(
threading.Thread(target=conversion_thread) for _ in range(10))
for t in threads:
t.start()
for t in threads:
t.join()
# Races would potentially create multiple functions / modules
# (non-deterministically, but with high likelihood).
self.assertEqual(len(set(outputs)), 1)
def test_reentrance(self):
def test_fn():
return 1 + 1
class ReentrantTranspiler(transpiler.FunctionTranspiler):
def __init__(self):
super(ReentrantTranspiler, self).__init__()
self._recursion_depth = 0
def transform_ast(self, node, ctx):
self._recursion_depth += 1
if self._recursion_depth < 2:
self.transform_function(test_fn, object(), None, {})
return FlipSignTransformer(ctx).visit(node)
tr = ReentrantTranspiler()
f, _, _ = tr.transform_function(test_fn, object(), None, {})
self.assertEqual(f(), 0)
def test_namespace_collisions_avoided(self):
class TestClass(object):
def global_var_for_test_namespace_collisions(self):
return global_var_for_test_namespace_collisions
tr = TestTranspiler()
obj = TestClass()
f, _, _ = tr.transform_function(
obj.global_var_for_test_namespace_collisions, object(), None, {})
self.assertIs(f(obj), global_var_for_test_namespace_collisions)
if __name__ == '__main__':
test.main()

View File

@ -52,13 +52,6 @@ def alias_tensors(*args):
raise ValueError('at least one argument required')
def capitalize_initial(s):
"""Capitalizes the initial of a string only."""
if s:
return s[0].upper() + s[1:]
return s
def get_range_len(start, limit, delta):
dist = ops.convert_to_tensor(limit - start)
unadjusted_len = dist // delta

View File

@ -29,15 +29,6 @@ from tensorflow.python.platform import test
class MiscTest(test.TestCase):
def test_capitalize_initial(self):
self.assertEqual('', misc.capitalize_initial(''))
self.assertEqual('A', misc.capitalize_initial('A'))
self.assertEqual('Ab', misc.capitalize_initial('Ab'))
self.assertEqual('AbC', misc.capitalize_initial('AbC'))
self.assertEqual('A', misc.capitalize_initial('a'))
self.assertEqual('Ab', misc.capitalize_initial('ab'))
self.assertEqual('AbC', misc.capitalize_initial('abC'))
@test_util.run_deprecated_v1
def test_alias_single_tensor(self):
a = constant(1)

View File

@ -198,11 +198,12 @@ def _live_tensors(f, attr_name="inputs"):
"""
node, _ = parser.parse_entity(f, ())
entity_info = transformer.EntityInfo(
name=f.__name__,
source_code=None,
source_file=None,
future_features=(),
namespace=sys.modules[f.__module__].__dict__)
ctx = transformer.Context(entity_info)
ctx = transformer.Context(entity_info, None, None)
graphs = cfg.build(node)
node = qual_names.resolve(node)