Enable support for loop directives, include shape_invariants.

PiperOrigin-RevId: 285966483
Change-Id: I3eae0b134cd2e954bfa0ac31e6a7411b3a5bb7df
This commit is contained in:
Dan Moldovan 2019-12-17 06:12:13 -08:00 committed by TensorFlower Gardener
parent 1768c8f2fa
commit e42b937e1a
6 changed files with 149 additions and 73 deletions

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import gast import gast
from tensorflow.python.autograph.core import converter from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import parser from tensorflow.python.autograph.pyct import parser
@ -151,6 +152,20 @@ class ControlFlowTransformer(converter.Base):
return node return node
def _create_loop_options(self, node):
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
return gast.Dict([], [])
loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
if directives.set_loop_options not in loop_directives:
return gast.Dict([], [])
opts_dict = loop_directives[directives.set_loop_options]
str_keys, values = zip(*opts_dict.items())
keys = [gast.Str(s) for s in str_keys]
values = list(values) # ast and gast don't play well with tuples.
return gast.Dict(keys, values)
def _create_undefined_assigns(self, undefined_symbols): def _create_undefined_assigns(self, undefined_symbols):
assignments = [] assignments = []
for s in undefined_symbols: for s in undefined_symbols:
@ -383,8 +398,7 @@ class ControlFlowTransformer(converter.Base):
composite_symbol_names = tuple( composite_symbol_names = tuple(
gast.Str(str(symbol)) for symbol in composite_loop_vars) gast.Str(str(symbol)) for symbol in composite_loop_vars)
# TODO(b/140125096): Populate. opts = self._create_loop_options(node)
opts = gast.Dict([], [])
# TODO(mdan): Use a single template. # TODO(mdan): Use a single template.
# If the body and test functions took a single tuple for loop_vars, instead # If the body and test functions took a single tuple for loop_vars, instead
@ -507,8 +521,7 @@ class ControlFlowTransformer(converter.Base):
composite_symbol_names = tuple( composite_symbol_names = tuple(
gast.Str(str(symbol)) for symbol in composite_loop_vars) gast.Str(str(symbol)) for symbol in composite_loop_vars)
# TODO(b/140125096): Populate. opts = self._create_loop_options(node)
opts = gast.Dict([], [])
# TODO(mdan): Use a single template. # TODO(mdan): Use a single template.
# If the body and test functions took a single tuple for loop_vars, instead # If the body and test functions took a single tuple for loop_vars, instead

View File

@ -26,7 +26,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs.doc_controls import do_not_generate_docs
UNSPECIFIED = object() UNSPECIFIED = object()
@ -47,37 +46,53 @@ def set_element_type(entity, dtype, shape=UNSPECIFIED):
del shape del shape
# TODO(b/140125096): Implement.
@do_not_generate_docs
@tf_export('autograph.experimental.set_loop_options') @tf_export('autograph.experimental.set_loop_options')
def set_loop_options( def set_loop_options(
parallel_iterations=UNSPECIFIED, parallel_iterations=UNSPECIFIED,
back_prop=UNSPECIFIED,
swap_memory=UNSPECIFIED, swap_memory=UNSPECIFIED,
maximum_iterations=UNSPECIFIED): maximum_iterations=UNSPECIFIED,
shape_invariants=UNSPECIFIED):
"""Specifies additional arguments to be passed to the enclosing while_loop. """Specifies additional arguments to be passed to the enclosing while_loop.
The parameters apply to and only to the immediately enclosing loop. It only The parameters apply to and only to the immediately enclosing loop. It only
has effect if the loop is staged as a TF while_loop; otherwise the parameters has effect if the loop is staged as a TF while_loop; otherwise the parameters
have no effect. have no effect.
Usage example: Usage:
@tf.function(autograph=True) >>> @tf.function(autograph=True)
def dynamic_rnn(..., parallel_iterations=32): ... def f():
num_steps = ... ... n = 0
for t in tf.range(num_steps): ... for i in tf.range(10):
tf.autograph.experimental.set_loop_options( ... tf.autograph.experimental.set_loop_options(maximum_iterations=3)
parallel_iterations=parallel_iterations) ... n += 1
... ... return n
>>> @tf.function(autograph=True)
... def f():
... v = tf.constant((0,))
... for i in tf.range(3):
... tf.autograph.experimental.set_loop_options(
... shape_invariants=[(v, tf.TensorShape([None]))]
... )
... v = tf.concat((v, [i]), 0)
... return v
Also see tf.while_loop.
Args: Args:
parallel_iterations: See tf.while_loop. parallel_iterations: The maximum number of iterations allowed to run in
back_prop: See tf.while_loop. parallel at any given time. Note that this does not guarantee parallel
swap_memory: See tf.while_loop. execution.
maximum_iterations: See tf.while_loop. swap_memory: Whether to store intermediate values needed for
gradients on the CPU instead of GPU.
maximum_iterations: Allows limiting the total number of iterations executed
by the loop.
shape_invariants: Allows controlling the argument with the same name passed
to tf.while_loop. Unlike tf.while_loop, this is a list of
`(tensor, shape)` pairs.
""" """
del parallel_iterations del parallel_iterations
del back_prop
del swap_memory del swap_memory
del maximum_iterations del maximum_iterations
del shape_invariants

