Internal cleanup: Move the bulk of the source code transformation infrastructure into the generic pyct module.
PiperOrigin-RevId: 305135067 Change-Id: Ifb84546c35a603942fd864769e7320a7ae95da3b
This commit is contained in:
parent
fc94412b39
commit
ff551c9f20
tensorflow/python
autograph
converters
core
impl
pyct
BUILD
common_transformers
static_analysis
transformer.pytransformer_test.pytranspiler.pytranspiler_test.pyutils
eager
@ -19,7 +19,6 @@ filegroup(
|
||||
py_library(
|
||||
name = "converters",
|
||||
srcs = [
|
||||
"arg_defaults.py",
|
||||
"asserts.py",
|
||||
"break_statements.py",
|
||||
"call_trees.py",
|
||||
@ -48,18 +47,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "arg_defaults_test",
|
||||
srcs = ["arg_defaults_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":converters",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/autograph/core:test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "asserts_test",
|
||||
srcs = ["asserts_test.py"],
|
||||
|
@ -1,105 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Modifies the signature to allow resolving the value of default arguments.
|
||||
|
||||
Normally, function symbols are captured either in a function's globals or
|
||||
closure. This is not true for default arguments, which are evaluated when the
|
||||
function is defined:
|
||||
|
||||
b = 1
|
||||
c = 2
|
||||
def f(a=b + 1):
|
||||
return a + c
|
||||
|
||||
In the above example, the namespace of the function would include `c = 2` but
|
||||
not `b`.
|
||||
|
||||
If we were to naively generate a new function:
|
||||
|
||||
def new_f(a=b + 1):
|
||||
return a + c
|
||||
|
||||
The generated code would fail to load unless we exposed a symbol `b`. Capturing
|
||||
the closure of such an expression is difficult. However, we can capture the
|
||||
default value of argument `a` with relative ease.
|
||||
|
||||
This converter replaces all default argument expressions with a constant so
|
||||
that they don't cause loading to fail. This requires that the default values
|
||||
are reset after loading the transformed function:
|
||||
|
||||
def new_f(a=None):
|
||||
return a + c
|
||||
|
||||
# ... later, after new_f was loaded ...
|
||||
new_f.__defaults__ = f.__defaults__
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
|
||||
|
||||
class _Function(object):
|
||||
pass
|
||||
|
||||
|
||||
class ArgDefaultsTransformer(converter.Base):
|
||||
"""Transforms top level argument defaults."""
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
self.state[_Function].enter()
|
||||
node.args = self.visit(node.args)
|
||||
# Only the top level function is modified - no need to visit the children.
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
self.state[_Function].enter()
|
||||
node.args = self.visit(node.args)
|
||||
# Only the top level function is modified - no need to visit the children.
|
||||
self.state[_Function].exit()
|
||||
return node
|
||||
|
||||
def visit_arguments(self, node):
|
||||
if self.state[_Function].level > 2:
|
||||
return node
|
||||
|
||||
for i in range(len(node.defaults)):
|
||||
node.defaults[i] = parser.parse_expression('None')
|
||||
|
||||
for i, d in enumerate(node.kw_defaults):
|
||||
if d is not None:
|
||||
node.kw_defaults[i] = parser.parse_expression('None')
|
||||
|
||||
# Only the top level function is modified - no need to visit the children.
|
||||
return node
|
||||
|
||||
|
||||
def transform(node, ctx):
|
||||
"""Transform function call to the compiled counterparts.
|
||||
|
||||
Args:
|
||||
node: AST
|
||||
ctx: EntityContext
|
||||
Returns:
|
||||
A tuple (node, new_names):
|
||||
node: The transformed AST
|
||||
new_names: set(string), containing any newly-generated names
|
||||
"""
|
||||
return ArgDefaultsTransformer(ctx).visit(node)
|
@ -1,108 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for arg_defaults module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.converters import arg_defaults
|
||||
from tensorflow.python.autograph.core import converter_testing
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ArgDefaultsTransformerTest(converter_testing.TestCase):
|
||||
|
||||
def assertTransformedFirstLineIs(self, node, expected):
|
||||
self.assertEqual(
|
||||
parser.unparse(node, include_encoding_marker=False).split('\n')[0],
|
||||
expected)
|
||||
|
||||
def test_no_args(self):
|
||||
|
||||
def test_fn():
|
||||
pass
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
self.assertTransformedFirstLineIs(node, 'def test_fn():')
|
||||
|
||||
def test_no_defaults(self):
|
||||
|
||||
def test_fn(a, b, *c, **e):
|
||||
return a, b, c, e
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b, *c, **e):')
|
||||
|
||||
# TODO(mdan): Add kwonly-arg tests when PY2 is no longer supported.
|
||||
|
||||
def test_arg_defaults(self):
|
||||
|
||||
def test_fn(a, b=1, c=2):
|
||||
return a, b, c
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, c=None):')
|
||||
|
||||
def test_arg_defaults_with_vararg(self):
|
||||
|
||||
def test_fn(a, b=1, *c): # pylint: disable=keyword-arg-before-vararg
|
||||
return a, b, c
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, *c):')
|
||||
|
||||
def test_arg_defaults_ignores_inner_lambda(self):
|
||||
|
||||
def test_fn():
|
||||
return (lambda x=7: x)()
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
||||
self.assertEqual(test_fn(), result.test_fn())
|
||||
|
||||
def test_arg_defaults_ignores_inner_function(self):
|
||||
|
||||
def test_fn():
|
||||
def inner_fn(a=3):
|
||||
return a
|
||||
return inner_fn()
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
||||
self.assertEqual(test_fn(), result.test_fn())
|
||||
|
||||
def test_arg_defaults_ignores_inner_function_returned(self):
|
||||
|
||||
def test_fn():
|
||||
def inner_fn(a=3):
|
||||
return a
|
||||
return inner_fn
|
||||
|
||||
node, ctx = self.prepare(test_fn, {})
|
||||
node = arg_defaults.transform(node, ctx)
|
||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
||||
self.assertEqual(test_fn()(), result.test_fn()())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -191,7 +191,7 @@ class CallTreeTransformer(converter.Base):
|
||||
return node
|
||||
|
||||
if (full_name == 'print' and
|
||||
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||
not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||
return node
|
||||
|
||||
template = """
|
||||
|
@ -54,8 +54,8 @@ class FunctionTransformer(converter.Base):
|
||||
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
||||
# function_wrappers.py.
|
||||
if fn_scope.level == 2:
|
||||
return self.ctx.program.options
|
||||
return self.ctx.program.options.call_options()
|
||||
return self.ctx.user.options
|
||||
return self.ctx.user.options.call_options()
|
||||
|
||||
def visit_Lambda(self, node):
|
||||
with self.state[_Function] as fn_scope:
|
||||
|
@ -53,7 +53,7 @@ class LogicalExpressionTransformer(converter.Base):
|
||||
op_type = type(operator)
|
||||
if op_type in LOGICAL_OPERATORS:
|
||||
return LOGICAL_OPERATORS[op_type]
|
||||
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
||||
if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
||||
if op_type in EQUALITY_OPERATORS:
|
||||
return EQUALITY_OPERATORS[op_type]
|
||||
return None
|
||||
@ -83,7 +83,7 @@ class LogicalExpressionTransformer(converter.Base):
|
||||
def visit_Compare(self, node):
|
||||
node = self.generic_visit(node)
|
||||
|
||||
if (not self.ctx.program.options.uses(
|
||||
if (not self.ctx.user.options.uses(
|
||||
converter.Feature.EQUALITY_OPERATORS)):
|
||||
return node
|
||||
|
||||
|
@ -253,25 +253,6 @@ class ProgramContext(
|
||||
pass
|
||||
|
||||
|
||||
class EntityContext(transformer.Context):
|
||||
"""Tracks the conversion of a single entity.
|
||||
|
||||
This object is mutable, and is updated during conversion. Not thread safe.
|
||||
|
||||
Attributes:
|
||||
namer: Namer
|
||||
info: transformer.EntityInfo
|
||||
program: ProgramContext,
|
||||
targe_name: Text
|
||||
"""
|
||||
|
||||
def __init__(self, namer, entity_info, program_ctx, target_name=None):
|
||||
super(EntityContext, self).__init__(entity_info)
|
||||
self.namer = namer
|
||||
self.program = program_ctx
|
||||
self.target_name = target_name
|
||||
|
||||
|
||||
class Base(transformer.Base):
|
||||
"""All converters should inherit from this class.
|
||||
|
||||
|
@ -168,12 +168,12 @@ class TestCase(test.TestCase):
|
||||
options=converter.ConversionOptions(recursive=recursive),
|
||||
autograph_module=None)
|
||||
entity_info = transformer.EntityInfo(
|
||||
name=test_fn.__name__,
|
||||
source_code=source,
|
||||
source_file='<fragment>',
|
||||
future_features=future_features,
|
||||
namespace=namespace)
|
||||
ctx = converter.EntityContext(
|
||||
namer, entity_info, program_ctx, 'test_fn')
|
||||
ctx = transformer.Context(entity_info, namer, program_ctx)
|
||||
origin_info.resolve_entity(node, source, test_fn)
|
||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||
return node, ctx
|
||||
|
@ -29,10 +29,7 @@ import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
|
||||
# pylint:disable=g-bad-import-order
|
||||
|
||||
import six
|
||||
# pylint:enable=g-bad-import-order
|
||||
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.core import converter
|
||||
@ -668,7 +665,7 @@ def to_graph(entity, recursive=True, experimental_optional_features=None):
|
||||
user_requested=True,
|
||||
optional_features=experimental_optional_features),
|
||||
autograph_module=tf_inspect.getmodule(to_graph))
|
||||
return conversion.convert(entity, program_ctx)
|
||||
return autograph_artifact(conversion.convert(entity, program_ctx))
|
||||
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
|
||||
logging.error(1, 'Error converting %s', entity, exc_info=True)
|
||||
raise ConversionError('converting {}: {}: {}'.format(
|
||||
|
@ -18,21 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import imp
|
||||
import inspect
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import unittest
|
||||
import weakref
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph import operators
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.converters import arg_defaults
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.converters import call_trees
|
||||
@ -50,304 +41,70 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import pretty_printer
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct import transpiler
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class _ConvertedEntityFactoryInfo(
|
||||
collections.namedtuple(
|
||||
'_ConvertedEntityFactoryInfo',
|
||||
('module_name', 'converted_name', 'factory_factory_name', 'source_map'))
|
||||
):
|
||||
"""Holds metadata about a converted entity stored as a dynamic factory.
|
||||
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
||||
|
||||
The dynamic factory is assumed to be created by _wrap_into_dynamic_factory,
|
||||
be named `factory_factory_name` and located inside the module named as
|
||||
`module_name`.
|
||||
def get_transformed_name(self, node):
|
||||
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
||||
|
||||
Attributes:
|
||||
module_name: Text, the name of the module containing the entity.
|
||||
converted_name: Text, the name of the converted entity.
|
||||
factory_factory_name: Text, the name of the dynamic factory.
|
||||
source_map: Dict.
|
||||
"""
|
||||
def transform_ast(self, node, ctx):
|
||||
# TODO(mdan): Insert list_comprehensions somewhere.
|
||||
unsupported_features_checker.verify(node)
|
||||
|
||||
def __str__(self):
|
||||
return '_ConvertedEntityFactoryInfo({} in {})'.format(
|
||||
self.converted_name, self.module_name)
|
||||
|
||||
def get_module(self):
|
||||
return sys.modules[self.module_name]
|
||||
|
||||
def get_factory(self):
|
||||
assert self.module_name in sys.modules
|
||||
factory_factory = getattr(sys.modules[self.module_name],
|
||||
self.factory_factory_name)
|
||||
return factory_factory()
|
||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||
node = converter.apply_(node, ctx, functions)
|
||||
node = converter.apply_(node, ctx, directives)
|
||||
node = converter.apply_(node, ctx, break_statements)
|
||||
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||
node = converter.apply_(node, ctx, asserts)
|
||||
# Note: sequencing continue canonicalization before for loop one avoids
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = converter.apply_(node, ctx, continue_statements)
|
||||
node = converter.apply_(node, ctx, return_statements)
|
||||
if ctx.user.options.uses(converter.Feature.LISTS):
|
||||
node = converter.apply_(node, ctx, lists)
|
||||
node = converter.apply_(node, ctx, slices)
|
||||
node = converter.apply_(node, ctx, call_trees)
|
||||
node = converter.apply_(node, ctx, control_flow)
|
||||
node = converter.apply_(node, ctx, conditional_expressions)
|
||||
node = converter.apply_(node, ctx, logical_expressions)
|
||||
return node
|
||||
|
||||
|
||||
# TODO(mdan): Add a garbage collection hook for cleaning up modules.
|
||||
class _FunctionCache(object):
|
||||
"""A hierarchical cache that uses the converted entity as weak key.
|
||||
|
||||
The keys soft references (i.e. they are discarded when the key is
|
||||
destroyed). The subkeys are normal hashable values.
|
||||
|
||||
This class is generic - see the call site for how the keys and values are
|
||||
defined.
|
||||
"""
|
||||
|
||||
__slots__ = ('_cache',)
|
||||
|
||||
def __init__(self):
|
||||
self._cache = weakref.WeakKeyDictionary()
|
||||
|
||||
def _get_key(self, entity):
|
||||
raise NotImplementedError('subclasses will override')
|
||||
|
||||
def has(self, entity, subkey):
|
||||
key = self._get_key(entity)
|
||||
if key not in self._cache:
|
||||
return False
|
||||
return subkey in self._cache[key]
|
||||
|
||||
def __getitem__(self, entity):
|
||||
key = self._get_key(entity)
|
||||
if key not in self._cache:
|
||||
# The bucket needs to be initialized to support this usage:
|
||||
# cache[key][subkey] = value
|
||||
self._cache[key] = {}
|
||||
return self._cache[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._cache)
|
||||
_TRANSPILER = AutoGraphTranspiler()
|
||||
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
||||
|
||||
|
||||
class _CodeObjectCache(_FunctionCache):
|
||||
"""A function cache based on code objects (i.e., the source code).
|
||||
|
||||
Multiple functions may share the same code object, but they may share the
|
||||
cache because we know they have the exact source code. This properly handles
|
||||
functions defined in a loop, bound methods, etc.
|
||||
|
||||
Falls back to the function object, if it doesn't have a code object.
|
||||
"""
|
||||
|
||||
def _get_key(self, entity):
|
||||
if hasattr(entity, '__code__'):
|
||||
return entity.__code__
|
||||
else:
|
||||
return entity
|
||||
|
||||
|
||||
class _UnboundInstanceCache(_FunctionCache):
|
||||
"""A function cache based on unbound function objects.
|
||||
|
||||
Unlike the _CodeObjectCache, this discriminates between different functions
|
||||
even if they have the same code. This properly handles decorators that may
|
||||
masquerade as various functions. Bound functions are not discriminated by
|
||||
the object they're bound to.
|
||||
"""
|
||||
|
||||
def _get_key(self, entity):
|
||||
if inspect.ismethod(entity):
|
||||
return entity.__func__
|
||||
return entity
|
||||
|
||||
|
||||
# Using a re-entrant lock to guard against the unlikely possibility that the
|
||||
# conversion process triggers additional code execution.
|
||||
_CACHE_LOCK = threading.RLock()
|
||||
|
||||
|
||||
_CACHE = _CodeObjectCache()
|
||||
_WHITELIST_CACHE = _UnboundInstanceCache()
|
||||
|
||||
|
||||
# Note: strictly speaking, a simple factory might have been sufficient for
|
||||
# functions. But the double factory approach allows us to control the closure
|
||||
# and globals of the converted code in a cleaner fashion.
|
||||
# TODO(mdan): A simple factory may be sufficient.
|
||||
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
|
||||
factory_name, closure_vars, future_features):
|
||||
"""Wraps an AST into the body of a dynamic factory.
|
||||
|
||||
This uses the dynamic factory (factory of factory) pattern to achieve the
|
||||
following:
|
||||
|
||||
1. The inner factory, dynamically creates the entity represented by nodes.
|
||||
2. The entity is parametrized by `ag__`, the internal AutoGraph module.
|
||||
3. The outer factory creates the inner factory with a lexical scope
|
||||
in which `closure_vars` are bound local variables. This in turn allows the
|
||||
caller to control the exact closure (i.e. non-global free variables) for
|
||||
the inner factory.
|
||||
|
||||
The AST is expected to define some symbol named by `entity_name`.
|
||||
|
||||
Args:
|
||||
nodes: ast.AST
|
||||
entity_name: Union[Text, ast.AST]
|
||||
factory_factory_name: Text
|
||||
factory_name: Text
|
||||
closure_vars: Iterable[Text]
|
||||
future_features: Iterable[Text], see EntityInfo.future_features.
|
||||
|
||||
Returns:
|
||||
ast.AST
|
||||
"""
|
||||
if not isinstance(nodes, (list, tuple)):
|
||||
nodes = (nodes,)
|
||||
|
||||
dummy_closure_defs = []
|
||||
for var_name in closure_vars:
|
||||
template = """
|
||||
var_name = None
|
||||
"""
|
||||
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
|
||||
|
||||
if future_features:
|
||||
future_imports = gast.ImportFrom(
|
||||
module='__future__',
|
||||
names=[gast.alias(name=name, asname=None) for name in future_features],
|
||||
level=0)
|
||||
else:
|
||||
future_imports = []
|
||||
|
||||
# These dummy symbol declarations create local fariables in a function scope,
|
||||
# so that the Python parser correctly marks them as free non-global variables
|
||||
# upon load (that is, it creates cell slots for each symbol). Their values are
|
||||
# not used, as the cells are swapped with the original entity's cells after
|
||||
# the code has been loaded.
|
||||
template = """
|
||||
future_imports
|
||||
def factory_factory_name():
|
||||
dummy_closure_defs
|
||||
def factory_name(ag__, ag_source_map__, ag_module__):
|
||||
entity_defs
|
||||
entity_name.ag_source_map = ag_source_map__
|
||||
entity_name.ag_module = ag_module__
|
||||
entity_name = ag__.autograph_artifact(entity_name)
|
||||
return entity_name
|
||||
return factory_name
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
future_imports=future_imports,
|
||||
factory_factory_name=factory_factory_name,
|
||||
factory_name=factory_name,
|
||||
dummy_closure_defs=dummy_closure_defs,
|
||||
entity_defs=nodes,
|
||||
entity_name=entity_name)
|
||||
|
||||
|
||||
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
|
||||
"""Returns a (possibly cached) factory for the converted result of entity."""
|
||||
# The cache subkey encompasses any conversion options on which the generated
|
||||
# code may depend.
|
||||
# The cached factory includes the necessary definitions to distinguish
|
||||
# between the global and non-global free variables. For this reason, the
|
||||
# cache subkey includes the names of the free non-globals.
|
||||
subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))
|
||||
|
||||
with _CACHE_LOCK:
|
||||
# The cache values are _ConvertedEntityFactoryInfo objects.
|
||||
if _CACHE.has(entity, subkey):
|
||||
# TODO(mdan): Check whether the module is still loaded.
|
||||
converted_entity_info = _CACHE[entity][subkey]
|
||||
logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey,
|
||||
converted_entity_info)
|
||||
return converted_entity_info
|
||||
|
||||
logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey)
|
||||
|
||||
nodes, converted_name, entity_info = convert_entity_to_ast(
|
||||
entity, program_ctx)
|
||||
|
||||
namer = naming.Namer(entity_info.namespace)
|
||||
factory_factory_name = namer.new_symbol('create_converted_entity_factory',
|
||||
())
|
||||
factory_name = namer.new_symbol('create_converted_entity', ())
|
||||
nodes = _wrap_into_dynamic_factory(nodes, converted_name,
|
||||
factory_factory_name, factory_name,
|
||||
free_nonglobal_var_names,
|
||||
entity_info.future_features)
|
||||
|
||||
module, _, source_map = loader.load_ast(nodes, include_source_map=True)
|
||||
module_name = module.__name__
|
||||
|
||||
converted_entity_info = _ConvertedEntityFactoryInfo(
|
||||
module_name=module_name,
|
||||
converted_name=converted_name,
|
||||
factory_factory_name=factory_factory_name,
|
||||
source_map=source_map)
|
||||
_CACHE[entity][subkey] = converted_entity_info
|
||||
return converted_entity_info
|
||||
|
||||
|
||||
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
||||
"""Creates a converted instance and binds it to match original entity."""
|
||||
factory = converted_entity_info.get_factory()
|
||||
|
||||
entity_globals = entity.__globals__
|
||||
entity_closure = entity.__closure__ or ()
|
||||
assert len(entity_closure) == len(free_nonglobal_var_names)
|
||||
|
||||
# Fit the original entity's cells to match the order of factory's cells.
|
||||
original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure))
|
||||
new_factory_cells = tuple(
|
||||
original_names_and_cells[name] for name in factory.__code__.co_freevars)
|
||||
|
||||
bound_factory = types.FunctionType(
|
||||
code=factory.__code__,
|
||||
globals=entity_globals,
|
||||
name=factory.__name__,
|
||||
argdefs=(),
|
||||
closure=new_factory_cells)
|
||||
|
||||
# Two other free vars: the internal "ag__" module and the source
|
||||
# map. These are wired via the parameters of the factory.
|
||||
converted_entity = bound_factory( # pylint:disable=not-callable
|
||||
ag_internal, converted_entity_info.source_map,
|
||||
converted_entity_info.get_module())
|
||||
|
||||
# Attach the default argument to the converted function.
|
||||
converted_entity.__defaults__ = entity.__defaults__
|
||||
if hasattr(entity, '__kwdefaults__'):
|
||||
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
||||
|
||||
return converted_entity
|
||||
custom_vars = None
|
||||
|
||||
|
||||
# TODO(mdan): Superfluous function, remove.
|
||||
# TODO(mdan): Put these extra fields inside __autograph_info__.
|
||||
def convert(entity, program_ctx):
|
||||
"""Converts an entity into an equivalent entity."""
|
||||
"""Applies AutoGraph to entity."""
|
||||
|
||||
if not hasattr(entity, '__code__'):
|
||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||
'expose a __code__ object. If this is a @tf.function,'
|
||||
' try passing f.python_function instead.')
|
||||
free_nonglobal_var_names = entity.__code__.co_freevars
|
||||
|
||||
for i, name in enumerate(free_nonglobal_var_names):
|
||||
if (name == 'ag__' and
|
||||
entity.__closure__[i].cell_contents is not ag_internal):
|
||||
raise ValueError('entity {} uses the reserved symbol "{}"'.format(
|
||||
entity, name))
|
||||
# TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered.
|
||||
_create_custom_vars(program_ctx)
|
||||
transformed, module, source_map = _TRANSPILER.transform_function(
|
||||
entity, program_ctx.options, program_ctx, custom_vars)
|
||||
|
||||
converted_entity_info = _convert_with_cache(entity, program_ctx,
|
||||
free_nonglobal_var_names)
|
||||
|
||||
return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
|
||||
assert not hasattr(transformed, 'ag_module')
|
||||
assert not hasattr(transformed, 'ag_source_map')
|
||||
transformed.ag_module = module
|
||||
transformed.ag_source_map = source_map
|
||||
return transformed
|
||||
|
||||
|
||||
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
||||
@ -472,58 +229,15 @@ def cache_whitelisted(entity, options):
|
||||
pass
|
||||
|
||||
|
||||
# TODO(mdan): Rename to convert_*_node to avoid confusion with convert.
|
||||
def convert_entity_to_ast(o, program_ctx):
|
||||
"""Compile a Python entity into equivalent TensorFlow.
|
||||
|
||||
Args:
|
||||
o: A Python entity.
|
||||
program_ctx: A ProgramContext object.
|
||||
|
||||
Returns:
|
||||
A tuple (ast, new_name, namespace):
|
||||
* ast: An AST representing an entity with interface equivalent to `o`,
|
||||
but which when executed it creates TF a graph.
|
||||
* new_name: The symbol name under which the new entity can be found.
|
||||
* namespace: A dict mapping all symbols visible to the converted entity,
|
||||
keyed by their symbol name.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if entity is of a type that is not yet supported.
|
||||
"""
|
||||
logging.log(1, 'Converting %s', o)
|
||||
|
||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
||||
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
|
||||
if logging.has_verbosity(4):
|
||||
for n in nodes:
|
||||
logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
|
||||
pretty_printer.fmt(n, color=False))
|
||||
|
||||
return nodes, name, entity_info
|
||||
|
||||
|
||||
def _add_reserved_symbol(namespace, name, entity):
|
||||
if name not in namespace:
|
||||
namespace[name] = entity
|
||||
elif namespace[name] != entity:
|
||||
raise ValueError('The name "%s" is reserved and may not be used.' % name)
|
||||
|
||||
|
||||
ag_internal = None
|
||||
|
||||
|
||||
# TODO(mdan): Move into core or replace with an actual importable module.
|
||||
def _add_self_references(namespace, autograph_module):
|
||||
def _create_custom_vars(program_ctx):
|
||||
"""Adds namespace references to the module that exposes the api itself."""
|
||||
global ag_internal
|
||||
if ag_internal is None:
|
||||
global custom_vars
|
||||
if custom_vars is None:
|
||||
# Craft a module that exposes parts of the external API as well as certain
|
||||
# internal modules.
|
||||
ag_internal = imp.new_module('autograph')
|
||||
ag_internal.__dict__.update(autograph_module.__dict__)
|
||||
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
|
||||
ag_internal.ConversionOptions = converter.ConversionOptions
|
||||
ag_internal.STD = converter.STANDARD_OPTIONS
|
||||
ag_internal.Feature = converter.Feature
|
||||
@ -536,102 +250,4 @@ def _add_self_references(namespace, autograph_module):
|
||||
ag_internal.__dict__.update(special_functions.__dict__)
|
||||
ag_internal.__dict__.update(operators.__dict__)
|
||||
|
||||
_add_reserved_symbol(namespace, 'ag__', ag_internal)
|
||||
|
||||
|
||||
def convert_func_to_ast(f, program_ctx, do_rename=True):
|
||||
"""Specialization of `convert_entity_to_ast` for callable functions."""
|
||||
|
||||
future_features = inspect_utils.getfutureimports(f)
|
||||
node, source = parser.parse_entity(f, future_features=future_features)
|
||||
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
|
||||
# Parsed AST should contain future imports and one function def node.
|
||||
|
||||
# In general, the output of inspect.getsource is inexact for lambdas because
|
||||
# it uses regex matching to adjust the exact location around the line number
|
||||
# that CPython records. Then, the entire containing line is returned, which
|
||||
# we may have trouble disambiguating. For example:
|
||||
# x, y = lambda: 1, lambda: 2
|
||||
if f.__name__ == '<lambda>':
|
||||
nodes = ast_util.find_matching_definitions(node, f)
|
||||
if len(nodes) != 1:
|
||||
raise ValueError(
|
||||
'Unable to identify source code of lambda function {}. It was'
|
||||
' defined on this line: {}, which must contain a single lambda with'
|
||||
' matching signature. To avoid ambiguity, define each lambda'
|
||||
' in a separate expression.'.format(f, source))
|
||||
node, = nodes
|
||||
|
||||
# TODO(znado): Place inside standard_analysis.
|
||||
origin_info.resolve_entity(node, source, f)
|
||||
|
||||
namespace = inspect_utils.getnamespace(f)
|
||||
_add_self_references(namespace, program_ctx.autograph_module)
|
||||
namer = naming.Namer(namespace)
|
||||
|
||||
if isinstance(node, gast.Lambda):
|
||||
new_name = namer.new_symbol('tf__lambda', ())
|
||||
elif do_rename:
|
||||
new_name = namer.new_symbol('tf__' + f.__name__, ())
|
||||
else:
|
||||
new_name = f.__name__
|
||||
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=source,
|
||||
source_file='<fragment>',
|
||||
future_features=future_features,
|
||||
namespace=namespace)
|
||||
context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
|
||||
node = node_to_graph(node, context)
|
||||
|
||||
if isinstance(node, gast.Lambda):
|
||||
node = gast.Assign(
|
||||
targets=[
|
||||
gast.Name(
|
||||
new_name, ctx=gast.Store(), annotation=None, type_comment=None)
|
||||
],
|
||||
value=node)
|
||||
elif do_rename:
|
||||
node.name = new_name
|
||||
else:
|
||||
assert node.name == new_name
|
||||
|
||||
return (node,), new_name, entity_info
|
||||
|
||||
|
||||
def node_to_graph(node, context):
|
||||
"""Convert Python code to equivalent TF graph mode code.
|
||||
|
||||
Args:
|
||||
node: AST, the code to convert.
|
||||
context: converter.EntityContext
|
||||
|
||||
Returns:
|
||||
A tuple (node, deps):
|
||||
* node: A Python ast node, representing the converted code.
|
||||
* deps: A set of strings, the fully qualified names of entity
|
||||
dependencies that this node has.
|
||||
"""
|
||||
# TODO(mdan): Insert list_comprehensions somewhere.
|
||||
unsupported_features_checker.verify(node)
|
||||
|
||||
node = converter.standard_analysis(node, context, is_initial=True)
|
||||
node = converter.apply_(node, context, functions)
|
||||
node = converter.apply_(node, context, arg_defaults)
|
||||
node = converter.apply_(node, context, directives)
|
||||
node = converter.apply_(node, context, break_statements)
|
||||
if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||
node = converter.apply_(node, context, asserts)
|
||||
# Note: sequencing continue canonicalization before for loop one avoids
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = converter.apply_(node, context, continue_statements)
|
||||
node = converter.apply_(node, context, return_statements)
|
||||
if context.program.options.uses(converter.Feature.LISTS):
|
||||
node = converter.apply_(node, context, lists)
|
||||
node = converter.apply_(node, context, slices)
|
||||
node = converter.apply_(node, context, call_trees)
|
||||
node = converter.apply_(node, context, control_flow)
|
||||
node = converter.apply_(node, context, conditional_expressions)
|
||||
node = converter.apply_(node, context, logical_expressions)
|
||||
return node
|
||||
custom_vars = {'ag__': ag_internal}
|
||||
|
@ -20,11 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import imp
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import weakref
|
||||
|
||||
import gast
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph import utils
|
||||
@ -33,7 +31,6 @@ from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.impl import api
|
||||
from tensorflow.python.autograph.impl import conversion
|
||||
from tensorflow.python.autograph.impl.testing import pybind_for_testing
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
@ -126,156 +123,6 @@ class ConversionTest(test.TestCase):
|
||||
# Note: currently, native bindings are whitelisted by a separate check.
|
||||
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
||||
|
||||
def test_convert_entity_to_ast_callable(self):
|
||||
b = 2
|
||||
|
||||
def f(a):
|
||||
return a + b
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx)
|
||||
fn_node, = nodes
|
||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertIs(info.namespace['b'], b)
|
||||
|
||||
def test_convert_entity_to_ast_function_with_defaults(self):
|
||||
b = 2
|
||||
c = 1
|
||||
|
||||
def f(a, d=c + 1):
|
||||
return a + b + d
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
||||
fn_node, = nodes
|
||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertEqual(
|
||||
parser.unparse(fn_node.args.defaults[0],
|
||||
include_encoding_marker=False).strip(), 'None')
|
||||
|
||||
def test_convert_entity_to_ast_call_tree(self):
|
||||
|
||||
def g(a):
|
||||
return a
|
||||
|
||||
def f(a):
|
||||
return g(a)
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
||||
f_node, = nodes
|
||||
self.assertEqual('tf__f', f_node.name)
|
||||
|
||||
def test_convert_entity_to_ast_lambda(self):
|
||||
b = 2
|
||||
f = lambda x: b * x if x > 0 else -x
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
||||
f, program_ctx)
|
||||
self.assertIsInstance(fn_node, gast.Assign)
|
||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||
self.assertEqual('tf__lambda', name)
|
||||
self.assertIs(entity_info.namespace['b'], b)
|
||||
|
||||
def test_convert_entity_to_ast_multiple_lambdas(self):
|
||||
a, b = 1, 2
|
||||
f, _ = (lambda x: a * x, lambda y: b * y)
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
||||
f, program_ctx)
|
||||
self.assertIsInstance(fn_node, gast.Assign)
|
||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||
self.assertEqual('tf__lambda', name)
|
||||
self.assertIs(entity_info.namespace['a'], a)
|
||||
|
||||
def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self):
|
||||
a, b = 1, 2
|
||||
f, _ = (lambda x: a * x, lambda x: b * x)
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
with self.assertRaises(ValueError):
|
||||
conversion.convert_entity_to_ast(f, program_ctx)
|
||||
|
||||
def test_convert_entity_to_ast_lambda_code_with_garbage(self):
|
||||
# pylint:disable=g-long-lambda
|
||||
f = ( # intentional wrap
|
||||
lambda x: (
|
||||
x # intentional wrap
|
||||
+ 1),)[0]
|
||||
# pylint:enable=g-long-lambda
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
(fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
||||
self.assertIsInstance(fn_node, gast.Assign)
|
||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
||||
self.assertEqual('tf__lambda', name)
|
||||
|
||||
def test_convert_entity_to_ast_nested_functions(self):
|
||||
b = 2
|
||||
|
||||
def f(x):
|
||||
|
||||
def g(x):
|
||||
return b * x
|
||||
|
||||
return g(x)
|
||||
|
||||
program_ctx = self._simple_program_ctx()
|
||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
||||
f, program_ctx)
|
||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
||||
self.assertEqual(fn_node.name, 'tf__f')
|
||||
self.assertEqual('tf__f', name)
|
||||
self.assertIs(entity_info.namespace['b'], b)
|
||||
|
||||
def test_convert_concurrency(self):
|
||||
|
||||
def test_fn():
|
||||
pass
|
||||
|
||||
generated_file_names = []
|
||||
|
||||
def conversion_thread():
|
||||
new_f = conversion.convert(test_fn, self._simple_program_ctx())
|
||||
generated_file_names.append(new_f.__code__.co_filename)
|
||||
|
||||
threads = tuple(
|
||||
threading.Thread(target=conversion_thread) for _ in range(10))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Races would potentially create multiple files (non-deterministically,
|
||||
# but with high likelihood).
|
||||
self.assertEqual(len(set(generated_file_names)), 1)
|
||||
|
||||
def test_convert_reentrance(self):
|
||||
|
||||
def test_fn():
|
||||
pass
|
||||
|
||||
# There are no known ways to cause convert to re-enter. So we instrument
|
||||
# an internal function to do that instead.
|
||||
old_node_to_graph = conversion.node_to_graph
|
||||
self.num_conversions = 0
|
||||
def node_to_graph_wrapper(node, context):
|
||||
self.num_conversions += 1
|
||||
if self.num_conversions < 2:
|
||||
conversion.convert(test_fn, self._simple_program_ctx())
|
||||
return old_node_to_graph(node, context)
|
||||
|
||||
try:
|
||||
conversion.node_to_graph = node_to_graph_wrapper
|
||||
new_f = conversion.convert(test_fn, self._simple_program_ctx())
|
||||
self.assertIsNotNone(new_f)
|
||||
finally:
|
||||
conversion.node_to_graph = old_node_to_graph
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -39,6 +39,7 @@ py_library(
|
||||
"qual_names.py",
|
||||
"templates.py",
|
||||
"transformer.py",
|
||||
"transpiler.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
@ -230,3 +231,14 @@ py_test(
|
||||
"@gast_archive//:gast",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "transpiler_test",
|
||||
srcs = ["transpiler_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":pyct",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
@ -50,8 +50,12 @@ class AnfTestBase(test.TestCase):
|
||||
|
||||
def _simple_context(self):
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
||||
return transformer.Context(entity_info)
|
||||
name='test_fn',
|
||||
source_code=None,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace=None)
|
||||
return transformer.Context(entity_info, None, None)
|
||||
|
||||
def assert_same_ast(self, expected_node, node, msg=None):
|
||||
expected_source = parser.unparse(expected_node, indentation=' ')
|
||||
|
@ -22,6 +22,7 @@ import gast
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
@ -113,11 +114,17 @@ class ScopeTest(test.TestCase):
|
||||
class ActivityAnalyzerTestBase(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn):
|
||||
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||
node, source = parser.parse_entity(test_fn, future_features=())
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=source, source_file=None, future_features=(), namespace={})
|
||||
name=test_fn.__name__,
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace={})
|
||||
node = qual_names.resolve(node)
|
||||
ctx = transformer.Context(entity_info)
|
||||
namer = naming.Namer({})
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = activity.resolve(node, ctx)
|
||||
return node, entity_info
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
@ -35,11 +36,17 @@ global_b = 17
|
||||
class LivenessAnalyzerTestBase(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn):
|
||||
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||
node, source = parser.parse_entity(test_fn, future_features=())
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=source, source_file=None, future_features=(), namespace={})
|
||||
name=test_fn.__name__,
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace={})
|
||||
node = qual_names.resolve(node)
|
||||
ctx = transformer.Context(entity_info)
|
||||
namer = naming.Namer({})
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = activity.resolve(node, ctx)
|
||||
graphs = cfg.build(node)
|
||||
liveness.resolve(node, ctx, graphs)
|
||||
|
@ -22,6 +22,7 @@ import six
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
@ -37,11 +38,17 @@ global_b = 17
|
||||
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
|
||||
|
||||
def _parse_and_analyze(self, test_fn):
|
||||
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||
node, source = parser.parse_entity(test_fn, future_features=())
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=source, source_file=None, future_features=(), namespace={})
|
||||
name=test_fn.__name__,
|
||||
source_code=source,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace={})
|
||||
node = qual_names.resolve(node)
|
||||
ctx = transformer.Context(entity_info)
|
||||
namer = naming.Namer({})
|
||||
ctx = transformer.Context(entity_info, namer, None)
|
||||
node = activity.resolve(node, ctx)
|
||||
graphs = cfg.build(node)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs,
|
||||
|
@ -36,20 +36,26 @@ class Context(object):
|
||||
|
||||
Attributes:
|
||||
info: EntityInfo, immutable.
|
||||
namer: naming.Namer.
|
||||
current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
|
||||
AST node to be processed successfully. Useful for error handling.
|
||||
user: An user-supplied context object. The object is opaque to the
|
||||
infrastructure, but will pe passed through to all custom transformations.
|
||||
"""
|
||||
|
||||
def __init__(self, info):
|
||||
def __init__(self, info, namer, user_context):
|
||||
self.info = info
|
||||
self.namer = namer
|
||||
self.current_origin = None
|
||||
self.user = user_context
|
||||
|
||||
|
||||
# TODO(mdan): Move to a standalone file.
|
||||
class EntityInfo(
|
||||
collections.namedtuple(
|
||||
'EntityInfo',
|
||||
('source_code', 'source_file', 'future_features', 'namespace'))):
|
||||
('name', 'source_code', 'source_file', 'future_features', 'namespace'))
|
||||
):
|
||||
"""Contains information about a Python entity.
|
||||
|
||||
Immutable.
|
||||
@ -57,6 +63,7 @@ class EntityInfo(
|
||||
Examples of entities include functions and classes.
|
||||
|
||||
Attributes:
|
||||
name: The name that identifies this entity.
|
||||
source_code: The entity's source code.
|
||||
source_file: The entity's source file.
|
||||
future_features: Tuple[Text], the future features that this entity was
|
||||
|
@ -31,8 +31,12 @@ class TransformerTest(test.TestCase):
|
||||
|
||||
def _simple_context(self):
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
||||
return transformer.Context(entity_info)
|
||||
name='Test_fn',
|
||||
source_code=None,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace=None)
|
||||
return transformer.Context(entity_info, None, None)
|
||||
|
||||
def assertSameAnno(self, first, second, key):
|
||||
self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
|
||||
@ -299,8 +303,12 @@ class CodeGeneratorTest(test.TestCase):
|
||||
|
||||
def _simple_context(self):
|
||||
entity_info = transformer.EntityInfo(
|
||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
||||
return transformer.Context(entity_info)
|
||||
name='test_fn',
|
||||
source_code=None,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace=None)
|
||||
return transformer.Context(entity_info, None, None)
|
||||
|
||||
def test_basic_codegen(self):
|
||||
|
||||
|
419
tensorflow/python/autograph/pyct/transpiler.py
Normal file
419
tensorflow/python/autograph/pyct/transpiler.py
Normal file
@ -0,0 +1,419 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Generic source code transformation infrastructure."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
import types
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import loader
|
||||
from tensorflow.python.autograph.pyct import naming
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
from tensorflow.python.autograph.pyct import templates
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
|
||||
|
||||
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
|
||||
outer_factory_name, closure_vars, factory_args,
|
||||
future_features):
|
||||
"""Wraps an AST into the body of a factory with consistent lexical context.
|
||||
|
||||
The AST is expected to define some symbol with a name given by `entity_name`.
|
||||
|
||||
This mechanism ensures that the resulting transformed entity has lexical
|
||||
scoping identical to that of the source entity, while allowing extra
|
||||
parametrization.
|
||||
|
||||
Two nested factories achieve the following:
|
||||
|
||||
1. The inner factory dynamically creates the entity represented by `nodes`.
|
||||
2. The inner factory is parametrized by a custom set of arguments.
|
||||
3. The inner factory has a closure identical to that of the transformed
|
||||
entity.
|
||||
4. The inner factory has local variables named like `args`, which `nodes` may
|
||||
use as additional parameters.
|
||||
5. The inner factory returns the variables given by `entity_name`.
|
||||
6. The outer factory is niladic.
|
||||
7. The outer factory has no closure.
|
||||
8. The outer factory creates the necessary lexical scope for the inner
|
||||
factory, so that the loaded code has the given configuration for
|
||||
closure/globals.
|
||||
9. The outer factory returns the inner factory.
|
||||
|
||||
Roughly speaking, the following code is generated:
|
||||
|
||||
from __future__ import future_feature_1
|
||||
from __future__ import future_feature_2
|
||||
...
|
||||
|
||||
def outer_factory():
|
||||
closure_var_1 = None
|
||||
closure_var_2 = None
|
||||
...
|
||||
|
||||
def inner_factory(arg_1, arg_2, ...):
|
||||
<<nodes>>
|
||||
return entity
|
||||
|
||||
return inner_factory
|
||||
|
||||
The lexical scoping is created using dummy symbol declarations which create
|
||||
local fariables in the body of the outer factory, so that the Python parser
|
||||
correctly marks them as free non-global variables upon load (that is, it
|
||||
creates cell slots for each symbol. Thes symbols are initialized with None,
|
||||
but their values are not expected to be used; instead, the caller is expected
|
||||
to replace them with the cells of the source entity. For more details, see:
|
||||
https://docs.python.org/3/reference/executionmodel.html#binding-of-names
|
||||
|
||||
Args:
|
||||
nodes: Tuple[ast.AST], the source code to wrap.
|
||||
entity_name: Union[Text, ast.AST], the name of the principal entity that
|
||||
`nodes` define.
|
||||
inner_factory_name: Text, the name of the inner factory.
|
||||
outer_factory_name: Text, the name of the outer factory.
|
||||
closure_vars: Iterable[Text], names of the closure variables for the inner
|
||||
factory.
|
||||
factory_args: Iterable[Text], names of additional arguments for the
|
||||
inner factory. Useful to configure variables that the converted code can
|
||||
use. Typically, these are modules.
|
||||
future_features: Iterable[Text], names of future statements to associate the
|
||||
code with.
|
||||
|
||||
Returns:
|
||||
ast.AST
|
||||
"""
|
||||
dummy_closure_defs = []
|
||||
for var_name in closure_vars:
|
||||
template = """
|
||||
var_name = None
|
||||
"""
|
||||
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
|
||||
|
||||
if future_features:
|
||||
future_imports = gast.ImportFrom(
|
||||
module='__future__',
|
||||
names=[gast.alias(name=name, asname=None) for name in future_features],
|
||||
level=0)
|
||||
else:
|
||||
future_imports = []
|
||||
|
||||
factory_args = [
|
||||
gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
|
||||
for name in factory_args
|
||||
]
|
||||
|
||||
template = """
|
||||
future_imports
|
||||
def outer_factory_name():
|
||||
dummy_closure_defs
|
||||
def inner_factory_name(factory_args):
|
||||
entity_defs
|
||||
return entity_name
|
||||
return inner_factory_name
|
||||
"""
|
||||
return templates.replace(
|
||||
template,
|
||||
dummy_closure_defs=dummy_closure_defs,
|
||||
entity_defs=nodes,
|
||||
entity_name=entity_name,
|
||||
factory_args=factory_args,
|
||||
future_imports=future_imports,
|
||||
inner_factory_name=inner_factory_name,
|
||||
outer_factory_name=outer_factory_name)
|
||||
|
||||
|
||||
class _TransformedFnFactory(object):
|
||||
"""Helper object that wraps a transformed function factory."""
|
||||
|
||||
def __init__(self, name, freevars, extra_locals):
|
||||
"""Creates a new factory for a transformed function.
|
||||
|
||||
Args:
|
||||
name: The function name.
|
||||
freevars: The list of non-global free variables for the function.
|
||||
extra_locals: Dict[Text, Any], names and values for custom variables that
|
||||
are accessible to the generated code as local variables.
|
||||
"""
|
||||
self._name = name
|
||||
self._freevars = freevars
|
||||
self._extra_locals = extra_locals
|
||||
|
||||
self._unbound_factory = None
|
||||
self.module = None
|
||||
self.source_map = None
|
||||
|
||||
def create(self,
|
||||
nodes,
|
||||
namer,
|
||||
inner_factory_name='inner_factory',
|
||||
outer_factory_name='outer_factory',
|
||||
future_features=()):
|
||||
"""Initializes a transformed function."""
|
||||
if self._unbound_factory is not None:
|
||||
raise ValueError('double initialization; create a new object instead')
|
||||
|
||||
inner_factory_name = namer.new_symbol(inner_factory_name, ())
|
||||
outer_factory_name = namer.new_symbol(outer_factory_name, ())
|
||||
nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
|
||||
outer_factory_name, self._freevars,
|
||||
self._extra_locals.keys(), future_features)
|
||||
|
||||
module, _, source_map = loader.load_ast(
|
||||
nodes, include_source_map=True)
|
||||
outer_factory = getattr(module, outer_factory_name)
|
||||
self._unbound_factory = outer_factory()
|
||||
self.module = module
|
||||
self.source_map = source_map
|
||||
|
||||
def instantiate(self,
|
||||
globals_,
|
||||
closure,
|
||||
defaults=None,
|
||||
kwdefaults=None):
|
||||
"""Creates a new instance of the transformed function."""
|
||||
if self._unbound_factory is None:
|
||||
raise ValueError('call create first')
|
||||
|
||||
factory_code = self._unbound_factory.__code__
|
||||
factory_freevars = factory_code.co_freevars
|
||||
closure_map = dict(zip(self._freevars, closure))
|
||||
factory_closure = tuple(
|
||||
closure_map[name] for name in factory_code.co_freevars)
|
||||
if len(factory_closure) != len(closure):
|
||||
raise ValueError(
|
||||
'closure mismatch, requested {}, but source function had {}'.format(
|
||||
self._freevars, factory_freevars))
|
||||
|
||||
bound_factory = types.FunctionType(
|
||||
code=factory_code,
|
||||
globals=globals_,
|
||||
name=self._name,
|
||||
argdefs=(),
|
||||
closure=factory_closure)
|
||||
|
||||
# The lint override is a false positive.
|
||||
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
|
||||
|
||||
if defaults:
|
||||
transformed_entity.__defaults__ = defaults
|
||||
if kwdefaults:
|
||||
transformed_entity.__kwdefaults__ = kwdefaults
|
||||
|
||||
return transformed_entity
|
||||
|
||||
|
||||
class FunctionTranspiler(object):
|
||||
"""A generic source-to-source transpiler for Python functions.
|
||||
|
||||
Its interface `transform_function` API offers a function-in, function-out
|
||||
interface. Internally, it takes care of parsing, caching and variable binding.
|
||||
|
||||
Users typically subclass this, customizing the transform_ast method.
|
||||
|
||||
Usually, instances of this class are singletons, since each instance manages
|
||||
its own cache. The caching subkey allows managing multiple types of
|
||||
transformation.
|
||||
|
||||
Example:
|
||||
|
||||
class MyTransformer(FunctionTranspiler):
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
node = <<transform node, usually using ast.NodeTransformer classes>>
|
||||
return node
|
||||
|
||||
transformer = MyTransfomer()
|
||||
|
||||
new_f, module, source_map = transformer.transform_function(f, ...)
|
||||
# new_f is a function with signature identical to f
|
||||
|
||||
The transformed function has access to the same namespace as the original
|
||||
function. To allow access to internal APIs, users may inject additional
|
||||
symbols though the `extra_locals` argument of `transform_function`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache_lock = threading.RLock()
|
||||
self._cache = cache.CodeObjectCache()
|
||||
|
||||
def transform_ast(self, node, user_context):
|
||||
"""Performs an actual transformation of a function's AST.
|
||||
|
||||
Subclasses must implement this method. They must not call it.
|
||||
|
||||
The method receives the original AST and generates code according to the
|
||||
AST that the method returns. For functions, the returned AST is expected to
|
||||
contain a function with the exact same arguments and closure. The resulting
|
||||
function will receive the globals, closure and argument defaults of the
|
||||
input function.
|
||||
|
||||
Args:
|
||||
node: One or more ast.AST nodes representing the AST to be transformed.
|
||||
user_context: The same value that the caller passed to
|
||||
`transform_function`.
|
||||
"""
|
||||
raise NotImplementedError('subclasses must override this')
|
||||
|
||||
def get_transformed_name(self, node):
|
||||
"""Returns a name for the output function. Subclasses may override this."""
|
||||
if isinstance(node, gast.Lambda):
|
||||
return 'lam'
|
||||
elif isinstance(node, gast.FunctionDef):
|
||||
# Note that we need to rename the function, to avoid any namespace
|
||||
# clashes.
|
||||
return node.name
|
||||
else:
|
||||
raise ValueError('Unknown node type {}'.format(node))
|
||||
|
||||
def _erase_arg_defaults(self, node):
|
||||
"""Erase argde fault expressions, which would otherwise be unbound."""
|
||||
args = node.args
|
||||
for i in range(len(args.defaults)):
|
||||
args.defaults[i] = parser.parse_expression('None')
|
||||
for i, d in enumerate(args.kw_defaults):
|
||||
if d is not None:
|
||||
args.kw_defaults[i] = parser.parse_expression('None')
|
||||
return node
|
||||
|
||||
def _transform_function(self, fn, user_context):
|
||||
"""Performs source code transformation on a function."""
|
||||
future_features = inspect_utils.getfutureimports(fn)
|
||||
node, source = parser.parse_entity(fn, future_features=future_features)
|
||||
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
|
||||
|
||||
# In general, the output of inspect.getsource is inexact for lambdas
|
||||
# because it uses regex matching to adjust the exact location around
|
||||
# the line number that CPython records. Then, the entire containing line
|
||||
# is returned, which we may have trouble disambiguating.
|
||||
# For example:
|
||||
# x, y = lambda: 1, lambda: 2
|
||||
is_lambda = fn.__name__ == '<lambda>'
|
||||
if is_lambda:
|
||||
nodes = ast_util.find_matching_definitions(node, fn)
|
||||
if len(nodes) != 1:
|
||||
raise ValueError(
|
||||
'Unable to identify source code of lambda function {}.'
|
||||
' It was defined in this code:\n'
|
||||
'{}\n'
|
||||
'This code must contain a single distinguishable lambda.'
|
||||
' To avoid this problem, define each lambda in a separate'
|
||||
' expression.'.format(fn, source))
|
||||
node, = nodes
|
||||
|
||||
origin_info.resolve_entity(node, source, fn)
|
||||
|
||||
namespace = inspect_utils.getnamespace(fn)
|
||||
namer = naming.Namer(namespace)
|
||||
new_name = namer.new_symbol(self.get_transformed_name(node), ())
|
||||
entity_info = transformer.EntityInfo(
|
||||
name=new_name,
|
||||
source_code=source,
|
||||
source_file='<fragment>',
|
||||
future_features=future_features,
|
||||
namespace=namespace)
|
||||
context = transformer.Context(entity_info, namer, user_context)
|
||||
|
||||
node = self._erase_arg_defaults(node)
|
||||
node = self.transform_ast(node, context)
|
||||
|
||||
if is_lambda:
|
||||
node = gast.Assign(
|
||||
targets=[
|
||||
gast.Name(
|
||||
new_name,
|
||||
ctx=gast.Store(),
|
||||
annotation=None,
|
||||
type_comment=None)
|
||||
],
|
||||
value=node)
|
||||
else:
|
||||
node.name = new_name
|
||||
|
||||
return node, context
|
||||
|
||||
def _cached_factory(self, fn, cache_subkey):
|
||||
cached_factory = self._cache[fn][cache_subkey]
|
||||
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
|
||||
cached_factory)
|
||||
return cached_factory
|
||||
|
||||
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
|
||||
"""Returns the transformed function factory for a given input."""
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
return self._cached_factory(fn, cache_subkey)
|
||||
|
||||
with self._cache_lock:
|
||||
# Check again under lock.
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
return self._cached_factory(fn, cache_subkey)
|
||||
|
||||
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
|
||||
nodes, ctx = self._transform_function(fn, user_context)
|
||||
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
|
||||
|
||||
factory = _TransformedFnFactory(
|
||||
ctx.info.name, fn.__code__.co_freevars, extra_locals)
|
||||
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
|
||||
self._cache[fn][cache_subkey] = factory
|
||||
return factory
|
||||
|
||||
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
|
||||
"""Transforms a function.
|
||||
|
||||
The `caching_subkey` argument allows mapping each function to multiple
|
||||
outputs in the cache. This is useful for instance when transformers
|
||||
can generate multiple variants of output code, typically as a result of
|
||||
different transformation flags.
|
||||
|
||||
Args:
|
||||
fn: A function or lambda.
|
||||
caching_subkey: Used for caching. Calls made for functions with the same
|
||||
code object and caching_subkey will return a cached instance on
|
||||
subsequent invocations. Using a constant will create unique per-function
|
||||
entries.
|
||||
user_context: An opaque object (may be none) that is forwarded to
|
||||
transform_ast.
|
||||
extra_locals: A Dict[Text, Any] containing additional variables to make
|
||||
available to the transformed code. These will be visible as local
|
||||
variables.
|
||||
Returns:
|
||||
A tuple:
|
||||
* A function or lambda with the same signature and closure as `fn`
|
||||
* The temporary module into which the transformed function was loaded
|
||||
* The source map as a
|
||||
Dict[origin_info.LineLocation, origin_info.OriginInfo]
|
||||
|
||||
"""
|
||||
factory = self._transformed_factory(fn, caching_subkey, user_context,
|
||||
extra_locals)
|
||||
|
||||
transformed_fn = factory.instantiate(
|
||||
globals_=fn.__globals__,
|
||||
closure=fn.__closure__ or (),
|
||||
defaults=fn.__defaults__,
|
||||
kwdefaults=getattr(fn, '__kwdefaults__', None))
|
||||
return transformed_fn, factory.module, factory.source_map
|
249
tensorflow/python/autograph/pyct/transpiler_test.py
Normal file
249
tensorflow/python/autograph/pyct/transpiler_test.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for transpiler module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.pyct import transformer
|
||||
from tensorflow.python.autograph.pyct import transpiler
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class FlipSignTransformer(transformer.Base):
|
||||
|
||||
def visit_BinOp(self, node):
|
||||
if isinstance(node.op, gast.Add):
|
||||
node.op = gast.Sub()
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
class TestTranspiler(transpiler.FunctionTranspiler):
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
return FlipSignTransformer(ctx).visit(node)
|
||||
|
||||
|
||||
global_var_for_test_global = 1
|
||||
global_var_for_test_namespace_collisions = object()
|
||||
|
||||
|
||||
class FunctionTranspilerTest(test.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
def f(a):
|
||||
return a + 1
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 0)
|
||||
|
||||
def test_closure(self):
|
||||
b = 1
|
||||
|
||||
def f(a):
|
||||
return a + b
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 0)
|
||||
b = 2
|
||||
self.assertEqual(f(1), -1)
|
||||
|
||||
def test_global(self):
|
||||
def f(a):
|
||||
return a + global_var_for_test_global
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
global global_var_for_test_global
|
||||
global_var_for_test_global = 1
|
||||
self.assertEqual(f(1), 0)
|
||||
global_var_for_test_global = 2
|
||||
self.assertEqual(f(1), -1)
|
||||
|
||||
def test_defaults(self):
|
||||
b = 2
|
||||
c = 1
|
||||
|
||||
def f(a, d=c + 1):
|
||||
return a + b + d
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 1 - 2 - 2)
|
||||
c = 0
|
||||
self.assertEqual(f(1), 1 - 2 - 2) # Defaults are evaluated at definition.
|
||||
b = 1
|
||||
self.assertEqual(f(1), 1 - 2 - 1)
|
||||
|
||||
def test_call_tree(self):
|
||||
|
||||
def g(a):
|
||||
return a + 1
|
||||
|
||||
def f(a):
|
||||
return g(a) + 1
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted.
|
||||
|
||||
def test_lambda(self):
|
||||
b = 2
|
||||
f = lambda x: (b + (x if x > 0 else -x))
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
self.assertEqual(f(-1), 2 - 1)
|
||||
|
||||
b = 3
|
||||
|
||||
self.assertEqual(f(1), 3 - 1)
|
||||
self.assertEqual(f(-1), 3 - 1)
|
||||
|
||||
def test_multiple_lambdas(self):
|
||||
a, b = 1, 2
|
||||
# This can be disambiguated by the argument names.
|
||||
f, _ = (lambda x: a + x, lambda y: b * y)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 1 - 1)
|
||||
|
||||
def test_multiple_lambdas_indistinguishable_definitions(self):
|
||||
a, b = 1, 2
|
||||
f, _ = (lambda x: a * x, lambda x: b * x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
with self.assertRaises(ValueError):
|
||||
tr.transform_function(f, object(), None, {})
|
||||
|
||||
def test_lambda_code_with_removable_garbage(self):
|
||||
# pylint:disable=g-long-lambda
|
||||
f = ( # intentional wrap
|
||||
lambda x: (
|
||||
x # intentional wrap
|
||||
+ 1),)[0]
|
||||
# pylint:enable=g-long-lambda
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 1 - 1)
|
||||
|
||||
def test_nested_functions(self):
|
||||
b = 2
|
||||
|
||||
def f(x):
|
||||
|
||||
def g(x):
|
||||
return b + x
|
||||
|
||||
return g(x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
|
||||
def test_nested_lambda(self):
|
||||
b = 2
|
||||
|
||||
def f(x):
|
||||
g = lambda x: b + x
|
||||
return g(x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
|
||||
def test_concurrency(self):
|
||||
|
||||
def f():
|
||||
pass
|
||||
|
||||
outputs = []
|
||||
|
||||
tr = TestTranspiler()
|
||||
cache_key = object()
|
||||
def conversion_thread():
|
||||
_, mod, _ = tr.transform_function(f, cache_key, None, {})
|
||||
outputs.append(mod.__name__)
|
||||
|
||||
threads = tuple(
|
||||
threading.Thread(target=conversion_thread) for _ in range(10))
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Races would potentially create multiple functions / modules
|
||||
# (non-deterministically, but with high likelihood).
|
||||
self.assertEqual(len(set(outputs)), 1)
|
||||
|
||||
def test_reentrance(self):
|
||||
|
||||
def test_fn():
|
||||
return 1 + 1
|
||||
|
||||
class ReentrantTranspiler(transpiler.FunctionTranspiler):
|
||||
|
||||
def __init__(self):
|
||||
super(ReentrantTranspiler, self).__init__()
|
||||
self._recursion_depth = 0
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
self._recursion_depth += 1
|
||||
if self._recursion_depth < 2:
|
||||
self.transform_function(test_fn, object(), None, {})
|
||||
return FlipSignTransformer(ctx).visit(node)
|
||||
|
||||
tr = ReentrantTranspiler()
|
||||
|
||||
f, _, _ = tr.transform_function(test_fn, object(), None, {})
|
||||
self.assertEqual(f(), 0)
|
||||
|
||||
def test_namespace_collisions_avoided(self):
|
||||
|
||||
class TestClass(object):
|
||||
|
||||
def global_var_for_test_namespace_collisions(self):
|
||||
return global_var_for_test_namespace_collisions
|
||||
|
||||
tr = TestTranspiler()
|
||||
obj = TestClass()
|
||||
|
||||
f, _, _ = tr.transform_function(
|
||||
obj.global_var_for_test_namespace_collisions, object(), None, {})
|
||||
self.assertIs(f(obj), global_var_for_test_namespace_collisions)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -52,13 +52,6 @@ def alias_tensors(*args):
|
||||
raise ValueError('at least one argument required')
|
||||
|
||||
|
||||
def capitalize_initial(s):
|
||||
"""Capitalizes the initial of a string only."""
|
||||
if s:
|
||||
return s[0].upper() + s[1:]
|
||||
return s
|
||||
|
||||
|
||||
def get_range_len(start, limit, delta):
|
||||
dist = ops.convert_to_tensor(limit - start)
|
||||
unadjusted_len = dist // delta
|
||||
|
@ -29,15 +29,6 @@ from tensorflow.python.platform import test
|
||||
|
||||
class MiscTest(test.TestCase):
|
||||
|
||||
def test_capitalize_initial(self):
|
||||
self.assertEqual('', misc.capitalize_initial(''))
|
||||
self.assertEqual('A', misc.capitalize_initial('A'))
|
||||
self.assertEqual('Ab', misc.capitalize_initial('Ab'))
|
||||
self.assertEqual('AbC', misc.capitalize_initial('AbC'))
|
||||
self.assertEqual('A', misc.capitalize_initial('a'))
|
||||
self.assertEqual('Ab', misc.capitalize_initial('ab'))
|
||||
self.assertEqual('AbC', misc.capitalize_initial('abC'))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_alias_single_tensor(self):
|
||||
a = constant(1)
|
||||
|
@ -198,11 +198,12 @@ def _live_tensors(f, attr_name="inputs"):
|
||||
"""
|
||||
node, _ = parser.parse_entity(f, ())
|
||||
entity_info = transformer.EntityInfo(
|
||||
name=f.__name__,
|
||||
source_code=None,
|
||||
source_file=None,
|
||||
future_features=(),
|
||||
namespace=sys.modules[f.__module__].__dict__)
|
||||
ctx = transformer.Context(entity_info)
|
||||
ctx = transformer.Context(entity_info, None, None)
|
||||
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
|
Loading…
Reference in New Issue
Block a user