272 lines
10 KiB
Python
272 lines
10 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Core conversion logic, serves as main point of access."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
import imp
|
|
import unittest
|
|
|
|
from tensorflow.python.autograph import operators
|
|
from tensorflow.python.autograph import utils
|
|
from tensorflow.python.autograph.converters import asserts
|
|
from tensorflow.python.autograph.converters import break_statements
|
|
from tensorflow.python.autograph.converters import call_trees
|
|
from tensorflow.python.autograph.converters import conditional_expressions
|
|
from tensorflow.python.autograph.converters import continue_statements
|
|
from tensorflow.python.autograph.converters import control_flow
|
|
from tensorflow.python.autograph.converters import directives
|
|
from tensorflow.python.autograph.converters import functions
|
|
from tensorflow.python.autograph.converters import lists
|
|
from tensorflow.python.autograph.converters import logical_expressions
|
|
from tensorflow.python.autograph.converters import return_statements
|
|
from tensorflow.python.autograph.converters import slices
|
|
from tensorflow.python.autograph.converters import variables
|
|
from tensorflow.python.autograph.core import config
|
|
from tensorflow.python.autograph.core import converter
|
|
from tensorflow.python.autograph.core import function_wrappers
|
|
from tensorflow.python.autograph.core import unsupported_features_checker
|
|
from tensorflow.python.autograph.lang import special_functions
|
|
from tensorflow.python.autograph.pyct import anno
|
|
from tensorflow.python.autograph.pyct import cache
|
|
from tensorflow.python.autograph.pyct import cfg
|
|
from tensorflow.python.autograph.pyct import inspect_utils
|
|
from tensorflow.python.autograph.pyct import qual_names
|
|
from tensorflow.python.autograph.pyct import transpiler
|
|
from tensorflow.python.autograph.pyct.static_analysis import activity
|
|
from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
|
|
from tensorflow.python.autograph.utils import ag_logging as logging
|
|
from tensorflow.python.eager import function
|
|
from tensorflow.python.util import tf_inspect
|
|
|
|
|
|
class AutoGraphTranspiler(transpiler.FunctionTranspiler):
|
|
|
|
def get_transformed_name(self, node):
|
|
return 'tf__' + super(AutoGraphTranspiler, self).get_transformed_name(node)
|
|
|
|
def transform_ast(self, node, ctx):
|
|
# TODO(mdan): Insert list_comprehensions somewhere.
|
|
unsupported_features_checker.verify(node)
|
|
|
|
# Run initial analysis.
|
|
graphs = cfg.build(node)
|
|
node = qual_names.resolve(node)
|
|
node = activity.resolve(node, ctx, None)
|
|
node = reaching_definitions.resolve(node, ctx, graphs)
|
|
anno.dup(
|
|
node,
|
|
{
|
|
anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
|
|
},
|
|
)
|
|
|
|
node = functions.transform(node, ctx)
|
|
node = directives.transform(node, ctx)
|
|
node = break_statements.transform(node, ctx)
|
|
if ctx.user.options.uses(converter.Feature.ASSERT_STATEMENTS):
|
|
node = asserts.transform(node, ctx)
|
|
# Note: sequencing continue canonicalization before for loop one avoids
|
|
# dealing with the extra loop increment operation that the for
|
|
# canonicalization creates.
|
|
node = continue_statements.transform(node, ctx)
|
|
node = return_statements.transform(node, ctx)
|
|
if ctx.user.options.uses(converter.Feature.LISTS):
|
|
node = lists.transform(node, ctx)
|
|
node = slices.transform(node, ctx)
|
|
node = call_trees.transform(node, ctx)
|
|
node = control_flow.transform(node, ctx)
|
|
node = conditional_expressions.transform(node, ctx)
|
|
node = logical_expressions.transform(node, ctx)
|
|
node = variables.transform(node, ctx)
|
|
return node
|
|
|
|
|
|
_TRANSPILER = AutoGraphTranspiler()
|
|
_WHITELIST_CACHE = cache.UnboundInstanceCache()
|
|
|
|
|
|
custom_vars = None
|
|
|
|
|
|
# TODO(mdan): Superfluous function, remove.
|
|
# TODO(mdan): Put these extra fields inside __autograph_info__.
|
|
def convert(entity, program_ctx):
|
|
"""Applies AutoGraph to entity."""
|
|
|
|
if not hasattr(entity, '__code__'):
|
|
raise ValueError('Cannot apply autograph to a function that doesn\'t '
|
|
'expose a __code__ object. If this is a @tf.function,'
|
|
' try passing f.python_function instead.')
|
|
|
|
_create_custom_vars(program_ctx)
|
|
transformed, module, source_map = _TRANSPILER.transform_function(
|
|
entity, program_ctx.options, program_ctx, custom_vars)
|
|
|
|
assert not hasattr(transformed, 'ag_module')
|
|
assert not hasattr(transformed, 'ag_source_map')
|
|
transformed.ag_module = module
|
|
transformed.ag_source_map = source_map
|
|
return transformed
|
|
|
|
|
|
# TODO(mdan): allow_namedtuple_subclass should be hardcoded to True.
|
|
def is_whitelisted(
|
|
o, check_call_override=True, allow_namedtuple_subclass=False):
|
|
"""Checks whether an entity is whitelisted for use in graph mode.
|
|
|
|
Examples of whitelisted entities include all members of the tensorflow
|
|
package.
|
|
|
|
Args:
|
|
o: A Python entity.
|
|
check_call_override: Reserved for internal use. When set to `False`, it
|
|
disables the rule according to which classes are whitelisted if their
|
|
__call__ method is whitelisted.
|
|
allow_namedtuple_subclass: Reserved for internal use. When `True`,
|
|
namedtuple subclasses are not whitelisted.
|
|
|
|
Returns:
|
|
Boolean
|
|
"""
|
|
# TODO(b/120224672): Fix this.
|
|
if isinstance(o, functools.partial):
|
|
# tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
|
|
# functools.partial objects do not have a __module__ attribute.
|
|
m = functools
|
|
else:
|
|
m = tf_inspect.getmodule(o)
|
|
|
|
# Examples of callables that lack a __module__ property include builtins.
|
|
if hasattr(m, '__name__'):
|
|
for rule in config.CONVERSION_RULES:
|
|
action = rule.get_action(m)
|
|
if action == config.Action.CONVERT:
|
|
logging.log(2, 'Not whitelisted: %s: %s', o, rule)
|
|
return False
|
|
elif action == config.Action.DO_NOT_CONVERT:
|
|
logging.log(2, 'Whitelisted: %s: %s', o, rule)
|
|
return True
|
|
|
|
# The check for __code__ below is because isgeneratorfunction crashes
|
|
# without one.
|
|
if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
|
|
logging.warn(
|
|
'Entity %s appears to be a generator function. It will not be converted'
|
|
' by AutoGraph.', o)
|
|
logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
|
|
return True
|
|
|
|
if (check_call_override and not tf_inspect.isclass(o) and
|
|
hasattr(o, '__call__')):
|
|
# Callable objects: whitelisted if their __call__ method is.
|
|
# The type check avoids infinite recursion around the __call__ method
|
|
# of function objects.
|
|
if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck
|
|
logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
|
|
return True
|
|
|
|
owner_class = None
|
|
if tf_inspect.ismethod(o):
|
|
# Methods of whitelisted classes are also whitelisted, even if they are
|
|
# bound via user subclasses.
|
|
#
|
|
# For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
|
|
# defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
|
|
# whitelisted.
|
|
#
|
|
# class Custom(tf.Foo):
|
|
# pass
|
|
#
|
|
# baz = Custom()
|
|
#
|
|
# For the example above, if `Custom` did overload `bar`, then it would no
|
|
# longer be whitelisted.
|
|
|
|
owner_class = inspect_utils.getmethodclass(o)
|
|
if owner_class is function.TfMethodTarget:
|
|
owner_class = o.__self__.target_class
|
|
if owner_class is not None:
|
|
if issubclass(owner_class, unittest.TestCase):
|
|
logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
|
|
return True
|
|
|
|
owner_class = inspect_utils.getdefiningclass(o, owner_class)
|
|
if is_whitelisted(
|
|
owner_class,
|
|
check_call_override=False,
|
|
allow_namedtuple_subclass=True):
|
|
logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
|
|
owner_class)
|
|
return True
|
|
|
|
if inspect_utils.isnamedtuple(o):
|
|
# Due to the way they're constructed, namedtuple types cannot be converted
|
|
# because they don't expose source code. But we assume they are safe for
|
|
# graph mode since they are just containers.
|
|
if allow_namedtuple_subclass:
|
|
if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__):
|
|
logging.log(2, 'Whitelisted: %s: named tuple', o)
|
|
return True
|
|
else:
|
|
logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
|
|
return True
|
|
|
|
logging.log(2, 'Not whitelisted: %s: default rule', o)
|
|
return False
|
|
|
|
|
|
def is_in_whitelist_cache(entity, options):
|
|
try:
|
|
return _WHITELIST_CACHE.has(entity, options)
|
|
except TypeError:
|
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
|
return False
|
|
|
|
|
|
def cache_whitelisted(entity, options):
|
|
try:
|
|
_WHITELIST_CACHE[entity][options] = True
|
|
except TypeError:
|
|
# Catch-all for entities that are unhashable or don't allow weakrefs.
|
|
pass
|
|
|
|
|
|
# TODO(mdan): Move into core or replace with an actual importable module.
|
|
def _create_custom_vars(program_ctx):
|
|
"""Adds namespace references to the module that exposes the api itself."""
|
|
global custom_vars
|
|
if custom_vars is None:
|
|
# Craft a module that exposes parts of the external API as well as certain
|
|
# internal modules.
|
|
ag_internal = imp.new_module('autograph')
|
|
ag_internal.__dict__.update(program_ctx.autograph_module.__dict__)
|
|
ag_internal.ConversionOptions = converter.ConversionOptions
|
|
ag_internal.STD = converter.STANDARD_OPTIONS
|
|
ag_internal.Feature = converter.Feature
|
|
ag_internal.utils = utils
|
|
ag_internal.FunctionScope = function_wrappers.FunctionScope
|
|
ag_internal.with_function_scope = function_wrappers.with_function_scope
|
|
# TODO(mdan): Add safeguards against name clashes.
|
|
# We don't want to create a submodule because we want the operators to be
|
|
# accessible as ag__.<operator>
|
|
ag_internal.__dict__.update(special_functions.__dict__)
|
|
ag_internal.__dict__.update(operators.__dict__)
|
|
|
|
custom_vars = {'ag__': ag_internal}
|