diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index fd0dc0ebc2b..fc6908784f9 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -63,7 +63,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections import enum from tensorflow.python.autograph.pyct import anno @@ -234,18 +233,17 @@ STANDARD_OPTIONS = ConversionOptions( optional_features=None) -class ProgramContext( - collections.namedtuple('ProgramContext', ('options', 'autograph_module'))): +class ProgramContext(object): """ProgramContext keeps track of converting function hierarchies. - This object is mutable, and is updated during conversion. Not thread safe. - Attributes: options: ConversionOptions - autograph_module: Module, a reference to the autograph module. This needs to - be specified by the caller to avoid circular dependencies. + autograph_module: Deprecated. Do not use. """ - pass + + def __init__(self, options, autograph_module=None): + self.options = options + self.autograph_module = autograph_module class Base(transformer.Base): diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 22e06000906..4301cf898bf 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -28,7 +28,6 @@ import six from tensorflow.python.autograph.core import config from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.impl import api -from tensorflow.python.autograph.impl import conversion from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -64,16 +63,26 @@ def is_inside_generated_code(): del frame -class TestingTranspiler(conversion.AutoGraphTranspiler): +class TestingTranspiler(api.PyToTF): """Testing version that only applies given transformations.""" - def __init__(self, converters): + def __init__(self, converters, ag_overrides): super(TestingTranspiler, self).__init__() if isinstance(converters, (list, tuple)): self._converters = converters else: self._converters = (converters,) self.transformed_ast = None + self._ag_overrides = ag_overrides + + def get_extra_locals(self): + retval = super(TestingTranspiler, self).get_extra_locals() + if self._ag_overrides: + modified_ag = imp.new_module('fake_autograph') + modified_ag.__dict__.update(retval['ag__'].__dict__) + modified_ag.__dict__.update(self._ag_overrides) + retval['ag__'] = modified_ag + return retval def transform_ast(self, node, ctx): node = self.initial_analysis(node, ctx) @@ -113,18 +122,8 @@ class TestCase(test.TestCase): options=converter.ConversionOptions(recursive=True), autograph_module=api) - conversion.create_custom_vars(program_ctx) - custom_vars = dict(conversion.custom_vars) - - if ag_overrides: - modified_ag = imp.new_module('fake_autograph') - modified_ag.__dict__.update(custom_vars['ag__'].__dict__) - modified_ag.__dict__.update(ag_overrides) - custom_vars['ag__'] = modified_ag - - tr = TestingTranspiler(converter_module) - transformed, _, _ = tr.transform_function( - f, program_ctx.options, program_ctx, custom_vars) + tr = TestingTranspiler(converter_module, ag_overrides) + transformed, _, _ = tr.transform_function(f, program_ctx) if include_ast: return transformed, tr.transformed_ast, tr.transform_ctx diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index 98e19fdde86..a5e1ab1705f 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This module contains the user-facing API for AutoGraph.""" +"""This module contains the user- and codegen-facing API for AutoGraph.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools +import imp import inspect import os import sys @@ -27,14 +28,38 @@ import traceback import six +from tensorflow.python.autograph import operators +from tensorflow.python.autograph import utils +from tensorflow.python.autograph.converters import asserts +from tensorflow.python.autograph.converters import break_statements +from tensorflow.python.autograph.converters import call_trees +from tensorflow.python.autograph.converters import conditional_expressions +from tensorflow.python.autograph.converters import continue_statements +from tensorflow.python.autograph.converters import control_flow +from tensorflow.python.autograph.converters import directives +from tensorflow.python.autograph.converters import functions +from tensorflow.python.autograph.converters import lists +from tensorflow.python.autograph.converters import logical_expressions +from tensorflow.python.autograph.converters import return_statements +from tensorflow.python.autograph.converters import slices +from tensorflow.python.autograph.converters import variables from tensorflow.python.autograph.core import ag_ctx from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.core import function_wrappers +from tensorflow.python.autograph.core import unsupported_features_checker from tensorflow.python.autograph.impl import conversion +from tensorflow.python.autograph.lang import special_functions from tensorflow.python.autograph.operators import py_builtins +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import error_utils from tensorflow.python.autograph.pyct import errors from tensorflow.python.autograph.pyct import inspect_utils from tensorflow.python.autograph.pyct import origin_info +from tensorflow.python.autograph.pyct import qual_names +from tensorflow.python.autograph.pyct import transpiler +from tensorflow.python.autograph.pyct.static_analysis import activity +from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions from tensorflow.python.autograph.utils import ag_logging as logging from tensorflow.python.eager import function from tensorflow.python.framework import errors_impl @@ -48,6 +73,11 @@ def is_autograph_strict_conversion_mode(): return int(os.environ.get('AUTOGRAPH_STRICT_CONVERSION', '0')) > 0 +# +# Error handling +# + + # TODO(mdan): Export this symbol. class AutoGraphError(errors.PyCTError): """Base class for all AutoGraph exceptions.""" @@ -113,6 +143,26 @@ class _ErrorMetadata(error_utils.ErrorMetadataBase): return StagingError(self.get_message()) +def _attach_error_metadata(e, f): + """Augments an error with the metadata necessary for rewrite.""" + if hasattr(e, 'ag_pass_through'): + return + + metadata = getattr(e, 'ag_error_metadata', None) + source_map = f.ag_source_map + + if metadata is None: + logging.log(1, 'Caught error in user callable %s', f, exc_info=True) + message = '{}: {}'.format(e.__class__.__name__, e) + else: + message = None + + cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] + + e.ag_error_metadata = _ErrorMetadata( + cause_tb, metadata, message, source_map, __file__) + + class StackTraceMapper(tf_stack.StackTraceMapper): """Remaps generated code to code it originated from.""" @@ -145,6 +195,106 @@ class StackTraceMapper(tf_stack.StackTraceMapper): return effective_source_map +# +# Actual source code transformation +# + + +class PyToTF(transpiler.PyToPy): + """The TensorFlow AutoGraph transformer.""" + + def __init__(self): + super(PyToTF, self).__init__() + + # TODO(mdan): Move into core or replace with an actual importable module. + # Craft a module that exposes the external API as well as certain + # internal modules. + ag_internal = imp.new_module('autograph') + ag_internal.__dict__.update(inspect.getmodule(PyToTF).__dict__) + ag_internal.ConversionOptions = converter.ConversionOptions + ag_internal.STD = converter.STANDARD_OPTIONS + ag_internal.Feature = converter.Feature + ag_internal.utils = utils + ag_internal.FunctionScope = function_wrappers.FunctionScope + ag_internal.with_function_scope = function_wrappers.with_function_scope + # TODO(mdan): Add safeguards against name clashes. + # We don't want to create a submodule because we want the operators to be + # accessible as ag__. + ag_internal.__dict__.update(special_functions.__dict__) + ag_internal.__dict__.update(operators.__dict__) + + self._extra_locals = {'ag__': ag_internal} + + def get_transformed_name(self, node): + return 'tf__' + super(PyToTF, self).get_transformed_name(node) + + def get_extra_locals(self): + return self._extra_locals + + def get_caching_key(self, ctx): + return ctx.options + + def initial_analysis(self, node, ctx): + graphs = cfg.build(node) + node = qual_names.resolve(node) + node = activity.resolve(node, ctx, None) + node = reaching_definitions.resolve(node, ctx, graphs) + anno.dup( + node, + { + anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, + }, + ) + return node + + def transform_ast(self, node, ctx): + unsupported_features_checker.verify(node) + node = self.initial_analysis(node, ctx) + + node = functions.transform(node, ctx) + node = directives.transform(node, ctx) + node = break_statements.transform(node, ctx) + if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): + node = asserts.transform(node, ctx) + # Note: sequencing continue canonicalization before for loop one avoids + # dealing with the extra loop increment operation that the for + # canonicalization creates. + node = continue_statements.transform(node, ctx) + node = return_statements.transform(node, ctx) + if ctx.user.options.uses(converter.Feature.LISTS): + node = lists.transform(node, ctx) + node = slices.transform(node, ctx) + node = call_trees.transform(node, ctx) + node = control_flow.transform(node, ctx) + node = conditional_expressions.transform(node, ctx) + node = logical_expressions.transform(node, ctx) + node = variables.transform(node, ctx) + return node + + +def _convert_actual(entity, program_ctx): + """Applies AutoGraph to entity.""" + + # TODO(mdan): Put these extra fields inside __autograph_info__. + if not hasattr(entity, '__code__'): + raise ValueError('Cannot apply autograph to a function that doesn\'t ' + 'expose a __code__ object. If this is a @tf.function,' + ' try passing f.python_function instead.') + + transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx) + + assert not hasattr(transformed, 'ag_module') + assert not hasattr(transformed, 'ag_source_map') + transformed.ag_module = module + transformed.ag_source_map = source_map + return transformed + + +# +# Generated code support +# + + def autograph_artifact(entity, extras=None): setattr(entity, 'autograph_info__', extras) return entity @@ -154,272 +304,12 @@ def is_autograph_artifact(entity): return hasattr(entity, 'autograph_info__') -def tf_convert(f, ctx, convert_by_default=True, user_requested=False): - """Decorator that applies AutoGraph to a function. - - Use in internal APIs. - - This API is suitable for high order functions internal to the TensorFlow API, - and more generally any function to which Autograph is not applied. - - Guidance: convert was a decorator meant for use directly by developers, and - will be soon deprecated in favor of tf.function. tf_convert is to be called - from high order functions internal to TF. - - Args: - f: Callable. - ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. - convert_by_default: bool, whether to use AutoGraph when the context doesn't - specify. - user_requested: bool, whether to ignore the conversion whitelist. See - ConversionOptions.user_requested. - - Returns: - Either `f or the converted version of `f`. - """ - - if is_autograph_artifact(f): - return f - f_wrapper = f - decorators, f = tf_decorator.unwrap(f) - - # TODO(mdan): Grab features from context. - # Note: we pass the original context through to convert to properly handle the - # following scenario, which can be used inside TF implementations: - # - # ctx = ag_ctx.control_status_ctx() - # @function(autograph=False) # Low-level graph code - # def inner_fn(): - # # The context is disabled here, but should be enabled in user user_fn - # tf_convert(user_fn, ctx=ctx) - if ctx.status == ag_ctx.Status.ENABLED: - wrapper_factory = convert( - recursive=True, user_requested=user_requested, conversion_ctx=ctx) - elif ctx.status == ag_ctx.Status.DISABLED: - wrapper_factory = do_not_convert - elif ctx.status == ag_ctx.Status.UNSPECIFIED: - if convert_by_default: - wrapper_factory = convert( - recursive=True, user_requested=user_requested, conversion_ctx=ctx) - else: - wrapper_factory = call_with_unspecified_conversion_status - else: - assert False, 'This switch contains all possible cases!' - wrapper = wrapper_factory(f) - - if decorators: - wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) - - return autograph_artifact(wrapper) - - -# TODO(mdan): Make private. -def convert(recursive=False, - optional_features=None, - user_requested=True, - conversion_ctx=ag_ctx.NullCtx()): - """Decorator that compiles a function to use TensorFlow ops. - - The decorator is dynamic - it recompiles the target whenever the decorated - function is called. This means the parameter values are known at conversion. - It also means that repeated calls with different types of parameters will be - correctly processed. - - Args: - recursive: bool, whether to recursively convert any functions or classes - that the converted function may use. - optional_features: converted.Feature, allows toggling optional or - experimental features. When set to None, only the core features are - enabled. - user_requested: bool, whether this is a function that the user explicitly - asked to be converted. See ConversionOptions.user_requested. - conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in - which `f` is used. - - Returns: - Callable, a decorator that converts the given function into an equivalent - function that uses TensorFlow ops. - """ - - def decorator(f): - """Decorator implementation.""" - - def wrapper(*args, **kwargs): - """Wrapper that calls the converted version of f.""" - options = converter.ConversionOptions( - recursive=recursive, - user_requested=user_requested, - optional_features=optional_features) - try: - with conversion_ctx: - return converted_call(f, args, kwargs, options=options) - except Exception as e: # pylint:disable=broad-except - if hasattr(e, 'ag_error_metadata'): - raise e.ag_error_metadata.to_exception(e) - else: - raise - - if inspect.isfunction(f) or inspect.ismethod(f): - wrapper = functools.update_wrapper(wrapper, f) - - decorated_wrapper = tf_decorator.make_decorator(f, wrapper) - return autograph_artifact(decorated_wrapper) - - return decorator - - -def call_with_unspecified_conversion_status(func): - """Decorator that resets the conversion context to the unspecified status.""" - def wrapper(*args, **kwargs): - with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): - return func(*args, **kwargs) - - if inspect.isfunction(func) or inspect.ismethod(func): - wrapper = functools.update_wrapper(wrapper, func) - - return autograph_artifact(wrapper) - - -@tf_export('autograph.experimental.do_not_convert') -def do_not_convert(func=None): - """Decorator that suppresses the conversion of a function. - - Args: - func: function to decorate. - - Returns: - If `func` is not None, returns a `Callable` which is equivalent to - `func`, but is not converted by AutoGraph. - If `func` is None, returns a decorator that, when invoked with a - single `func` argument, returns a `Callable` equivalent to the - above case. - """ - if func is None: - return do_not_convert - - def wrapper(*args, **kwargs): - with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): - return func(*args, **kwargs) - - if inspect.isfunction(func) or inspect.ismethod(func): - wrapper = functools.update_wrapper(wrapper, func) - - return autograph_artifact(wrapper) - - -def _attach_metadata(e, f): - """Augments an error with the metadata necessary for rewrite.""" - if hasattr(e, 'ag_pass_through'): - return - - metadata = getattr(e, 'ag_error_metadata', None) - source_map = f.ag_source_map - - if metadata is None: - logging.log(1, 'Caught error in user callable %s', f, exc_info=True) - message = '{}: {}'.format(e.__class__.__name__, e) - else: - message = None - - cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:] - - e.ag_error_metadata = _ErrorMetadata( - cause_tb, metadata, message, source_map, __file__) - - -def _call_unconverted(f, args, kwargs, options, update_cache=True): - """Calls the original function without converting with AutoGraph.""" - if update_cache: - conversion.cache_whitelisted(f, options) - - if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): - return f.__self__.call(args, kwargs) - - if kwargs is not None: - return f(*args, **kwargs) - return f(*args) - - -def _is_of_known_loaded_module(f, module_name): - mod = sys.modules.get(module_name, None) - if mod is None: - return False - if any(v is not None for v in mod.__dict__.values() if f is v): - return True - return False - - -def _is_known_loaded_type(f, module_name, entity_name): - """Tests whether the function or method is an instance of a known type.""" - if (module_name not in sys.modules or - not hasattr(sys.modules[module_name], entity_name)): - return False - type_entity = getattr(sys.modules[module_name], entity_name) - if isinstance(f, type_entity): - # The method if of this type. Example: - # - # o = ClassType() - # function(o.method)() - return True - # Note: inspect is required here, to avoid unpacking tf.function decorators. - if inspect.ismethod(f): - # The the unbound method if of this type. Example: - # - # class ClassType: - # @function - # def method(self): - # ... - # o = ClassType() - # o.method() - if isinstance(f.__func__, type_entity): - return True - return False - - -def _fall_back_unconverted(f, args, kwargs, options, exc): - """Falls back to calling the function unconverted, in case of error.""" - # TODO(mdan): Consider adding an internal metric. - warning_template = ( - 'AutoGraph could not transform %s and will run it as-is.\n' - '%s' - 'Cause: %s\n' - 'To silence this warning, decorate the function with' - ' @tf.autograph.experimental.do_not_convert') - if isinstance(exc, errors.UnsupportedLanguageElementError): - if not conversion.is_in_whitelist_cache(f, options): - logging.warn(warning_template, f, '', exc) - else: - file_bug_message = ( - 'Please report this to the TensorFlow team. When filing the bug, set' - ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' - ' attach the full output.\n') - logging.warn(warning_template, f, file_bug_message, exc) - - return _call_unconverted(f, args, kwargs, options) - - -def _log_callargs(f, args, kwargs): - """Logging helper.""" - logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) - if not six.PY2: - logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) - - if kwargs is not None: - callargs = tf_inspect.getcallargs(f, *args, **kwargs) - else: - callargs = tf_inspect.getcallargs(f, *args) - - formatted_callargs = '\n'.join( - ' {}: {}'.format(k, v) for k, v in callargs.items()) - logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) - - def converted_call(f, args, kwargs, caller_fn_scope=None, options=None): - """Compiles a function call inline. + """Converts a function call inline. For internal use only. @@ -492,40 +382,7 @@ def converted_call(f, else: return py_builtins.overload_of(f)(*args) - # TODO(b/122265385): Remove this bypass. - if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or - _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')): - logging.warn( - '{} appears to be decorated by wrapt, which is not yet supported' - ' by AutoGraph. The function will run as-is.' - ' You may still apply AutoGraph before the wrapt decorator.'.format(f)) - logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f) - return _call_unconverted(f, args, kwargs, options) - - if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'): - logging.log(2, 'Permanently whitelisted: %s: lru_cache', f) - return _call_unconverted(f, args, kwargs, options) - - # Constructors are permanently whitelisted. - # TODO(mdan): Toggle as experimental feature instead. - # TODO(b/124016764): Remove this limitation. - if inspect_utils.isconstructor(f): - logging.log(2, 'Permanently whitelisted: %s: constructor', f) - return _call_unconverted(f, args, kwargs, options) - - # Other built-in modules are permanently whitelisted. - # TODO(mdan): Figure out how to do this consistently for all stdlib modules. - if any( - _is_of_known_loaded_module(f, m) - for m in ('collections', 'pdb', 'copy', 'inspect', 're')): - logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f) - return _call_unconverted(f, args, kwargs, options) - - # Custom ops and kernels are also permanently whitelisted. - # See tensorflow.framework.load_library. - if (hasattr(f, '__module__') and - hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')): - logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f) + if conversion.is_unsupported(f): return _call_unconverted(f, args, kwargs, options) if not options.user_requested and conversion.is_whitelisted(f): @@ -579,9 +436,8 @@ def converted_call(f, return _call_unconverted(f, args, kwargs, options) try: - program_ctx = converter.ProgramContext( - options=options, autograph_module=tf_inspect.getmodule(converted_call)) - converted_f = conversion.convert(target_entity, program_ctx) + program_ctx = converter.ProgramContext(options=options) + converted_f = _convert_actual(target_entity, program_ctx) if logging.has_verbosity(2): _log_callargs(converted_f, effective_args, kwargs) except Exception as e: # pylint:disable=broad-except @@ -597,12 +453,226 @@ def converted_call(f, else: result = converted_f(*effective_args) except Exception as e: - _attach_metadata(e, converted_f) + _attach_error_metadata(e, converted_f) raise return result +def _call_unconverted(f, args, kwargs, options, update_cache=True): + """Calls the original function without converting with AutoGraph.""" + if update_cache: + conversion.cache_whitelisted(f, options) + + if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): + return f.__self__.call(args, kwargs) + + if kwargs is not None: + return f(*args, **kwargs) + return f(*args) + + +def _fall_back_unconverted(f, args, kwargs, options, exc): + """Falls back to calling the function unconverted, in case of error.""" + # TODO(mdan): Consider adding an internal metric. + warning_template = ( + 'AutoGraph could not transform %s and will run it as-is.\n' + '%s' + 'Cause: %s\n' + 'To silence this warning, decorate the function with' + ' @tf.autograph.experimental.do_not_convert') + if isinstance(exc, errors.UnsupportedLanguageElementError): + if not conversion.is_in_whitelist_cache(f, options): + logging.warn(warning_template, f, '', exc) + else: + file_bug_message = ( + 'Please report this to the TensorFlow team. When filing the bug, set' + ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and' + ' attach the full output.\n') + logging.warn(warning_template, f, file_bug_message, exc) + + return _call_unconverted(f, args, kwargs, options) + + +# +# TensorFlow integration +# + + +def tf_convert(f, ctx, convert_by_default=True, user_requested=False): + """Decorator that applies AutoGraph to a function. + + Use in internal APIs. + + This API is suitable for high order functions internal to the TensorFlow API, + and more generally any function to which Autograph is not applied. + + Guidance: convert was a decorator meant for use directly by developers, and + will be soon deprecated in favor of tf.function. tf_convert is to be called + from high order functions internal to TF. + + Args: + f: Callable. + ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. + convert_by_default: bool, whether to use AutoGraph when the context doesn't + specify. + user_requested: bool, whether to ignore the conversion whitelist. See + ConversionOptions.user_requested. + + Returns: + Either `f or the converted version of `f`. + """ + + if is_autograph_artifact(f): + return f + f_wrapper = f + decorators, f = tf_decorator.unwrap(f) + + # TODO(mdan): Grab features from context. + # Note: we pass the original context through to convert to properly handle the + # following scenario, which can be used inside TF implementations: + # + # ctx = ag_ctx.control_status_ctx() + # @function(autograph=False) # Low-level graph code + # def inner_fn(): + # # The context is disabled here, but should be enabled in user user_fn + # tf_convert(user_fn, ctx=ctx) + if ctx.status == ag_ctx.Status.ENABLED: + wrapper_factory = convert( + recursive=True, user_requested=user_requested, conversion_ctx=ctx) + elif ctx.status == ag_ctx.Status.DISABLED: + wrapper_factory = do_not_convert + elif ctx.status == ag_ctx.Status.UNSPECIFIED: + if convert_by_default: + wrapper_factory = convert( + recursive=True, user_requested=user_requested, conversion_ctx=ctx) + else: + wrapper_factory = call_with_unspecified_conversion_status + else: + assert False, 'This switch contains all possible cases!' + wrapper = wrapper_factory(f) + + if decorators: + wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper) + + return autograph_artifact(wrapper) + + +def call_with_unspecified_conversion_status(func): + """Decorator that resets the conversion context to the unspecified status.""" + def wrapper(*args, **kwargs): + with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED): + return func(*args, **kwargs) + + if inspect.isfunction(func) or inspect.ismethod(func): + wrapper = functools.update_wrapper(wrapper, func) + + return autograph_artifact(wrapper) + + +def _log_callargs(f, args, kwargs): + """Logging helper.""" + logging.log(2, 'Defaults of %s : %s', f, f.__defaults__) + if not six.PY2: + logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__) + + if kwargs is not None: + callargs = tf_inspect.getcallargs(f, *args, **kwargs) + else: + callargs = tf_inspect.getcallargs(f, *args) + + formatted_callargs = '\n'.join( + ' {}: {}'.format(k, v) for k, v in callargs.items()) + logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs) + + +# +# Public API +# + + +@tf_export('autograph.experimental.do_not_convert') +def do_not_convert(func=None): + """Decorator that suppresses the conversion of a function. + + Args: + func: function to decorate. + + Returns: + If `func` is not None, returns a `Callable` which is equivalent to + `func`, but is not converted by AutoGraph. + If `func` is None, returns a decorator that, when invoked with a + single `func` argument, returns a `Callable` equivalent to the + above case. + """ + if func is None: + return do_not_convert + + def wrapper(*args, **kwargs): + with ag_ctx.ControlStatusCtx(status=ag_ctx.Status.DISABLED): + return func(*args, **kwargs) + + if inspect.isfunction(func) or inspect.ismethod(func): + wrapper = functools.update_wrapper(wrapper, func) + + return autograph_artifact(wrapper) + + +# TODO(mdan): Make private. +def convert(recursive=False, + optional_features=None, + user_requested=True, + conversion_ctx=ag_ctx.NullCtx()): + """Decorator that compiles a function to use TensorFlow ops. + + The decorator is dynamic - it recompiles the target whenever the decorated + function is called. This means the parameter values are known at conversion. + It also means that repeated calls with different types of parameters will be + correctly processed. + + Args: + recursive: bool, whether to recursively convert any functions or classes + that the converted function may use. + optional_features: converted.Feature, allows toggling optional or + experimental features. When set to None, only the core features are + enabled. + user_requested: bool, whether this is a function that the user explicitly + asked to be converted. See ConversionOptions.user_requested. + conversion_ctx: Optional ag_ctx.ControlStatusCtx, the Autograph context in + which `f` is used. + + Returns: + Callable, a decorator that converts the given function into an equivalent + function that uses TensorFlow ops. + """ + + def decorator(f): + """Decorator implementation.""" + + def wrapper(*args, **kwargs): + """Wrapper that calls the converted version of f.""" + options = converter.ConversionOptions( + recursive=recursive, + user_requested=user_requested, + optional_features=optional_features) + try: + with conversion_ctx: + return converted_call(f, args, kwargs, options=options) + except Exception as e: # pylint:disable=broad-except + if hasattr(e, 'ag_error_metadata'): + raise e.ag_error_metadata.to_exception(e) + else: + raise + + if inspect.isfunction(f) or inspect.ismethod(f): + wrapper = functools.update_wrapper(wrapper, f) + + decorated_wrapper = tf_decorator.make_decorator(f, wrapper) + return autograph_artifact(decorated_wrapper) + + return decorator + + # pylint:disable=line-too-long @tf_export('autograph.to_graph', v1=[]) def to_graph(entity, recursive=True, experimental_optional_features=None): @@ -668,9 +738,8 @@ def to_graph(entity, recursive=True, experimental_optional_features=None): options=converter.ConversionOptions( recursive=recursive, user_requested=True, - optional_features=experimental_optional_features), - autograph_module=tf_inspect.getmodule(to_graph)) - return autograph_artifact(conversion.convert(entity, program_ctx)) + optional_features=experimental_optional_features)) + return autograph_artifact(_convert_actual(entity, program_ctx)) except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e: logging.error(1, 'Error converting %s', entity, exc_info=True) raise ConversionError('converting {}: {}: {}'.format( @@ -845,3 +914,6 @@ def to_code(entity, recursive=True, experimental_optional_features=None): recursive=recursive, experimental_optional_features=experimental_optional_features)) return textwrap.dedent(source) + + +_TRANSPILER = PyToTF() diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 4d5ddeebcc1..3c1f7e97bde 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -19,111 +19,97 @@ from __future__ import division from __future__ import print_function import functools -import imp +import inspect +import sys import unittest -from tensorflow.python.autograph import operators -from tensorflow.python.autograph import utils -from tensorflow.python.autograph.converters import asserts -from tensorflow.python.autograph.converters import break_statements -from tensorflow.python.autograph.converters import call_trees -from tensorflow.python.autograph.converters import conditional_expressions -from tensorflow.python.autograph.converters import continue_statements -from tensorflow.python.autograph.converters import control_flow -from tensorflow.python.autograph.converters import directives -from tensorflow.python.autograph.converters import functions -from tensorflow.python.autograph.converters import lists -from tensorflow.python.autograph.converters import logical_expressions -from tensorflow.python.autograph.converters import return_statements -from tensorflow.python.autograph.converters import slices -from tensorflow.python.autograph.converters import variables from tensorflow.python.autograph.core import config -from tensorflow.python.autograph.core import converter -from tensorflow.python.autograph.core import function_wrappers -from tensorflow.python.autograph.core import unsupported_features_checker -from tensorflow.python.autograph.lang import special_functions -from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import cache -from tensorflow.python.autograph.pyct import cfg from tensorflow.python.autograph.pyct import inspect_utils -from tensorflow.python.autograph.pyct import qual_names -from tensorflow.python.autograph.pyct import transpiler -from tensorflow.python.autograph.pyct.static_analysis import activity -from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions from tensorflow.python.autograph.utils import ag_logging as logging from tensorflow.python.eager import function from tensorflow.python.util import tf_inspect -class AutoGraphTranspiler(transpiler.FunctionTranspiler): - - def get_transformed_name(self, node): - return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node) - - def initial_analysis(self, node, ctx): - graphs = cfg.build(node) - node = qual_names.resolve(node) - node = activity.resolve(node, ctx, None) - node = reaching_definitions.resolve(node, ctx, graphs) - anno.dup( - node, - { - anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, - }, - ) - return node - - def transform_ast(self, node, ctx): - unsupported_features_checker.verify(node) - node = self.initial_analysis(node, ctx) - - node = functions.transform(node, ctx) - node = directives.transform(node, ctx) - node = break_statements.transform(node, ctx) - if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS): - node = asserts.transform(node, ctx) - # Note: sequencing continue canonicalization before for loop one avoids - # dealing with the extra loop increment operation that the for - # canonicalization creates. - node = continue_statements.transform(node, ctx) - node = return_statements.transform(node, ctx) - if ctx.user.options.uses(converter.Feature.LISTS): - node = lists.transform(node, ctx) - node = slices.transform(node, ctx) - node = call_trees.transform(node, ctx) - node = control_flow.transform(node, ctx) - node = conditional_expressions.transform(node, ctx) - node = logical_expressions.transform(node, ctx) - node = variables.transform(node, ctx) - return node - - -_TRANSPILER = AutoGraphTranspiler() _WHITELIST_CACHE = cache.UnboundInstanceCache() -custom_vars = None +def _is_of_known_loaded_module(f, module_name): + mod = sys.modules.get(module_name, None) + if mod is None: + return False + if any(v is not None for v in mod.__dict__.values() if f is v): + return True + return False -# TODO(mdan): Superfluous function, remove. -# TODO(mdan): Put these extra fields inside __autograph_info__. -def convert(entity, program_ctx): - """Applies AutoGraph to entity.""" +def _is_known_loaded_type(f, module_name, entity_name): + """Tests whether the function or method is an instance of a known type.""" + if (module_name not in sys.modules or + not hasattr(sys.modules[module_name], entity_name)): + return False + type_entity = getattr(sys.modules[module_name], entity_name) + if isinstance(f, type_entity): + # The method if of this type. Example: + # + # o = ClassType() + # function(o.method)() + return True + # Note: inspect is required here, to avoid unpacking tf.function decorators. + if inspect.ismethod(f): + # The the unbound method if of this type. Example: + # + # class ClassType: + # @function + # def method(self): + # ... + # o = ClassType() + # o.method() + if isinstance(f.__func__, type_entity): + return True + return False - if not hasattr(entity, '__code__'): - raise ValueError('Cannot apply autograph to a function that doesn\'t ' - 'expose a __code__ object. If this is a @tf.function,' - ' try passing f.python_function instead.') - create_custom_vars(program_ctx) - transformed, module, source_map = _TRANSPILER.transform_function( - entity, program_ctx.options, program_ctx, custom_vars) +def is_unsupported(o): + """Checks whether an entity is supported by AutoGraph at all.""" - assert not hasattr(transformed, 'ag_module') - assert not hasattr(transformed, 'ag_source_map') - transformed.ag_module = module - transformed.ag_source_map = source_map - return transformed + # TODO(b/122265385): Remove this bypass. + if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper') or + _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')): + logging.warn( + '{} appears to be decorated by wrapt, which is not yet supported' + ' by AutoGraph. The function will run as-is.' + ' You may still apply AutoGraph before the wrapt decorator.'.format(o)) + logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o) + return True + + if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'): + logging.log(2, 'Permanently whitelisted: %s: lru_cache', o) + return True + + # Constructors are permanently whitelisted. + # TODO(mdan): Toggle as experimental feature instead. + # TODO(b/124016764): Remove this limitation. + if inspect_utils.isconstructor(o): + logging.log(2, 'Permanently whitelisted: %s: constructor', o) + return True + + # Other built-in modules are permanently whitelisted. + # TODO(mdan): Figure out how to do this consistently for all stdlib modules. + if any( + _is_of_known_loaded_module(o, m) + for m in ('collections', 'pdb', 'copy', 'inspect', 're')): + logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o) + return True + + # Custom ops and kernels are also permanently whitelisted. + # See tensorflow.framework.load_library. + if (hasattr(o, '__module__') and + hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')): + logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o) + return True + + return False # TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. @@ -246,28 +232,3 @@ def cache_whitelisted(entity, options): except TypeError: # Catch-all for entities that are unhashable or don't allow weakrefs. pass - - -# TODO(mdan): Move into core or replace with an actual importable module. -# Visible for testing. -def create_custom_vars(program_ctx): - """Adds namespace references to the module that exposes the api itself.""" - global custom_vars - if custom_vars is None: - # Craft a module that exposes parts of the external API as well as certain - # internal modules. - ag_internal = imp.new_module('autograph') - ag_internal.__dict__.update(program_ctx.autograph_module.__dict__) - ag_internal.ConversionOptions = converter.ConversionOptions - ag_internal.STD = converter.STANDARD_OPTIONS - ag_internal.Feature = converter.Feature - ag_internal.utils = utils - ag_internal.FunctionScope = function_wrappers.FunctionScope - ag_internal.with_function_scope = function_wrappers.with_function_scope - # TODO(mdan): Add safeguards against name clashes. - # We don't want to create a submodule because we want the operators to be - # accessible as ag__. - ag_internal.__dict__.update(special_functions.__dict__) - ag_internal.__dict__.update(operators.__dict__) - - custom_vars = {'ag__': ag_internal} diff --git a/tensorflow/python/autograph/pyct/transpiler.py b/tensorflow/python/autograph/pyct/transpiler.py index f02968585ee..7e0588383ec 100644 --- a/tensorflow/python/autograph/pyct/transpiler.py +++ b/tensorflow/python/autograph/pyct/transpiler.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import inspect import threading import types @@ -143,11 +144,11 @@ def _wrap_into_factory(nodes, entity_name, inner_factory_name, outer_factory_name=outer_factory_name) -class _TransformedFnFactory(object): - """Helper object that wraps a transformed function factory.""" +class _PythonFnFactory(object): + """Helper object that wraps a Python function factory.""" def __init__(self, name, freevars, extra_locals): - """Creates a new factory for a transformed function. + """Creates a new factory for a Python function. Args: name: The function name. @@ -169,7 +170,7 @@ class _TransformedFnFactory(object): inner_factory_name='inner_factory', outer_factory_name='outer_factory', future_features=()): - """Initializes a transformed function.""" + """Initializes a function.""" if self._unbound_factory is not None: raise ValueError('double initialization; create a new object instead') @@ -191,7 +192,7 @@ class _TransformedFnFactory(object): closure, defaults=None, kwdefaults=None): - """Creates a new instance of the transformed function.""" + """Creates a new function instance.""" if self._unbound_factory is None: raise ValueError('call create first') @@ -213,78 +214,70 @@ class _TransformedFnFactory(object): closure=factory_closure) # The lint override is a false positive. - transformed_entity = bound_factory(**self._extra_locals) # pylint:disable=not-callable + new_fn = bound_factory(**self._extra_locals) # pylint:disable=not-callable if defaults: - transformed_entity.__defaults__ = defaults + new_fn.__defaults__ = defaults if kwdefaults: - transformed_entity.__kwdefaults__ = kwdefaults + new_fn.__kwdefaults__ = kwdefaults - return transformed_entity + return new_fn -class FunctionTranspiler(object): - """A generic source-to-source transpiler for Python functions. +class GenericTranspiler(object): + """A generic transpiler for Python functions. - Its interface `transform_function` API offers a function-in, function-out - interface. Internally, it takes care of parsing, caching and variable binding. + Its interface is the `transform` API, which can process Python function + objects. Internally, it handles parsing. - Users typically subclass this, customizing the transform_ast method. - - Usually, instances of this class are singletons, since each instance manages - its own cache. The caching subkey allows managing multiple types of - transformation. + Users typically subclass this, customizing the `transform_ast` method. The + output of transformed_ast is returned directly by `transform`. Existing + methods like `transform_function` may also be overloaded. Example: - class MyTransformer(FunctionTranspiler): + class MyTransformer(GenericTranspiler): - def transform_ast(self, node, ctx): - node = <> - return node + def transform(self, obj): + result = <> + return result transformer = MyTransfomer() - new_f, module, source_map = transformer.transform_function(f, ...) - # new_f is a function with signature identical to f - - The transformed function has access to the same namespace as the original - function. To allow access to internal APIs, users may inject additional - symbols though the `extra_locals` argument of `transform_function`. + result = transformer.transform(f, ...) + # result is the output """ - def __init__(self): - self._cache_lock = threading.RLock() - self._cache = cache.CodeObjectCache() - - def transform_ast(self, node, user_context): + def transform_ast(self, node, ctx): """Performs an actual transformation of a function's AST. - Subclasses must implement this method. They must not call it. - - The method receives the original AST and generates code according to the - AST that the method returns. For functions, the returned AST is expected to - contain a function with the exact same arguments and closure. The resulting - function will receive the globals, closure and argument defaults of the - input function. + Subclasses must implement this method, and do not usually call it. Args: node: One or more ast.AST nodes representing the AST to be transformed. - user_context: The same value that the caller passed to - `transform_function`. + ctx: transformer.Context. """ raise NotImplementedError('subclasses must override this') - def get_transformed_name(self, node): - """Returns a name for the output function. Subclasses may override this.""" - if isinstance(node, gast.Lambda): - return 'lam' - elif isinstance(node, gast.FunctionDef): - # Note that we need to rename the function, to avoid any namespace - # clashes. - return node.name - else: - raise ValueError('Unknown node type {}'.format(node)) + def transform(self, obj, user_context): + """Transforms a Python object. + + Users typically call this method. + + Args: + obj: A Python object, function, type, etc. + user_context: An opaque object (may be None) that is forwarded to + transform_ast, through the ctx.user_context argument. + Returns: + Tre result of calling transform_function. + + Raises: + NotImplementedError: if the type of obj is not handled. + """ + if inspect.isfunction(obj) or inspect.ismethod(obj): + return self.transform_function(obj, user_context) + + raise NotImplementedError('Non-function: {}'.format(type(obj))) def _erase_arg_defaults(self, node): """Erase argde fault expressions, which would otherwise be unbound.""" @@ -296,8 +289,21 @@ class FunctionTranspiler(object): args.kw_defaults[i] = parser.parse_expression('None') return node - def _transform_function(self, fn, user_context): - """Performs source code transformation on a function.""" + def transform_function(self, fn, user_context): + """Transforms a function. + + Subclasses may override this method. The return value is opaque. + + The method receives the original AST. The result is passed as-is to the + output of `transform`. + + Args: + fn: A function or lambda. + user_context: An opaque object (may be None) that is forwarded to + transform_ast, through the ctx.user_context argument. + Returns: + Any. By default it returns the output of transform_ast. + """ future_features = inspect_utils.getfutureimports(fn) node, source = parser.parse_entity(fn, future_features=future_features) logging.log(3, 'Source code of %s:\n\n%s\n', fn, source) @@ -333,63 +339,129 @@ class FunctionTranspiler(object): return node, context + +class PyToPy(GenericTranspiler): + """A generic Python-to-Python transpiler. + + Its `transform` method offers a function-in, function-out interface. + Internally, it takes care of parsing, caching and loading of the translated + code. + + Users typically subclass this, overriding `transform_ast`. + + Usually, instances of this class are singletons, since each instance manages + its own cache. The caching can be controlled by overriding `get_caching_key`. + + Example: + + class MyTransformer(PyToPy): + + def transform_ast(self, node, ctx): + node = <> + return node + + transformer = MyTransfomer() + + new_f, module, source_map = transformer.transform_function(f, ...) + # new_f is a function with signature identical to f + + The transformed function has access to the same namespace as the original + function. To allow access to internal APIs, users may inject additional + symbols by overriding `get_extra_locals`. + """ + + def __init__(self): + self._cache_lock = threading.RLock() + self._cache = cache.CodeObjectCache() + + def get_transformed_name(self, node): + """Returns a name for the output function. Subclasses may override this.""" + if isinstance(node, gast.Lambda): + return 'lam' + elif isinstance(node, gast.FunctionDef): + # Note that we need to rename the function, to avoid any namespace + # clashes. + return node.name + raise ValueError('Unknown node type {}'.format(node)) + + def get_extra_locals(self): + """Returns extra static local variables to be made to transformed code. + + Subclasses must override this. + + Returns: + extra_locals: A Dict[Text, Any] containing additional variables to make + available to the transformed code. + """ + raise NotImplementedError('subclasses must override this') + + def get_caching_key(self, user_context): + """Returns a unique key to use for caching. + + Subclasses must override this. + + Calls made to `transform_function` with functions that have the same code + object and caching key will return a cached instance on subsequent + invocations. + + Args: + user_context: The context object which was passed to `transform`. + + Returns: + extra_locals: A hashable. + """ + raise NotImplementedError('subclasses must override this') + def _cached_factory(self, fn, cache_subkey): cached_factory = self._cache[fn][cache_subkey] logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey, cached_factory) return cached_factory - def _transformed_factory(self, fn, cache_subkey, user_context, extra_locals): - """Returns the transformed function factory for a given input.""" - if self._cache.has(fn, cache_subkey): - return self._cached_factory(fn, cache_subkey) + def transform_function(self, fn, user_context): + """Transforms a function. See GenericTranspiler.trasnform_function. - with self._cache_lock: - # Check again under lock. - if self._cache.has(fn, cache_subkey): - return self._cached_factory(fn, cache_subkey) - - logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) - nodes, ctx = self._transform_function(fn, user_context) - - if logging.has_verbosity(2): - logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) - - factory = _TransformedFnFactory( - ctx.info.name, fn.__code__.co_freevars, extra_locals) - factory.create(nodes, ctx.namer, future_features=ctx.info.future_features) - self._cache[fn][cache_subkey] = factory - return factory - - def transform_function(self, fn, caching_subkey, user_context, extra_locals): - """Transforms a function. - - The `caching_subkey` argument allows mapping each function to multiple - outputs in the cache. This is useful for instance when transformers - can generate multiple variants of output code, typically as a result of - different transformation flags. + This overload wraps the parent's `transform_function`, adding caching and + facilities to instantiate the output as a Python object. It also + adds facilities to make new symbols available to the generated Python code, + visible as local variables - see `get_extra_locals`. Args: fn: A function or lambda. - caching_subkey: Used for caching. Calls made for functions with the same - code object and caching_subkey will return a cached instance on - subsequent invocations. Using a constant will create unique per-function - entries. - user_context: An opaque object (may be none) that is forwarded to - transform_ast. - extra_locals: A Dict[Text, Any] containing additional variables to make - available to the transformed code. These will be visible as local - variables. + user_context: An opaque object (may be None) that is forwarded to + transform_ast, through the ctx.user_context argument. Returns: A tuple: * A function or lambda with the same signature and closure as `fn` * The temporary module into which the transformed function was loaded * The source map as a Dict[origin_info.LineLocation, origin_info.OriginInfo] - """ - factory = self._transformed_factory(fn, caching_subkey, user_context, - extra_locals) + cache_subkey = self.get_caching_key(user_context) + + if self._cache.has(fn, cache_subkey): + # Fast path: use a lock-free check. + factory = self._cached_factory(fn, cache_subkey) + + else: + with self._cache_lock: + # Check again under lock. + if self._cache.has(fn, cache_subkey): + factory = self._cached_factory(fn, cache_subkey) + + else: + logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey) + # TODO(mdan): Confusing overloading pattern. Fix. + nodes, ctx = super(PyToPy, self).transform_function(fn, user_context) + + if logging.has_verbosity(2): + logging.log(2, 'Transformed %s:\n\n%s\n', fn, parser.unparse(nodes)) + + factory = _PythonFnFactory( + ctx.info.name, fn.__code__.co_freevars, self.get_extra_locals()) + factory.create( + nodes, ctx.namer, future_features=ctx.info.future_features) + self._cache[fn][cache_subkey] = factory transformed_fn = factory.instantiate( globals_=fn.__globals__, diff --git a/tensorflow/python/autograph/pyct/transpiler_test.py b/tensorflow/python/autograph/pyct/transpiler_test.py index 501ed6d78b2..b61a4c7a889 100644 --- a/tensorflow/python/autograph/pyct/transpiler_test.py +++ b/tensorflow/python/autograph/pyct/transpiler_test.py @@ -35,7 +35,14 @@ class FlipSignTransformer(transformer.Base): return self.generic_visit(node) -class TestTranspiler(transpiler.FunctionTranspiler): +class TestTranspiler(transpiler.PyToPy): + + def get_caching_key(self, ctx): + del ctx + return 0 + + def get_extra_locals(self): + return {} def transform_ast(self, node, ctx): return FlipSignTransformer(ctx).visit(node) @@ -45,14 +52,14 @@ global_var_for_test_global = 1 global_var_for_test_namespace_collisions = object() -class FunctionTranspilerTest(test.TestCase): +class PyToPyTest(test.TestCase): def test_basic(self): def f(a): return a + 1 tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 0) @@ -63,7 +70,7 @@ class FunctionTranspilerTest(test.TestCase): return a + b tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 0) b = 2 @@ -74,7 +81,7 @@ class FunctionTranspilerTest(test.TestCase): return a + global_var_for_test_global tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) global global_var_for_test_global global_var_for_test_global = 1 @@ -90,7 +97,7 @@ class FunctionTranspilerTest(test.TestCase): return a + b + d tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 1 - 2 - 2) c = 0 @@ -107,7 +114,7 @@ class FunctionTranspilerTest(test.TestCase): return g(a) + 1 tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 1 - 1 + 1) # Only f is converted. @@ -116,7 +123,7 @@ class FunctionTranspilerTest(test.TestCase): f = lambda x: (b + (x if x > 0 else -x)) tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 2 - 1) self.assertEqual(f(-1), 2 - 1) @@ -132,7 +139,7 @@ class FunctionTranspilerTest(test.TestCase): f, _ = (lambda x: a + x, lambda y: b * y) tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 1 - 1) @@ -147,7 +154,7 @@ class FunctionTranspilerTest(test.TestCase): return g(x) tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 2 - 1) @@ -159,7 +166,7 @@ class FunctionTranspilerTest(test.TestCase): return g(x) tr = TestTranspiler() - f, _, _ = tr.transform_function(f, object(), None, {}) + f, _, _ = tr.transform(f, None) self.assertEqual(f(1), 2 - 1) @@ -171,9 +178,11 @@ class FunctionTranspilerTest(test.TestCase): outputs = [] tr = TestTranspiler() - cache_key = object() + # Note: this is not a test, it's a required invariant. + assert tr.get_caching_key(None) == tr.get_caching_key(None) + def conversion_thread(): - _, mod, _ = tr.transform_function(f, cache_key, None, {}) + _, mod, _ = tr.transform(f, None) outputs.append(mod.__name__) threads = tuple( @@ -192,21 +201,28 @@ class FunctionTranspilerTest(test.TestCase): def test_fn(): return 1 + 1 - class ReentrantTranspiler(transpiler.FunctionTranspiler): + class ReentrantTranspiler(transpiler.PyToPy): def __init__(self): super(ReentrantTranspiler, self).__init__() self._recursion_depth = 0 + def get_caching_key(self, ctx): + del ctx + return 0 + + def get_extra_locals(self): + return {} + def transform_ast(self, node, ctx): self._recursion_depth += 1 if self._recursion_depth < 2: - self.transform_function(test_fn, object(), None, {}) + self.transform(test_fn, None) return FlipSignTransformer(ctx).visit(node) tr = ReentrantTranspiler() - f, _, _ = tr.transform_function(test_fn, object(), None, {}) + f, _, _ = tr.transform(test_fn, None) self.assertEqual(f(), 0) def test_namespace_collisions_avoided(self): @@ -219,8 +235,8 @@ class FunctionTranspilerTest(test.TestCase): tr = TestTranspiler() obj = TestClass() - f, _, _ = tr.transform_function( - obj.global_var_for_test_namespace_collisions, object(), None, {}) + f, _, _ = tr.transform( + obj.global_var_for_test_namespace_collisions, None) self.assertIs(f(obj), global_var_for_test_namespace_collisions)