Internal cleanup: Move the bulk of the source code transformation infrastructure into the generic pyct module.
PiperOrigin-RevId: 305135067 Change-Id: Ifb84546c35a603942fd864769e7320a7ae95da3b
This commit is contained in:
parent
fc94412b39
commit
ff551c9f20
@ -19,7 +19,6 @@ filegroup(
|
|||||||
py_library(
|
py_library(
|
||||||
name = "converters",
|
name = "converters",
|
||||||
srcs = [
|
srcs = [
|
||||||
"arg_defaults.py",
|
|
||||||
"asserts.py",
|
"asserts.py",
|
||||||
"break_statements.py",
|
"break_statements.py",
|
||||||
"call_trees.py",
|
"call_trees.py",
|
||||||
@ -48,18 +47,6 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
|
||||||
name = "arg_defaults_test",
|
|
||||||
srcs = ["arg_defaults_test.py"],
|
|
||||||
python_version = "PY3",
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":converters",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python/autograph/core:test_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "asserts_test",
|
name = "asserts_test",
|
||||||
srcs = ["asserts_test.py"],
|
srcs = ["asserts_test.py"],
|
||||||
|
@ -1,105 +0,0 @@
|
|||||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Modifies the signature to allow resolving the value of default arguments.
|
|
||||||
|
|
||||||
Normally, function symbols are captured either in a function's globals or
|
|
||||||
closure. This is not true for default arguments, which are evaluated when the
|
|
||||||
function is defined:
|
|
||||||
|
|
||||||
b = 1
|
|
||||||
c = 2
|
|
||||||
def f(a=b + 1):
|
|
||||||
return a + c
|
|
||||||
|
|
||||||
In the above example, the namespace of the function would include `c = 2` but
|
|
||||||
not `b`.
|
|
||||||
|
|
||||||
If we were to naively generate a new function:
|
|
||||||
|
|
||||||
def new_f(a=b + 1):
|
|
||||||
return a + c
|
|
||||||
|
|
||||||
The generated code would fail to load unless we exposed a symbol `b`. Capturing
|
|
||||||
the closure of such an expression is difficult. However, we can capture the
|
|
||||||
default value of argument `a` with relative ease.
|
|
||||||
|
|
||||||
This converter replaces all default argument expressions with a constant so
|
|
||||||
that they don't cause loading to fail. This requires that the default values
|
|
||||||
are reset after loading the transformed function:
|
|
||||||
|
|
||||||
def new_f(a=None):
|
|
||||||
return a + c
|
|
||||||
|
|
||||||
# ... later, after new_f was loaded ...
|
|
||||||
new_f.__defaults__ = f.__defaults__
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import converter
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
|
|
||||||
|
|
||||||
class _Function(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ArgDefaultsTransformer(converter.Base):
|
|
||||||
"""Transforms top level argument defaults."""
|
|
||||||
|
|
||||||
def visit_Lambda(self, node):
|
|
||||||
self.state[_Function].enter()
|
|
||||||
node.args = self.visit(node.args)
|
|
||||||
# Only the top level function is modified - no need to visit the children.
|
|
||||||
self.state[_Function].exit()
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_FunctionDef(self, node):
|
|
||||||
self.state[_Function].enter()
|
|
||||||
node.args = self.visit(node.args)
|
|
||||||
# Only the top level function is modified - no need to visit the children.
|
|
||||||
self.state[_Function].exit()
|
|
||||||
return node
|
|
||||||
|
|
||||||
def visit_arguments(self, node):
|
|
||||||
if self.state[_Function].level > 2:
|
|
||||||
return node
|
|
||||||
|
|
||||||
for i in range(len(node.defaults)):
|
|
||||||
node.defaults[i] = parser.parse_expression('None')
|
|
||||||
|
|
||||||
for i, d in enumerate(node.kw_defaults):
|
|
||||||
if d is not None:
|
|
||||||
node.kw_defaults[i] = parser.parse_expression('None')
|
|
||||||
|
|
||||||
# Only the top level function is modified - no need to visit the children.
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
|
||||||
def transform(node, ctx):
|
|
||||||
"""Transform function call to the compiled counterparts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: AST
|
|
||||||
ctx: EntityContext
|
|
||||||
Returns:
|
|
||||||
A tuple (node, new_names):
|
|
||||||
node: The transformed AST
|
|
||||||
new_names: set(string), containing any newly-generated names
|
|
||||||
"""
|
|
||||||
return ArgDefaultsTransformer(ctx).visit(node)
|
|
@ -1,108 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for arg_defaults module."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.converters import arg_defaults
|
|
||||||
from tensorflow.python.autograph.core import converter_testing
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class ArgDefaultsTransformerTest(converter_testing.TestCase):
|
|
||||||
|
|
||||||
def assertTransformedFirstLineIs(self, node, expected):
|
|
||||||
self.assertEqual(
|
|
||||||
parser.unparse(node, include_encoding_marker=False).split('\n')[0],
|
|
||||||
expected)
|
|
||||||
|
|
||||||
def test_no_args(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
pass
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
self.assertTransformedFirstLineIs(node, 'def test_fn():')
|
|
||||||
|
|
||||||
def test_no_defaults(self):
|
|
||||||
|
|
||||||
def test_fn(a, b, *c, **e):
|
|
||||||
return a, b, c, e
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b, *c, **e):')
|
|
||||||
|
|
||||||
# TODO(mdan): Add kwonly-arg tests when PY2 is no longer supported.
|
|
||||||
|
|
||||||
def test_arg_defaults(self):
|
|
||||||
|
|
||||||
def test_fn(a, b=1, c=2):
|
|
||||||
return a, b, c
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, c=None):')
|
|
||||||
|
|
||||||
def test_arg_defaults_with_vararg(self):
|
|
||||||
|
|
||||||
def test_fn(a, b=1, *c): # pylint: disable=keyword-arg-before-vararg
|
|
||||||
return a, b, c
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
self.assertTransformedFirstLineIs(node, 'def test_fn(a, b=None, *c):')
|
|
||||||
|
|
||||||
def test_arg_defaults_ignores_inner_lambda(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
return (lambda x=7: x)()
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
|
||||||
self.assertEqual(test_fn(), result.test_fn())
|
|
||||||
|
|
||||||
def test_arg_defaults_ignores_inner_function(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
def inner_fn(a=3):
|
|
||||||
return a
|
|
||||||
return inner_fn()
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
|
||||||
self.assertEqual(test_fn(), result.test_fn())
|
|
||||||
|
|
||||||
def test_arg_defaults_ignores_inner_function_returned(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
def inner_fn(a=3):
|
|
||||||
return a
|
|
||||||
return inner_fn
|
|
||||||
|
|
||||||
node, ctx = self.prepare(test_fn, {})
|
|
||||||
node = arg_defaults.transform(node, ctx)
|
|
||||||
with self.converted(test_fn, arg_defaults, {}) as result:
|
|
||||||
self.assertEqual(test_fn()(), result.test_fn()())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test.main()
|
|
@ -191,7 +191,7 @@ class CallTreeTransformer(converter.Base):
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
if (full_name == 'print' and
|
if (full_name == 'print' and
|
||||||
not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
template = """
|
template = """
|
||||||
|
@ -54,8 +54,8 @@ class FunctionTransformer(converter.Base):
|
|||||||
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
# ControlStatusCtx(autograph=ENABLED) when user_requested is True. See
|
||||||
# function_wrappers.py.
|
# function_wrappers.py.
|
||||||
if fn_scope.level == 2:
|
if fn_scope.level == 2:
|
||||||
return self.ctx.program.options
|
return self.ctx.user.options
|
||||||
return self.ctx.program.options.call_options()
|
return self.ctx.user.options.call_options()
|
||||||
|
|
||||||
def visit_Lambda(self, node):
|
def visit_Lambda(self, node):
|
||||||
with self.state[_Function] as fn_scope:
|
with self.state[_Function] as fn_scope:
|
||||||
|
@ -53,7 +53,7 @@ class LogicalExpressionTransformer(converter.Base):
|
|||||||
op_type = type(operator)
|
op_type = type(operator)
|
||||||
if op_type in LOGICAL_OPERATORS:
|
if op_type in LOGICAL_OPERATORS:
|
||||||
return LOGICAL_OPERATORS[op_type]
|
return LOGICAL_OPERATORS[op_type]
|
||||||
if self.ctx.program.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
|
||||||
if op_type in EQUALITY_OPERATORS:
|
if op_type in EQUALITY_OPERATORS:
|
||||||
return EQUALITY_OPERATORS[op_type]
|
return EQUALITY_OPERATORS[op_type]
|
||||||
return None
|
return None
|
||||||
@ -83,7 +83,7 @@ class LogicalExpressionTransformer(converter.Base):
|
|||||||
def visit_Compare(self, node):
|
def visit_Compare(self, node):
|
||||||
node = self.generic_visit(node)
|
node = self.generic_visit(node)
|
||||||
|
|
||||||
if (not self.ctx.program.options.uses(
|
if (not self.ctx.user.options.uses(
|
||||||
converter.Feature.EQUALITY_OPERATORS)):
|
converter.Feature.EQUALITY_OPERATORS)):
|
||||||
return node
|
return node
|
||||||
|
|
||||||
|
@ -253,25 +253,6 @@ class ProgramContext(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class EntityContext(transformer.Context):
|
|
||||||
"""Tracks the conversion of a single entity.
|
|
||||||
|
|
||||||
This object is mutable, and is updated during conversion. Not thread safe.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
namer: Namer
|
|
||||||
info: transformer.EntityInfo
|
|
||||||
program: ProgramContext,
|
|
||||||
targe_name: Text
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, namer, entity_info, program_ctx, target_name=None):
|
|
||||||
super(EntityContext, self).__init__(entity_info)
|
|
||||||
self.namer = namer
|
|
||||||
self.program = program_ctx
|
|
||||||
self.target_name = target_name
|
|
||||||
|
|
||||||
|
|
||||||
class Base(transformer.Base):
|
class Base(transformer.Base):
|
||||||
"""All converters should inherit from this class.
|
"""All converters should inherit from this class.
|
||||||
|
|
||||||
|
@ -168,12 +168,12 @@ class TestCase(test.TestCase):
|
|||||||
options=converter.ConversionOptions(recursive=recursive),
|
options=converter.ConversionOptions(recursive=recursive),
|
||||||
autograph_module=None)
|
autograph_module=None)
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
|
name=test_fn.__name__,
|
||||||
source_code=source,
|
source_code=source,
|
||||||
source_file='<fragment>',
|
source_file='<fragment>',
|
||||||
future_features=future_features,
|
future_features=future_features,
|
||||||
namespace=namespace)
|
namespace=namespace)
|
||||||
ctx = converter.EntityContext(
|
ctx = transformer.Context(entity_info, namer, program_ctx)
|
||||||
namer, entity_info, program_ctx, 'test_fn')
|
|
||||||
origin_info.resolve_entity(node, source, test_fn)
|
origin_info.resolve_entity(node, source, test_fn)
|
||||||
node = converter.standard_analysis(node, ctx, is_initial=True)
|
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||||
return node, ctx
|
return node, ctx
|
||||||
|
@ -29,10 +29,7 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
# pylint:disable=g-bad-import-order
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
# pylint:enable=g-bad-import-order
|
|
||||||
|
|
||||||
from tensorflow.python.autograph.core import ag_ctx
|
from tensorflow.python.autograph.core import ag_ctx
|
||||||
from tensorflow.python.autograph.core import converter
|
from tensorflow.python.autograph.core import converter
|
||||||
@ -668,7 +665,7 @@ def to_graph(entity, recursive=True, experimental_optional_features=None):
|
|||||||
user_requested=True,
|
user_requested=True,
|
||||||
optional_features=experimental_optional_features),
|
optional_features=experimental_optional_features),
|
||||||
autograph_module=tf_inspect.getmodule(to_graph))
|
autograph_module=tf_inspect.getmodule(to_graph))
|
||||||
return conversion.convert(entity, program_ctx)
|
return autograph_artifact(conversion.convert(entity, program_ctx))
|
||||||
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
|
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
|
||||||
logging.error(1, 'Error converting %s', entity, exc_info=True)
|
logging.error(1, 'Error converting %s', entity, exc_info=True)
|
||||||
raise ConversionError('converting {}: {}: {}'.format(
|
raise ConversionError('converting {}: {}: {}'.format(
|
||||||
|
@ -18,21 +18,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
|
||||||
import functools
|
import functools
|
||||||
import imp
|
import imp
|
||||||
import inspect
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import types
|
|
||||||
import unittest
|
import unittest
|
||||||
import weakref
|
|
||||||
|
|
||||||
import gast
|
|
||||||
|
|
||||||
from tensorflow.python.autograph import operators
|
from tensorflow.python.autograph import operators
|
||||||
from tensorflow.python.autograph import utils
|
from tensorflow.python.autograph import utils
|
||||||
from tensorflow.python.autograph.converters import arg_defaults
|
|
||||||
from tensorflow.python.autograph.converters import asserts
|
from tensorflow.python.autograph.converters import asserts
|
||||||
from tensorflow.python.autograph.converters import break_statements
|
from tensorflow.python.autograph.converters import break_statements
|
||||||
from tensorflow.python.autograph.converters import call_trees
|
from tensorflow.python.autograph.converters import call_trees
|
||||||
@ -50,304 +41,70 @@ from tensorflow.python.autograph.core import converter
|
|||||||
from tensorflow.python.autograph.core import function_wrappers
|
from tensorflow.python.autograph.core import function_wrappers
|
||||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||||
from tensorflow.python.autograph.lang import special_functions
|
from tensorflow.python.autograph.lang import special_functions
|
||||||
from tensorflow.python.autograph.pyct import ast_util
|
from tensorflow.python.autograph.pyct import cache
|
||||||
from tensorflow.python.autograph.pyct import inspect_utils
|
from tensorflow.python.autograph.pyct import inspect_utils
|
||||||
from tensorflow.python.autograph.pyct import loader
|
from tensorflow.python.autograph.pyct import transpiler
|
||||||
from tensorflow.python.autograph.pyct import naming
|
|
||||||
from tensorflow.python.autograph.pyct import origin_info
|
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.autograph.pyct import pretty_printer
|
|
||||||
from tensorflow.python.autograph.pyct import templates
|
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
|
||||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
class _ConvertedEntityFactoryInfo(
|
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
||||||
collections.namedtuple(
|
|
||||||
'_ConvertedEntityFactoryInfo',
|
|
||||||
('module_name', 'converted_name', 'factory_factory_name', 'source_map'))
|
|
||||||
):
|
|
||||||
"""Holds metadata about a converted entity stored as a dynamic factory.
|
|
||||||
|
|
||||||
The dynamic factory is assumed to be created by _wrap_into_dynamic_factory,
|
def get_transformed_name(self, node):
|
||||||
be named `factory_factory_name` and located inside the module named as
|
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
||||||
`module_name`.
|
|
||||||
|
|
||||||
Attributes:
|
def transform_ast(self, node, ctx):
|
||||||
module_name: Text, the name of the module containing the entity.
|
# TODO(mdan): Insert list_comprehensions somewhere.
|
||||||
converted_name: Text, the name of the converted entity.
|
unsupported_features_checker.verify(node)
|
||||||
factory_factory_name: Text, the name of the dynamic factory.
|
|
||||||
source_map: Dict.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __str__(self):
|
node = converter.standard_analysis(node, ctx, is_initial=True)
|
||||||
return '_ConvertedEntityFactoryInfo({} in {})'.format(
|
node = converter.apply_(node, ctx, functions)
|
||||||
self.converted_name, self.module_name)
|
node = converter.apply_(node, ctx, directives)
|
||||||
|
node = converter.apply_(node, ctx, break_statements)
|
||||||
def get_module(self):
|
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||||
return sys.modules[self.module_name]
|
node = converter.apply_(node, ctx, asserts)
|
||||||
|
# Note: sequencing continue canonicalization before for loop one avoids
|
||||||
def get_factory(self):
|
# dealing with the extra loop increment operation that the for
|
||||||
assert self.module_name in sys.modules
|
# canonicalization creates.
|
||||||
factory_factory = getattr(sys.modules[self.module_name],
|
node = converter.apply_(node, ctx, continue_statements)
|
||||||
self.factory_factory_name)
|
node = converter.apply_(node, ctx, return_statements)
|
||||||
return factory_factory()
|
if ctx.user.options.uses(converter.Feature.LISTS):
|
||||||
|
node = converter.apply_(node, ctx, lists)
|
||||||
|
node = converter.apply_(node, ctx, slices)
|
||||||
|
node = converter.apply_(node, ctx, call_trees)
|
||||||
|
node = converter.apply_(node, ctx, control_flow)
|
||||||
|
node = converter.apply_(node, ctx, conditional_expressions)
|
||||||
|
node = converter.apply_(node, ctx, logical_expressions)
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Add a garbage collection hook for cleaning up modules.
|
_TRANSPILER = AutoGraphTranspiler()
|
||||||
class _FunctionCache(object):
|
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
||||||
"""A hierarchical cache that uses the converted entity as weak key.
|
|
||||||
|
|
||||||
The keys soft references (i.e. they are discarded when the key is
|
|
||||||
destroyed). The subkeys are normal hashable values.
|
|
||||||
|
|
||||||
This class is generic - see the call site for how the keys and values are
|
|
||||||
defined.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ('_cache',)
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._cache = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
def _get_key(self, entity):
|
|
||||||
raise NotImplementedError('subclasses will override')
|
|
||||||
|
|
||||||
def has(self, entity, subkey):
|
|
||||||
key = self._get_key(entity)
|
|
||||||
if key not in self._cache:
|
|
||||||
return False
|
|
||||||
return subkey in self._cache[key]
|
|
||||||
|
|
||||||
def __getitem__(self, entity):
|
|
||||||
key = self._get_key(entity)
|
|
||||||
if key not in self._cache:
|
|
||||||
# The bucket needs to be initialized to support this usage:
|
|
||||||
# cache[key][subkey] = value
|
|
||||||
self._cache[key] = {}
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._cache)
|
|
||||||
|
|
||||||
|
|
||||||
class _CodeObjectCache(_FunctionCache):
|
custom_vars = None
|
||||||
"""A function cache based on code objects (i.e., the source code).
|
|
||||||
|
|
||||||
Multiple functions may share the same code object, but they may share the
|
|
||||||
cache because we know they have the exact source code. This properly handles
|
|
||||||
functions defined in a loop, bound methods, etc.
|
|
||||||
|
|
||||||
Falls back to the function object, if it doesn't have a code object.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_key(self, entity):
|
|
||||||
if hasattr(entity, '__code__'):
|
|
||||||
return entity.__code__
|
|
||||||
else:
|
|
||||||
return entity
|
|
||||||
|
|
||||||
|
|
||||||
class _UnboundInstanceCache(_FunctionCache):
|
|
||||||
"""A function cache based on unbound function objects.
|
|
||||||
|
|
||||||
Unlike the _CodeObjectCache, this discriminates between different functions
|
|
||||||
even if they have the same code. This properly handles decorators that may
|
|
||||||
masquerade as various functions. Bound functions are not discriminated by
|
|
||||||
the object they're bound to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _get_key(self, entity):
|
|
||||||
if inspect.ismethod(entity):
|
|
||||||
return entity.__func__
|
|
||||||
return entity
|
|
||||||
|
|
||||||
|
|
||||||
# Using a re-entrant lock to guard against the unlikely possibility that the
|
|
||||||
# conversion process triggers additional code execution.
|
|
||||||
_CACHE_LOCK = threading.RLock()
|
|
||||||
|
|
||||||
|
|
||||||
_CACHE = _CodeObjectCache()
|
|
||||||
_WHITELIST_CACHE = _UnboundInstanceCache()
|
|
||||||
|
|
||||||
|
|
||||||
# Note: strictly speaking, a simple factory might have been sufficient for
|
|
||||||
# functions. But the double factory approach allows us to control the closure
|
|
||||||
# and globals of the converted code in a cleaner fashion.
|
|
||||||
# TODO(mdan): A simple factory may be sufficient.
|
|
||||||
def _wrap_into_dynamic_factory(nodes, entity_name, factory_factory_name,
|
|
||||||
factory_name, closure_vars, future_features):
|
|
||||||
"""Wraps an AST into the body of a dynamic factory.
|
|
||||||
|
|
||||||
This uses the dynamic factory (factory of factory) pattern to achieve the
|
|
||||||
following:
|
|
||||||
|
|
||||||
1. The inner factory, dynamically creates the entity represented by nodes.
|
|
||||||
2. The entity is parametrized by `ag__`, the internal AutoGraph module.
|
|
||||||
3. The outer factory creates the inner factory with a lexical scope
|
|
||||||
in which `closure_vars` are bound local variables. This in turn allows the
|
|
||||||
caller to control the exact closure (i.e. non-global free variables) for
|
|
||||||
the inner factory.
|
|
||||||
|
|
||||||
The AST is expected to define some symbol named by `entity_name`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nodes: ast.AST
|
|
||||||
entity_name: Union[Text, ast.AST]
|
|
||||||
factory_factory_name: Text
|
|
||||||
factory_name: Text
|
|
||||||
closure_vars: Iterable[Text]
|
|
||||||
future_features: Iterable[Text], see EntityInfo.future_features.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ast.AST
|
|
||||||
"""
|
|
||||||
if not isinstance(nodes, (list, tuple)):
|
|
||||||
nodes = (nodes,)
|
|
||||||
|
|
||||||
dummy_closure_defs = []
|
|
||||||
for var_name in closure_vars:
|
|
||||||
template = """
|
|
||||||
var_name = None
|
|
||||||
"""
|
|
||||||
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
|
|
||||||
|
|
||||||
if future_features:
|
|
||||||
future_imports = gast.ImportFrom(
|
|
||||||
module='__future__',
|
|
||||||
names=[gast.alias(name=name, asname=None) for name in future_features],
|
|
||||||
level=0)
|
|
||||||
else:
|
|
||||||
future_imports = []
|
|
||||||
|
|
||||||
# These dummy symbol declarations create local fariables in a function scope,
|
|
||||||
# so that the Python parser correctly marks them as free non-global variables
|
|
||||||
# upon load (that is, it creates cell slots for each symbol). Their values are
|
|
||||||
# not used, as the cells are swapped with the original entity's cells after
|
|
||||||
# the code has been loaded.
|
|
||||||
template = """
|
|
||||||
future_imports
|
|
||||||
def factory_factory_name():
|
|
||||||
dummy_closure_defs
|
|
||||||
def factory_name(ag__, ag_source_map__, ag_module__):
|
|
||||||
entity_defs
|
|
||||||
entity_name.ag_source_map = ag_source_map__
|
|
||||||
entity_name.ag_module = ag_module__
|
|
||||||
entity_name = ag__.autograph_artifact(entity_name)
|
|
||||||
return entity_name
|
|
||||||
return factory_name
|
|
||||||
"""
|
|
||||||
return templates.replace(
|
|
||||||
template,
|
|
||||||
future_imports=future_imports,
|
|
||||||
factory_factory_name=factory_factory_name,
|
|
||||||
factory_name=factory_name,
|
|
||||||
dummy_closure_defs=dummy_closure_defs,
|
|
||||||
entity_defs=nodes,
|
|
||||||
entity_name=entity_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
|
|
||||||
"""Returns a (possibly cached) factory for the converted result of entity."""
|
|
||||||
# The cache subkey encompasses any conversion options on which the generated
|
|
||||||
# code may depend.
|
|
||||||
# The cached factory includes the necessary definitions to distinguish
|
|
||||||
# between the global and non-global free variables. For this reason, the
|
|
||||||
# cache subkey includes the names of the free non-globals.
|
|
||||||
subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))
|
|
||||||
|
|
||||||
with _CACHE_LOCK:
|
|
||||||
# The cache values are _ConvertedEntityFactoryInfo objects.
|
|
||||||
if _CACHE.has(entity, subkey):
|
|
||||||
# TODO(mdan): Check whether the module is still loaded.
|
|
||||||
converted_entity_info = _CACHE[entity][subkey]
|
|
||||||
logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity, subkey,
|
|
||||||
converted_entity_info)
|
|
||||||
return converted_entity_info
|
|
||||||
|
|
||||||
logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey)
|
|
||||||
|
|
||||||
nodes, converted_name, entity_info = convert_entity_to_ast(
|
|
||||||
entity, program_ctx)
|
|
||||||
|
|
||||||
namer = naming.Namer(entity_info.namespace)
|
|
||||||
factory_factory_name = namer.new_symbol('create_converted_entity_factory',
|
|
||||||
())
|
|
||||||
factory_name = namer.new_symbol('create_converted_entity', ())
|
|
||||||
nodes = _wrap_into_dynamic_factory(nodes, converted_name,
|
|
||||||
factory_factory_name, factory_name,
|
|
||||||
free_nonglobal_var_names,
|
|
||||||
entity_info.future_features)
|
|
||||||
|
|
||||||
module, _, source_map = loader.load_ast(nodes, include_source_map=True)
|
|
||||||
module_name = module.__name__
|
|
||||||
|
|
||||||
converted_entity_info = _ConvertedEntityFactoryInfo(
|
|
||||||
module_name=module_name,
|
|
||||||
converted_name=converted_name,
|
|
||||||
factory_factory_name=factory_factory_name,
|
|
||||||
source_map=source_map)
|
|
||||||
_CACHE[entity][subkey] = converted_entity_info
|
|
||||||
return converted_entity_info
|
|
||||||
|
|
||||||
|
|
||||||
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
|
|
||||||
"""Creates a converted instance and binds it to match original entity."""
|
|
||||||
factory = converted_entity_info.get_factory()
|
|
||||||
|
|
||||||
entity_globals = entity.__globals__
|
|
||||||
entity_closure = entity.__closure__ or ()
|
|
||||||
assert len(entity_closure) == len(free_nonglobal_var_names)
|
|
||||||
|
|
||||||
# Fit the original entity's cells to match the order of factory's cells.
|
|
||||||
original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure))
|
|
||||||
new_factory_cells = tuple(
|
|
||||||
original_names_and_cells[name] for name in factory.__code__.co_freevars)
|
|
||||||
|
|
||||||
bound_factory = types.FunctionType(
|
|
||||||
code=factory.__code__,
|
|
||||||
globals=entity_globals,
|
|
||||||
name=factory.__name__,
|
|
||||||
argdefs=(),
|
|
||||||
closure=new_factory_cells)
|
|
||||||
|
|
||||||
# Two other free vars: the internal "ag__" module and the source
|
|
||||||
# map. These are wired via the parameters of the factory.
|
|
||||||
converted_entity = bound_factory( # pylint:disable=not-callable
|
|
||||||
ag_internal, converted_entity_info.source_map,
|
|
||||||
converted_entity_info.get_module())
|
|
||||||
|
|
||||||
# Attach the default argument to the converted function.
|
|
||||||
converted_entity.__defaults__ = entity.__defaults__
|
|
||||||
if hasattr(entity, '__kwdefaults__'):
|
|
||||||
converted_entity.__kwdefaults__ = entity.__kwdefaults__
|
|
||||||
|
|
||||||
return converted_entity
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(mdan): Superfluous function, remove.
|
||||||
|
# TODO(mdan): Put these extra fields inside __autograph_info__.
|
||||||
def convert(entity, program_ctx):
|
def convert(entity, program_ctx):
|
||||||
"""Converts an entity into an equivalent entity."""
|
"""Applies AutoGraph to entity."""
|
||||||
|
|
||||||
if not hasattr(entity, '__code__'):
|
if not hasattr(entity, '__code__'):
|
||||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||||
'expose a __code__ object. If this is a @tf.function,'
|
'expose a __code__ object. If this is a @tf.function,'
|
||||||
' try passing f.python_function instead.')
|
' try passing f.python_function instead.')
|
||||||
free_nonglobal_var_names = entity.__code__.co_freevars
|
|
||||||
|
|
||||||
for i, name in enumerate(free_nonglobal_var_names):
|
_create_custom_vars(program_ctx)
|
||||||
if (name == 'ag__' and
|
transformed, module, source_map = _TRANSPILER.transform_function(
|
||||||
entity.__closure__[i].cell_contents is not ag_internal):
|
entity, program_ctx.options, program_ctx, custom_vars)
|
||||||
raise ValueError('entity {} uses the reserved symbol "{}"'.format(
|
|
||||||
entity, name))
|
|
||||||
# TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered.
|
|
||||||
|
|
||||||
converted_entity_info = _convert_with_cache(entity, program_ctx,
|
assert not hasattr(transformed, 'ag_module')
|
||||||
free_nonglobal_var_names)
|
assert not hasattr(transformed, 'ag_source_map')
|
||||||
|
transformed.ag_module = module
|
||||||
return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
|
transformed.ag_source_map = source_map
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
||||||
@ -472,58 +229,15 @@ def cache_whitelisted(entity, options):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Rename to convert_*_node to avoid confusion with convert.
|
|
||||||
def convert_entity_to_ast(o, program_ctx):
|
|
||||||
"""Compile a Python entity into equivalent TensorFlow.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
o: A Python entity.
|
|
||||||
program_ctx: A ProgramContext object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (ast, new_name, namespace):
|
|
||||||
* ast: An AST representing an entity with interface equivalent to `o`,
|
|
||||||
but which when executed it creates TF a graph.
|
|
||||||
* new_name: The symbol name under which the new entity can be found.
|
|
||||||
* namespace: A dict mapping all symbols visible to the converted entity,
|
|
||||||
keyed by their symbol name.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: if entity is of a type that is not yet supported.
|
|
||||||
"""
|
|
||||||
logging.log(1, 'Converting %s', o)
|
|
||||||
|
|
||||||
nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
|
|
||||||
|
|
||||||
if logging.has_verbosity(2):
|
|
||||||
logging.log(2, 'Compiled output of %s:\n\n%s\n', o, parser.unparse(nodes))
|
|
||||||
if logging.has_verbosity(4):
|
|
||||||
for n in nodes:
|
|
||||||
logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
|
|
||||||
pretty_printer.fmt(n, color=False))
|
|
||||||
|
|
||||||
return nodes, name, entity_info
|
|
||||||
|
|
||||||
|
|
||||||
def _add_reserved_symbol(namespace, name, entity):
|
|
||||||
if name not in namespace:
|
|
||||||
namespace[name] = entity
|
|
||||||
elif namespace[name] != entity:
|
|
||||||
raise ValueError('The name "%s" is reserved and may not be used.' % name)
|
|
||||||
|
|
||||||
|
|
||||||
ag_internal = None
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Move into core or replace with an actual importable module.
|
# TODO(mdan): Move into core or replace with an actual importable module.
|
||||||
def _add_self_references(namespace, autograph_module):
|
def _create_custom_vars(program_ctx):
|
||||||
"""Adds namespace references to the module that exposes the api itself."""
|
"""Adds namespace references to the module that exposes the api itself."""
|
||||||
global ag_internal
|
global custom_vars
|
||||||
if ag_internal is None:
|
if custom_vars is None:
|
||||||
# Craft a module that exposes parts of the external API as well as certain
|
# Craft a module that exposes parts of the external API as well as certain
|
||||||
# internal modules.
|
# internal modules.
|
||||||
ag_internal = imp.new_module('autograph')
|
ag_internal = imp.new_module('autograph')
|
||||||
ag_internal.__dict__.update(autograph_module.__dict__)
|
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
|
||||||
ag_internal.ConversionOptions = converter.ConversionOptions
|
ag_internal.ConversionOptions = converter.ConversionOptions
|
||||||
ag_internal.STD = converter.STANDARD_OPTIONS
|
ag_internal.STD = converter.STANDARD_OPTIONS
|
||||||
ag_internal.Feature = converter.Feature
|
ag_internal.Feature = converter.Feature
|
||||||
@ -536,102 +250,4 @@ def _add_self_references(namespace, autograph_module):
|
|||||||
ag_internal.__dict__.update(special_functions.__dict__)
|
ag_internal.__dict__.update(special_functions.__dict__)
|
||||||
ag_internal.__dict__.update(operators.__dict__)
|
ag_internal.__dict__.update(operators.__dict__)
|
||||||
|
|
||||||
_add_reserved_symbol(namespace, 'ag__', ag_internal)
|
custom_vars = {'ag__': ag_internal}
|
||||||
|
|
||||||
|
|
||||||
def convert_func_to_ast(f, program_ctx, do_rename=True):
|
|
||||||
"""Specialization of `convert_entity_to_ast` for callable functions."""
|
|
||||||
|
|
||||||
future_features = inspect_utils.getfutureimports(f)
|
|
||||||
node, source = parser.parse_entity(f, future_features=future_features)
|
|
||||||
logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
|
|
||||||
# Parsed AST should contain future imports and one function def node.
|
|
||||||
|
|
||||||
# In general, the output of inspect.getsource is inexact for lambdas because
|
|
||||||
# it uses regex matching to adjust the exact location around the line number
|
|
||||||
# that CPython records. Then, the entire containing line is returned, which
|
|
||||||
# we may have trouble disambiguating. For example:
|
|
||||||
# x, y = lambda: 1, lambda: 2
|
|
||||||
if f.__name__ == '<lambda>':
|
|
||||||
nodes = ast_util.find_matching_definitions(node, f)
|
|
||||||
if len(nodes) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
'Unable to identify source code of lambda function {}. It was'
|
|
||||||
' defined on this line: {}, which must contain a single lambda with'
|
|
||||||
' matching signature. To avoid ambiguity, define each lambda'
|
|
||||||
' in a separate expression.'.format(f, source))
|
|
||||||
node, = nodes
|
|
||||||
|
|
||||||
# TODO(znado): Place inside standard_analysis.
|
|
||||||
origin_info.resolve_entity(node, source, f)
|
|
||||||
|
|
||||||
namespace = inspect_utils.getnamespace(f)
|
|
||||||
_add_self_references(namespace, program_ctx.autograph_module)
|
|
||||||
namer = naming.Namer(namespace)
|
|
||||||
|
|
||||||
if isinstance(node, gast.Lambda):
|
|
||||||
new_name = namer.new_symbol('tf__lambda', ())
|
|
||||||
elif do_rename:
|
|
||||||
new_name = namer.new_symbol('tf__' + f.__name__, ())
|
|
||||||
else:
|
|
||||||
new_name = f.__name__
|
|
||||||
|
|
||||||
entity_info = transformer.EntityInfo(
|
|
||||||
source_code=source,
|
|
||||||
source_file='<fragment>',
|
|
||||||
future_features=future_features,
|
|
||||||
namespace=namespace)
|
|
||||||
context = converter.EntityContext(namer, entity_info, program_ctx, new_name)
|
|
||||||
node = node_to_graph(node, context)
|
|
||||||
|
|
||||||
if isinstance(node, gast.Lambda):
|
|
||||||
node = gast.Assign(
|
|
||||||
targets=[
|
|
||||||
gast.Name(
|
|
||||||
new_name, ctx=gast.Store(), annotation=None, type_comment=None)
|
|
||||||
],
|
|
||||||
value=node)
|
|
||||||
elif do_rename:
|
|
||||||
node.name = new_name
|
|
||||||
else:
|
|
||||||
assert node.name == new_name
|
|
||||||
|
|
||||||
return (node,), new_name, entity_info
|
|
||||||
|
|
||||||
|
|
||||||
def node_to_graph(node, context):
|
|
||||||
"""Convert Python code to equivalent TF graph mode code.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node: AST, the code to convert.
|
|
||||||
context: converter.EntityContext
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple (node, deps):
|
|
||||||
* node: A Python ast node, representing the converted code.
|
|
||||||
* deps: A set of strings, the fully qualified names of entity
|
|
||||||
dependencies that this node has.
|
|
||||||
"""
|
|
||||||
# TODO(mdan): Insert list_comprehensions somewhere.
|
|
||||||
unsupported_features_checker.verify(node)
|
|
||||||
|
|
||||||
node = converter.standard_analysis(node, context, is_initial=True)
|
|
||||||
node = converter.apply_(node, context, functions)
|
|
||||||
node = converter.apply_(node, context, arg_defaults)
|
|
||||||
node = converter.apply_(node, context, directives)
|
|
||||||
node = converter.apply_(node, context, break_statements)
|
|
||||||
if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
|
||||||
node = converter.apply_(node, context, asserts)
|
|
||||||
# Note: sequencing continue canonicalization before for loop one avoids
|
|
||||||
# dealing with the extra loop increment operation that the for
|
|
||||||
# canonicalization creates.
|
|
||||||
node = converter.apply_(node, context, continue_statements)
|
|
||||||
node = converter.apply_(node, context, return_statements)
|
|
||||||
if context.program.options.uses(converter.Feature.LISTS):
|
|
||||||
node = converter.apply_(node, context, lists)
|
|
||||||
node = converter.apply_(node, context, slices)
|
|
||||||
node = converter.apply_(node, context, call_trees)
|
|
||||||
node = converter.apply_(node, context, control_flow)
|
|
||||||
node = converter.apply_(node, context, conditional_expressions)
|
|
||||||
node = converter.apply_(node, context, logical_expressions)
|
|
||||||
return node
|
|
||||||
|
@ -20,11 +20,9 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import imp
|
import imp
|
||||||
import sys
|
import sys
|
||||||
import threading
|
|
||||||
import types
|
import types
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import gast
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.autograph import utils
|
from tensorflow.python.autograph import utils
|
||||||
@ -33,7 +31,6 @@ from tensorflow.python.autograph.core import converter
|
|||||||
from tensorflow.python.autograph.impl import api
|
from tensorflow.python.autograph.impl import api
|
||||||
from tensorflow.python.autograph.impl import conversion
|
from tensorflow.python.autograph.impl import conversion
|
||||||
from tensorflow.python.autograph.impl.testing import pybind_for_testing
|
from tensorflow.python.autograph.impl.testing import pybind_for_testing
|
||||||
from tensorflow.python.autograph.pyct import parser
|
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -126,156 +123,6 @@ class ConversionTest(test.TestCase):
|
|||||||
# Note: currently, native bindings are whitelisted by a separate check.
|
# Note: currently, native bindings are whitelisted by a separate check.
|
||||||
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
self.assertFalse(conversion.is_whitelisted(test_object.method))
|
||||||
|
|
||||||
def test_convert_entity_to_ast_callable(self):
|
|
||||||
b = 2
|
|
||||||
|
|
||||||
def f(a):
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
nodes, name, info = conversion.convert_entity_to_ast(f, program_ctx)
|
|
||||||
fn_node, = nodes
|
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
|
||||||
self.assertEqual('tf__f', name)
|
|
||||||
self.assertIs(info.namespace['b'], b)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_function_with_defaults(self):
|
|
||||||
b = 2
|
|
||||||
c = 1
|
|
||||||
|
|
||||||
def f(a, d=c + 1):
|
|
||||||
return a + b + d
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
nodes, name, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
|
||||||
fn_node, = nodes
|
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
|
||||||
self.assertEqual('tf__f', name)
|
|
||||||
self.assertEqual(
|
|
||||||
parser.unparse(fn_node.args.defaults[0],
|
|
||||||
include_encoding_marker=False).strip(), 'None')
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_call_tree(self):
|
|
||||||
|
|
||||||
def g(a):
|
|
||||||
return a
|
|
||||||
|
|
||||||
def f(a):
|
|
||||||
return g(a)
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
nodes, _, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
|
||||||
f_node, = nodes
|
|
||||||
self.assertEqual('tf__f', f_node.name)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_lambda(self):
|
|
||||||
b = 2
|
|
||||||
f = lambda x: b * x if x > 0 else -x
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
|
||||||
f, program_ctx)
|
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
|
||||||
self.assertEqual('tf__lambda', name)
|
|
||||||
self.assertIs(entity_info.namespace['b'], b)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_multiple_lambdas(self):
|
|
||||||
a, b = 1, 2
|
|
||||||
f, _ = (lambda x: a * x, lambda y: b * y)
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
|
||||||
f, program_ctx)
|
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
|
||||||
self.assertEqual('tf__lambda', name)
|
|
||||||
self.assertIs(entity_info.namespace['a'], a)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_multiple_lambdas_ambiguous_definitions(self):
|
|
||||||
a, b = 1, 2
|
|
||||||
f, _ = (lambda x: a * x, lambda x: b * x)
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
conversion.convert_entity_to_ast(f, program_ctx)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_lambda_code_with_garbage(self):
|
|
||||||
# pylint:disable=g-long-lambda
|
|
||||||
f = ( # intentional wrap
|
|
||||||
lambda x: (
|
|
||||||
x # intentional wrap
|
|
||||||
+ 1),)[0]
|
|
||||||
# pylint:enable=g-long-lambda
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
(fn_node,), name, _ = conversion.convert_entity_to_ast(f, program_ctx)
|
|
||||||
self.assertIsInstance(fn_node, gast.Assign)
|
|
||||||
self.assertIsInstance(fn_node.value, gast.Lambda)
|
|
||||||
self.assertEqual('tf__lambda', name)
|
|
||||||
|
|
||||||
def test_convert_entity_to_ast_nested_functions(self):
|
|
||||||
b = 2
|
|
||||||
|
|
||||||
def f(x):
|
|
||||||
|
|
||||||
def g(x):
|
|
||||||
return b * x
|
|
||||||
|
|
||||||
return g(x)
|
|
||||||
|
|
||||||
program_ctx = self._simple_program_ctx()
|
|
||||||
(fn_node,), name, entity_info = conversion.convert_entity_to_ast(
|
|
||||||
f, program_ctx)
|
|
||||||
self.assertIsInstance(fn_node, gast.FunctionDef)
|
|
||||||
self.assertEqual(fn_node.name, 'tf__f')
|
|
||||||
self.assertEqual('tf__f', name)
|
|
||||||
self.assertIs(entity_info.namespace['b'], b)
|
|
||||||
|
|
||||||
def test_convert_concurrency(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
pass
|
|
||||||
|
|
||||||
generated_file_names = []
|
|
||||||
|
|
||||||
def conversion_thread():
|
|
||||||
new_f = conversion.convert(test_fn, self._simple_program_ctx())
|
|
||||||
generated_file_names.append(new_f.__code__.co_filename)
|
|
||||||
|
|
||||||
threads = tuple(
|
|
||||||
threading.Thread(target=conversion_thread) for _ in range(10))
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
# Races would potentially create multiple files (non-deterministically,
|
|
||||||
# but with high likelihood).
|
|
||||||
self.assertEqual(len(set(generated_file_names)), 1)
|
|
||||||
|
|
||||||
def test_convert_reentrance(self):
|
|
||||||
|
|
||||||
def test_fn():
|
|
||||||
pass
|
|
||||||
|
|
||||||
# There are no known ways to cause convert to re-enter. So we instrument
|
|
||||||
# an internal function to do that instead.
|
|
||||||
old_node_to_graph = conversion.node_to_graph
|
|
||||||
self.num_conversions = 0
|
|
||||||
def node_to_graph_wrapper(node, context):
|
|
||||||
self.num_conversions += 1
|
|
||||||
if self.num_conversions < 2:
|
|
||||||
conversion.convert(test_fn, self._simple_program_ctx())
|
|
||||||
return old_node_to_graph(node, context)
|
|
||||||
|
|
||||||
try:
|
|
||||||
conversion.node_to_graph = node_to_graph_wrapper
|
|
||||||
new_f = conversion.convert(test_fn, self._simple_program_ctx())
|
|
||||||
self.assertIsNotNone(new_f)
|
|
||||||
finally:
|
|
||||||
conversion.node_to_graph = old_node_to_graph
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -39,6 +39,7 @@ py_library(
|
|||||||
"qual_names.py",
|
"qual_names.py",
|
||||||
"templates.py",
|
"templates.py",
|
||||||
"transformer.py",
|
"transformer.py",
|
||||||
|
"transpiler.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
@ -230,3 +231,14 @@ py_test(
|
|||||||
"@gast_archive//:gast",
|
"@gast_archive//:gast",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "transpiler_test",
|
||||||
|
srcs = ["transpiler_test.py"],
|
||||||
|
python_version = "PY3",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":pyct",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -50,8 +50,12 @@ class AnfTestBase(test.TestCase):
|
|||||||
|
|
||||||
def _simple_context(self):
|
def _simple_context(self):
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
name='test_fn',
|
||||||
return transformer.Context(entity_info)
|
source_code=None,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace=None)
|
||||||
|
return transformer.Context(entity_info, None, None)
|
||||||
|
|
||||||
def assert_same_ast(self, expected_node, node, msg=None):
|
def assert_same_ast(self, expected_node, node, msg=None):
|
||||||
expected_source = parser.unparse(expected_node, indentation=' ')
|
expected_source = parser.unparse(expected_node, indentation=' ')
|
||||||
|
@ -22,6 +22,7 @@ import gast
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
|
from tensorflow.python.autograph.pyct import naming
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
from tensorflow.python.autograph.pyct import qual_names
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
@ -113,11 +114,17 @@ class ScopeTest(test.TestCase):
|
|||||||
class ActivityAnalyzerTestBase(test.TestCase):
|
class ActivityAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
|
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source, source_file=None, future_features=(), namespace={})
|
name=test_fn.__name__,
|
||||||
|
source_code=source,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace={})
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
ctx = transformer.Context(entity_info)
|
namer = naming.Namer({})
|
||||||
|
ctx = transformer.Context(entity_info, namer, None)
|
||||||
node = activity.resolve(node, ctx)
|
node = activity.resolve(node, ctx)
|
||||||
return node, entity_info
|
return node, entity_info
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
from tensorflow.python.autograph.pyct import cfg
|
||||||
|
from tensorflow.python.autograph.pyct import naming
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
from tensorflow.python.autograph.pyct import qual_names
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
@ -35,11 +36,17 @@ global_b = 17
|
|||||||
class LivenessAnalyzerTestBase(test.TestCase):
|
class LivenessAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
|
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source, source_file=None, future_features=(), namespace={})
|
name=test_fn.__name__,
|
||||||
|
source_code=source,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace={})
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
ctx = transformer.Context(entity_info)
|
namer = naming.Namer({})
|
||||||
|
ctx = transformer.Context(entity_info, namer, None)
|
||||||
node = activity.resolve(node, ctx)
|
node = activity.resolve(node, ctx)
|
||||||
graphs = cfg.build(node)
|
graphs = cfg.build(node)
|
||||||
liveness.resolve(node, ctx, graphs)
|
liveness.resolve(node, ctx, graphs)
|
||||||
|
@ -22,6 +22,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.python.autograph.pyct import anno
|
from tensorflow.python.autograph.pyct import anno
|
||||||
from tensorflow.python.autograph.pyct import cfg
|
from tensorflow.python.autograph.pyct import cfg
|
||||||
|
from tensorflow.python.autograph.pyct import naming
|
||||||
from tensorflow.python.autograph.pyct import parser
|
from tensorflow.python.autograph.pyct import parser
|
||||||
from tensorflow.python.autograph.pyct import qual_names
|
from tensorflow.python.autograph.pyct import qual_names
|
||||||
from tensorflow.python.autograph.pyct import transformer
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
@ -37,11 +38,17 @@ global_b = 17
|
|||||||
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
|
class ReachingDefinitionsAnalyzerTestBase(test.TestCase):
|
||||||
|
|
||||||
def _parse_and_analyze(self, test_fn):
|
def _parse_and_analyze(self, test_fn):
|
||||||
|
# TODO(mdan): Use a custom FunctionTransformer here.
|
||||||
node, source = parser.parse_entity(test_fn, future_features=())
|
node, source = parser.parse_entity(test_fn, future_features=())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=source, source_file=None, future_features=(), namespace={})
|
name=test_fn.__name__,
|
||||||
|
source_code=source,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace={})
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
ctx = transformer.Context(entity_info)
|
namer = naming.Namer({})
|
||||||
|
ctx = transformer.Context(entity_info, namer, None)
|
||||||
node = activity.resolve(node, ctx)
|
node = activity.resolve(node, ctx)
|
||||||
graphs = cfg.build(node)
|
graphs = cfg.build(node)
|
||||||
node = reaching_definitions.resolve(node, ctx, graphs,
|
node = reaching_definitions.resolve(node, ctx, graphs,
|
||||||
|
@ -36,20 +36,26 @@ class Context(object):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
info: EntityInfo, immutable.
|
info: EntityInfo, immutable.
|
||||||
|
namer: naming.Namer.
|
||||||
current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
|
current_origin: origin_info.OriginInfo, holds the OriginInfo of the last
|
||||||
AST node to be processed successfully. Useful for error handling.
|
AST node to be processed successfully. Useful for error handling.
|
||||||
|
user: An user-supplied context object. The object is opaque to the
|
||||||
|
infrastructure, but will pe passed through to all custom transformations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, info):
|
def __init__(self, info, namer, user_context):
|
||||||
self.info = info
|
self.info = info
|
||||||
|
self.namer = namer
|
||||||
self.current_origin = None
|
self.current_origin = None
|
||||||
|
self.user = user_context
|
||||||
|
|
||||||
|
|
||||||
# TODO(mdan): Move to a standalone file.
|
# TODO(mdan): Move to a standalone file.
|
||||||
class EntityInfo(
|
class EntityInfo(
|
||||||
collections.namedtuple(
|
collections.namedtuple(
|
||||||
'EntityInfo',
|
'EntityInfo',
|
||||||
('source_code', 'source_file', 'future_features', 'namespace'))):
|
('name', 'source_code', 'source_file', 'future_features', 'namespace'))
|
||||||
|
):
|
||||||
"""Contains information about a Python entity.
|
"""Contains information about a Python entity.
|
||||||
|
|
||||||
Immutable.
|
Immutable.
|
||||||
@ -57,6 +63,7 @@ class EntityInfo(
|
|||||||
Examples of entities include functions and classes.
|
Examples of entities include functions and classes.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
name: The name that identifies this entity.
|
||||||
source_code: The entity's source code.
|
source_code: The entity's source code.
|
||||||
source_file: The entity's source file.
|
source_file: The entity's source file.
|
||||||
future_features: Tuple[Text], the future features that this entity was
|
future_features: Tuple[Text], the future features that this entity was
|
||||||
|
@ -31,8 +31,12 @@ class TransformerTest(test.TestCase):
|
|||||||
|
|
||||||
def _simple_context(self):
|
def _simple_context(self):
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
name='Test_fn',
|
||||||
return transformer.Context(entity_info)
|
source_code=None,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace=None)
|
||||||
|
return transformer.Context(entity_info, None, None)
|
||||||
|
|
||||||
def assertSameAnno(self, first, second, key):
|
def assertSameAnno(self, first, second, key):
|
||||||
self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
|
self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
|
||||||
@ -299,8 +303,12 @@ class CodeGeneratorTest(test.TestCase):
|
|||||||
|
|
||||||
def _simple_context(self):
|
def _simple_context(self):
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
source_code=None, source_file=None, future_features=(), namespace=None)
|
name='test_fn',
|
||||||
return transformer.Context(entity_info)
|
source_code=None,
|
||||||
|
source_file=None,
|
||||||
|
future_features=(),
|
||||||
|
namespace=None)
|
||||||
|
return transformer.Context(entity_info, None, None)
|
||||||
|
|
||||||
def test_basic_codegen(self):
|
def test_basic_codegen(self):
|
||||||
|
|
||||||
|
419
tensorflow/python/autograph/pyct/transpiler.py
Normal file
419
tensorflow/python/autograph/pyct/transpiler.py
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Generic source code transformation infrastructure."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import types
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.pyct import ast_util
|
||||||
|
from tensorflow.python.autograph.pyct import cache
|
||||||
|
from tensorflow.python.autograph.pyct import inspect_utils
|
||||||
|
from tensorflow.python.autograph.pyct import loader
|
||||||
|
from tensorflow.python.autograph.pyct import naming
|
||||||
|
from tensorflow.python.autograph.pyct import origin_info
|
||||||
|
from tensorflow.python.autograph.pyct import parser
|
||||||
|
from tensorflow.python.autograph.pyct import templates
|
||||||
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
|
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_into_factory(nodes, entity_name, inner_factory_name,
|
||||||
|
outer_factory_name, closure_vars, factory_args,
|
||||||
|
future_features):
|
||||||
|
"""Wraps an AST into the body of a factory with consistent lexical context.
|
||||||
|
|
||||||
|
The AST is expected to define some symbol with a name given by `entity_name`.
|
||||||
|
|
||||||
|
This mechanism ensures that the resulting transformed entity has lexical
|
||||||
|
scoping identical to that of the source entity, while allowing extra
|
||||||
|
parametrization.
|
||||||
|
|
||||||
|
Two nested factories achieve the following:
|
||||||
|
|
||||||
|
1. The inner factory dynamically creates the entity represented by `nodes`.
|
||||||
|
2. The inner factory is parametrized by a custom set of arguments.
|
||||||
|
3. The inner factory has a closure identical to that of the transformed
|
||||||
|
entity.
|
||||||
|
4. The inner factory has local variables named like `args`, which `nodes` may
|
||||||
|
use as additional parameters.
|
||||||
|
5. The inner factory returns the variables given by `entity_name`.
|
||||||
|
6. The outer factory is niladic.
|
||||||
|
7. The outer factory has no closure.
|
||||||
|
8. The outer factory creates the necessary lexical scope for the inner
|
||||||
|
factory, so that the loaded code has the given configuration for
|
||||||
|
closure/globals.
|
||||||
|
9. The outer factory returns the inner factory.
|
||||||
|
|
||||||
|
Roughly speaking, the following code is generated:
|
||||||
|
|
||||||
|
from __future__ import future_feature_1
|
||||||
|
from __future__ import future_feature_2
|
||||||
|
...
|
||||||
|
|
||||||
|
def outer_factory():
|
||||||
|
closure_var_1 = None
|
||||||
|
closure_var_2 = None
|
||||||
|
...
|
||||||
|
|
||||||
|
def inner_factory(arg_1, arg_2, ...):
|
||||||
|
<<nodes>>
|
||||||
|
return entity
|
||||||
|
|
||||||
|
return inner_factory
|
||||||
|
|
||||||
|
The lexical scoping is created using dummy symbol declarations which create
|
||||||
|
local fariables in the body of the outer factory, so that the Python parser
|
||||||
|
correctly marks them as free non-global variables upon load (that is, it
|
||||||
|
creates cell slots for each symbol. Thes symbols are initialized with None,
|
||||||
|
but their values are not expected to be used; instead, the caller is expected
|
||||||
|
to replace them with the cells of the source entity. For more details, see:
|
||||||
|
https://docs.python.org/3/reference/executionmodel.html#binding-of-names
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: Tuple[ast.AST], the source code to wrap.
|
||||||
|
entity_name: Union[Text, ast.AST], the name of the principal entity that
|
||||||
|
`nodes` define.
|
||||||
|
inner_factory_name: Text, the name of the inner factory.
|
||||||
|
outer_factory_name: Text, the name of the outer factory.
|
||||||
|
closure_vars: Iterable[Text], names of the closure variables for the inner
|
||||||
|
factory.
|
||||||
|
factory_args: Iterable[Text], names of additional arguments for the
|
||||||
|
inner factory. Useful to configure variables that the converted code can
|
||||||
|
use. Typically, these are modules.
|
||||||
|
future_features: Iterable[Text], names of future statements to associate the
|
||||||
|
code with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ast.AST
|
||||||
|
"""
|
||||||
|
dummy_closure_defs = []
|
||||||
|
for var_name in closure_vars:
|
||||||
|
template = """
|
||||||
|
var_name = None
|
||||||
|
"""
|
||||||
|
dummy_closure_defs.extend(templates.replace(template, var_name=var_name))
|
||||||
|
|
||||||
|
if future_features:
|
||||||
|
future_imports = gast.ImportFrom(
|
||||||
|
module='__future__',
|
||||||
|
names=[gast.alias(name=name, asname=None) for name in future_features],
|
||||||
|
level=0)
|
||||||
|
else:
|
||||||
|
future_imports = []
|
||||||
|
|
||||||
|
factory_args = [
|
||||||
|
gast.Name(name, ctx=gast.Param(), annotation=None, type_comment=None)
|
||||||
|
for name in factory_args
|
||||||
|
]
|
||||||
|
|
||||||
|
template = """
|
||||||
|
future_imports
|
||||||
|
def outer_factory_name():
|
||||||
|
dummy_closure_defs
|
||||||
|
def inner_factory_name(factory_args):
|
||||||
|
entity_defs
|
||||||
|
return entity_name
|
||||||
|
return inner_factory_name
|
||||||
|
"""
|
||||||
|
return templates.replace(
|
||||||
|
template,
|
||||||
|
dummy_closure_defs=dummy_closure_defs,
|
||||||
|
entity_defs=nodes,
|
||||||
|
entity_name=entity_name,
|
||||||
|
factory_args=factory_args,
|
||||||
|
future_imports=future_imports,
|
||||||
|
inner_factory_name=inner_factory_name,
|
||||||
|
outer_factory_name=outer_factory_name)
|
||||||
|
|
||||||
|
|
||||||
|
class _TransformedFnFactory(object):
|
||||||
|
"""Helper object that wraps a transformed function factory."""
|
||||||
|
|
||||||
|
def __init__(self, name, freevars, extra_locals):
|
||||||
|
"""Creates a new factory for a transformed function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The function name.
|
||||||
|
freevars: The list of non-global free variables for the function.
|
||||||
|
extra_locals: Dict[Text, Any], names and values for custom variables that
|
||||||
|
are accessible to the generated code as local variables.
|
||||||
|
"""
|
||||||
|
self._name = name
|
||||||
|
self._freevars = freevars
|
||||||
|
self._extra_locals = extra_locals
|
||||||
|
|
||||||
|
self._unbound_factory = None
|
||||||
|
self.module = None
|
||||||
|
self.source_map = None
|
||||||
|
|
||||||
|
def create(self,
|
||||||
|
nodes,
|
||||||
|
namer,
|
||||||
|
inner_factory_name='inner_factory',
|
||||||
|
outer_factory_name='outer_factory',
|
||||||
|
future_features=()):
|
||||||
|
"""Initializes a transformed function."""
|
||||||
|
if self._unbound_factory is not None:
|
||||||
|
raise ValueError('double initialization; create a new object instead')
|
||||||
|
|
||||||
|
inner_factory_name = namer.new_symbol(inner_factory_name, ())
|
||||||
|
outer_factory_name = namer.new_symbol(outer_factory_name, ())
|
||||||
|
nodes = _wrap_into_factory(nodes, self._name, inner_factory_name,
|
||||||
|
outer_factory_name, self._freevars,
|
||||||
|
self._extra_locals.keys(), future_features)
|
||||||
|
|
||||||
|
module, _, source_map = loader.load_ast(
|
||||||
|
nodes, include_source_map=True)
|
||||||
|
outer_factory = getattr(module, outer_factory_name)
|
||||||
|
self._unbound_factory = outer_factory()
|
||||||
|
self.module = module
|
||||||
|
self.source_map = source_map
|
||||||
|
|
||||||
|
def instantiate(self,
|
||||||
|
globals_,
|
||||||
|
closure,
|
||||||
|
defaults=None,
|
||||||
|
kwdefaults=None):
|
||||||
|
"""Creates a new instance of the transformed function."""
|
||||||
|
if self._unbound_factory is None:
|
||||||
|
raise ValueError('call create first')
|
||||||
|
|
||||||
|
factory_code = self._unbound_factory.__code__
|
||||||
|
factory_freevars = factory_code.co_freevars
|
||||||
|
closure_map = dict(zip(self._freevars, closure))
|
||||||
|
factory_closure = tuple(
|
||||||
|
closure_map[name] for name in factory_code.co_freevars)
|
||||||
|
if len(factory_closure) != len(closure):
|
||||||
|
raise ValueError(
|
||||||
|
'closure mismatch, requested {}, but source function had {}'.format(
|
||||||
|
self._freevars, factory_freevars))
|
||||||
|
|
||||||
|
bound_factory = types.FunctionType(
|
||||||
|
code=factory_code,
|
||||||
|
globals=globals_,
|
||||||
|
name=self._name,
|
||||||
|
argdefs=(),
|
||||||
|
closure=factory_closure)
|
||||||
|
|
||||||
|
# The lint override is a false positive.
|
||||||
|
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
|
||||||
|
|
||||||
|
if defaults:
|
||||||
|
transformed_entity.__defaults__ = defaults
|
||||||
|
if kwdefaults:
|
||||||
|
transformed_entity.__kwdefaults__ = kwdefaults
|
||||||
|
|
||||||
|
return transformed_entity
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionTranspiler(object):
|
||||||
|
"""A generic source-to-source transpiler for Python functions.
|
||||||
|
|
||||||
|
Its interface `transform_function` API offers a function-in, function-out
|
||||||
|
interface. Internally, it takes care of parsing, caching and variable binding.
|
||||||
|
|
||||||
|
Users typically subclass this, customizing the transform_ast method.
|
||||||
|
|
||||||
|
Usually, instances of this class are singletons, since each instance manages
|
||||||
|
its own cache. The caching subkey allows managing multiple types of
|
||||||
|
transformation.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
class MyTransformer(FunctionTranspiler):
|
||||||
|
|
||||||
|
def transform_ast(self, node, ctx):
|
||||||
|
node = <<transform node, usually using ast.NodeTransformer classes>>
|
||||||
|
return node
|
||||||
|
|
||||||
|
transformer = MyTransfomer()
|
||||||
|
|
||||||
|
new_f, module, source_map = transformer.transform_function(f, ...)
|
||||||
|
# new_f is a function with signature identical to f
|
||||||
|
|
||||||
|
The transformed function has access to the same namespace as the original
|
||||||
|
function. To allow access to internal APIs, users may inject additional
|
||||||
|
symbols though the `extra_locals` argument of `transform_function`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache_lock = threading.RLock()
|
||||||
|
self._cache = cache.CodeObjectCache()
|
||||||
|
|
||||||
|
def transform_ast(self, node, user_context):
|
||||||
|
"""Performs an actual transformation of a function's AST.
|
||||||
|
|
||||||
|
Subclasses must implement this method. They must not call it.
|
||||||
|
|
||||||
|
The method receives the original AST and generates code according to the
|
||||||
|
AST that the method returns. For functions, the returned AST is expected to
|
||||||
|
contain a function with the exact same arguments and closure. The resulting
|
||||||
|
function will receive the globals, closure and argument defaults of the
|
||||||
|
input function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node: One or more ast.AST nodes representing the AST to be transformed.
|
||||||
|
user_context: The same value that the caller passed to
|
||||||
|
`transform_function`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('subclasses must override this')
|
||||||
|
|
||||||
|
def get_transformed_name(self, node):
|
||||||
|
"""Returns a name for the output function. Subclasses may override this."""
|
||||||
|
if isinstance(node, gast.Lambda):
|
||||||
|
return 'lam'
|
||||||
|
elif isinstance(node, gast.FunctionDef):
|
||||||
|
# Note that we need to rename the function, to avoid any namespace
|
||||||
|
# clashes.
|
||||||
|
return node.name
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown node type {}'.format(node))
|
||||||
|
|
||||||
|
def _erase_arg_defaults(self, node):
|
||||||
|
"""Erase argde fault expressions, which would otherwise be unbound."""
|
||||||
|
args = node.args
|
||||||
|
for i in range(len(args.defaults)):
|
||||||
|
args.defaults[i] = parser.parse_expression('None')
|
||||||
|
for i, d in enumerate(args.kw_defaults):
|
||||||
|
if d is not None:
|
||||||
|
args.kw_defaults[i] = parser.parse_expression('None')
|
||||||
|
return node
|
||||||
|
|
||||||
|
def _transform_function(self, fn, user_context):
|
||||||
|
"""Performs source code transformation on a function."""
|
||||||
|
future_features = inspect_utils.getfutureimports(fn)
|
||||||
|
node, source = parser.parse_entity(fn, future_features=future_features)
|
||||||
|
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
|
||||||
|
|
||||||
|
# In general, the output of inspect.getsource is inexact for lambdas
|
||||||
|
# because it uses regex matching to adjust the exact location around
|
||||||
|
# the line number that CPython records. Then, the entire containing line
|
||||||
|
# is returned, which we may have trouble disambiguating.
|
||||||
|
# For example:
|
||||||
|
# x, y = lambda: 1, lambda: 2
|
||||||
|
is_lambda = fn.__name__ == '<lambda>'
|
||||||
|
if is_lambda:
|
||||||
|
nodes = ast_util.find_matching_definitions(node, fn)
|
||||||
|
if len(nodes) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
'Unable to identify source code of lambda function {}.'
|
||||||
|
' It was defined in this code:\n'
|
||||||
|
'{}\n'
|
||||||
|
'This code must contain a single distinguishable lambda.'
|
||||||
|
' To avoid this problem, define each lambda in a separate'
|
||||||
|
' expression.'.format(fn, source))
|
||||||
|
node, = nodes
|
||||||
|
|
||||||
|
origin_info.resolve_entity(node, source, fn)
|
||||||
|
|
||||||
|
namespace = inspect_utils.getnamespace(fn)
|
||||||
|
namer = naming.Namer(namespace)
|
||||||
|
new_name = namer.new_symbol(self.get_transformed_name(node), ())
|
||||||
|
entity_info = transformer.EntityInfo(
|
||||||
|
name=new_name,
|
||||||
|
source_code=source,
|
||||||
|
source_file='<fragment>',
|
||||||
|
future_features=future_features,
|
||||||
|
namespace=namespace)
|
||||||
|
context = transformer.Context(entity_info, namer, user_context)
|
||||||
|
|
||||||
|
node = self._erase_arg_defaults(node)
|
||||||
|
node = self.transform_ast(node, context)
|
||||||
|
|
||||||
|
if is_lambda:
|
||||||
|
node = gast.Assign(
|
||||||
|
targets=[
|
||||||
|
gast.Name(
|
||||||
|
new_name,
|
||||||
|
ctx=gast.Store(),
|
||||||
|
annotation=None,
|
||||||
|
type_comment=None)
|
||||||
|
],
|
||||||
|
value=node)
|
||||||
|
else:
|
||||||
|
node.name = new_name
|
||||||
|
|
||||||
|
return node, context
|
||||||
|
|
||||||
|
def _cached_factory(self, fn, cache_subkey):
|
||||||
|
cached_factory = self._cache[fn][cache_subkey]
|
||||||
|
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
|
||||||
|
cached_factory)
|
||||||
|
return cached_factory
|
||||||
|
|
||||||
|
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
|
||||||
|
"""Returns the transformed function factory for a given input."""
|
||||||
|
if self._cache.has(fn, cache_subkey):
|
||||||
|
return self._cached_factory(fn, cache_subkey)
|
||||||
|
|
||||||
|
with self._cache_lock:
|
||||||
|
# Check again under lock.
|
||||||
|
if self._cache.has(fn, cache_subkey):
|
||||||
|
return self._cached_factory(fn, cache_subkey)
|
||||||
|
|
||||||
|
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
|
||||||
|
nodes, ctx = self._transform_function(fn, user_context)
|
||||||
|
|
||||||
|
if logging.has_verbosity(2):
|
||||||
|
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
|
||||||
|
|
||||||
|
factory = _TransformedFnFactory(
|
||||||
|
ctx.info.name, fn.__code__.co_freevars, extra_locals)
|
||||||
|
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
|
||||||
|
self._cache[fn][cache_subkey] = factory
|
||||||
|
return factory
|
||||||
|
|
||||||
|
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
|
||||||
|
"""Transforms a function.
|
||||||
|
|
||||||
|
The `caching_subkey` argument allows mapping each function to multiple
|
||||||
|
outputs in the cache. This is useful for instance when transformers
|
||||||
|
can generate multiple variants of output code, typically as a result of
|
||||||
|
different transformation flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: A function or lambda.
|
||||||
|
caching_subkey: Used for caching. Calls made for functions with the same
|
||||||
|
code object and caching_subkey will return a cached instance on
|
||||||
|
subsequent invocations. Using a constant will create unique per-function
|
||||||
|
entries.
|
||||||
|
user_context: An opaque object (may be none) that is forwarded to
|
||||||
|
transform_ast.
|
||||||
|
extra_locals: A Dict[Text, Any] containing additional variables to make
|
||||||
|
available to the transformed code. These will be visible as local
|
||||||
|
variables.
|
||||||
|
Returns:
|
||||||
|
A tuple:
|
||||||
|
* A function or lambda with the same signature and closure as `fn`
|
||||||
|
* The temporary module into which the transformed function was loaded
|
||||||
|
* The source map as a
|
||||||
|
Dict[origin_info.LineLocation, origin_info.OriginInfo]
|
||||||
|
|
||||||
|
"""
|
||||||
|
factory = self._transformed_factory(fn, caching_subkey, user_context,
|
||||||
|
extra_locals)
|
||||||
|
|
||||||
|
transformed_fn = factory.instantiate(
|
||||||
|
globals_=fn.__globals__,
|
||||||
|
closure=fn.__closure__ or (),
|
||||||
|
defaults=fn.__defaults__,
|
||||||
|
kwdefaults=getattr(fn, '__kwdefaults__', None))
|
||||||
|
return transformed_fn, factory.module, factory.source_map
|
249
tensorflow/python/autograph/pyct/transpiler_test.py
Normal file
249
tensorflow/python/autograph/pyct/transpiler_test.py
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for transpiler module."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import gast
|
||||||
|
|
||||||
|
from tensorflow.python.autograph.pyct import transformer
|
||||||
|
from tensorflow.python.autograph.pyct import transpiler
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class FlipSignTransformer(transformer.Base):
|
||||||
|
|
||||||
|
def visit_BinOp(self, node):
|
||||||
|
if isinstance(node.op, gast.Add):
|
||||||
|
node.op = gast.Sub()
|
||||||
|
return self.generic_visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranspiler(transpiler.FunctionTranspiler):
|
||||||
|
|
||||||
|
def transform_ast(self, node, ctx):
|
||||||
|
return FlipSignTransformer(ctx).visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
global_var_for_test_global = 1
|
||||||
|
global_var_for_test_namespace_collisions = object()
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionTranspilerTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_basic(self):
|
||||||
|
def f(a):
|
||||||
|
return a + 1
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 0)
|
||||||
|
|
||||||
|
def test_closure(self):
|
||||||
|
b = 1
|
||||||
|
|
||||||
|
def f(a):
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 0)
|
||||||
|
b = 2
|
||||||
|
self.assertEqual(f(1), -1)
|
||||||
|
|
||||||
|
def test_global(self):
|
||||||
|
def f(a):
|
||||||
|
return a + global_var_for_test_global
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
global global_var_for_test_global
|
||||||
|
global_var_for_test_global = 1
|
||||||
|
self.assertEqual(f(1), 0)
|
||||||
|
global_var_for_test_global = 2
|
||||||
|
self.assertEqual(f(1), -1)
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
b = 2
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
def f(a, d=c + 1):
|
||||||
|
return a + b + d
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 1 - 2 - 2)
|
||||||
|
c = 0
|
||||||
|
self.assertEqual(f(1), 1 - 2 - 2) # Defaults are evaluated at definition.
|
||||||
|
b = 1
|
||||||
|
self.assertEqual(f(1), 1 - 2 - 1)
|
||||||
|
|
||||||
|
def test_call_tree(self):
|
||||||
|
|
||||||
|
def g(a):
|
||||||
|
return a + 1
|
||||||
|
|
||||||
|
def f(a):
|
||||||
|
return g(a) + 1
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted.
|
||||||
|
|
||||||
|
def test_lambda(self):
|
||||||
|
b = 2
|
||||||
|
f = lambda x: (b + (x if x > 0 else -x))
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 2 - 1)
|
||||||
|
self.assertEqual(f(-1), 2 - 1)
|
||||||
|
|
||||||
|
b = 3
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 3 - 1)
|
||||||
|
self.assertEqual(f(-1), 3 - 1)
|
||||||
|
|
||||||
|
def test_multiple_lambdas(self):
|
||||||
|
a, b = 1, 2
|
||||||
|
# This can be disambiguated by the argument names.
|
||||||
|
f, _ = (lambda x: a + x, lambda y: b * y)
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 1 - 1)
|
||||||
|
|
||||||
|
def test_multiple_lambdas_indistinguishable_definitions(self):
|
||||||
|
a, b = 1, 2
|
||||||
|
f, _ = (lambda x: a * x, lambda x: b * x)
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
def test_lambda_code_with_removable_garbage(self):
|
||||||
|
# pylint:disable=g-long-lambda
|
||||||
|
f = ( # intentional wrap
|
||||||
|
lambda x: (
|
||||||
|
x # intentional wrap
|
||||||
|
+ 1),)[0]
|
||||||
|
# pylint:enable=g-long-lambda
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 1 - 1)
|
||||||
|
|
||||||
|
def test_nested_functions(self):
|
||||||
|
b = 2
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
|
||||||
|
def g(x):
|
||||||
|
return b + x
|
||||||
|
|
||||||
|
return g(x)
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 2 - 1)
|
||||||
|
|
||||||
|
def test_nested_lambda(self):
|
||||||
|
b = 2
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
g = lambda x: b + x
|
||||||
|
return g(x)
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||||
|
|
||||||
|
self.assertEqual(f(1), 2 - 1)
|
||||||
|
|
||||||
|
def test_concurrency(self):
|
||||||
|
|
||||||
|
def f():
|
||||||
|
pass
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
cache_key = object()
|
||||||
|
def conversion_thread():
|
||||||
|
_, mod, _ = tr.transform_function(f, cache_key, None, {})
|
||||||
|
outputs.append(mod.__name__)
|
||||||
|
|
||||||
|
threads = tuple(
|
||||||
|
threading.Thread(target=conversion_thread) for _ in range(10))
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
# Races would potentially create multiple functions / modules
|
||||||
|
# (non-deterministically, but with high likelihood).
|
||||||
|
self.assertEqual(len(set(outputs)), 1)
|
||||||
|
|
||||||
|
def test_reentrance(self):
|
||||||
|
|
||||||
|
def test_fn():
|
||||||
|
return 1 + 1
|
||||||
|
|
||||||
|
class ReentrantTranspiler(transpiler.FunctionTranspiler):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ReentrantTranspiler, self).__init__()
|
||||||
|
self._recursion_depth = 0
|
||||||
|
|
||||||
|
def transform_ast(self, node, ctx):
|
||||||
|
self._recursion_depth += 1
|
||||||
|
if self._recursion_depth < 2:
|
||||||
|
self.transform_function(test_fn, object(), None, {})
|
||||||
|
return FlipSignTransformer(ctx).visit(node)
|
||||||
|
|
||||||
|
tr = ReentrantTranspiler()
|
||||||
|
|
||||||
|
f, _, _ = tr.transform_function(test_fn, object(), None, {})
|
||||||
|
self.assertEqual(f(), 0)
|
||||||
|
|
||||||
|
def test_namespace_collisions_avoided(self):
|
||||||
|
|
||||||
|
class TestClass(object):
|
||||||
|
|
||||||
|
def global_var_for_test_namespace_collisions(self):
|
||||||
|
return global_var_for_test_namespace_collisions
|
||||||
|
|
||||||
|
tr = TestTranspiler()
|
||||||
|
obj = TestClass()
|
||||||
|
|
||||||
|
f, _, _ = tr.transform_function(
|
||||||
|
obj.global_var_for_test_namespace_collisions, object(), None, {})
|
||||||
|
self.assertIs(f(obj), global_var_for_test_namespace_collisions)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -52,13 +52,6 @@ def alias_tensors(*args):
|
|||||||
raise ValueError('at least one argument required')
|
raise ValueError('at least one argument required')
|
||||||
|
|
||||||
|
|
||||||
def capitalize_initial(s):
|
|
||||||
"""Capitalizes the initial of a string only."""
|
|
||||||
if s:
|
|
||||||
return s[0].upper() + s[1:]
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def get_range_len(start, limit, delta):
|
def get_range_len(start, limit, delta):
|
||||||
dist = ops.convert_to_tensor(limit - start)
|
dist = ops.convert_to_tensor(limit - start)
|
||||||
unadjusted_len = dist // delta
|
unadjusted_len = dist // delta
|
||||||
|
@ -29,15 +29,6 @@ from tensorflow.python.platform import test
|
|||||||
|
|
||||||
class MiscTest(test.TestCase):
|
class MiscTest(test.TestCase):
|
||||||
|
|
||||||
def test_capitalize_initial(self):
|
|
||||||
self.assertEqual('', misc.capitalize_initial(''))
|
|
||||||
self.assertEqual('A', misc.capitalize_initial('A'))
|
|
||||||
self.assertEqual('Ab', misc.capitalize_initial('Ab'))
|
|
||||||
self.assertEqual('AbC', misc.capitalize_initial('AbC'))
|
|
||||||
self.assertEqual('A', misc.capitalize_initial('a'))
|
|
||||||
self.assertEqual('Ab', misc.capitalize_initial('ab'))
|
|
||||||
self.assertEqual('AbC', misc.capitalize_initial('abC'))
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def test_alias_single_tensor(self):
|
def test_alias_single_tensor(self):
|
||||||
a = constant(1)
|
a = constant(1)
|
||||||
|
@ -198,11 +198,12 @@ def _live_tensors(f, attr_name="inputs"):
|
|||||||
"""
|
"""
|
||||||
node, _ = parser.parse_entity(f, ())
|
node, _ = parser.parse_entity(f, ())
|
||||||
entity_info = transformer.EntityInfo(
|
entity_info = transformer.EntityInfo(
|
||||||
|
name=f.__name__,
|
||||||
source_code=None,
|
source_code=None,
|
||||||
source_file=None,
|
source_file=None,
|
||||||
future_features=(),
|
future_features=(),
|
||||||
namespace=sys.modules[f.__module__].__dict__)
|
namespace=sys.modules[f.__module__].__dict__)
|
||||||
ctx = transformer.Context(entity_info)
|
ctx = transformer.Context(entity_info, None, None)
|
||||||
|
|
||||||
graphs = cfg.build(node)
|
graphs = cfg.build(node)
|
||||||
node = qual_names.resolve(node)
|
node = qual_names.resolve(node)
|
||||||
|
Loading…
Reference in New Issue
Block a user