Slightly refactor the source-to-source translation API to better support non-Python outputs.

PiperOrigin-RevId: 319982764
Change-Id: I49017145719330596b55b0f9190eccf29a9a46c4
This commit is contained in:
Dan Moldovan 2020-07-07 07:57:14 -07:00 committed by TensorFlower Gardener
parent 0d3c1eef7d
commit ca59e0b5d7
6 changed files with 668 additions and 550 deletions

View File

@ -63,7 +63,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import enum
from tensorflow.python.autograph.pyct import anno
@ -234,18 +233,17 @@ STANDARD_OPTIONS = ConversionOptions(
optional_features=None)
class ProgramContext(
collections.namedtuple('ProgramContext', ('options', 'autograph_module'))):
class ProgramContext(object):
"""ProgramContext keeps track of converting function hierarchies.
This object is mutable, and is updated during conversion. Not thread safe.
Attributes:
options: ConversionOptions
autograph_module: Module, a reference to the autograph module. This needs to
be specified by the caller to avoid circular dependencies.
autograph_module: Deprecated. Do not use.
"""
pass
def __init__(self, options, autograph_module=None):
self.options = options
self.autograph_module = autograph_module
class Base(transformer.Base):

View File

@ -28,7 +28,6 @@ import six
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@ -64,16 +63,26 @@ def is_inside_generated_code():
del frame
class TestingTranspiler(conversion.AutoGraphTranspiler):
class TestingTranspiler(api.PyToTF):
"""Testing version that only applies given transformations."""
def __init__(self, converters):
def __init__(self, converters, ag_overrides):
super(TestingTranspiler, self).__init__()
if isinstance(converters, (list, tuple)):
self._converters = converters
else:
self._converters = (converters,)
self.transformed_ast = None
self._ag_overrides = ag_overrides
def get_extra_locals(self):
retval = super(TestingTranspiler, self).get_extra_locals()
if self._ag_overrides:
modified_ag = imp.new_module('fake_autograph')
modified_ag.__dict__.update(retval['ag__'].__dict__)
modified_ag.__dict__.update(self._ag_overrides)
retval['ag__'] = modified_ag
return retval
def transform_ast(self, node, ctx):
node = self.initial_analysis(node, ctx)
@ -113,18 +122,8 @@ class TestCase(test.TestCase):
options=converter.ConversionOptions(recursive=True),
autograph_module=api)
conversion.create_custom_vars(program_ctx)
custom_vars = dict(conversion.custom_vars)
if ag_overrides:
modified_ag = imp.new_module('fake_autograph')
modified_ag.__dict__.update(custom_vars['ag__'].__dict__)
modified_ag.__dict__.update(ag_overrides)
custom_vars['ag__'] = modified_ag
tr = TestingTranspiler(converter_module)
transformed, _, _ = tr.transform_function(
f, program_ctx.options, program_ctx, custom_vars)
tr = TestingTranspiler(converter_module, ag_overrides)
transformed, _, _ = tr.transform_function(f, program_ctx)
if include_ast:
return transformed, tr.transformed_ast, tr.transform_ctx

View File