View File

@ -125,68 +125,91 @@ def _is_subshape(left, right):
return True return True
def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var): # TODO(mdan): Remove these verifications once TF ops can properly report names.
"""Verifies whether init_loop_var and first_iter_var are consistent.""" def _verify_single_loop_var(
if isinstance(init_loop_var, (bool, int, float, str)): name, check_shape, init, entry, exit_, shape_invariant):
init_loop_var = ops.convert_to_tensor_v2(init_loop_var) """Verifies whether the initial, entry and exit values are consistent."""
if isinstance(init, (bool, int, float, str, np.ndarray)):
init = ops.convert_to_tensor_v2(init)
if isinstance(entry, (bool, int, float, str, np.ndarray)):
entry = ops.convert_to_tensor_v2(entry)
if isinstance(exit_, (bool, int, float, str)):
exit_ = ops.convert_to_tensor_v2(exit_)
if isinstance(first_iter_var, (bool, int, float, str)): if (not tensor_util.is_tensor(entry) or
first_iter_var = ops.convert_to_tensor_v2(first_iter_var) not tensor_util.is_tensor(exit_)):
if (not tensor_util.is_tensor(init_loop_var) or
not tensor_util.is_tensor(first_iter_var)):
return return
# TODO(mdan): Properly account for CompositeTensors. # TODO(mdan): Properly account for CompositeTensors.
if (not hasattr(init_loop_var, 'dtype') or if (not hasattr(entry, 'dtype') or
not hasattr(first_iter_var, 'dtype')): not hasattr(exit_, 'dtype')):
return return
if (not hasattr(init_loop_var, 'shape') or if (not hasattr(entry, 'shape') or
not hasattr(first_iter_var, 'shape')): not hasattr(exit_, 'shape')):
return return
if init_loop_var.dtype != first_iter_var.dtype: if entry.dtype != exit_.dtype:
raise TypeError( raise TypeError(
'"{}" has dtype {} before the loop, but dtype {} after one' '"{}" has dtype {} before the loop, but dtype {} after one'
' iteration. TensorFlow control flow requires it stays the' ' iteration. TensorFlow control flow requires it stays the'
' same.'.format( ' same.'.format(
name, name,
init_loop_var.dtype.name, entry.dtype.name,
first_iter_var.dtype.name, exit_.dtype.name,
)) ))
if check_shape: if check_shape:
init_shape = init_loop_var.shape exit_shape = exit_.shape
first_iter_shape = first_iter_var.shape if shape_invariant is None:
# TODO(b/135183013): Update needed once we support shape_invariants. entry_shape = entry.shape
if not _is_subshape(first_iter_shape, init_shape): if not _is_subshape(exit_shape, entry_shape):
raise ValueError( raise ValueError(
'"{}" has shape {} before the loop, but shape {} after one' '"{}" has shape {} before the loop, but shape {} after one'
' iteration. TensorFlow control flow requires it stays the' ' iteration. Use tf.autograph.experimental.set_loop_options to set'
' same or be more specific.'.format(name, init_shape, ' shape invariants.'.format(name, entry_shape, exit_shape))
first_iter_shape)) else:
init_shape = init.shape
if not _is_subshape(init_shape, shape_invariant):
raise ValueError(
'"{}" has shape {} before the loop, which does not conform with'
' the shape invariant {}.'.format(name, init_shape,
shape_invariant))
if not _is_subshape(exit_shape, shape_invariant):
raise ValueError(
'"{}" has shape {} after the loop, which does not conform with'
' the shape invariant {}.'.format(
name, exit_shape, shape_invariant))
def _verify_tf_loop_vars(init_loop_vars, def _verify_tf_loop_vars(init_vars,
first_iter_vars, iter_entry_vars,
iter_exit_vars,
symbol_names, symbol_names,
opts, opts,
check_shapes=True): check_shapes=True):
"""Verifies loop variables for consistency.""" """Verifies loop variables for consistency."""
# TODO(b/140125096): Use this. if check_shapes and 'shape_invariants' in opts:
del opts shape_invariants = opts['shape_invariants']
else:
shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
named_vars = zip(symbol_names, init_loop_vars, first_iter_vars) named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
for name, init_loop_var, first_iter_var in named_vars: shape_invariants)
for name, init, entry, exit_, invariant in named_vars:
try: try:
nest.assert_same_structure( nest.assert_same_structure(entry, exit_, expand_composites=True)
init_loop_var, first_iter_var, expand_composites=True)
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
raise TypeError('"{}" does not have the same nested structure after one' raise TypeError('"{}" does not have the same nested structure after one'
' iteration.\n\n{}'.format(name, e)) ' iteration.\n\n{}'.format(name, e))
if invariant is not None:
try:
nest.assert_same_structure(init, invariant, expand_composites=False)
except (ValueError, TypeError) as e:
raise TypeError('"{}" does not have the same nested structure as its'
' corresponding shape invariant.\n\n{}'.format(name, e))
nest.map_structure( nest.map_structure(
functools.partial(_verify_single_loop_var, name, check_shapes), functools.partial(_verify_single_loop_var, name, check_shapes), init,
init_loop_var, first_iter_var) entry, exit_, invariant)
def _verify_single_cond_var(name, body_var, orelse_var): def _verify_single_cond_var(name, body_var, orelse_var):
@ -425,6 +448,8 @@ def _tf_ragged_for_stmt(iter_,
else: else:
n = iter_.row_lengths()[0] n = iter_.row_lengths()[0]
opts['maximum_iterations'] = n
def while_body(iterate_index, *loop_vars): def while_body(iterate_index, *loop_vars):
"""Main loop body.""" """Main loop body."""
iterate = iter_[iterate_index] iterate = iter_[iterate_index]
@ -566,7 +591,7 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
# Note: this verification duplicates that perfrmed in tf_while_stmt, # Note: this verification duplicates that perfrmed in tf_while_stmt,
# but needs to be done earlier to prevent the tf.cond inside while_body # but needs to be done earlier to prevent the tf.cond inside while_body
# from blowing up first. # from blowing up first.
_verify_tf_loop_vars(loop_vars, new_vars, _verify_tf_loop_vars(init_vars, loop_vars, new_vars,
basic_symbol_names + composite_symbol_names, opts) basic_symbol_names + composite_symbol_names, opts)
return new_vars return new_vars
@ -653,20 +678,26 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
# TODO(mdan): Simplify this - following it is extremely difficult. # TODO(mdan): Simplify this - following it is extremely difficult.
init_state = get_state()
aug_init_vars = init_vars, init_state
def scan_body(aug_vars, iterate): def scan_body(aug_vars, iterate):
"""The main loop body wrapper. Only calculates the stop condition.""" """The main loop body wrapper. Only calculates the stop condition."""
loop_vars, state = aug_vars loop_vars, state = aug_vars
def true_fn(): def true_fn():
"""Main path - stop condition is not set."""
set_state(state) set_state(state)
outputs = body(iterate, *loop_vars) new_vars = body(iterate, *loop_vars)
new_state = get_state()
_verify_tf_loop_vars( _verify_tf_loop_vars(
init_vars + init_state,
loop_vars + state, loop_vars + state,
outputs + state, new_vars + new_state,
basic_symbol_names + composite_symbol_names, basic_symbol_names + composite_symbol_names,
opts, opts,
check_shapes=False) check_shapes=False)
return outputs, get_state() return new_vars, new_state
extra_cond = extra_test(*loop_vars) extra_cond = extra_test(*loop_vars)
new_vars, new_state = control_flow_ops.cond( new_vars, new_state = control_flow_ops.cond(
@ -690,11 +721,9 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
del extra_cond del extra_cond
return output_aug_vars, output_state return output_aug_vars, output_state
init_state = get_state() ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
aug_vars = init_vars, init_state
ds = _general_purpose_scan(ds, aug_vars, scan_body)
ds = ds.apply(take_while_ops.take_while(take_while_predicate)) ds = ds.apply(take_while_ops.take_while(take_while_predicate))
final_aug_vars = ds.reduce(aug_vars, reduce_body) final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
final_vars, final_state = final_aug_vars final_vars, final_state = final_aug_vars
set_state(final_state) set_state(final_state)
return final_vars return final_vars
@ -741,6 +770,7 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
new_state = get_state() new_state = get_state()
_verify_tf_loop_vars( _verify_tf_loop_vars(
init_vars + init_state,
loop_vars + state, loop_vars + state,
new_vars + new_state, new_vars + new_state,
symbol_names, symbol_names,
@ -824,11 +854,23 @@ def while_stmt(test,
return _py_while_stmt(test, body, get_state, set_state, init_vars, opts) return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
def _shape_invariants_mapping_to_positional_list(mapping, keys):
# The keys are not expected to be hashable.
mapping = {id(k): (k, v) for k, v in mapping}
result = []
for k in keys:
map_key, map_val = mapping.get(id(k), (None, None))
result.append(map_val if map_key is k else None)
return tuple(result)
def _tf_while_stmt(test, body, get_state, set_state, init_vars, def _tf_while_stmt(test, body, get_state, set_state, init_vars,
basic_symbol_names, composite_symbol_names, opts): basic_symbol_names, composite_symbol_names, opts):
"""Overload of while_stmt that stages a TF while_stmt.""" """Overload of while_stmt that stages a TF while_stmt."""
_disallow_undefs_into_loop(*init_vars) _disallow_undefs_into_loop(*init_vars)
aug_init_vars = init_vars + get_state()
# TODO(mdan): Simplify this. # TODO(mdan): Simplify this.
loop_vars_slice = slice(len(init_vars)) loop_vars_slice = slice(len(init_vars))
state_slice = slice(len(init_vars), None) state_slice = slice(len(init_vars), None)
@ -844,7 +886,7 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
set_state(state) set_state(state)
loop_vars = body(*aug_loop_vars[loop_vars_slice]) loop_vars = body(*aug_loop_vars[loop_vars_slice])
new_state = loop_vars + get_state() new_state = loop_vars + get_state()
_verify_tf_loop_vars(aug_loop_vars, new_state, _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
basic_symbol_names + composite_symbol_names, opts) basic_symbol_names + composite_symbol_names, opts)
return new_state return new_state
@ -853,7 +895,10 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
# This enforces consistency across versions. # This enforces consistency across versions.
opts['return_same_structure'] = True opts['return_same_structure'] = True
aug_init_vars = init_vars + get_state() if 'shape_invariants' in opts:
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
opts['shape_invariants'], aug_init_vars)
final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body, final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
aug_init_vars, **opts) aug_init_vars, **opts)
final_state = final_aug_vars[state_slice] final_state = final_aug_vars[state_slice]

View File

@ -503,13 +503,16 @@ def _shape_invariant_to_type_spec(var, shape):
Returns: Returns:
A `TypeSpec` for `var`, consistent with the given shape. A `TypeSpec` for `var`, consistent with the given shape.
""" """
if isinstance(shape, type_spec.TypeSpec): if shape is None:
return type_spec.type_spec_from_value(var)
elif isinstance(shape, type_spec.TypeSpec):
if not shape.is_compatible_with(var): if not shape.is_compatible_with(var):
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
return shape return shape
elif not isinstance(shape, tensor_shape.TensorShape): elif not isinstance(shape, tensor_shape.TensorShape):
raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r" raise TypeError(
% shape) "Expected shape to be a TypeSpec, TensorShape or None, got %r for"
" value %r" % (shape, var))
if isinstance(var, ops.Tensor): if isinstance(var, ops.Tensor):
return tensor_spec.TensorSpec(shape, var.dtype) return tensor_spec.TensorSpec(shape, var.dtype)

View File

@ -10,6 +10,6 @@ tf_module {
} }
member_method { member_method {
name: "set_loop_options" name: "set_loop_options"
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
} }
} }

View File

@ -10,6 +10,6 @@ tf_module {
} }
member_method { member_method {
name: "set_loop_options" name: "set_loop_options"
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
} }
} }