Slightly refactor the source-to-source translation API to better support non-Python outputs.
PiperOrigin-RevId: 319982764 Change-Id: I49017145719330596b55b0f9190eccf29a9a46c4
This commit is contained in:
parent
0d3c1eef7d
commit
ca59e0b5d7
@ -63,7 +63,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import enum
|
||||
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
@ -234,18 +233,17 @@ STANDARD_OPTIONS = ConversionOptions(
|
||||
optional_features=None)
|
||||
|
||||
|
||||
class ProgramContext(
|
||||
collections.namedtuple('ProgramContext', ('options', 'autograph_module'))):
|
||||
class ProgramContext(object):
|
||||
"""ProgramContext keeps track of converting function hierarchies.
|
||||
|
||||
This object is mutable, and is updated during conversion. Not thread safe.
|
||||
|
||||
Attributes:
|
||||
options: ConversionOptions
|
||||
autograph_module: Module, a reference to the autograph module. This needs to
|
||||
be specified by the caller to avoid circular dependencies.
|
||||
autograph_module: Deprecated. Do not use.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __init__(self, options, autograph_module=None):
|
||||
self.options = options
|
||||
self.autograph_module = autograph_module
|
||||
|
||||
|
||||
class Base(transformer.Base):
|
||||
|
@ -28,7 +28,6 @@ import six
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.impl import api
|
||||
from tensorflow.python.autograph.impl import conversion
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -64,16 +63,26 @@ def is_inside_generated_code():
|
||||
del frame
|
||||
|
||||
|
||||
class TestingTranspiler(conversion.AutoGraphTranspiler):
|
||||
class TestingTranspiler(api.PyToTF):
|
||||
"""Testing version that only applies given transformations."""
|
||||
|
||||
def __init__(self, converters):
|
||||
def __init__(self, converters, ag_overrides):
|
||||
super(TestingTranspiler, self).__init__()
|
||||
if isinstance(converters, (list, tuple)):
|
||||
self._converters = converters
|
||||
else:
|
||||
self._converters = (converters,)
|
||||
self.transformed_ast = None
|
||||
self._ag_overrides = ag_overrides
|
||||
|
||||
def get_extra_locals(self):
|
||||
retval = super(TestingTranspiler, self).get_extra_locals()
|
||||
if self._ag_overrides:
|
||||
modified_ag = imp.new_module('fake_autograph')
|
||||
modified_ag.__dict__.update(retval['ag__'].__dict__)
|
||||
modified_ag.__dict__.update(self._ag_overrides)
|
||||
retval['ag__'] = modified_ag
|
||||
return retval
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
node = self.initial_analysis(node, ctx)
|
||||
@ -113,18 +122,8 @@ class TestCase(test.TestCase):
|
||||
options=converter.ConversionOptions(recursive=True),
|
||||
autograph_module=api)
|
||||
|
||||
conversion.create_custom_vars(program_ctx)
|
||||
custom_vars = dict(conversion.custom_vars)
|
||||
|
||||
if ag_overrides:
|
||||
modified_ag = imp.new_module('fake_autograph')
|
||||
modified_ag.__dict__.update(custom_vars['ag__'].__dict__)
|
||||
modified_ag.__dict__.update(ag_overrides)
|
||||
custom_vars['ag__'] = modified_ag
|
||||
|
||||
tr = TestingTranspiler(converter_module)
|
||||
transformed, _, _ = tr.transform_function(
|
||||
f, program_ctx.options, program_ctx, custom_vars)
|
||||
tr = TestingTranspiler(converter_module, ag_overrides)
|
||||
transformed, _, _ = tr.transform_function(f, program_ctx)
|
||||
|
||||
if include_ast:
|
||||
return transformed, tr.transformed_ast, tr.transform_ctx
|
||||
|
@ -12,13 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""This module contains the user-facing API for AutoGraph."""
|
||||
"""This module contains the user- and codegen-facing API for AutoGraph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import imp
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
@ -27,14 +28,38 @@ import traceback
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.autograph import operators
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.converters import call_trees
|
||||
from tensorflow.python.autograph.converters import conditional_expressions
|
||||
from tensorflow.python.autograph.converters import continue_statements
|
||||
from tensorflow.python.autograph.converters import control_flow
|
||||
from tensorflow.python.autograph.converters import directives
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import lists
|
||||
from tensorflow.python.autograph.converters import logical_expressions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.converters import slices
|
||||
from tensorflow.python.autograph.converters import variables
|
||||
from tensorflow.python.autograph.core import ag_ctx
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.impl import conversion
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import error_utils
|
||||
from tensorflow.python.autograph.pyct import errors
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import origin_info
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transpiler
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import errors_impl
|
||||
@ -48,6 +73,11 @@ def is_autograph_strict_conversion_mode():
|
||||
return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0
|
||||
|
||||
|
||||
#
|
||||
# Error handling
|
||||
#
|
||||
|
||||
|
||||
# TODO(mdan): Export this symbol.
|
||||
class AutoGraphError(errors.PyCTError):
|
||||
"""Base class for all AutoGraph exceptions."""
|
||||
@ -113,6 +143,26 @@ class _ErrorMetadata(error_utils.ErrorMetadataBase):
|
||||
return StagingError(self.get_message())
|
||||
|
||||
|
||||
def _attach_error_metadata(e, f):
|
||||
"""Augments an error with the metadata necessary for rewrite."""
|
||||
if hasattr(e, 'ag_pass_through'):
|
||||
return
|
||||
|
||||
metadata = getattr(e, 'ag_error_metadata', None)
|
||||
source_map = f.ag_source_map
|
||||
|
||||
if metadata is None:
|
||||
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
|
||||
message = '{}: {}'.format(e.__class__.__name__, e)
|
||||
else:
|
||||
message = None
|
||||
|
||||
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
|
||||
|
||||
e.ag_error_metadata = _ErrorMetadata(
|
||||
cause_tb, metadata, message, source_map, __file__)
|
||||
|
||||
|
||||
class StackTraceMapper(tf_stack.StackTraceMapper):
|
||||
"""Remaps generated code to code it originated from."""
|
||||
|
||||
@ -145,6 +195,106 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
|
||||
return effective_source_map
|
||||
|
||||
|
||||
#
|
||||
# Actual source code transformation
|
||||
#
|
||||
|
||||
|
||||
class PyToTF(transpiler.PyToPy):
|
||||
"""The TensorFlow AutoGraph transformer."""
|
||||
|
||||
def __init__(self):
|
||||
super(PyToTF, self).__init__()
|
||||
|
||||
# TODO(mdan): Move into core or replace with an actual importable module.
|
||||
# Craft a module that exposes the external API as well as certain
|
||||
# internal modules.
|
||||
ag_internal = imp.new_module('autograph')
|
||||
ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__)
|
||||
ag_internal.ConversionOptions = converter.ConversionOptions
|
||||
ag_internal.STD = converter.STANDARD_OPTIONS
|
||||
ag_internal.Feature = converter.Feature
|
||||
ag_internal.utils = utils
|
||||
ag_internal.FunctionScope = function_wrappers.FunctionScope
|
||||
ag_internal.with_function_scope = function_wrappers.with_function_scope
|
||||
# TODO(mdan): Add safeguards against name clashes.
|
||||
# We don't want to create a submodule because we want the operators to be
|
||||
# accessible as ag__.<operator>
|
||||
ag_internal.__dict__.update(special_functions.__dict__)
|
||||
ag_internal.__dict__.update(operators.__dict__)
|
||||
|
||||
self._extra_locals = {'ag__': ag_internal}
|
||||
|
||||
def get_transformed_name(self, node):
|
||||
return 'tf__' + super(PyToTF, self).get_transformed_name(node)
|
||||
|
||||
def get_extra_locals(self):
|
||||
return self._extra_locals
|
||||
|
||||
def get_caching_key(self, ctx):
|
||||
return ctx.options
|
||||
|
||||
def initial_analysis(self, node, ctx):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
anno.dup(
|
||||
node,
|
||||
{
|
||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||
},
|
||||
)
|
||||
return node
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
unsupported_features_checker.verify(node)
|
||||
node = self.initial_analysis(node, ctx)
|
||||
|
||||
node = functions.transform(node, ctx)
|
||||
node = directives.transform(node, ctx)
|
||||
node = break_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||
node = asserts.transform(node, ctx)
|
||||
# Note: sequencing continue canonicalization before for loop one avoids
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = continue_statements.transform(node, ctx)
|
||||
node = return_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.LISTS):
|
||||
node = lists.transform(node, ctx)
|
||||
node = slices.transform(node, ctx)
|
||||
node = call_trees.transform(node, ctx)
|
||||
node = control_flow.transform(node, ctx)
|
||||
node = conditional_expressions.transform(node, ctx)
|
||||
node = logical_expressions.transform(node, ctx)
|
||||
node = variables.transform(node, ctx)
|
||||
return node
|
||||
|
||||
|
||||
def _convert_actual(entity, program_ctx):
|
||||
"""Applies AutoGraph to entity."""
|
||||
|
||||
# TODO(mdan): Put these extra fields inside __autograph_info__.
|
||||
if not hasattr(entity, '__code__'):
|
||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||
'expose a __code__ object. If this is a @tf.function,'
|
||||
' try passing f.python_function instead.')
|
||||
|
||||
transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
|
||||
|
||||
assert not hasattr(transformed, 'ag_module')
|
||||
assert not hasattr(transformed, 'ag_source_map')
|
||||
transformed.ag_module = module
|
||||
transformed.ag_source_map = source_map
|
||||
return transformed
|
||||
|
||||
|
||||
#
|
||||
# Generated code support
|
||||
#
|
||||
|
||||
|
||||
def autograph_artifact(entity, extras=None):
|
||||
setattr(entity, 'autograph_info__', extras)
|
||||
return entity
|
||||
@ -154,272 +304,12 @@ def is_autograph_artifact(entity):
|
||||
return hasattr(entity, 'autograph_info__')
|
||||
|
||||
|
||||
def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||
"""Decorator that applies AutoGraph to a function.
|
||||
|
||||
Use in internal APIs.
|
||||
|
||||
This API is suitable for high order functions internal to the TensorFlow API,
|
||||
and more generally any function to which Autograph is not applied.
|
||||
|
||||
Guidance: convert was a decorator meant for use directly by developers, and
|
||||
will be soon deprecated in favor of tf.function. tf_convert is to be called
|
||||
from high order functions internal to TF.
|
||||
|
||||
Args:
|
||||
f: Callable.
|
||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||
specify.
|
||||
user_requested: bool, whether to ignore the conversion whitelist. See
|
||||
ConversionOptions.user_requested.
|
||||
|
||||
Returns:
|
||||
Either `f or the converted version of `f`.
|
||||
"""
|
||||
|
||||
if is_autograph_artifact(f):
|
||||
return f
|
||||
f_wrapper = f
|
||||
decorators, f = tf_decorator.unwrap(f)
|
||||
|
||||
# TODO(mdan): Grab features from context.
|
||||
# Note: we pass the original context through to convert to properly handle the
|
||||
# following scenario, which can be used inside TF implementations:
|
||||
#
|
||||
# ctx = ag_ctx.control_status_ctx()
|
||||
# @function(autograph=False) # Low-level graph code
|
||||
# def inner_fn():
|
||||
# # The context is disabled here, but should be enabled in user user_fn
|
||||
# tf_convert(user_fn, ctx=ctx)
|
||||
if ctx.status == ag_ctx.Status.ENABLED:
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
elif ctx.status == ag_ctx.Status.DISABLED:
|
||||
wrapper_factory = do_not_convert
|
||||
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
|
||||
if convert_by_default:
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
else:
|
||||
wrapper_factory = call_with_unspecified_conversion_status
|
||||
else:
|
||||
assert False, 'This switch contains all possible cases!'
|
||||
wrapper = wrapper_factory(f)
|
||||
|
||||
if decorators:
|
||||
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
# TODO(mdan): Make private.
|
||||
def convert(recursive=False,
|
||||
optional_features=None,
|
||||
user_requested=True,
|
||||
conversion_ctx=ag_ctx.NullCtx()):
|
||||
"""Decorator that compiles a function to use TensorFlow ops.
|
||||
|
||||
The decorator is dynamic - it recompiles the target whenever the decorated
|
||||
function is called. This means the parameter values are known at conversion.
|
||||
It also means that repeated calls with different types of parameters will be
|
||||
correctly processed.
|
||||
|
||||
Args:
|
||||
recursive: bool, whether to recursively convert any functions or classes
|
||||
that the converted function may use.
|
||||
optional_features: converted.Feature, allows toggling optional or
|
||||
experimental features. When set to None, only the core features are
|
||||
enabled.
|
||||
user_requested: bool, whether this is a function that the user explicitly
|
||||
asked to be converted. See ConversionOptions.user_requested.
|
||||
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
|
||||
which `f` is used.
|
||||
|
||||
Returns:
|
||||
Callable, a decorator that converts the given function into an equivalent
|
||||
function that uses TensorFlow ops.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
"""Decorator implementation."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper that calls the converted version of f."""
|
||||
options = converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
user_requested=user_requested,
|
||||
optional_features=optional_features)
|
||||
try:
|
||||
with conversion_ctx:
|
||||
return converted_call(f, args, kwargs, options=options)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
if hasattr(e, 'ag_error_metadata'):
|
||||
raise e.ag_error_metadata.to_exception(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
if inspect.isfunction(f) or inspect.ismethod(f):
|
||||
wrapper = functools.update_wrapper(wrapper, f)
|
||||
|
||||
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||
return autograph_artifact(decorated_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def call_with_unspecified_conversion_status(func):
|
||||
"""Decorator that resets the conversion context to the unspecified status."""
|
||||
def wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
@tf_export('autograph.experimental.do_not_convert')
|
||||
def do_not_convert(func=None):
|
||||
"""Decorator that suppresses the conversion of a function.
|
||||
|
||||
Args:
|
||||
func: function to decorate.
|
||||
|
||||
Returns:
|
||||
If `func` is not None, returns a `Callable` which is equivalent to
|
||||
`func`, but is not converted by AutoGraph.
|
||||
If `func` is None, returns a decorator that, when invoked with a
|
||||
single `func` argument, returns a `Callable` equivalent to the
|
||||
above case.
|
||||
"""
|
||||
if func is None:
|
||||
return do_not_convert
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
def _attach_metadata(e, f):
|
||||
"""Augments an error with the metadata necessary for rewrite."""
|
||||
if hasattr(e, 'ag_pass_through'):
|
||||
return
|
||||
|
||||
metadata = getattr(e, 'ag_error_metadata', None)
|
||||
source_map = f.ag_source_map
|
||||
|
||||
if metadata is None:
|
||||
logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
|
||||
message = '{}: {}'.format(e.__class__.__name__, e)
|
||||
else:
|
||||
message = None
|
||||
|
||||
cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
|
||||
|
||||
e.ag_error_metadata = _ErrorMetadata(
|
||||
cause_tb, metadata, message, source_map, __file__)
|
||||
|
||||
|
||||
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||
"""Calls the original function without converting with AutoGraph."""
|
||||
if update_cache:
|
||||
conversion.cache_whitelisted(f, options)
|
||||
|
||||
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
|
||||
return f.__self__.call(args, kwargs)
|
||||
|
||||
if kwargs is not None:
|
||||
return f(*args, **kwargs)
|
||||
return f(*args)
|
||||
|
||||
|
||||
def _is_of_known_loaded_module(f, module_name):
|
||||
mod = sys.modules.get(module_name, None)
|
||||
if mod is None:
|
||||
return False
|
||||
if any(v is not None for v in mod.__dict__.values() if f is v):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_known_loaded_type(f, module_name, entity_name):
|
||||
"""Tests whether the function or method is an instance of a known type."""
|
||||
if (module_name not in sys.modules or
|
||||
not hasattr(sys.modules[module_name], entity_name)):
|
||||
return False
|
||||
type_entity = getattr(sys.modules[module_name], entity_name)
|
||||
if isinstance(f, type_entity):
|
||||
# The method if of this type. Example:
|
||||
#
|
||||
# o = ClassType()
|
||||
# function(o.method)()
|
||||
return True
|
||||
# Note: inspect is required here, to avoid unpacking tf.function decorators.
|
||||
if inspect.ismethod(f):
|
||||
# The the unbound method if of this type. Example:
|
||||
#
|
||||
# class ClassType:
|
||||
# @function
|
||||
# def method(self):
|
||||
# ...
|
||||
# o = ClassType()
|
||||
# o.method()
|
||||
if isinstance(f.__func__, type_entity):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _fall_back_unconverted(f, args, kwargs, options, exc):
|
||||
"""Falls back to calling the function unconverted, in case of error."""
|
||||
# TODO(mdan): Consider adding an internal metric.
|
||||
warning_template = (
|
||||
'AutoGraph could not transform %s and will run it as-is.\n'
|
||||
'%s'
|
||||
'Cause: %s\n'
|
||||
'To silence this warning, decorate the function with'
|
||||
' @tf.autograph.experimental.do_not_convert')
|
||||
if isinstance(exc, errors.UnsupportedLanguageElementError):
|
||||
if not conversion.is_in_whitelist_cache(f, options):
|
||||
logging.warn(warning_template, f, '', exc)
|
||||
else:
|
||||
file_bug_message = (
|
||||
'Please report this to the TensorFlow team. When filing the bug, set'
|
||||
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
|
||||
' attach the full output.\n')
|
||||
logging.warn(warning_template, f, file_bug_message, exc)
|
||||
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
|
||||
def _log_callargs(f, args, kwargs):
|
||||
"""Logging helper."""
|
||||
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
|
||||
if not six.PY2:
|
||||
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
|
||||
|
||||
if kwargs is not None:
|
||||
callargs = tf_inspect.getcallargs(f, *args, **kwargs)
|
||||
else:
|
||||
callargs = tf_inspect.getcallargs(f, *args)
|
||||
|
||||
formatted_callargs = '\n'.join(
|
||||
' {}: {}'.format(k, v) for k, v in callargs.items())
|
||||
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
|
||||
|
||||
|
||||
def converted_call(f,
|
||||
args,
|
||||
kwargs,
|
||||
caller_fn_scope=None,
|
||||
options=None):
|
||||
"""Compiles a function call inline.
|
||||
"""Converts a function call inline.
|
||||
|
||||
For internal use only.
|
||||
|
||||
@ -492,40 +382,7 @@ def converted_call(f,
|
||||
else:
|
||||
return py_builtins.overload_of(f)(*args)
|
||||
|
||||
# TODO(b/122265385): Remove this bypass.
|
||||
if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
|
||||
_is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
|
||||
logging.warn(
|
||||
'{} appears to be decorated by wrapt, which is not yet supported'
|
||||
' by AutoGraph. The function will run as-is.'
|
||||
' You may still apply AutoGraph before the wrapt decorator.'.format(f))
|
||||
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
# Constructors are permanently whitelisted.
|
||||
# TODO(mdan): Toggle as experimental feature instead.
|
||||
# TODO(b/124016764): Remove this limitation.
|
||||
if inspect_utils.isconstructor(f):
|
||||
logging.log(2, 'Permanently whitelisted: %s: constructor', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
# Other built-in modules are permanently whitelisted.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
if any(
|
||||
_is_of_known_loaded_module(f, m)
|
||||
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
# Custom ops and kernels are also permanently whitelisted.
|
||||
# See tensorflow.framework.load_library.
|
||||
if (hasattr(f, '__module__') and
|
||||
hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
|
||||
if conversion.is_unsupported(f):
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
if not options.user_requested and conversion.is_whitelisted(f):
|
||||
@ -579,9 +436,8 @@ def converted_call(f,
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
try:
|
||||
program_ctx = converter.ProgramContext(
|
||||
options=options, autograph_module=tf_inspect.getmodule(converted_call))
|
||||
converted_f = conversion.convert(target_entity, program_ctx)
|
||||
program_ctx = converter.ProgramContext(options=options)
|
||||
converted_f = _convert_actual(target_entity, program_ctx)
|
||||
if logging.has_verbosity(2):
|
||||
_log_callargs(converted_f, effective_args, kwargs)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
@ -597,12 +453,226 @@ def converted_call(f,
|
||||
else:
|
||||
result = converted_f(*effective_args)
|
||||
except Exception as e:
|
||||
_attach_metadata(e, converted_f)
|
||||
_attach_error_metadata(e, converted_f)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _call_unconverted(f, args, kwargs, options, update_cache=True):
|
||||
"""Calls the original function without converting with AutoGraph."""
|
||||
if update_cache:
|
||||
conversion.cache_whitelisted(f, options)
|
||||
|
||||
if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget):
|
||||
return f.__self__.call(args, kwargs)
|
||||
|
||||
if kwargs is not None:
|
||||
return f(*args, **kwargs)
|
||||
return f(*args)
|
||||
|
||||
|
||||
def _fall_back_unconverted(f, args, kwargs, options, exc):
|
||||
"""Falls back to calling the function unconverted, in case of error."""
|
||||
# TODO(mdan): Consider adding an internal metric.
|
||||
warning_template = (
|
||||
'AutoGraph could not transform %s and will run it as-is.\n'
|
||||
'%s'
|
||||
'Cause: %s\n'
|
||||
'To silence this warning, decorate the function with'
|
||||
' @tf.autograph.experimental.do_not_convert')
|
||||
if isinstance(exc, errors.UnsupportedLanguageElementError):
|
||||
if not conversion.is_in_whitelist_cache(f, options):
|
||||
logging.warn(warning_template, f, '', exc)
|
||||
else:
|
||||
file_bug_message = (
|
||||
'Please report this to the TensorFlow team. When filing the bug, set'
|
||||
' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
|
||||
' attach the full output.\n')
|
||||
logging.warn(warning_template, f, file_bug_message, exc)
|
||||
|
||||
return _call_unconverted(f, args, kwargs, options)
|
||||
|
||||
|
||||
#
|
||||
# TensorFlow integration
|
||||
#
|
||||
|
||||
|
||||
def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
|
||||
"""Decorator that applies AutoGraph to a function.
|
||||
|
||||
Use in internal APIs.
|
||||
|
||||
This API is suitable for high order functions internal to the TensorFlow API,
|
||||
and more generally any function to which Autograph is not applied.
|
||||
|
||||
Guidance: convert was a decorator meant for use directly by developers, and
|
||||
will be soon deprecated in favor of tf.function. tf_convert is to be called
|
||||
from high order functions internal to TF.
|
||||
|
||||
Args:
|
||||
f: Callable.
|
||||
ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
|
||||
convert_by_default: bool, whether to use AutoGraph when the context doesn't
|
||||
specify.
|
||||
user_requested: bool, whether to ignore the conversion whitelist. See
|
||||
ConversionOptions.user_requested.
|
||||
|
||||
Returns:
|
||||
Either `f or the converted version of `f`.
|
||||
"""
|
||||
|
||||
if is_autograph_artifact(f):
|
||||
return f
|
||||
f_wrapper = f
|
||||
decorators, f = tf_decorator.unwrap(f)
|
||||
|
||||
# TODO(mdan): Grab features from context.
|
||||
# Note: we pass the original context through to convert to properly handle the
|
||||
# following scenario, which can be used inside TF implementations:
|
||||
#
|
||||
# ctx = ag_ctx.control_status_ctx()
|
||||
# @function(autograph=False) # Low-level graph code
|
||||
# def inner_fn():
|
||||
# # The context is disabled here, but should be enabled in user user_fn
|
||||
# tf_convert(user_fn, ctx=ctx)
|
||||
if ctx.status == ag_ctx.Status.ENABLED:
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
elif ctx.status == ag_ctx.Status.DISABLED:
|
||||
wrapper_factory = do_not_convert
|
||||
elif ctx.status == ag_ctx.Status.UNSPECIFIED:
|
||||
if convert_by_default:
|
||||
wrapper_factory = convert(
|
||||
recursive=True, user_requested=user_requested, conversion_ctx=ctx)
|
||||
else:
|
||||
wrapper_factory = call_with_unspecified_conversion_status
|
||||
else:
|
||||
assert False, 'This switch contains all possible cases!'
|
||||
wrapper = wrapper_factory(f)
|
||||
|
||||
if decorators:
|
||||
wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
def call_with_unspecified_conversion_status(func):
|
||||
"""Decorator that resets the conversion context to the unspecified status."""
|
||||
def wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
def _log_callargs(f, args, kwargs):
|
||||
"""Logging helper."""
|
||||
logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
|
||||
if not six.PY2:
|
||||
logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)
|
||||
|
||||
if kwargs is not None:
|
||||
callargs = tf_inspect.getcallargs(f, *args, **kwargs)
|
||||
else:
|
||||
callargs = tf_inspect.getcallargs(f, *args)
|
||||
|
||||
formatted_callargs = '\n'.join(
|
||||
' {}: {}'.format(k, v) for k, v in callargs.items())
|
||||
logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
|
||||
|
||||
|
||||
#
|
||||
# Public API
|
||||
#
|
||||
|
||||
|
||||
@tf_export('autograph.experimental.do_not_convert')
|
||||
def do_not_convert(func=None):
|
||||
"""Decorator that suppresses the conversion of a function.
|
||||
|
||||
Args:
|
||||
func: function to decorate.
|
||||
|
||||
Returns:
|
||||
If `func` is not None, returns a `Callable` which is equivalent to
|
||||
`func`, but is not converted by AutoGraph.
|
||||
If `func` is None, returns a decorator that, when invoked with a
|
||||
single `func` argument, returns a `Callable` equivalent to the
|
||||
above case.
|
||||
"""
|
||||
if func is None:
|
||||
return do_not_convert
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if inspect.isfunction(func) or inspect.ismethod(func):
|
||||
wrapper = functools.update_wrapper(wrapper, func)
|
||||
|
||||
return autograph_artifact(wrapper)
|
||||
|
||||
|
||||
# TODO(mdan): Make private.
|
||||
def convert(recursive=False,
|
||||
optional_features=None,
|
||||
user_requested=True,
|
||||
conversion_ctx=ag_ctx.NullCtx()):
|
||||
"""Decorator that compiles a function to use TensorFlow ops.
|
||||
|
||||
The decorator is dynamic - it recompiles the target whenever the decorated
|
||||
function is called. This means the parameter values are known at conversion.
|
||||
It also means that repeated calls with different types of parameters will be
|
||||
correctly processed.
|
||||
|
||||
Args:
|
||||
recursive: bool, whether to recursively convert any functions or classes
|
||||
that the converted function may use.
|
||||
optional_features: converted.Feature, allows toggling optional or
|
||||
experimental features. When set to None, only the core features are
|
||||
enabled.
|
||||
user_requested: bool, whether this is a function that the user explicitly
|
||||
asked to be converted. See ConversionOptions.user_requested.
|
||||
conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in
|
||||
which `f` is used.
|
||||
|
||||
Returns:
|
||||
Callable, a decorator that converts the given function into an equivalent
|
||||
function that uses TensorFlow ops.
|
||||
"""
|
||||
|
||||
def decorator(f):
|
||||
"""Decorator implementation."""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Wrapper that calls the converted version of f."""
|
||||
options = converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
user_requested=user_requested,
|
||||
optional_features=optional_features)
|
||||
try:
|
||||
with conversion_ctx:
|
||||
return converted_call(f, args, kwargs, options=options)
|
||||
except Exception as e: # pylint:disable=broad-except
|
||||
if hasattr(e, 'ag_error_metadata'):
|
||||
raise e.ag_error_metadata.to_exception(e)
|
||||
else:
|
||||
raise
|
||||
|
||||
if inspect.isfunction(f) or inspect.ismethod(f):
|
||||
wrapper = functools.update_wrapper(wrapper, f)
|
||||
|
||||
decorated_wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||
return autograph_artifact(decorated_wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# pylint:disable=line-too-long
|
||||
@tf_export('autograph.to_graph', v1=[])
|
||||
def to_graph(entity, recursive=True, experimental_optional_features=None):
|
||||
@ -668,9 +738,8 @@ def to_graph(entity, recursive=True, experimental_optional_features=None):
|
||||
options=converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
user_requested=True,
|
||||
optional_features=experimental_optional_features),
|
||||
autograph_module=tf_inspect.getmodule(to_graph))
|
||||
return autograph_artifact(conversion.convert(entity, program_ctx))
|
||||
optional_features=experimental_optional_features))
|
||||
return autograph_artifact(_convert_actual(entity, program_ctx))
|
||||
except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
|
||||
logging.error(1, 'Error converting %s', entity, exc_info=True)
|
||||
raise ConversionError('converting {}: {}: {}'.format(
|
||||
@ -845,3 +914,6 @@ def to_code(entity, recursive=True, experimental_optional_features=None):
|
||||
recursive=recursive,
|
||||
experimental_optional_features=experimental_optional_features))
|
||||
return textwrap.dedent(source)
|
||||
|
||||
|
||||
_TRANSPILER = PyToTF()
|
||||
|
@ -19,111 +19,97 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import imp
|
||||
import inspect
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from tensorflow.python.autograph import operators
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.converters import asserts
|
||||
from tensorflow.python.autograph.converters import break_statements
|
||||
from tensorflow.python.autograph.converters import call_trees
|
||||
from tensorflow.python.autograph.converters import conditional_expressions
|
||||
from tensorflow.python.autograph.converters import continue_statements
|
||||
from tensorflow.python.autograph.converters import control_flow
|
||||
from tensorflow.python.autograph.converters import directives
|
||||
from tensorflow.python.autograph.converters import functions
|
||||
from tensorflow.python.autograph.converters import lists
|
||||
from tensorflow.python.autograph.converters import logical_expressions
|
||||
from tensorflow.python.autograph.converters import return_statements
|
||||
from tensorflow.python.autograph.converters import slices
|
||||
from tensorflow.python.autograph.converters import variables
|
||||
from tensorflow.python.autograph.core import config
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.core import function_wrappers
|
||||
from tensorflow.python.autograph.core import unsupported_features_checker
|
||||
from tensorflow.python.autograph.lang import special_functions
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import cache
|
||||
from tensorflow.python.autograph.pyct import cfg
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.pyct import qual_names
|
||||
from tensorflow.python.autograph.pyct import transpiler
|
||||
from tensorflow.python.autograph.pyct.static_analysis import activity
|
||||
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
||||
from tensorflow.python.autograph.utils import ag_logging as logging
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
||||
|
||||
def get_transformed_name(self, node):
|
||||
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
||||
|
||||
def initial_analysis(self, node, ctx):
|
||||
graphs = cfg.build(node)
|
||||
node = qual_names.resolve(node)
|
||||
node = activity.resolve(node, ctx, None)
|
||||
node = reaching_definitions.resolve(node, ctx, graphs)
|
||||
anno.dup(
|
||||
node,
|
||||
{
|
||||
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
||||
},
|
||||
)
|
||||
return node
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
unsupported_features_checker.verify(node)
|
||||
node = self.initial_analysis(node, ctx)
|
||||
|
||||
node = functions.transform(node, ctx)
|
||||
node = directives.transform(node, ctx)
|
||||
node = break_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
||||
node = asserts.transform(node, ctx)
|
||||
# Note: sequencing continue canonicalization before for loop one avoids
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = continue_statements.transform(node, ctx)
|
||||
node = return_statements.transform(node, ctx)
|
||||
if ctx.user.options.uses(converter.Feature.LISTS):
|
||||
node = lists.transform(node, ctx)
|
||||
node = slices.transform(node, ctx)
|
||||
node = call_trees.transform(node, ctx)
|
||||
node = control_flow.transform(node, ctx)
|
||||
node = conditional_expressions.transform(node, ctx)
|
||||
node = logical_expressions.transform(node, ctx)
|
||||
node = variables.transform(node, ctx)
|
||||
return node
|
||||
|
||||
|
||||
_TRANSPILER = AutoGraphTranspiler()
|
||||
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
||||
|
||||
|
||||
custom_vars = None
|
||||
def _is_of_known_loaded_module(f, module_name):
|
||||
mod = sys.modules.get(module_name, None)
|
||||
if mod is None:
|
||||
return False
|
||||
if any(v is not None for v in mod.__dict__.values() if f is v):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO(mdan): Superfluous function, remove.
|
||||
# TODO(mdan): Put these extra fields inside __autograph_info__.
|
||||
def convert(entity, program_ctx):
|
||||
"""Applies AutoGraph to entity."""
|
||||
def _is_known_loaded_type(f, module_name, entity_name):
|
||||
"""Tests whether the function or method is an instance of a known type."""
|
||||
if (module_name not in sys.modules or
|
||||
not hasattr(sys.modules[module_name], entity_name)):
|
||||
return False
|
||||
type_entity = getattr(sys.modules[module_name], entity_name)
|
||||
if isinstance(f, type_entity):
|
||||
# The method if of this type. Example:
|
||||
#
|
||||
# o = ClassType()
|
||||
# function(o.method)()
|
||||
return True
|
||||
# Note: inspect is required here, to avoid unpacking tf.function decorators.
|
||||
if inspect.ismethod(f):
|
||||
# The the unbound method if of this type. Example:
|
||||
#
|
||||
# class ClassType:
|
||||
# @function
|
||||
# def method(self):
|
||||
# ...
|
||||
# o = ClassType()
|
||||
# o.method()
|
||||
if isinstance(f.__func__, type_entity):
|
||||
return True
|
||||
return False
|
||||
|
||||
if not hasattr(entity, '__code__'):
|
||||
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
||||
'expose a __code__ object. If this is a @tf.function,'
|
||||
' try passing f.python_function instead.')
|
||||
|
||||
create_custom_vars(program_ctx)
|
||||
transformed, module, source_map = _TRANSPILER.transform_function(
|
||||
entity, program_ctx.options, program_ctx, custom_vars)
|
||||
def is_unsupported(o):
|
||||
"""Checks whether an entity is supported by AutoGraph at all."""
|
||||
|
||||
assert not hasattr(transformed, 'ag_module')
|
||||
assert not hasattr(transformed, 'ag_source_map')
|
||||
transformed.ag_module = module
|
||||
transformed.ag_source_map = source_map
|
||||
return transformed
|
||||
# TODO(b/122265385): Remove this bypass.
|
||||
if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or
|
||||
_is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
|
||||
logging.warn(
|
||||
'{} appears to be decorated by wrapt, which is not yet supported'
|
||||
' by AutoGraph. The function will run as-is.'
|
||||
' You may still apply AutoGraph before the wrapt decorator.'.format(o))
|
||||
logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o)
|
||||
return True
|
||||
|
||||
if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
|
||||
logging.log(2, 'Permanently whitelisted: %s: lru_cache', o)
|
||||
return True
|
||||
|
||||
# Constructors are permanently whitelisted.
|
||||
# TODO(mdan): Toggle as experimental feature instead.
|
||||
# TODO(b/124016764): Remove this limitation.
|
||||
if inspect_utils.isconstructor(o):
|
||||
logging.log(2, 'Permanently whitelisted: %s: constructor', o)
|
||||
return True
|
||||
|
||||
# Other built-in modules are permanently whitelisted.
|
||||
# TODO(mdan): Figure out how to do this consistently for all stdlib modules.
|
||||
if any(
|
||||
_is_of_known_loaded_module(o, m)
|
||||
for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o)
|
||||
return True
|
||||
|
||||
# Custom ops and kernels are also permanently whitelisted.
|
||||
# See tensorflow.framework.load_library.
|
||||
if (hasattr(o, '__module__') and
|
||||
hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
|
||||
logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
||||
@ -246,28 +232,3 @@ def cache_whitelisted(entity, options):
|
||||
except TypeError:
|
||||
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
||||
pass
|
||||
|
||||
|
||||
# TODO(mdan): Move into core or replace with an actual importable module.
|
||||
# Visible for testing.
|
||||
def create_custom_vars(program_ctx):
|
||||
"""Adds namespace references to the module that exposes the api itself."""
|
||||
global custom_vars
|
||||
if custom_vars is None:
|
||||
# Craft a module that exposes parts of the external API as well as certain
|
||||
# internal modules.
|
||||
ag_internal = imp.new_module('autograph')
|
||||
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
|
||||
ag_internal.ConversionOptions = converter.ConversionOptions
|
||||
ag_internal.STD = converter.STANDARD_OPTIONS
|
||||
ag_internal.Feature = converter.Feature
|
||||
ag_internal.utils = utils
|
||||
ag_internal.FunctionScope = function_wrappers.FunctionScope
|
||||
ag_internal.with_function_scope = function_wrappers.with_function_scope
|
||||
# TODO(mdan): Add safeguards against name clashes.
|
||||
# We don't want to create a submodule because we want the operators to be
|
||||
# accessible as ag__.<operator>
|
||||
ag_internal.__dict__.update(special_functions.__dict__)
|
||||
ag_internal.__dict__.update(operators.__dict__)
|
||||
|
||||
custom_vars = {'ag__': ag_internal}
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import types
|
||||
|
||||
@ -143,11 +144,11 @@ def _wrap_into_factory(nodes, entity_name, inner_factory_name,
|
||||
outer_factory_name=outer_factory_name)
|
||||
|
||||
|
||||
class _TransformedFnFactory(object):
|
||||
"""Helper object that wraps a transformed function factory."""
|
||||
class _PythonFnFactory(object):
|
||||
"""Helper object that wraps a Python function factory."""
|
||||
|
||||
def __init__(self, name, freevars, extra_locals):
|
||||
"""Creates a new factory for a transformed function.
|
||||
"""Creates a new factory for a Python function.
|
||||
|
||||
Args:
|
||||
name: The function name.
|
||||
@ -169,7 +170,7 @@ class _TransformedFnFactory(object):
|
||||
inner_factory_name='inner_factory',
|
||||
outer_factory_name='outer_factory',
|
||||
future_features=()):
|
||||
"""Initializes a transformed function."""
|
||||
"""Initializes a function."""
|
||||
if self._unbound_factory is not None:
|
||||
raise ValueError('double initialization; create a new object instead')
|
||||
|
||||
@ -191,7 +192,7 @@ class _TransformedFnFactory(object):
|
||||
closure,
|
||||
defaults=None,
|
||||
kwdefaults=None):
|
||||
"""Creates a new instance of the transformed function."""
|
||||
"""Creates a new function instance."""
|
||||
if self._unbound_factory is None:
|
||||
raise ValueError('call create first')
|
||||
|
||||
@ -213,78 +214,70 @@ class _TransformedFnFactory(object):
|
||||
closure=factory_closure)
|
||||
|
||||
# The lint override is a false positive.
|
||||
transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable
|
||||
new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable
|
||||
|
||||
if defaults:
|
||||
transformed_entity.__defaults__ = defaults
|
||||
new_fn.__defaults__ = defaults
|
||||
if kwdefaults:
|
||||
transformed_entity.__kwdefaults__ = kwdefaults
|
||||
new_fn.__kwdefaults__ = kwdefaults
|
||||
|
||||
return transformed_entity
|
||||
return new_fn
|
||||
|
||||
|
||||
class FunctionTranspiler(object):
|
||||
"""A generic source-to-source transpiler for Python functions.
|
||||
class GenericTranspiler(object):
|
||||
"""A generic transpiler for Python functions.
|
||||
|
||||
Its interface `transform_function` API offers a function-in, function-out
|
||||
interface. Internally, it takes care of parsing, caching and variable binding.
|
||||
Its interface is the `transform` API, which can process Python function
|
||||
objects. Internally, it handles parsing.
|
||||
|
||||
Users typically subclass this, customizing the transform_ast method.
|
||||
|
||||
Usually, instances of this class are singletons, since each instance manages
|
||||
its own cache. The caching subkey allows managing multiple types of
|
||||
transformation.
|
||||
Users typically subclass this, customizing the `transform_ast` method. The
|
||||
output of transformed_ast is returned directly by `transform`. Existing
|
||||
methods like `transform_function` may also be overloaded.
|
||||
|
||||
Example:
|
||||
|
||||
class MyTransformer(FunctionTranspiler):
|
||||
class MyTransformer(GenericTranspiler):
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
node = <<transform node, usually using ast.NodeTransformer classes>>
|
||||
return node
|
||||
def transform(self, obj):
|
||||
result = <<transform node>>
|
||||
return result
|
||||
|
||||
transformer = MyTransfomer()
|
||||
|
||||
new_f, module, source_map = transformer.transform_function(f, ...)
|
||||
# new_f is a function with signature identical to f
|
||||
|
||||
The transformed function has access to the same namespace as the original
|
||||
function. To allow access to internal APIs, users may inject additional
|
||||
symbols though the `extra_locals` argument of `transform_function`.
|
||||
result = transformer.transform(f, ...)
|
||||
# result is the output
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache_lock = threading.RLock()
|
||||
self._cache = cache.CodeObjectCache()
|
||||
|
||||
def transform_ast(self, node, user_context):
|
||||
def transform_ast(self, node, ctx):
|
||||
"""Performs an actual transformation of a function's AST.
|
||||
|
||||
Subclasses must implement this method. They must not call it.
|
||||
|
||||
The method receives the original AST and generates code according to the
|
||||
AST that the method returns. For functions, the returned AST is expected to
|
||||
contain a function with the exact same arguments and closure. The resulting
|
||||
function will receive the globals, closure and argument defaults of the
|
||||
input function.
|
||||
Subclasses must implement this method, and do not usually call it.
|
||||
|
||||
Args:
|
||||
node: One or more ast.AST nodes representing the AST to be transformed.
|
||||
user_context: The same value that the caller passed to
|
||||
`transform_function`.
|
||||
ctx: transformer.Context.
|
||||
"""
|
||||
raise NotImplementedError('subclasses must override this')
|
||||
|
||||
def get_transformed_name(self, node):
|
||||
"""Returns a name for the output function. Subclasses may override this."""
|
||||
if isinstance(node, gast.Lambda):
|
||||
return 'lam'
|
||||
elif isinstance(node, gast.FunctionDef):
|
||||
# Note that we need to rename the function, to avoid any namespace
|
||||
# clashes.
|
||||
return node.name
|
||||
else:
|
||||
raise ValueError('Unknown node type {}'.format(node))
|
||||
def transform(self, obj, user_context):
|
||||
"""Transforms a Python object.
|
||||
|
||||
Users typically call this method.
|
||||
|
||||
Args:
|
||||
obj: A Python object, function, type, etc.
|
||||
user_context: An opaque object (may be None) that is forwarded to
|
||||
transform_ast, through the ctx.user_context argument.
|
||||
Returns:
|
||||
Tre result of calling transform_function.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if the type of obj is not handled.
|
||||
"""
|
||||
if inspect.isfunction(obj) or inspect.ismethod(obj):
|
||||
return self.transform_function(obj, user_context)
|
||||
|
||||
raise NotImplementedError('Non-function: {}'.format(type(obj)))
|
||||
|
||||
def _erase_arg_defaults(self, node):
|
||||
"""Erase argde fault expressions, which would otherwise be unbound."""
|
||||
@ -296,8 +289,21 @@ class FunctionTranspiler(object):
|
||||
args.kw_defaults[i] = parser.parse_expression('None')
|
||||
return node
|
||||
|
||||
def _transform_function(self, fn, user_context):
|
||||
"""Performs source code transformation on a function."""
|
||||
def transform_function(self, fn, user_context):
|
||||
"""Transforms a function.
|
||||
|
||||
Subclasses may override this method. The return value is opaque.
|
||||
|
||||
The method receives the original AST. The result is passed as-is to the
|
||||
output of `transform`.
|
||||
|
||||
Args:
|
||||
fn: A function or lambda.
|
||||
user_context: An opaque object (may be None) that is forwarded to
|
||||
transform_ast, through the ctx.user_context argument.
|
||||
Returns:
|
||||
Any. By default it returns the output of transform_ast.
|
||||
"""
|
||||
future_features = inspect_utils.getfutureimports(fn)
|
||||
node, source = parser.parse_entity(fn, future_features=future_features)
|
||||
logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)
|
||||
@ -333,63 +339,129 @@ class FunctionTranspiler(object):
|
||||
|
||||
return node, context
|
||||
|
||||
|
||||
class PyToPy(GenericTranspiler):
|
||||
"""A generic Python-to-Python transpiler.
|
||||
|
||||
Its `transform` method offers a function-in, function-out interface.
|
||||
Internally, it takes care of parsing, caching and loading of the translated
|
||||
code.
|
||||
|
||||
Users typically subclass this, overriding `transform_ast`.
|
||||
|
||||
Usually, instances of this class are singletons, since each instance manages
|
||||
its own cache. The caching can be controlled by overriding `get_caching_key`.
|
||||
|
||||
Example:
|
||||
|
||||
class MyTransformer(PyToPy):
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
node = <<transform node, usually using ast.NodeTransformer classes>>
|
||||
return node
|
||||
|
||||
transformer = MyTransfomer()
|
||||
|
||||
new_f, module, source_map = transformer.transform_function(f, ...)
|
||||
# new_f is a function with signature identical to f
|
||||
|
||||
The transformed function has access to the same namespace as the original
|
||||
function. To allow access to internal APIs, users may inject additional
|
||||
symbols by overriding `get_extra_locals`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache_lock = threading.RLock()
|
||||
self._cache = cache.CodeObjectCache()
|
||||
|
||||
def get_transformed_name(self, node):
|
||||
"""Returns a name for the output function. Subclasses may override this."""
|
||||
if isinstance(node, gast.Lambda):
|
||||
return 'lam'
|
||||
elif isinstance(node, gast.FunctionDef):
|
||||
# Note that we need to rename the function, to avoid any namespace
|
||||
# clashes.
|
||||
return node.name
|
||||
raise ValueError('Unknown node type {}'.format(node))
|
||||
|
||||
def get_extra_locals(self):
|
||||
"""Returns extra static local variables to be made to transformed code.
|
||||
|
||||
Subclasses must override this.
|
||||
|
||||
Returns:
|
||||
extra_locals: A Dict[Text, Any] containing additional variables to make
|
||||
available to the transformed code.
|
||||
"""
|
||||
raise NotImplementedError('subclasses must override this')
|
||||
|
||||
def get_caching_key(self, user_context):
|
||||
"""Returns a unique key to use for caching.
|
||||
|
||||
Subclasses must override this.
|
||||
|
||||
Calls made to `transform_function` with functions that have the same code
|
||||
object and caching key will return a cached instance on subsequent
|
||||
invocations.
|
||||
|
||||
Args:
|
||||
user_context: The context object which was passed to `transform`.
|
||||
|
||||
Returns:
|
||||
extra_locals: A hashable.
|
||||
"""
|
||||
raise NotImplementedError('subclasses must override this')
|
||||
|
||||
def _cached_factory(self, fn, cache_subkey):
|
||||
cached_factory = self._cache[fn][cache_subkey]
|
||||
logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
|
||||
cached_factory)
|
||||
return cached_factory
|
||||
|
||||
def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals):
|
||||
"""Returns the transformed function factory for a given input."""
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
return self._cached_factory(fn, cache_subkey)
|
||||
def transform_function(self, fn, user_context):
|
||||
"""Transforms a function. See GenericTranspiler.trasnform_function.
|
||||
|
||||
with self._cache_lock:
|
||||
# Check again under lock.
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
return self._cached_factory(fn, cache_subkey)
|
||||
|
||||
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
|
||||
nodes, ctx = self._transform_function(fn, user_context)
|
||||
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
|
||||
|
||||
factory = _TransformedFnFactory(
|
||||
ctx.info.name, fn.__code__.co_freevars, extra_locals)
|
||||
factory.create(nodes, ctx.namer, future_features=ctx.info.future_features)
|
||||
self._cache[fn][cache_subkey] = factory
|
||||
return factory
|
||||
|
||||
def transform_function(self, fn, caching_subkey, user_context, extra_locals):
|
||||
"""Transforms a function.
|
||||
|
||||
The `caching_subkey` argument allows mapping each function to multiple
|
||||
outputs in the cache. This is useful for instance when transformers
|
||||
can generate multiple variants of output code, typically as a result of
|
||||
different transformation flags.
|
||||
This overload wraps the parent's `transform_function`, adding caching and
|
||||
facilities to instantiate the output as a Python object. It also
|
||||
adds facilities to make new symbols available to the generated Python code,
|
||||
visible as local variables - see `get_extra_locals`.
|
||||
|
||||
Args:
|
||||
fn: A function or lambda.
|
||||
caching_subkey: Used for caching. Calls made for functions with the same
|
||||
code object and caching_subkey will return a cached instance on
|
||||
subsequent invocations. Using a constant will create unique per-function
|
||||
entries.
|
||||
user_context: An opaque object (may be none) that is forwarded to
|
||||
transform_ast.
|
||||
extra_locals: A Dict[Text, Any] containing additional variables to make
|
||||
available to the transformed code. These will be visible as local
|
||||
variables.
|
||||
user_context: An opaque object (may be None) that is forwarded to
|
||||
transform_ast, through the ctx.user_context argument.
|
||||
Returns:
|
||||
A tuple:
|
||||
* A function or lambda with the same signature and closure as `fn`
|
||||
* The temporary module into which the transformed function was loaded
|
||||
* The source map as a
|
||||
Dict[origin_info.LineLocation, origin_info.OriginInfo]
|
||||
|
||||
"""
|
||||
factory = self._transformed_factory(fn, caching_subkey, user_context,
|
||||
extra_locals)
|
||||
cache_subkey = self.get_caching_key(user_context)
|
||||
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
# Fast path: use a lock-free check.
|
||||
factory = self._cached_factory(fn, cache_subkey)
|
||||
|
||||
else:
|
||||
with self._cache_lock:
|
||||
# Check again under lock.
|
||||
if self._cache.has(fn, cache_subkey):
|
||||
factory = self._cached_factory(fn, cache_subkey)
|
||||
|
||||
else:
|
||||
logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
|
||||
# TODO(mdan): Confusing overloading pattern. Fix.
|
||||
nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
|
||||
|
||||
if logging.has_verbosity(2):
|
||||
logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes))
|
||||
|
||||
factory = _PythonFnFactory(
|
||||
ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals())
|
||||
factory.create(
|
||||
nodes, ctx.namer, future_features=ctx.info.future_features)
|
||||
self._cache[fn][cache_subkey] = factory
|
||||
|
||||
transformed_fn = factory.instantiate(
|
||||
globals_=fn.__globals__,
|
||||
|
@ -35,7 +35,14 @@ class FlipSignTransformer(transformer.Base):
|
||||
return self.generic_visit(node)
|
||||
|
||||
|
||||
class TestTranspiler(transpiler.FunctionTranspiler):
|
||||
class TestTranspiler(transpiler.PyToPy):
|
||||
|
||||
def get_caching_key(self, ctx):
|
||||
del ctx
|
||||
return 0
|
||||
|
||||
def get_extra_locals(self):
|
||||
return {}
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
return FlipSignTransformer(ctx).visit(node)
|
||||
@ -45,14 +52,14 @@ global_var_for_test_global = 1
|
||||
global_var_for_test_namespace_collisions = object()
|
||||
|
||||
|
||||
class FunctionTranspilerTest(test.TestCase):
|
||||
class PyToPyTest(test.TestCase):
|
||||
|
||||
def test_basic(self):
|
||||
def f(a):
|
||||
return a + 1
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 0)
|
||||
|
||||
@ -63,7 +70,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return a + b
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 0)
|
||||
b = 2
|
||||
@ -74,7 +81,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return a + global_var_for_test_global
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
global global_var_for_test_global
|
||||
global_var_for_test_global = 1
|
||||
@ -90,7 +97,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return a + b + d
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 1 - 2 - 2)
|
||||
c = 0
|
||||
@ -107,7 +114,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return g(a) + 1
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted.
|
||||
|
||||
@ -116,7 +123,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
f = lambda x: (b + (x if x > 0 else -x))
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
self.assertEqual(f(-1), 2 - 1)
|
||||
@ -132,7 +139,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
f, _ = (lambda x: a + x, lambda y: b * y)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 1 - 1)
|
||||
|
||||
@ -147,7 +154,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return g(x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
|
||||
@ -159,7 +166,7 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
return g(x)
|
||||
|
||||
tr = TestTranspiler()
|
||||
f, _, _ = tr.transform_function(f, object(), None, {})
|
||||
f, _, _ = tr.transform(f, None)
|
||||
|
||||
self.assertEqual(f(1), 2 - 1)
|
||||
|
||||
@ -171,9 +178,11 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
outputs = []
|
||||
|
||||
tr = TestTranspiler()
|
||||
cache_key = object()
|
||||
# Note: this is not a test, it's a required invariant.
|
||||
assert tr.get_caching_key(None) == tr.get_caching_key(None)
|
||||
|
||||
def conversion_thread():
|
||||
_, mod, _ = tr.transform_function(f, cache_key, None, {})
|
||||
_, mod, _ = tr.transform(f, None)
|
||||
outputs.append(mod.__name__)
|
||||
|
||||
threads = tuple(
|
||||
@ -192,21 +201,28 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
def test_fn():
|
||||
return 1 + 1
|
||||
|
||||
class ReentrantTranspiler(transpiler.FunctionTranspiler):
|
||||
class ReentrantTranspiler(transpiler.PyToPy):
|
||||
|
||||
def __init__(self):
|
||||
super(ReentrantTranspiler, self).__init__()
|
||||
self._recursion_depth = 0
|
||||
|
||||
def get_caching_key(self, ctx):
|
||||
del ctx
|
||||
return 0
|
||||
|
||||
def get_extra_locals(self):
|
||||
return {}
|
||||
|
||||
def transform_ast(self, node, ctx):
|
||||
self._recursion_depth += 1
|
||||
if self._recursion_depth < 2:
|
||||
self.transform_function(test_fn, object(), None, {})
|
||||
self.transform(test_fn, None)
|
||||
return FlipSignTransformer(ctx).visit(node)
|
||||
|
||||
tr = ReentrantTranspiler()
|
||||
|
||||
f, _, _ = tr.transform_function(test_fn, object(), None, {})
|
||||
f, _, _ = tr.transform(test_fn, None)
|
||||
self.assertEqual(f(), 0)
|
||||
|
||||
def test_namespace_collisions_avoided(self):
|
||||
@ -219,8 +235,8 @@ class FunctionTranspilerTest(test.TestCase):
|
||||
tr = TestTranspiler()
|
||||
obj = TestClass()
|
||||
|
||||
f, _, _ = tr.transform_function(
|
||||
obj.global_var_for_test_namespace_collisions, object(), None, {})
|
||||
f, _, _ = tr.transform(
|
||||
obj.global_var_for_test_namespace_collisions, None)
|
||||
self.assertIs(f(obj), global_var_for_test_namespace_collisions)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user