@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This module contains the user-facing API for AutoGraph."""
"""This module contains the user- and codegen-facing API for AutoGraph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import imp
import inspect
import os
import sys
@ -27,14 +28,38 @@ import traceback
import six
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import call_trees
from tensorflow.python.autograph.converters import conditional_expressions
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import ag_ctx
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.impl import conversion
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import error_utils
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import origin_info
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transpiler
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.eager import function
from tensorflow.python.framework import errors_impl
@ -48,6 +73,11 @@ def is_autograph_strict_conversion_mode():
return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
#
# Error handling
#
# TODO(mdan): Export this symbol.
class AutoGraphError(errors.PyCTError):
"""Base class for all AutoGraph exceptions."""
@ -113,6 +143,26 @@ class _ErrorMetadata(error_utils.ErrorMetadataBase):
return StagingError(self.get_message())
def _attach_error_metadata(e, f):
"""Augments an error with the metadata necessary for rewrite."""
if hasattr(e, 'ag_pass_through'):
return
metadata = getattr(e, 'ag_error_metadata', None)
source_map = f.ag_source_map
if metadata is None:
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
message = '{}: {}'.format(e.__class__.__name__, e)
else:
message = None
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
e.ag_error_metadata = _ErrorMetadata(
cause_tb, metadata, message, source_map, __file__)
class StackTraceMapper(tf_stack.StackTraceMapper):
"""Remaps generated code to code it originated from."""
@ -145,6 +195,106 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
return effective_source_map
#
# Actual source code transformation
#
class PyToTF(transpiler.PyToPy):
"""The TensorFlow AutoGraph transformer."""
def __init__(self):
super(PyToTF, self).__init__()
# TODO(mdan): Move into core or replace with an actual importable module.
# Craft a module that exposes the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
ag_internal.ConversionOptions = converter.ConversionOptions
ag_internal.STD = converter.STANDARD_OPTIONS
ag_internal.Feature = converter.Feature
ag_internal.utils = utils
ag_internal.FunctionScope = function_wrappers.FunctionScope
ag_internal.with_function_scope = function_wrappers.with_function_scope
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
self._extra_locals = {'ag__': ag_internal}
def get_transformed_name(self, node):
return 'tf__' + super(PyToTF, self).get_transformed_name(node)
def get_extra_locals(self):
return self._extra_locals
def get_caching_key(self, ctx):
return ctx.options
def initial_analysis(self, node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node
def transform_ast(self, node, ctx):
unsupported_features_checker.verify(node)
node = self.initial_analysis(node, ctx)
node = functions.transform(node, ctx)
node = directives.transform(node, ctx)
node = break_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
node = asserts.transform(node, ctx)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.LISTS):
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
node = call_trees.transform(node, ctx)
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)
return node
def _convert_actual(entity, program_ctx):
"""Applies AutoGraph to entity."""
# TODO(mdan): Put these extra fields inside __autograph_info__.
if not hasattr(entity, '__code__'):
raise ValueError('Cannot apply autograph to a function that doesn\'t '
'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.')
transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
assert not hasattr(transformed, 'ag_module')
assert not hasattr(transformed, 'ag_source_map')
transformed.ag_module = module
transformed.ag_source_map = source_map
return transformed
#
# Generated code support
#
def autograph_artifact(entity, extras=None):
setattr(entity, 'autograph_info__', extras)
return entity
@ -154,272 +304,12 @@ def is_autograph_artifact(entity):
return hasattr(entity, 'autograph_info__')
def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
"""Decorator that applies AutoGraph to a function.
Use in internal APIs.
This API is suitable for high order functions internal to the TensorFlow API,
and more generally any function to which Autograph is not applied.
Guidance: convert was a decorator meant for use directly by developers, and
will be soon deprecated in favor of tf.function. tf_convert is to be called
from high order functions internal to TF.
Args:
f: Callable.
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
convert_by_default: bool, whether to use AutoGraph when the context doesn't
specify.
user_requested: bool, whether to ignore the conversion whitelist. See
ConversionOptions.user_requested.
Returns:
Either `f or the converted version of `f`.
"""
if is_autograph_artifact(f):
return f
f_wrapper = f
decorators, f = tf_decorator.unwrap(f)
# TODO(mdan): Grab features from context.
# Note: we pass the original context through to convert to properly handle the
# following scenario, which can be used inside TF implementations:
#
# ctx = ag_ctx.control_status_ctx()
# @function(autograph=False) # Low-level graph code
# def inner_fn():
# # The context is disabled here, but should be enabled in user user_fn
# tf_convert(user_fn, ctx=ctx)
if ctx.status == ag_ctx.Status.ENABLED:
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
elif ctx.status == ag_ctx.Status.DISABLED:
wrapper_factory = do_not_convert
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
if convert_by_default:
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
else:
wrapper_factory = call_with_unspecified_conversion_status
else:
assert False, 'This switch contains all possible cases!'
wrapper = wrapper_factory(f)
if decorators:
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
return autograph_artifact(wrapper)
# TODO(mdan): Make private.
def convert(recursive=False,
optional_features=None,
user_requested=True,
conversion_ctx=ag_ctx.NullCtx()):
"""Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated
function is called. This means the parameter values are known at conversion.
It also means that repeated calls with different types of parameters will be
correctly processed.
Args:
recursive: bool, whether to recursively convert any functions or classes
that the converted function may use.
optional_features: converted.Feature, allows toggling optional or
experimental features. When set to None, only the core features are
enabled.
user_requested: bool, whether this is a function that the user explicitly
asked to be converted. See ConversionOptions.user_requested.
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
which `f` is used.
Returns:
Callable, a decorator that converts the given function into an equivalent
function that uses TensorFlow ops.
"""
def decorator(f):
"""Decorator implementation."""
def wrapper(*args, **kwargs):
"""Wrapper that calls the converted version of f."""
options = converter.ConversionOptions(
recursive=recursive,
user_requested=user_requested,
optional_features=optional_features)
try:
with conversion_ctx:
return converted_call(f, args, kwargs, options=options)
except Exception as e: # pylint:disable=broad-except
if hasattr(e, 'ag_error_metadata'):
raise e.ag_error_metadata.to_exception(e)
else:
raise
if inspect.isfunction(f) or inspect.ismethod(f):
wrapper = functools.update_wrapper(wrapper, f)
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
return autograph_artifact(decorated_wrapper)
return decorator
def call_with_unspecified_conversion_status(func):
"""Decorator that resets the conversion context to the unspecified status."""
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
return func(*args, **kwargs)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
return autograph_artifact(wrapper)
@tf_export('autograph.experimental.do_not_convert')
def do_not_convert(func=None):
"""Decorator that suppresses the conversion of a function.
Args:
func: function to decorate.
Returns:
If `func` is not None, returns a `Callable` which is equivalent to
`func`, but is not converted by AutoGraph.
If `func` is None, returns a decorator that, when invoked with a
single `func` argument, returns a `Callable` equivalent to the
above case.
"""
if func is None:
return do_not_convert
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return func(*args, **kwargs)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
return autograph_artifact(wrapper)
def _attach_metadata(e, f):
"""Augments an error with the metadata necessary for rewrite."""
if hasattr(e, 'ag_pass_through'):
return
metadata = getattr(e, 'ag_error_metadata', None)
source_map = f.ag_source_map
if metadata is None:
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
message = '{}: {}'.format(e.__class__.__name__, e)
else:
message = None
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
e.ag_error_metadata = _ErrorMetadata(
cause_tb, metadata, message, source_map, __file__)
def _call_unconverted(f, args, kwargs, options, update_cache=True):
"""Calls the original function without converting with AutoGraph."""
if update_cache:
conversion.cache_whitelisted(f, options)
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
return f.__self__.call(args, kwargs)
if kwargs is not None:
return f(*args, **kwargs)
return f(*args)
def _is_of_known_loaded_module(f, module_name):
mod = sys.modules.get(module_name, None)
if mod is None:
return False
if any(v is not None for v in mod.__dict__.values() if f is v):
return True
return False
def _is_known_loaded_type(f, module_name, entity_name):
"""Tests whether the function or method is an instance of a known type."""
if (module_name not in sys.modules or
not hasattr(sys.modules[module_name], entity_name)):
return False
type_entity = getattr(sys.modules[module_name], entity_name)
if isinstance(f, type_entity):
# The method if of this type. Example:
#
# o = ClassType()
# function(o.method)()
return True
# Note: inspect is required here, to avoid unpacking tf.function decorators.
if inspect.ismethod(f):
# The the unbound method if of this type. Example:
#
# class ClassType:
# @function
# def method(self):
# ...
# o = ClassType()
# o.method()
if isinstance(f.__func__, type_entity):
return True
return False
def _fall_back_unconverted(f, args, kwargs, options, exc):
"""Falls back to calling the function unconverted, in case of error."""
# TODO(mdan): Consider adding an internal metric.
warning_template = (
'AutoGraph could not transform %s and will run it as-is.\n'
'%s'
'Cause: %s\n'
'To silence this warning, decorate the function with'
' @tf.autograph.experimental.do_not_convert')
if isinstance(exc, errors.UnsupportedLanguageElementError):
if not conversion.is_in_whitelist_cache(f, options):
logging.warn(warning_template, f, '', exc)
else:
file_bug_message = (
'Please report this to the TensorFlow team. When filing the bug, set'
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
' attach the full output.\n')
logging.warn(warning_template, f, file_bug_message, exc)
return _call_unconverted(f, args, kwargs, options)
def _log_callargs(f, args, kwargs):
"""Logging helper."""
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
if not six.PY2:
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
if kwargs is not None:
callargs = tf_inspect.getcallargs(f, *args, **kwargs)
else:
callargs = tf_inspect.getcallargs(f, *args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
def converted_call(f,
args,
kwargs,
caller_fn_scope=None,
options=None):
"""Compiles a function call inline.
"""Converts a function call inline.
For internal use only.
@ -492,40 +382,7 @@ def converted_call(f,
else:
return py_builtins.overload_of(f)(*args)
# TODO(b/122265385): Remove this bypass.
if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
_is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
logging.warn(
'{} appears to be decorated by wrapt, which is not yet supported'
' by AutoGraph. The function will run as-is.'
' You may still apply AutoGraph before the wrapt decorator.'.format(f))
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
return _call_unconverted(f, args, kwargs, options)
if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
return _call_unconverted(f, args, kwargs, options)
# Constructors are permanently whitelisted.
# TODO(mdan): Toggle as experimental feature instead.
# TODO(b/124016764): Remove this limitation.
if inspect_utils.isconstructor(f):
logging.log(2, 'Permanently whitelisted: %s: constructor', f)
return _call_unconverted(f, args, kwargs, options)
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any(
_is_of_known_loaded_module(f, m)
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
return _call_unconverted(f, args, kwargs, options)
# Custom ops and kernels are also permanently whitelisted.
# See tensorflow.framework.load_library.
if (hasattr(f, '__module__') and
hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
if conversion.is_unsupported(f):
return _call_unconverted(f, args, kwargs, options)
if not options.user_requested and conversion.is_whitelisted(f):
@ -579,9 +436,8 @@ def converted_call(f,
return _call_unconverted(f, args, kwargs, options)
try:
program_ctx = converter.ProgramContext(
options=options, autograph_module=tf_inspect.getmodule(converted_call))
converted_f = conversion.convert(target_entity, program_ctx)
program_ctx = converter.ProgramContext(options=options)
converted_f = _convert_actual(target_entity, program_ctx)
if logging.has_verbosity(2):
_log_callargs(converted_f, effective_args, kwargs)
except Exception as e: # pylint:disable=broad-except
@ -597,12 +453,226 @@ def converted_call(f,
else:
result = converted_f(*effective_args)
except Exception as e:
_attach_metadata(e, converted_f)
_attach_error_metadata(e, converted_f)
raise
return result
def _call_unconverted(f, args, kwargs, options, update_cache=True):
"""Calls the original function without converting with AutoGraph."""
if update_cache:
conversion.cache_whitelisted(f, options)
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
return f.__self__.call(args, kwargs)
if kwargs is not None:
return f(*args, **kwargs)
return f(*args)
def _fall_back_unconverted(f, args, kwargs, options, exc):
"""Falls back to calling the function unconverted, in case of error."""
# TODO(mdan): Consider adding an internal metric.
warning_template = (
'AutoGraph could not transform %s and will run it as-is.\n'
'%s'
'Cause: %s\n'
'To silence this warning, decorate the function with'
' @tf.autograph.experimental.do_not_convert')
if isinstance(exc, errors.UnsupportedLanguageElementError):
if not conversion.is_in_whitelist_cache(f, options):
logging.warn(warning_template, f, '', exc)
else:
file_bug_message = (
'Please report this to the TensorFlow team. When filing the bug, set'
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
' attach the full output.\n')
logging.warn(warning_template, f, file_bug_message, exc)
return _call_unconverted(f, args, kwargs, options)
#
# TensorFlow integration
#
def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
"""Decorator that applies AutoGraph to a function.
Use in internal APIs.
This API is suitable for high order functions internal to the TensorFlow API,
and more generally any function to which Autograph is not applied.
Guidance: convert was a decorator meant for use directly by developers, and
will be soon deprecated in favor of tf.function. tf_convert is to be called
from high order functions internal to TF.
Args:
f: Callable.
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
convert_by_default: bool, whether to use AutoGraph when the context doesn't
specify.
user_requested: bool, whether to ignore the conversion whitelist. See
ConversionOptions.user_requested.
Returns:
Either `f or the converted version of `f`.
"""
if is_autograph_artifact(f):
return f
f_wrapper = f
decorators, f = tf_decorator.unwrap(f)
# TODO(mdan): Grab features from context.
# Note: we pass the original context through to convert to properly handle the
# following scenario, which can be used inside TF implementations:
#
# ctx = ag_ctx.control_status_ctx()
# @function(autograph=False) # Low-level graph code
# def inner_fn():
# # The context is disabled here, but should be enabled in user user_fn
# tf_convert(user_fn, ctx=ctx)
if ctx.status == ag_ctx.Status.ENABLED:
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
elif ctx.status == ag_ctx.Status.DISABLED:
wrapper_factory = do_not_convert
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
if convert_by_default:
wrapper_factory = convert(
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
else:
wrapper_factory = call_with_unspecified_conversion_status
else:
assert False, 'This switch contains all possible cases!'
wrapper = wrapper_factory(f)
if decorators:
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
return autograph_artifact(wrapper)
def call_with_unspecified_conversion_status(func):
"""Decorator that resets the conversion context to the unspecified status."""
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
return func(*args, **kwargs)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
return autograph_artifact(wrapper)
def _log_callargs(f, args, kwargs):
"""Logging helper."""
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
if not six.PY2:
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
if kwargs is not None:
callargs = tf_inspect.getcallargs(f, *args, **kwargs)
else:
callargs = tf_inspect.getcallargs(f, *args)
formatted_callargs = '\n'.join(
' {}: {}'.format(k, v) for k, v in callargs.items())
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
#
# Public API
#
@tf_export('autograph.experimental.do_not_convert')
def do_not_convert(func=None):
"""Decorator that suppresses the conversion of a function.
Args:
func: function to decorate.
Returns:
If `func` is not None, returns a `Callable` which is equivalent to
`func`, but is not converted by AutoGraph.
If `func` is None, returns a decorator that, when invoked with a
single `func` argument, returns a `Callable` equivalent to the
above case.
"""
if func is None:
return do_not_convert
def wrapper(*args, **kwargs):
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
return func(*args, **kwargs)
if inspect.isfunction(func) or inspect.ismethod(func):
wrapper = functools.update_wrapper(wrapper, func)
return autograph_artifact(wrapper)
# TODO(mdan): Make private.
def convert(recursive=False,
optional_features=None,
user_requested=True,
conversion_ctx=ag_ctx.NullCtx()):
"""Decorator that compiles a function to use TensorFlow ops.
The decorator is dynamic - it recompiles the target whenever the decorated
function is called. This means the parameter values are known at conversion.
It also means that repeated calls with different types of parameters will be
correctly processed.
Args:
recursive: bool, whether to recursively convert any functions or classes
that the converted function may use.
optional_features: converted.Feature, allows toggling optional or
experimental features. When set to None, only the core features are
enabled.
user_requested: bool, whether this is a function that the user explicitly
asked to be converted. See ConversionOptions.user_requested.
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
which `f` is used.
Returns:
Callable, a decorator that converts the given function into an equivalent
function that uses TensorFlow ops.
"""
def decorator(f):
"""Decorator implementation."""
def wrapper(*args, **kwargs):
"""Wrapper that calls the converted version of f."""
options = converter.ConversionOptions(
recursive=recursive,
user_requested=user_requested,
optional_features=optional_features)
try:
with conversion_ctx:
return converted_call(f, args, kwargs, options=options)
except Exception as e: # pylint:disable=broad-except
if hasattr(e, 'ag_error_metadata'):
raise e.ag_error_metadata.to_exception(e)
else:
raise
if inspect.isfunction(f) or inspect.ismethod(f):
wrapper = functools.update_wrapper(wrapper, f)
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
return autograph_artifact(decorated_wrapper)
return decorator
# pylint:disable=line-too-long
@tf_export('autograph.to_graph', v1=[])
def to_graph(entity, recursive=True, experimental_optional_features=None):
@ -668,9 +738,8 @@ def to_graph(entity, recursive=True, experimental_optional_features=None):
options=converter.ConversionOptions(
recursive=recursive,
user_requested=True,
optional_features=experimental_optional_features),
autograph_module=tf_inspect.getmodule(to_graph))
return autograph_artifact(conversion.convert(entity, program_ctx))
optional_features=experimental_optional_features))
return autograph_artifact(_convert_actual(entity, program_ctx))
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
logging.error(1, 'Error converting %s', entity, exc_info=True)
raise ConversionError('converting {}: {}: {}'.format(
@ -845,3 +914,6 @@ def to_code(entity, recursive=True, experimental_optional_features=None):
recursive=recursive,
experimental_optional_features=experimental_optional_features))
return textwrap.dedent(source)
_TRANSPILER = PyToTF()

View File

@ -19,111 +19,97 @@ from __future__ import division
from __future__ import print_function
import functools
import imp
import inspect
import sys
import unittest
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.converters import asserts
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.converters import call_trees
from tensorflow.python.autograph.converters import conditional_expressions
from tensorflow.python.autograph.converters import continue_statements
from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import functions
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.converters import variables
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import function_wrappers
from tensorflow.python.autograph.core import unsupported_features_checker
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cache
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transpiler
from tensorflow.python.autograph.pyct.static_analysis import activity
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.eager import function
from tensorflow.python.util import tf_inspect
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
def get_transformed_name(self, node):
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
def initial_analysis(self, node, ctx):
graphs = cfg.build(node)
node = qual_names.resolve(node)
node = activity.resolve(node, ctx, None)
node = reaching_definitions.resolve(node, ctx, graphs)
anno.dup(
node,
{
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
},
)
return node
def transform_ast(self, node, ctx):
unsupported_features_checker.verify(node)
node = self.initial_analysis(node, ctx)
node = functions.transform(node, ctx)
node = directives.transform(node, ctx)
node = break_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
node = asserts.transform(node, ctx)
# Note: sequencing continue canonicalization before for loop one avoids
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = continue_statements.transform(node, ctx)
node = return_statements.transform(node, ctx)
if ctx.user.options.uses(converter.Feature.LISTS):
node = lists.transform(node, ctx)
node = slices.transform(node, ctx)
node = call_trees.transform(node, ctx)
node = control_flow.transform(node, ctx)
node = conditional_expressions.transform(node, ctx)
node = logical_expressions.transform(node, ctx)
node = variables.transform(node, ctx)
return node
_TRANSPILER = AutoGraphTranspiler()
_WHITELIST_CACHE = cache.UnboundInstanceCache()
custom_vars = None
def _is_of_known_loaded_module(f, module_name):
mod = sys.modules.get(module_name, None)
if mod is None:
return False
if any(v is not None for v in mod.__dict__.values() if f is v):
return True
return False
# TODO(mdan): Superfluous function, remove.
# TODO(mdan): Put these extra fields inside __autograph_info__.
def convert(entity, program_ctx):
"""Applies AutoGraph to entity."""
def _is_known_loaded_type(f, module_name, entity_name):
"""Tests whether the function or method is an instance of a known type."""
if (module_name not in sys.modules or
not hasattr(sys.modules[module_name], entity_name)):
return False
type_entity = getattr(sys.modules[module_name], entity_name)
if isinstance(f, type_entity):
# The method if of this type. Example:
#
# o = ClassType()
# function(o.method)()
return True
# Note: inspect is required here, to avoid unpacking tf.function decorators.
if inspect.ismethod(f):
# The the unbound method if of this type. Example:
#
# class ClassType:
# @function
# def method(self):
# ...
# o = ClassType()
# o.method()
if isinstance(f.__func__, type_entity):
return True
return False
if not hasattr(entity, '__code__'):
raise ValueError('Cannot apply autograph to a function that doesn\'t '
'expose a __code__ object. If this is a @tf.function,'
' try passing f.python_function instead.')
create_custom_vars(program_ctx)
transformed, module, source_map = _TRANSPILER.transform_function(
entity, program_ctx.options, program_ctx, custom_vars)
def is_unsupported(o):
"""Checks whether an entity is supported by AutoGraph at all."""
assert not hasattr(transformed, 'ag_module')
assert not hasattr(transformed, 'ag_source_map')
transformed.ag_module = module
transformed.ag_source_map = source_map
return transformed
# TODO(b/122265385): Remove this bypass.
if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
_is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
logging.warn(
'{} appears to be decorated by wrapt, which is not yet supported'
' by AutoGraph. The function will run as-is.'
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o)
return True
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
logging.log(2, 'Permanently whitelisted: %s: lru_cache', o)
return True
# Constructors are permanently whitelisted.
# TODO(mdan): Toggle as experimental feature instead.
# TODO(b/124016764): Remove this limitation.
if inspect_utils.isconstructor(o):
logging.log(2, 'Permanently whitelisted: %s: constructor', o)
return True
# Other built-in modules are permanently whitelisted.
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
if any(
_is_of_known_loaded_module(o, m)
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o)
return True
# Custom ops and kernels are also permanently whitelisted.
# See tensorflow.framework.load_library.
if (hasattr(o, '__module__') and
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o)
return True
return False
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
@ -246,28 +232,3 @@ def cache_whitelisted(entity, options):
except TypeError:
# Catch-all for entities that are unhashable or don't allow weakrefs.
pass
# TODO(mdan): Move into core or replace with an actual importable module.
# Visible for testing.
def create_custom_vars(program_ctx):
"""Adds namespace references to the module that exposes the api itself."""
global custom_vars
if custom_vars is None:
# Craft a module that exposes parts of the external API as well as certain
# internal modules.
ag_internal = imp.new_module('autograph')
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
ag_internal.ConversionOptions = converter.ConversionOptions
ag_internal.STD = converter.STANDARD_OPTIONS
ag_internal.Feature = converter.Feature
ag_internal.utils = utils
ag_internal.FunctionScope = function_wrappers.FunctionScope
ag_internal.with_function_scope = function_wrappers.with_function_scope
# TODO(mdan): Add safeguards against name clashes.
# We don't want to create a submodule because we want the operators to be
# accessible as ag__.<operator>
ag_internal.__dict__.update(special_functions.__dict__)
ag_internal.__dict__.update(operators.__dict__)
custom_vars = {'ag__': ag_internal}

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import threading
import types
@ -143,11 +144,11 @@ def _wrap_into_factory(nodes, entity_name, inner_factory_name,
outer_factory_name=outer_factory_name)
class _TransformedFnFactory(object):
"""Helper object that wraps a transformed function factory."""
class _PythonFnFactory(object):
"""Helper object that wraps a Python function factory."""
def __init__(self, name, freevars, extra_locals):
"""Creates a new factory for a transformed function.
"""Creates a new factory for a Python function.
Args:
name: The function name.
@ -169,7 +170,7 @@ class _TransformedFnFactory(object):
inner_factory_name='inner_factory',
outer_factory_name='outer_factory',
future_features=()):
"""Initializes a transformed function."""
"""Initializes a function."""
if self._unbound_factory is not None:
raise ValueError('double initialization; create a new object instead')
@ -191,7 +192,7 @@ class _TransformedFnFactory(object):
closure,
defaults=None,
kwdefaults=None):
"""Creates a new instance of the transformed function."""
"""Creates a new function instance."""
if self._unbound_factory is None:
raise ValueError('call create first')
@ -213,78 +214,70 @@ class _TransformedFnFactory(object):
closure=factory_closure)
# The lint override is a false positive.
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable
if defaults:
transformed_entity.__defaults__ = defaults
new_fn.__defaults__ = defaults
if kwdefaults:
transformed_entity.__kwdefaults__ = kwdefaults
new_fn.__kwdefaults__ = kwdefaults
return transformed_entity
return new_fn
class FunctionTranspiler(object):
"""A generic source-to-source transpiler for Python functions.
class GenericTranspiler(object):
"""A generic transpiler for Python functions.
Its interface `transform_function` API offers a function-in, function-out
interface. Internally, it takes care of parsing, caching and variable binding.
Its interface is the `transform` API, which can process Python function
objects. Internally, it handles parsing.
Users typically subclass this, customizing the transform_ast method.
Usually, instances of this class are singletons, since each instance manages
its own cache. The caching subkey allows managing multiple types of
transformation.
Users typically subclass this, customizing the `transform_ast` method. The
output of transformed_ast is returned directly by `transform`. Existing
methods like `transform_function` may also be overloaded.
Example:
class MyTransformer(FunctionTranspiler):
class MyTransformer(GenericTranspiler):
def transform_ast(self, node, ctx):
node = <<transform node, usually using ast.NodeTransformer classes>>
return node
def transform(self, obj):
result = <<transform node>>
return result
transformer = MyTransfomer()
new_f, module, source_map = transformer.transform_function(f, ...)
# new_f is a function with signature identical to f
The transformed function has access to the same namespace as the original
function. To allow access to internal APIs, users may inject additional
symbols though the `extra_locals` argument of `transform_function`.
result = transformer.transform(f, ...)
# result is the output
"""
def __init__(self):
self._cache_lock = threading.RLock()
self._cache = cache.CodeObjectCache()
def transform_ast(self, node, user_context):
def transform_ast(self, node, ctx):
"""Performs an actual transformation of a function's AST.
Subclasses must implement this method. They must not call it.
The method receives the original AST and generates code according to the
AST that the method returns. For functions, the returned AST is expected to
contain a function with the exact same arguments and closure. The resulting
function will receive the globals, closure and argument defaults of the
input function.
Subclasses must implement this method, and do not usually call it.
Args:
node: One or more ast.AST nodes representing the AST to be transformed.
user_context: The same value that the caller passed to
`transform_function`.
ctx: transformer.Context.
"""
raise NotImplementedError('subclasses must override this')
def get_transformed_name(self, node):
"""Returns a name for the output function. Subclasses may override this."""
if isinstance(node, gast.Lambda):
return 'lam'
elif isinstance(node, gast.FunctionDef):
# Note that we need to rename the function, to avoid any namespace
# clashes.
return node.name
else:
raise ValueError('Unknown node type {}'.format(node))
def transform(self, obj, user_context):
"""Transforms a Python object.
Users typically call this method.
Args:
obj: A Python object, function, type, etc.
user_context: An opaque object (may be None) that is forwarded to
transform_ast, through the ctx.user_context argument.
Returns:
Tre result of calling transform_function.
Raises:
NotImplementedError: if the type of obj is not handled.
"""
if inspect.isfunction(obj) or inspect.ismethod(obj):
return self.transform_function(obj, user_context)
raise NotImplementedError('Non-function: {}'.format(type(obj)))
def _erase_arg_defaults(self, node):
"""Erase argde fault expressions, which would otherwise be unbound."""
@ -296,8 +289,21 @@ class FunctionTranspiler(object):
args.kw_defaults[i] = parser.parse_expression('None')
return node
def _transform_function(self, fn, user_context):
"""Performs source code transformation on a function."""
def transform_function(self, fn, user_context):
"""Transforms a function.
Subclasses may override this method. The return value is opaque.
The method receives the original AST. The result is passed as-is to the
output of `transform`.
Args:
fn: A function or lambda.
user_context: An opaque object (may be None) that is forwarded to
transform_ast, through the ctx.user_context argument.
Returns:
Any. By default it returns the output of transform_ast.
"""
future_features = inspect_utils.getfutureimports(fn)
node, source = parser.parse_entity(fn, future_features=future_features)
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
@ -333,63 +339,129 @@ class FunctionTranspiler(object):
return node, context
class PyToPy(GenericTranspiler):
"""A generic Python-to-Python transpiler.
Its `transform` method offers a function-in, function-out interface.
Internally, it takes care of parsing, caching and loading of the translated
code.
Users typically subclass this, overriding `transform_ast`.
Usually, instances of this class are singletons, since each instance manages
its own cache. The caching can be controlled by overriding `get_caching_key`.
Example:
class MyTransformer(PyToPy):
def transform_ast(self, node, ctx):
node = <<transform node, usually using ast.NodeTransformer classes>>
return node
transformer = MyTransfomer()
new_f, module, source_map = transformer.transform_function(f, ...)
# new_f is a function with signature identical to f
The transformed function has access to the same namespace as the original
function. To allow access to internal APIs, users may inject additional
symbols by overriding `get_extra_locals`.
"""
def __init__(self):
self._cache_lock = threading.RLock()
self._cache = cache.CodeObjectCache()
def get_transformed_name(self, node):
"""Returns a name for the output function. Subclasses may override this."""
if isinstance(node, gast.Lambda):
return 'lam'
elif isinstance(node, gast.FunctionDef):
# Note that we need to rename the function, to avoid any namespace
# clashes.
return node.name
raise ValueError('Unknown node type {}'.format(node))
def get_extra_locals(self):
"""Returns extra static local variables to be made to transformed code.
Subclasses must override this.
Returns:
extra_locals: A Dict[Text, Any] containing additional variables to make
available to the transformed code.
"""
raise NotImplementedError('subclasses must override this')
def get_caching_key(self, user_context):
"""Returns a unique key to use for caching.
Subclasses must override this.
Calls made to `transform_function` with functions that have the same code
object and caching key will return a cached instance on subsequent
invocations.
Args:
user_context: The context object which was passed to `transform`.
Returns:
extra_locals: A hashable.
"""
raise NotImplementedError('subclasses must override this')
def _cached_factory(self, fn, cache_subkey):
cached_factory = self._cache[fn][cache_subkey]
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
cached_factory)
return cached_factory
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
"""Returns the transformed function factory for a given input."""
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
def transform_function(self, fn, user_context):
"""Transforms a function. See GenericTranspiler.trasnform_function.
with self._cache_lock:
# Check again under lock.
if self._cache.has(fn, cache_subkey):
return self._cached_factory(fn, cache_subkey)
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
nodes, ctx = self._transform_function(fn, user_context)
if logging.has_verbosity(2):
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
factory = _TransformedFnFactory(
ctx.info.name, fn.__code__.co_freevars, extra_locals)
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
self._cache[fn][cache_subkey] = factory
return factory
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
"""Transforms a function.
The `caching_subkey` argument allows mapping each function to multiple
outputs in the cache. This is useful for instance when transformers
can generate multiple variants of output code, typically as a result of
different transformation flags.
This overload wraps the parent's `transform_function`, adding caching and
facilities to instantiate the output as a Python object. It also
adds facilities to make new symbols available to the generated Python code,
visible as local variables - see `get_extra_locals`.
Args:
fn: A function or lambda.
caching_subkey: Used for caching. Calls made for functions with the same
code object and caching_subkey will return a cached instance on
subsequent invocations. Using a constant will create unique per-function
entries.
user_context: An opaque object (may be none) that is forwarded to
transform_ast.
extra_locals: A Dict[Text, Any] containing additional variables to make
available to the transformed code. These will be visible as local
variables.
user_context: An opaque object (may be None) that is forwarded to
transform_ast, through the ctx.user_context argument.
Returns:
A tuple:
* A function or lambda with the same signature and closure as `fn`
* The temporary module into which the transformed function was loaded
* The source map as a
Dict[origin_info.LineLocation, origin_info.OriginInfo]
"""
factory = self._transformed_factory(fn, caching_subkey, user_context,
extra_locals)
cache_subkey = self.get_caching_key(user_context)
if self._cache.has(fn, cache_subkey):
# Fast path: use a lock-free check.
factory = self._cached_factory(fn, cache_subkey)
else:
with self._cache_lock:
# Check again under lock.
if self._cache.has(fn, cache_subkey):
factory = self._cached_factory(fn, cache_subkey)
else:
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
# TODO(mdan): Confusing overloading pattern. Fix.
nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
if logging.has_verbosity(2):
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
factory = _PythonFnFactory(
ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals())
factory.create(
nodes, ctx.namer, future_features=ctx.info.future_features)
self._cache[fn][cache_subkey] = factory
transformed_fn = factory.instantiate(
globals_=fn.__globals__,

View File

@ -35,7 +35,14 @@ class FlipSignTransformer(transformer.Base):
return self.generic_visit(node)
class TestTranspiler(transpiler.FunctionTranspiler):
class TestTranspiler(transpiler.PyToPy):
def get_caching_key(self, ctx):
del ctx
return 0
def get_extra_locals(self):
return {}
def transform_ast(self, node, ctx):
return FlipSignTransformer(ctx).visit(node)
@ -45,14 +52,14 @@ global_var_for_test_global = 1
global_var_for_test_namespace_collisions = object()
class FunctionTranspilerTest(test.TestCase):
class PyToPyTest(test.TestCase):
def test_basic(self):
def f(a):
return a + 1
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 0)
@ -63,7 +70,7 @@ class FunctionTranspilerTest(test.TestCase):
return a + b
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 0)
b = 2
@ -74,7 +81,7 @@ class FunctionTranspilerTest(test.TestCase):
return a + global_var_for_test_global
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
global global_var_for_test_global
global_var_for_test_global = 1
@ -90,7 +97,7 @@ class FunctionTranspilerTest(test.TestCase):
return a + b + d
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 1 - 2 - 2)
c = 0
@ -107,7 +114,7 @@ class FunctionTranspilerTest(test.TestCase):
return g(a) + 1
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted.
@ -116,7 +123,7 @@ class FunctionTranspilerTest(test.TestCase):
f = lambda x: (b + (x if x > 0 else -x))
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 2 - 1)
self.assertEqual(f(-1), 2 - 1)
@ -132,7 +139,7 @@ class FunctionTranspilerTest(test.TestCase):
f, _ = (lambda x: a + x, lambda y: b * y)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 1 - 1)
@ -147,7 +154,7 @@ class FunctionTranspilerTest(test.TestCase):
return g(x)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 2 - 1)
@ -159,7 +166,7 @@ class FunctionTranspilerTest(test.TestCase):
return g(x)
tr = TestTranspiler()
f, _, _ = tr.transform_function(f, object(), None, {})
f, _, _ = tr.transform(f, None)
self.assertEqual(f(1), 2 - 1)
@ -171,9 +178,11 @@ class FunctionTranspilerTest(test.TestCase):
outputs = []
tr = TestTranspiler()
cache_key = object()
# Note: this is not a test, it's a required invariant.
assert tr.get_caching_key(None) == tr.get_caching_key(None)
def conversion_thread():
_, mod, _ = tr.transform_function(f, cache_key, None, {})
_, mod, _ = tr.transform(f, None)
outputs.append(mod.__name__)
threads = tuple(
@ -192,21 +201,28 @@ class FunctionTranspilerTest(test.TestCase):
def test_fn():
return 1 + 1
class ReentrantTranspiler(transpiler.FunctionTranspiler):
class ReentrantTranspiler(transpiler.PyToPy):
def __init__(self):
super(ReentrantTranspiler, self).__init__()
self._recursion_depth = 0
def get_caching_key(self, ctx):
del ctx
return 0
def get_extra_locals(self):
return {}
def transform_ast(self, node, ctx):
self._recursion_depth += 1
if self._recursion_depth < 2:
self.transform_function(test_fn, object(), None, {})
self.transform(test_fn, None)
return FlipSignTransformer(ctx).visit(node)
tr = ReentrantTranspiler()
f, _, _ = tr.transform_function(test_fn, object(), None, {})
f, _, _ = tr.transform(test_fn, None)
self.assertEqual(f(), 0)
def test_namespace_collisions_avoided(self):
@ -219,8 +235,8 @@ class FunctionTranspilerTest(test.TestCase):
tr = TestTranspiler()
obj = TestClass()
f, _, _ = tr.transform_function(
obj.global_var_for_test_namespace_collisions, object(), None, {})
f, _, _ = tr.transform(
obj.global_var_for_test_namespace_collisions, None)
self.assertIs(f(obj), global_var_for_test_namespace_collisions)