Add a flag for controlling the features that are applied at conversion. Turn control dependencies and lists off for defun.
PiperOrigin-RevId: 217188171
This commit is contained in:
parent
8b21130fb6
commit
0e0ea561eb
@ -25,6 +25,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.autograph import operators
|
||||
from tensorflow.python.autograph import utils
|
||||
from tensorflow.python.autograph.core.converter import ConversionOptions
|
||||
from tensorflow.python.autograph.core.converter import Feature
|
||||
from tensorflow.python.autograph.core.errors import GraphConstructionError
|
||||
from tensorflow.python.autograph.core.errors import improved_errors
|
||||
from tensorflow.python.autograph.core.errors import TfRuntimeError
|
||||
@ -44,6 +45,7 @@ from tensorflow.python.util.all_util import remove_undocumented
|
||||
_allowed_symbols = [
|
||||
# Main API
|
||||
'ConversionOptions',
|
||||
'Feature',
|
||||
'RunMode',
|
||||
'convert',
|
||||
'converted_call',
|
||||
|
||||
@ -89,6 +89,19 @@ from tensorflow.python.autograph.pyct.static_analysis import type_info
|
||||
# TODO(mdan): Add a test specific to this converter.
|
||||
|
||||
|
||||
class Feature(Enum):
|
||||
"""Constants to use when selecting AutoGraph features."""
|
||||
|
||||
ALL = 'Enable all features.'
|
||||
|
||||
AUTO_CONTROL_DEPS = (
|
||||
'Insert of control dependencies in the generated code.')
|
||||
LISTS = 'Convert list idioms, like initializers, slices, append, etc.'
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class ConversionOptions(object):
|
||||
"""Immutable container for global conversion flags.
|
||||
|
||||
@ -103,18 +116,31 @@ class ConversionOptions(object):
|
||||
force_conversion: bool, whether to force convertinng the target entity. When
|
||||
force_conversion is turned off, the converter may decide to return the
|
||||
function as-is.
|
||||
optional_features: Union[Feature, Set[Feature]], controls the use of
|
||||
optional features in the conversion process. See Feature for available
|
||||
options.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
recursive=False,
|
||||
verbose=False,
|
||||
strip_decorators=None,
|
||||
force_conversion=False):
|
||||
force_conversion=False,
|
||||
optional_features=Feature.ALL):
|
||||
self.recursive = recursive
|
||||
self.verbose = verbose
|
||||
self.strip_decorators = strip_decorators or ()
|
||||
self.force_conversion = force_conversion
|
||||
|
||||
if not isinstance(optional_features, (set, list, tuple)):
|
||||
optional_features = (optional_features,)
|
||||
optional_features = frozenset(optional_features)
|
||||
self.optional_features = optional_features
|
||||
|
||||
def uses(self, feature):
|
||||
return (Feature.ALL in self.optional_features or
|
||||
feature in self.optional_features)
|
||||
|
||||
def to_ast(self, namespace):
|
||||
"""Returns a representation of this object as an AST node.
|
||||
|
||||
@ -132,8 +158,9 @@ class ConversionOptions(object):
|
||||
constructor_name(
|
||||
recursive=recursive_val,
|
||||
verbose=verbose_val,
|
||||
strip_decorators=strip_decorator_names,
|
||||
force_conversion=force_conversion_val)
|
||||
strip_decorators=strip_decorators_val,
|
||||
force_conversion=force_conversion_val,
|
||||
optional_features=optional_features_val)
|
||||
"""
|
||||
|
||||
def as_qualified_name(o):
|
||||
@ -143,8 +170,15 @@ class ConversionOptions(object):
|
||||
o, namespace))
|
||||
return name
|
||||
|
||||
strip_decorators_code = '({})'.format(', '.join(
|
||||
tuple(as_qualified_name(o) for o in self.strip_decorators)))
|
||||
def list_of_names(values):
|
||||
return parser.parse_expression('({})'.format(', '.join(
|
||||
tuple(as_qualified_name(v) for v in values))))
|
||||
|
||||
def list_of_features(values):
|
||||
return parser.parse_expression('({})'.format(', '.join(
|
||||
'ag__.Feature.{}'.format(v)
|
||||
for v in Feature.__members__
|
||||
if v in values)))
|
||||
|
||||
expr_ast = templates.replace(
|
||||
template,
|
||||
@ -152,9 +186,10 @@ class ConversionOptions(object):
|
||||
as_qualified_name(ConversionOptions)),
|
||||
recursive_val=parser.parse_expression(str(self.recursive)),
|
||||
verbose_val=parser.parse_expression(str(self.verbose)),
|
||||
strip_decorator_names=parser.parse_expression(strip_decorators_code),
|
||||
strip_decorators_val=list_of_names(self.strip_decorators),
|
||||
force_conversion_val=parser.parse_expression(
|
||||
str(self.force_conversion)))
|
||||
str(self.force_conversion)),
|
||||
optional_features_val=list_of_features(self.optional_features))
|
||||
return expr_ast[0].value
|
||||
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ from tensorflow.python.autograph.operators import py_builtins
|
||||
from tensorflow.python.autograph.pyct import compiler
|
||||
from tensorflow.python.autograph.pyct import inspect_utils
|
||||
from tensorflow.python.autograph.utils import py_func
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
@ -66,6 +67,7 @@ def convert(recursive=False, verbose=False):
|
||||
recursive=recursive,
|
||||
verbose=verbose,
|
||||
force_conversion=True,
|
||||
optional_features=converter.Feature.ALL,
|
||||
), *args, **kwargs)
|
||||
|
||||
wrapper = tf_decorator.make_decorator(f, wrapper)
|
||||
@ -142,6 +144,9 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
|
||||
# TODO(mdan): Move to a private, undocumented module.
|
||||
def converted_call(f, owner, options, *args, **kwargs):
|
||||
"""Compiles a function call inline. For internal use only."""
|
||||
if options.verbose:
|
||||
logging.info('Converted call: {}; owner: {}'.format(f, owner))
|
||||
|
||||
if owner is not None:
|
||||
if not isinstance(f, str):
|
||||
raise ValueError(
|
||||
@ -233,7 +238,8 @@ def converted_call(f, owner, options, *args, **kwargs):
|
||||
arg_values=arg_values,
|
||||
arg_types=arg_types,
|
||||
partial_types=partial_types,
|
||||
strip_decorators=options.strip_decorators)
|
||||
strip_decorators=options.strip_decorators,
|
||||
optional_features=options.optional_features)
|
||||
return converted_f(*effective_args, **kwargs)
|
||||
|
||||
|
||||
@ -246,7 +252,8 @@ def to_graph(e,
|
||||
arg_values=None,
|
||||
arg_types=None,
|
||||
partial_types=None,
|
||||
strip_decorators=None):
|
||||
strip_decorators=None,
|
||||
optional_features=converter.Feature.ALL):
|
||||
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
|
||||
|
||||
Supported Python entities include:
|
||||
@ -267,6 +274,8 @@ def to_graph(e,
|
||||
partial_types: Set[Type], reserved for internal use.
|
||||
strip_decorators: Tuple[Callable], same as
|
||||
ConversionOptions.strip_decorators.
|
||||
optional_features: Union[Feature, Set[Feature]], same as
|
||||
ConversionOptions.optional_features.
|
||||
|
||||
Returns:
|
||||
Union[Callable, Type], the converted entity, which is the same kind as e
|
||||
@ -284,7 +293,8 @@ def to_graph(e,
|
||||
options=converter.ConversionOptions(
|
||||
recursive=recursive,
|
||||
verbose=verbose,
|
||||
strip_decorators=strip_decorators),
|
||||
strip_decorators=strip_decorators,
|
||||
optional_features=optional_features),
|
||||
partial_types=partial_types,
|
||||
autograph_module=tf_inspect.getmodule(to_graph),
|
||||
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
|
||||
@ -295,7 +305,7 @@ def to_graph(e,
|
||||
for dep in reversed(program_ctx.conversion_order):
|
||||
nodes.extend(program_ctx.dependency_cache[dep])
|
||||
|
||||
compiled_module, compiled_src = compiler.ast_to_object(
|
||||
compiled_module, _ = compiler.ast_to_object(
|
||||
nodes,
|
||||
source_prefix=program_ctx.required_imports,
|
||||
include_source_map=True)
|
||||
|
||||
@ -347,16 +347,17 @@ def node_to_graph(node, context, rewrite_errors=True):
|
||||
# dealing with the extra loop increment operation that the for
|
||||
# canonicalization creates.
|
||||
node = converter.apply_(node, context, continue_statements)
|
||||
context.info.namespace['len'] = len
|
||||
node = converter.apply_(node, context, return_statements)
|
||||
node = converter.apply_(node, context, lists)
|
||||
node = converter.apply_(node, context, slices)
|
||||
if context.program.options.uses(converter.Feature.LISTS):
|
||||
node = converter.apply_(node, context, lists)
|
||||
node = converter.apply_(node, context, slices)
|
||||
node = converter.apply_(node, context, builtin_functions)
|
||||
node = converter.apply_(node, context, call_trees)
|
||||
node = converter.apply_(node, context, control_flow)
|
||||
node = converter.apply_(node, context, conditional_expressions)
|
||||
node = converter.apply_(node, context, logical_expressions)
|
||||
node = converter.apply_(node, context, side_effect_guards)
|
||||
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
|
||||
node = converter.apply_(node, context, side_effect_guards)
|
||||
node = converter.apply_(node, context, function_scopes)
|
||||
if rewrite_errors:
|
||||
node = converter.apply_(node, context, error_handlers)
|
||||
|
||||
@ -960,10 +960,13 @@ def func_graph_from_py_func(name,
|
||||
try:
|
||||
if experimental_autograph:
|
||||
func_outputs = autograph.converted_call(
|
||||
python_func,
|
||||
python_func, None,
|
||||
autograph.ConversionOptions(
|
||||
verbose=True, recursive=True, strip_decorators=(defun,)),
|
||||
*func_args, **func_kwargs)
|
||||
verbose=True,
|
||||
recursive=True,
|
||||
strip_decorators=(defun,),
|
||||
optional_features=(),
|
||||
), *func_args, **func_kwargs)
|
||||
else:
|
||||
func_outputs = python_func(*func_args, **func_kwargs)
|
||||
# invariant: `func_outputs` contains only Tensors and `None`s.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user