STT-tensorflow/tensorflow/python/autograph/pyct/transpiler.py
Dan Moldovan ff551c9f20 Internal cleanup: Move the bulk of the source code transformation infrastructure into the generic pyct module.
PiperOrigin-RevId: 305135067
Change-Id: Ifb84546c35a603942fd864769e7320a7ae95da3b
2020-04-06 15:51:41 -07:00

420 lines
15 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Generic source code transformation infrastructure."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
import types
import gast
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import loader
from tensorflow.python.autograph.pyct import naming
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.utils import ag_logging as logging
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
outer_factory_name, closure_vars, factory_args,
future_features):
"""Wraps an AST into the body of a factory with consistent lexical context.
The AST is expected to define some symbol with a name given by `entity_name`.
This mechanism ensures that the resulting transformed entity has lexical
scoping identical to that of the source entity, while allowing extra
parametrization.
Two nested factories achieve the following:
1. The inner factory dynamically creates the entity represented by `nodes`.
2. The inner factory is parametrized by a custom set of arguments.
3. The inner factory has a closure identical to that of the transformed
entity.
4. The inner factory has local variables named like `args`, which `nodes` may
use as additional parameters.
5. The inner factory returns the variables given by `entity_name`.
6. The outer factory is niladic.
7. The outer factory has no closure.
8. The outer factory creates the necessary lexical scope for the inner
factory, so that the loaded code has the given configuration for
closure/globals.
9. The outer factory returns the inner factory.
Roughly speaking, the following code is generated:
from __future__ import future_feature_1
from __future__ import future_feature_2
...
def outer_factory():
closure_var_1 = None
closure_var_2 = None
...
def inner_factory(arg_1, arg_2, ...):
<<nodes>>
return entity
return inner_factory
The lexical scoping is created using dummy symbol declarations which create
local fariables in the body of the outer factory, so that the Python parser
correctly marks them as free non-global variables upon load (that is, it
creates cell slots for each symbol. Thes symbols are initialized with None,
but their values are not expected to be used; instead, the caller is expected
to replace them with the cells of the source entity. For more details, see:
https://docs.python.org/3/reference/executionmodel.html#binding-of-names
Args:
nodes: Tuple[ast.AST], the source code to wrap.
entity_name: Union[Text, ast.AST], the name of the principal entity that
`nodes` define.
inner_factory_name: Text, the name of the inner factory.
outer_factory_name: Text, the name of the outer factory.
closure_vars: Iterable[Text], names of the closure variables for the inner
factory.
factory_args: Iterable[Text], names of additional arguments for the
inner factory. Useful to configure variables that the converted code can
use. Typically, these are modules.
future_features: Iterable[Text], names of future statements to associate the
code with.
Returns:
ast.AST
"""
dummy_closure_defs = []
for var_name in closure_vars:
template = """
var_name = None
"""
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
if future_features:
future_imports = gast.ImportFrom(
module='__future__',
names=[gast.alias(name=name, asname=None) for name in future_features],
level=0)
else:
future_imports = []
factory_args = [
gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
for name in factory_args
]
template = """
future_imports
def outer_factory_name():
dummy_closure_defs
def inner_factory_name(factory_args):
entity_defs
return entity_name
return inner_factory_name
"""
return templates.replace(
template,
dummy_closure_defs=dummy_closure_defs,
entity_defs=nodes,
entity_name=entity_name,
factory_args=factory_args,
future_imports=future_imports,
inner_factory_name=inner_factory_name,
outer_factory_name=outer_factory_name)
class _TransformedFnFactory(object):
"""Helper object that wraps a transformed function factory."""
def __init__(self, name, freevars, extra_locals):
"""Creates a new factory for a transformed function.
Args:
name: The function name.
freevars: The list of non-global free variables for the function.
extra_locals: Dict[Text, Any], names and values for custom variables that
are accessible to the generated code as local variables.
"""
self._name = name
self._freevars = freevars
self._extra_locals = extra_locals
self._unbound_factory = None
self.module = None
self.source_map = None
def create(self,
nodes,
namer,
inner_factory_name='inner_factory',
outer_factory_name='outer_factory',
future_features=()):
"""Initializes a transformed function."""
if self._unbound_factory is not None:
raise ValueError('double initialization; create a new object instead')
inner_factory_name = namer.new_symbol(inner_factory_name, ())
outer_factory_name = namer.new_symbol(outer_factory_name, ())
nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
outer_factory_name, self._freevars,
self._extra_locals.keys(), future_features)
module, _, source_map = loader.load_ast(
nodes, include_source_map=True)
outer_factory = getattr(module, outer_factory_name)
self._unbound_factory = outer_factory()
self.module = module
self.source_map = source_map
def instantiate(self,
globals_,
closure,
defaults=None,
kwdefaults=None):
"""Creates a new instance of the transformed function."""
if self._unbound_factory is None:
raise ValueError('call create first')
factory_code = self._unbound_factory.__code__
factory_freevars = factory_code.co_freevars
closure_map = dict(zip(self._freevars, closure))
factory_closure = tuple(
closure_map[name] for name in factory_code.co_freevars)
if len(factory_closure) != len(closure):
raise ValueError(
'closure mismatch, requested {}, but source function had {}'.format(
self._freevars, factory_freevars))
bound_factory = types.FunctionType(
code=factory_code,
globals=globals_,
name=self._name,
argdefs=(),
closure=factory_closure)
# The lint override is a false positive.
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
if defaults:
transformed_entity.__defaults__ = defaults
if kwdefaults:
transformed_entity.__kwdefaults__ = kwdefaults
return transformed_entity
class FunctionTranspiler(object):
"""A generic source-to-source transpiler for Python functions.
Its interface `transform_function` API offers a function-in, function-out
interface. Internally, it takes care of parsing, caching and variable binding.
Users typically subclass this, customizing the transform_ast method.
Usually, instances of this class are singletons, since each instance manages
its own cache. The caching subkey allows managing multiple types of
transformation.
Example:
class MyTransformer(FunctionTranspiler):
def transform_ast(self, node, ctx):
node = <<transform node, usually using ast.NodeTransformer classes>>
return node
transformer = MyTransfomer()
new_f, module, source_map = transformer.transform_function(f, ...)
# new_f is a function with signature identical to f
The transformed function has access to the same namespace as the original
function. To allow access to internal APIs, users may inject additional
symbols though the `extra_locals` argument of `transform_function`.
"""
def __init__(self):
self._cache_lock = threading.RLock()
self._cache = cache.CodeObjectCache()
def transform_ast(self, node, user_context):
"""Performs an actual transformation of a function's AST.
Subclasses must implement this method. They must not call it.
The method receives the original AST and generates code according to the
AST that the method returns. For functions, the returned AST is expected to
contain a function with the exact same arguments and closure. The resulting
function will receive the globals, closure and argument defaults of the
input function.
Args:
node: One or more ast.AST nodes representing the AST to be transformed.
user_context: The same value that the caller passed to
`transform_function`.
"""
raise NotImplementedError('subclasses must override this')
def get_transformed_name(self, node):
"""Returns a name for the output function. Subclasses may override this."""
if isinstance(node, gast.Lambda):
return 'lam'
elif isinstance(node, gast.FunctionDef):
# Note that we need to rename the function, to avoid any namespace
# clashes.
return node.name
else:
raise ValueError('Unknown node type {}'.format(node))
def _erase_arg_defaults(self, node):
"""Erase argde fault expressions, which would otherwise be unbound."""
args = node.args
for i in range(len(args.defaults)):
args.defaults[i] = parser.parse_expression('None')
for i, d in enumerate(args.kw_defaults):
if d is not None:
args.kw_defaults[i] = parser.parse_expression('None')
return node
def _transform_function(self, fn, user_context):
"""Performs source code transformation on a function."""
future_features = inspect_utils.getfutureimports(fn)
node, source = parser.parse_entity(fn, future_features=future_features)
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
# In general, the output of inspect.getsource is inexact for lambdas
# because it uses regex matching to adjust the exact location around
# the line number that CPython records. Then, the entire containing line
# is returned, which we may have trouble disambiguating.
# For example:
# x, y = lambda: 1, lambda: 2
is_lambda = fn.__name__ == '<lambda>'
if is_lambda:
nodes = ast_util.find_matching_definitions(node, fn)
if len(nodes) != 1:
raise ValueError(
'Unable to identify source code of lambda function {}.'
' It was defined in this code:\n'
'{}\n'
'This code must contain a single distinguishable lambda.'
' To avoid this problem, define each lambda in a separate'
' expression.'.format(fn, source))
node, = nodes
origin_info.resolve_entity(node, source, fn)
namespace = inspect_utils.getnamespace(fn)
namer = naming.Namer(namespace)
new_name = namer.new_symbol(self.get_transformed_name(node), ())
entity_info = transformer.EntityInfo(
name=new_name,
source_code=source,
source_file='<fragment>',
future_features=future_features,
namespace=namespace)
context = transformer.Context(entity_info, namer, user_context)
node = self._erase_arg_defaults(node)
node = self.transform_ast(node, context)
if is_lambda:
node = gast.Assign(
targets=[
gast.Name(
new_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=node)
else:
node.name = new_name
return node, context
def _cached_factory(self, fn, cache_subkey):
cached_factory = self._cache[fn][cache_subkey]
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
cached_factory)
return cached_factory
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
"""Returns the transformed function factory for a given input."""
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
with self._cache_lock:
# Check again under lock.
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
nodes, ctx = self._transform_function(fn, user_context)
if logging.has_verbosity(2):
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
factory = _TransformedFnFactory(
ctx.info.name, fn.__code__.co_freevars, extra_locals)
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
self._cache[fn][cache_subkey] = factory
return factory
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
"""Transforms a function.
The `caching_subkey` argument allows mapping each function to multiple
outputs in the cache. This is useful for instance when transformers
can generate multiple variants of output code, typically as a result of
different transformation flags.
Args:
fn: A function or lambda.
caching_subkey: Used for caching. Calls made for functions with the same
code object and caching_subkey will return a cached instance on
subsequent invocations. Using a constant will create unique per-function
entries.
user_context: An opaque object (may be none) that is forwarded to
transform_ast.
extra_locals: A Dict[Text, Any] containing additional variables to make
available to the transformed code. These will be visible as local
variables.
Returns:
A tuple:
* A function or lambda with the same signature and closure as `fn`
* The temporary module into which the transformed function was loaded
* The source map as a
Dict[origin_info.LineLocation, origin_info.OriginInfo]
"""
factory = self._transformed_factory(fn, caching_subkey, user_context,
extra_locals)
transformed_fn = factory.instantiate(
globals_=fn.__globals__,
closure=fn.__closure__ or (),
defaults=fn.__defaults__,
kwdefaults=getattr(fn, '__kwdefaults__', None))
return transformed_fn, factory.module, factory.source_map