diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py index 5bf488cd209..f2812f3ff1c 100644 --- a/tensorflow/python/autograph/converters/control_flow.py +++ b/tensorflow/python/autograph/converters/control_flow.py @@ -21,6 +21,7 @@ from __future__ import print_function import gast from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.lang import directives from tensorflow.python.autograph.pyct import anno from tensorflow.python.autograph.pyct import ast_util from tensorflow.python.autograph.pyct import parser @@ -151,6 +152,20 @@ class ControlFlowTransformer(converter.Base): return node + def _create_loop_options(self, node): + if not anno.hasanno(node, anno.Basic.DIRECTIVES): + return gast.Dict([], []) + + loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES) + if directives.set_loop_options not in loop_directives: + return gast.Dict([], []) + + opts_dict = loop_directives[directives.set_loop_options] + str_keys, values = zip(*opts_dict.items()) + keys = [gast.Str(s) for s in str_keys] + values = list(values) # ast and gast don't play well with tuples. + return gast.Dict(keys, values) + def _create_undefined_assigns(self, undefined_symbols): assignments = [] for s in undefined_symbols: @@ -383,8 +398,7 @@ class ControlFlowTransformer(converter.Base): composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) - # TODO(b/140125096): Populate. - opts = gast.Dict([], []) + opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead @@ -507,8 +521,7 @@ class ControlFlowTransformer(converter.Base): composite_symbol_names = tuple( gast.Str(str(symbol)) for symbol in composite_loop_vars) - # TODO(b/140125096): Populate. - opts = gast.Dict([], []) + opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead diff --git a/tensorflow/python/autograph/lang/directives.py b/tensorflow/python/autograph/lang/directives.py index 5373a7cd187..26b5ffa97ac 100644 --- a/tensorflow/python/autograph/lang/directives.py +++ b/tensorflow/python/autograph/lang/directives.py @@ -26,7 +26,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.util.tf_export import tf_export -from tensorflow.tools.docs.doc_controls import do_not_generate_docs UNSPECIFIED = object() @@ -47,37 +46,53 @@ def set_element_type(entity, dtype, shape=UNSPECIFIED): del shape -# TODO(b/140125096): Implement. -@do_not_generate_docs @tf_export('autograph.experimental.set_loop_options') def set_loop_options( parallel_iterations=UNSPECIFIED, - back_prop=UNSPECIFIED, swap_memory=UNSPECIFIED, - maximum_iterations=UNSPECIFIED): + maximum_iterations=UNSPECIFIED, + shape_invariants=UNSPECIFIED): """Specifies additional arguments to be passed to the enclosing while_loop. The parameters apply to and only to the immediately enclosing loop. It only has effect if the loop is staged as a TF while_loop; otherwise the parameters have no effect. - Usage example: + Usage: - @tf.function(autograph=True) - def dynamic_rnn(..., parallel_iterations=32): - num_steps = ... - for t in tf.range(num_steps): - tf.autograph.experimental.set_loop_options( - parallel_iterations=parallel_iterations) - ... + >>> @tf.function(autograph=True) + ... def f(): + ... n = 0 + ... for i in tf.range(10): + ... tf.autograph.experimental.set_loop_options(maximum_iterations=3) + ... n += 1 + ... return n + + >>> @tf.function(autograph=True) + ... def f(): + ... v = tf.constant((0,)) + ... for i in tf.range(3): + ... tf.autograph.experimental.set_loop_options( + ... shape_invariants=[(v, tf.TensorShape([None]))] + ... ) + ... v = tf.concat((v, [i]), 0) + ... return v + + Also see tf.while_loop. Args: - parallel_iterations: See tf.while_loop. - back_prop: See tf.while_loop. - swap_memory: See tf.while_loop. - maximum_iterations: See tf.while_loop. + parallel_iterations: The maximum number of iterations allowed to run in + parallel at any given time. Note that this does not guarantee parallel + execution. + swap_memory: Whether to store intermediate values needed for + gradients on the CPU instead of GPU. + maximum_iterations: Allows limiting the total number of iterations executed + by the loop. + shape_invariants: Allows controlling the argument with the same name passed + to tf.while_loop. Unlike tf.while_loop, this is a list of + `(tensor, shape)` pairs. """ del parallel_iterations - del back_prop del swap_memory del maximum_iterations + del shape_invariants diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py index c862379e1d0..f9b2ff9338e 100644 --- a/tensorflow/python/autograph/operators/control_flow.py +++ b/tensorflow/python/autograph/operators/control_flow.py @@ -125,68 +125,91 @@ def _is_subshape(left, right): return True -def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var): - """Verifies whether init_loop_var and first_iter_var are consistent.""" - if isinstance(init_loop_var, (bool, int, float, str)): - init_loop_var = ops.convert_to_tensor_v2(init_loop_var) +# TODO(mdan): Remove these verifications once TF ops can properly report names. +def _verify_single_loop_var( + name, check_shape, init, entry, exit_, shape_invariant): + """Verifies whether the initial, entry and exit values are consistent.""" + if isinstance(init, (bool, int, float, str, np.ndarray)): + init = ops.convert_to_tensor_v2(init) + if isinstance(entry, (bool, int, float, str, np.ndarray)): + entry = ops.convert_to_tensor_v2(entry) + if isinstance(exit_, (bool, int, float, str)): + exit_ = ops.convert_to_tensor_v2(exit_) - if isinstance(first_iter_var, (bool, int, float, str)): - first_iter_var = ops.convert_to_tensor_v2(first_iter_var) - - if (not tensor_util.is_tensor(init_loop_var) or - not tensor_util.is_tensor(first_iter_var)): + if (not tensor_util.is_tensor(entry) or + not tensor_util.is_tensor(exit_)): return # TODO(mdan): Properly account for CompositeTensors. - if (not hasattr(init_loop_var, 'dtype') or - not hasattr(first_iter_var, 'dtype')): + if (not hasattr(entry, 'dtype') or + not hasattr(exit_, 'dtype')): return - if (not hasattr(init_loop_var, 'shape') or - not hasattr(first_iter_var, 'shape')): + if (not hasattr(entry, 'shape') or + not hasattr(exit_, 'shape')): return - if init_loop_var.dtype != first_iter_var.dtype: + if entry.dtype != exit_.dtype: raise TypeError( '"{}" has dtype {} before the loop, but dtype {} after one' ' iteration. TensorFlow control flow requires it stays the' ' same.'.format( name, - init_loop_var.dtype.name, - first_iter_var.dtype.name, + entry.dtype.name, + exit_.dtype.name, )) - if check_shape: - init_shape = init_loop_var.shape - first_iter_shape = first_iter_var.shape - # TODO(b/135183013): Update needed once we support shape_invariants. - if not _is_subshape(first_iter_shape, init_shape): - raise ValueError( - '"{}" has shape {} before the loop, but shape {} after one' - ' iteration. TensorFlow control flow requires it stays the' - ' same or be more specific.'.format(name, init_shape, - first_iter_shape)) + exit_shape = exit_.shape + if shape_invariant is None: + entry_shape = entry.shape + if not _is_subshape(exit_shape, entry_shape): + raise ValueError( + '"{}" has shape {} before the loop, but shape {} after one' + ' iteration. Use tf.autograph.experimental.set_loop_options to set' + ' shape invariants.'.format(name, entry_shape, exit_shape)) + else: + init_shape = init.shape + if not _is_subshape(init_shape, shape_invariant): + raise ValueError( + '"{}" has shape {} before the loop, which does not conform with' + ' the shape invariant {}.'.format(name, init_shape, + shape_invariant)) + if not _is_subshape(exit_shape, shape_invariant): + raise ValueError( + '"{}" has shape {} after the loop, which does not conform with' + ' the shape invariant {}.'.format( + name, exit_shape, shape_invariant)) -def _verify_tf_loop_vars(init_loop_vars, - first_iter_vars, +def _verify_tf_loop_vars(init_vars, + iter_entry_vars, + iter_exit_vars, symbol_names, opts, check_shapes=True): """Verifies loop variables for consistency.""" - # TODO(b/140125096): Use this. - del opts + if check_shapes and 'shape_invariants' in opts: + shape_invariants = opts['shape_invariants'] + else: + shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) - named_vars = zip(symbol_names, init_loop_vars, first_iter_vars) - for name, init_loop_var, first_iter_var in named_vars: + named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars, + shape_invariants) + for name, init, entry, exit_, invariant in named_vars: try: - nest.assert_same_structure( - init_loop_var, first_iter_var, expand_composites=True) + nest.assert_same_structure(entry, exit_, expand_composites=True) except (ValueError, TypeError) as e: raise TypeError('"{}" does not have the same nested structure after one' ' iteration.\n\n{}'.format(name, e)) + if invariant is not None: + try: + nest.assert_same_structure(init, invariant, expand_composites=False) + except (ValueError, TypeError) as e: + raise TypeError('"{}" does not have the same nested structure as its' + ' corresponding shape invariant.\n\n{}'.format(name, e)) + nest.map_structure( - functools.partial(_verify_single_loop_var, name, check_shapes), - init_loop_var, first_iter_var) + functools.partial(_verify_single_loop_var, name, check_shapes), init, + entry, exit_, invariant) def _verify_single_cond_var(name, body_var, orelse_var): @@ -425,6 +448,8 @@ def _tf_ragged_for_stmt(iter_, else: n = iter_.row_lengths()[0] + opts['maximum_iterations'] = n + def while_body(iterate_index, *loop_vars): """Main loop body.""" iterate = iter_[iterate_index] @@ -566,7 +591,7 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state, # Note: this verification duplicates that perfrmed in tf_while_stmt, # but needs to be done earlier to prevent the tf.cond inside while_body # from blowing up first. - _verify_tf_loop_vars(loop_vars, new_vars, + _verify_tf_loop_vars(init_vars, loop_vars, new_vars, basic_symbol_names + composite_symbol_names, opts) return new_vars @@ -653,20 +678,26 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, # TODO(mdan): Simplify this - following it is extremely difficult. + init_state = get_state() + aug_init_vars = init_vars, init_state + def scan_body(aug_vars, iterate): """The main loop body wrapper. Only calculates the stop condition.""" loop_vars, state = aug_vars def true_fn(): + """Main path - stop condition is not set.""" set_state(state) - outputs = body(iterate, *loop_vars) + new_vars = body(iterate, *loop_vars) + new_state = get_state() _verify_tf_loop_vars( + init_vars + init_state, loop_vars + state, - outputs + state, + new_vars + new_state, basic_symbol_names + composite_symbol_names, opts, check_shapes=False) - return outputs, get_state() + return new_vars, new_state extra_cond = extra_test(*loop_vars) new_vars, new_state = control_flow_ops.cond( @@ -690,11 +721,9 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, del extra_cond return output_aug_vars, output_state - init_state = get_state() - aug_vars = init_vars, init_state - ds = _general_purpose_scan(ds, aug_vars, scan_body) + ds = _general_purpose_scan(ds, aug_init_vars, scan_body) ds = ds.apply(take_while_ops.take_while(take_while_predicate)) - final_aug_vars = ds.reduce(aug_vars, reduce_body) + final_aug_vars = ds.reduce(aug_init_vars, reduce_body) final_vars, final_state = final_aug_vars set_state(final_state) return final_vars @@ -741,6 +770,7 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, new_state = get_state() _verify_tf_loop_vars( + init_vars + init_state, loop_vars + state, new_vars + new_state, symbol_names, @@ -824,11 +854,23 @@ def while_stmt(test, return _py_while_stmt(test, body, get_state, set_state, init_vars, opts) +def _shape_invariants_mapping_to_positional_list(mapping, keys): + # The keys are not expected to be hashable. + mapping = {id(k): (k, v) for k, v in mapping} + result = [] + for k in keys: + map_key, map_val = mapping.get(id(k), (None, None)) + result.append(map_val if map_key is k else None) + return tuple(result) + + def _tf_while_stmt(test, body, get_state, set_state, init_vars, basic_symbol_names, composite_symbol_names, opts): """Overload of while_stmt that stages a TF while_stmt.""" _disallow_undefs_into_loop(*init_vars) + aug_init_vars = init_vars + get_state() + # TODO(mdan): Simplify this. loop_vars_slice = slice(len(init_vars)) state_slice = slice(len(init_vars), None) @@ -844,7 +886,7 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars, set_state(state) loop_vars = body(*aug_loop_vars[loop_vars_slice]) new_state = loop_vars + get_state() - _verify_tf_loop_vars(aug_loop_vars, new_state, + _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state, basic_symbol_names + composite_symbol_names, opts) return new_state @@ -853,7 +895,10 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars, # This enforces consistency across versions. opts['return_same_structure'] = True - aug_init_vars = init_vars + get_state() + if 'shape_invariants' in opts: + opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( + opts['shape_invariants'], aug_init_vars) + final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body, aug_init_vars, **opts) final_state = final_aug_vars[state_slice] diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index c04c55457b1..d33a9ad5b20 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -503,13 +503,16 @@ def _shape_invariant_to_type_spec(var, shape): Returns: A `TypeSpec` for `var`, consistent with the given shape. """ - if isinstance(shape, type_spec.TypeSpec): + if shape is None: + return type_spec.type_spec_from_value(var) + elif isinstance(shape, type_spec.TypeSpec): if not shape.is_compatible_with(var): raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) return shape elif not isinstance(shape, tensor_shape.TensorShape): - raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r" - % shape) + raise TypeError( + "Expected shape to be a TypeSpec, TensorShape or None, got %r for" + " value %r" % (shape, var)) if isinstance(var, ops.Tensor): return tensor_spec.TensorSpec(shape, var.dtype) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt index 1454a2d9567..5450d6448c8 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.autograph.experimental.pbtxt @@ -10,6 +10,6 @@ tf_module { } member_method { name: "set_loop_options" - argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " + argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt index 1454a2d9567..5450d6448c8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.autograph.experimental.pbtxt @@ -10,6 +10,6 @@ tf_module { } member_method { name: "set_loop_options" - argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " + argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], " } }