Add a flag for controlling the features that are applied at conversion. Turn control dependencies and lists off for defun.

PiperOrigin-RevId: 217188171
This commit is contained in:
Dan Moldovan 2018-10-15 12:32:35 -07:00 committed by TensorFlower Gardener
parent 8b21130fb6
commit 0e0ea561eb
5 changed files with 69 additions and 18 deletions

View File

@ -25,6 +25,7 @@ from __future__ import print_function
from tensorflow.python.autograph import operators
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core.converter import ConversionOptions
from tensorflow.python.autograph.core.converter import Feature
from tensorflow.python.autograph.core.errors import GraphConstructionError
from tensorflow.python.autograph.core.errors import improved_errors
from tensorflow.python.autograph.core.errors import TfRuntimeError
@ -44,6 +45,7 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
# Main API
'ConversionOptions',
'Feature',
'RunMode',
'convert',
'converted_call',

View File

@ -89,6 +89,19 @@ from tensorflow.python.autograph.pyct.static_analysis import type_info
# TODO(mdan): Add a test specific to this converter.
class Feature(Enum):
"""Constants to use when selecting AutoGraph features."""
ALL = 'Enable all features.'
AUTO_CONTROL_DEPS = (
'Insert of control dependencies in the generated code.')
LISTS = 'Convert list idioms, like initializers, slices, append, etc.'
def __repr__(self):
return self.name
class ConversionOptions(object):
"""Immutable container for global conversion flags.
@ -103,18 +116,31 @@ class ConversionOptions(object):
force_conversion: bool, whether to force convertinng the target entity. When
force_conversion is turned off, the converter may decide to return the
function as-is.
optional_features: Union[Feature, Set[Feature]], controls the use of
optional features in the conversion process. See Feature for available
options.
"""
def __init__(self,
recursive=False,
verbose=False,
strip_decorators=None,
force_conversion=False):
force_conversion=False,
optional_features=Feature.ALL):
self.recursive = recursive
self.verbose = verbose
self.strip_decorators = strip_decorators or ()
self.force_conversion = force_conversion
if not isinstance(optional_features, (set, list, tuple)):
optional_features = (optional_features,)
optional_features = frozenset(optional_features)
self.optional_features = optional_features
def uses(self, feature):
return (Feature.ALL in self.optional_features or
feature in self.optional_features)
def to_ast(self, namespace):
"""Returns a representation of this object as an AST node.
@ -132,8 +158,9 @@ class ConversionOptions(object):
constructor_name(
recursive=recursive_val,
verbose=verbose_val,
strip_decorators=strip_decorator_names,
force_conversion=force_conversion_val)
strip_decorators=strip_decorators_val,
force_conversion=force_conversion_val,
optional_features=optional_features_val)
"""
def as_qualified_name(o):
@ -143,8 +170,15 @@ class ConversionOptions(object):
o, namespace))
return name
strip_decorators_code = '({})'.format(', '.join(
tuple(as_qualified_name(o) for o in self.strip_decorators)))
def list_of_names(values):
return parser.parse_expression('({})'.format(', '.join(
tuple(as_qualified_name(v) for v in values))))
def list_of_features(values):
return parser.parse_expression('({})'.format(', '.join(
'ag__.Feature.{}'.format(v)
for v in Feature.__members__
if v in values)))
expr_ast = templates.replace(
template,
@ -152,9 +186,10 @@ class ConversionOptions(object):
as_qualified_name(ConversionOptions)),
recursive_val=parser.parse_expression(str(self.recursive)),
verbose_val=parser.parse_expression(str(self.verbose)),
strip_decorator_names=parser.parse_expression(strip_decorators_code),
strip_decorators_val=list_of_names(self.strip_decorators),
force_conversion_val=parser.parse_expression(
str(self.force_conversion)))
str(self.force_conversion)),
optional_features_val=list_of_features(self.optional_features))
return expr_ast[0].value

View File

@ -28,6 +28,7 @@ from tensorflow.python.autograph.operators import py_builtins
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -66,6 +67,7 @@ def convert(recursive=False, verbose=False):
recursive=recursive,
verbose=verbose,
force_conversion=True,
optional_features=converter.Feature.ALL,
), *args, **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@ -142,6 +144,9 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
# TODO(mdan): Move to a private, undocumented module.
def converted_call(f, owner, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
if options.verbose:
logging.info('Converted call: {}; owner: {}'.format(f, owner))
if owner is not None:
if not isinstance(f, str):
raise ValueError(
@ -233,7 +238,8 @@ def converted_call(f, owner, options, *args, **kwargs):
arg_values=arg_values,
arg_types=arg_types,
partial_types=partial_types,
strip_decorators=options.strip_decorators)
strip_decorators=options.strip_decorators,
optional_features=options.optional_features)
return converted_f(*effective_args, **kwargs)
@ -246,7 +252,8 @@ def to_graph(e,
arg_values=None,
arg_types=None,
partial_types=None,
strip_decorators=None):
strip_decorators=None,
optional_features=converter.Feature.ALL):
"""Converts a Python entity into equivalent code that uses TensorFlow ops.
Supported Python entities include:
@ -267,6 +274,8 @@ def to_graph(e,
partial_types: Set[Type], reserved for internal use.
strip_decorators: Tuple[Callable], same as
ConversionOptions.strip_decorators.
optional_features: Union[Feature, Set[Feature]], same as
ConversionOptions.optional_features.
Returns:
Union[Callable, Type], the converted entity, which is the same kind as e
@ -284,7 +293,8 @@ def to_graph(e,
options=converter.ConversionOptions(
recursive=recursive,
verbose=verbose,
strip_decorators=strip_decorators),
strip_decorators=strip_decorators,
optional_features=optional_features),
partial_types=partial_types,
autograph_module=tf_inspect.getmodule(to_graph),
uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
@ -295,7 +305,7 @@ def to_graph(e,
for dep in reversed(program_ctx.conversion_order):
nodes.extend(program_ctx.dependency_cache[dep])
compiled_module, compiled_src = compiler.ast_to_object(
compiled_module, _ = compiler.ast_to_object(
nodes,
source_prefix=program_ctx.required_imports,
include_source_map=True)

View File

@ -347,16 +347,17 @@ def node_to_graph(node, context, rewrite_errors=True):
# dealing with the extra loop increment operation that the for
# canonicalization creates.
node = converter.apply_(node, context, continue_statements)
context.info.namespace['len'] = len
node = converter.apply_(node, context, return_statements)
node = converter.apply_(node, context, lists)
node = converter.apply_(node, context, slices)
if context.program.options.uses(converter.Feature.LISTS):
node = converter.apply_(node, context, lists)
node = converter.apply_(node, context, slices)
node = converter.apply_(node, context, builtin_functions)
node = converter.apply_(node, context, call_trees)
node = converter.apply_(node, context, control_flow)
node = converter.apply_(node, context, conditional_expressions)
node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, side_effect_guards)
if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS):
node = converter.apply_(node, context, side_effect_guards)
node = converter.apply_(node, context, function_scopes)
if rewrite_errors:
node = converter.apply_(node, context, error_handlers)

View File

@ -960,10 +960,13 @@ def func_graph_from_py_func(name,
try:
if experimental_autograph:
func_outputs = autograph.converted_call(
python_func,
python_func, None,
autograph.ConversionOptions(
verbose=True, recursive=True, strip_decorators=(defun,)),
*func_args, **func_kwargs)
verbose=True,
recursive=True,
strip_decorators=(defun,),
optional_features=(),
), *func_args, **func_kwargs)
else:
func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.