Enable support for loop directives, include shape_invariants.
PiperOrigin-RevId: 285966483 Change-Id: I3eae0b134cd2e954bfa0ac31e6a7411b3a5bb7df
This commit is contained in:
parent
1768c8f2fa
commit
e42b937e1a
tensorflow
python
tools/api/golden
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import gast
|
||||
|
||||
from tensorflow.python.autograph.core import converter
|
||||
from tensorflow.python.autograph.lang import directives
|
||||
from tensorflow.python.autograph.pyct import anno
|
||||
from tensorflow.python.autograph.pyct import ast_util
|
||||
from tensorflow.python.autograph.pyct import parser
|
||||
@ -151,6 +152,20 @@ class ControlFlowTransformer(converter.Base):
|
||||
|
||||
return node
|
||||
|
||||
def _create_loop_options(self, node):
|
||||
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
|
||||
return gast.Dict([], [])
|
||||
|
||||
loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
|
||||
if directives.set_loop_options not in loop_directives:
|
||||
return gast.Dict([], [])
|
||||
|
||||
opts_dict = loop_directives[directives.set_loop_options]
|
||||
str_keys, values = zip(*opts_dict.items())
|
||||
keys = [gast.Str(s) for s in str_keys]
|
||||
values = list(values) # ast and gast don't play well with tuples.
|
||||
return gast.Dict(keys, values)
|
||||
|
||||
def _create_undefined_assigns(self, undefined_symbols):
|
||||
assignments = []
|
||||
for s in undefined_symbols:
|
||||
@ -383,8 +398,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
composite_symbol_names = tuple(
|
||||
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
||||
|
||||
# TODO(b/140125096): Populate.
|
||||
opts = gast.Dict([], [])
|
||||
opts = self._create_loop_options(node)
|
||||
|
||||
# TODO(mdan): Use a single template.
|
||||
# If the body and test functions took a single tuple for loop_vars, instead
|
||||
@ -507,8 +521,7 @@ class ControlFlowTransformer(converter.Base):
|
||||
composite_symbol_names = tuple(
|
||||
gast.Str(str(symbol)) for symbol in composite_loop_vars)
|
||||
|
||||
# TODO(b/140125096): Populate.
|
||||
opts = gast.Dict([], [])
|
||||
opts = self._create_loop_options(node)
|
||||
|
||||
# TODO(mdan): Use a single template.
|
||||
# If the body and test functions took a single tuple for loop_vars, instead
|
||||
|
@ -26,7 +26,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs.doc_controls import do_not_generate_docs
|
||||
|
||||
UNSPECIFIED = object()
|
||||
|
||||
@ -47,37 +46,53 @@ def set_element_type(entity, dtype, shape=UNSPECIFIED):
|
||||
del shape
|
||||
|
||||
|
||||
# TODO(b/140125096): Implement.
|
||||
@do_not_generate_docs
|
||||
@tf_export('autograph.experimental.set_loop_options')
|
||||
def set_loop_options(
|
||||
parallel_iterations=UNSPECIFIED,
|
||||
back_prop=UNSPECIFIED,
|
||||
swap_memory=UNSPECIFIED,
|
||||
maximum_iterations=UNSPECIFIED):
|
||||
maximum_iterations=UNSPECIFIED,
|
||||
shape_invariants=UNSPECIFIED):
|
||||
"""Specifies additional arguments to be passed to the enclosing while_loop.
|
||||
|
||||
The parameters apply to and only to the immediately enclosing loop. It only
|
||||
has effect if the loop is staged as a TF while_loop; otherwise the parameters
|
||||
have no effect.
|
||||
|
||||
Usage example:
|
||||
Usage:
|
||||
|
||||
@tf.function(autograph=True)
|
||||
def dynamic_rnn(..., parallel_iterations=32):
|
||||
num_steps = ...
|
||||
for t in tf.range(num_steps):
|
||||
tf.autograph.experimental.set_loop_options(
|
||||
parallel_iterations=parallel_iterations)
|
||||
...
|
||||
>>> @tf.function(autograph=True)
|
||||
... def f():
|
||||
... n = 0
|
||||
... for i in tf.range(10):
|
||||
... tf.autograph.experimental.set_loop_options(maximum_iterations=3)
|
||||
... n += 1
|
||||
... return n
|
||||
|
||||
>>> @tf.function(autograph=True)
|
||||
... def f():
|
||||
... v = tf.constant((0,))
|
||||
... for i in tf.range(3):
|
||||
... tf.autograph.experimental.set_loop_options(
|
||||
... shape_invariants=[(v, tf.TensorShape([None]))]
|
||||
... )
|
||||
... v = tf.concat((v, [i]), 0)
|
||||
... return v
|
||||
|
||||
Also see tf.while_loop.
|
||||
|
||||
Args:
|
||||
parallel_iterations: See tf.while_loop.
|
||||
back_prop: See tf.while_loop.
|
||||
swap_memory: See tf.while_loop.
|
||||
maximum_iterations: See tf.while_loop.
|
||||
parallel_iterations: The maximum number of iterations allowed to run in
|
||||
parallel at any given time. Note that this does not guarantee parallel
|
||||
execution.
|
||||
swap_memory: Whether to store intermediate values needed for
|
||||
gradients on the CPU instead of GPU.
|
||||
maximum_iterations: Allows limiting the total number of iterations executed
|
||||
by the loop.
|
||||
shape_invariants: Allows controlling the argument with the same name passed
|
||||
to tf.while_loop. Unlike tf.while_loop, this is a list of
|
||||
`(tensor, shape)` pairs.
|
||||
"""
|
||||
del parallel_iterations
|
||||
del back_prop
|
||||
del swap_memory
|
||||
del maximum_iterations
|
||||
del shape_invariants
|
||||
|
@ -125,68 +125,91 @@ def _is_subshape(left, right):
|
||||
return True
|
||||
|
||||
|
||||
def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var):
|
||||
"""Verifies whether init_loop_var and first_iter_var are consistent."""
|
||||
if isinstance(init_loop_var, (bool, int, float, str)):
|
||||
init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
|
||||
# TODO(mdan): Remove these verifications once TF ops can properly report names.
|
||||
def _verify_single_loop_var(
|
||||
name, check_shape, init, entry, exit_, shape_invariant):
|
||||
"""Verifies whether the initial, entry and exit values are consistent."""
|
||||
if isinstance(init, (bool, int, float, str, np.ndarray)):
|
||||
init = ops.convert_to_tensor_v2(init)
|
||||
if isinstance(entry, (bool, int, float, str, np.ndarray)):
|
||||
entry = ops.convert_to_tensor_v2(entry)
|
||||
if isinstance(exit_, (bool, int, float, str)):
|
||||
exit_ = ops.convert_to_tensor_v2(exit_)
|
||||
|
||||
if isinstance(first_iter_var, (bool, int, float, str)):
|
||||
first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
|
||||
|
||||
if (not tensor_util.is_tensor(init_loop_var) or
|
||||
not tensor_util.is_tensor(first_iter_var)):
|
||||
if (not tensor_util.is_tensor(entry) or
|
||||
not tensor_util.is_tensor(exit_)):
|
||||
return
|
||||
|
||||
# TODO(mdan): Properly account for CompositeTensors.
|
||||
if (not hasattr(init_loop_var, 'dtype') or
|
||||
not hasattr(first_iter_var, 'dtype')):
|
||||
if (not hasattr(entry, 'dtype') or
|
||||
not hasattr(exit_, 'dtype')):
|
||||
return
|
||||
if (not hasattr(init_loop_var, 'shape') or
|
||||
not hasattr(first_iter_var, 'shape')):
|
||||
if (not hasattr(entry, 'shape') or
|
||||
not hasattr(exit_, 'shape')):
|
||||
return
|
||||
|
||||
if init_loop_var.dtype != first_iter_var.dtype:
|
||||
if entry.dtype != exit_.dtype:
|
||||
raise TypeError(
|
||||
'"{}" has dtype {} before the loop, but dtype {} after one'
|
||||
' iteration. TensorFlow control flow requires it stays the'
|
||||
' same.'.format(
|
||||
name,
|
||||
init_loop_var.dtype.name,
|
||||
first_iter_var.dtype.name,
|
||||
entry.dtype.name,
|
||||
exit_.dtype.name,
|
||||
))
|
||||
|
||||
if check_shape:
|
||||
init_shape = init_loop_var.shape
|
||||
first_iter_shape = first_iter_var.shape
|
||||
# TODO(b/135183013): Update needed once we support shape_invariants.
|
||||
if not _is_subshape(first_iter_shape, init_shape):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} before the loop, but shape {} after one'
|
||||
' iteration. TensorFlow control flow requires it stays the'
|
||||
' same or be more specific.'.format(name, init_shape,
|
||||
first_iter_shape))
|
||||
exit_shape = exit_.shape
|
||||
if shape_invariant is None:
|
||||
entry_shape = entry.shape
|
||||
if not _is_subshape(exit_shape, entry_shape):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} before the loop, but shape {} after one'
|
||||
' iteration. Use tf.autograph.experimental.set_loop_options to set'
|
||||
' shape invariants.'.format(name, entry_shape, exit_shape))
|
||||
else:
|
||||
init_shape = init.shape
|
||||
if not _is_subshape(init_shape, shape_invariant):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} before the loop, which does not conform with'
|
||||
' the shape invariant {}.'.format(name, init_shape,
|
||||
shape_invariant))
|
||||
if not _is_subshape(exit_shape, shape_invariant):
|
||||
raise ValueError(
|
||||
'"{}" has shape {} after the loop, which does not conform with'
|
||||
' the shape invariant {}.'.format(
|
||||
name, exit_shape, shape_invariant))
|
||||
|
||||
|
||||
def _verify_tf_loop_vars(init_loop_vars,
|
||||
first_iter_vars,
|
||||
def _verify_tf_loop_vars(init_vars,
|
||||
iter_entry_vars,
|
||||
iter_exit_vars,
|
||||
symbol_names,
|
||||
opts,
|
||||
check_shapes=True):
|
||||
"""Verifies loop variables for consistency."""
|
||||
# TODO(b/140125096): Use this.
|
||||
del opts
|
||||
if check_shapes and 'shape_invariants' in opts:
|
||||
shape_invariants = opts['shape_invariants']
|
||||
else:
|
||||
shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
|
||||
|
||||
named_vars = zip(symbol_names, init_loop_vars, first_iter_vars)
|
||||
for name, init_loop_var, first_iter_var in named_vars:
|
||||
named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
|
||||
shape_invariants)
|
||||
for name, init, entry, exit_, invariant in named_vars:
|
||||
try:
|
||||
nest.assert_same_structure(
|
||||
init_loop_var, first_iter_var, expand_composites=True)
|
||||
nest.assert_same_structure(entry, exit_, expand_composites=True)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError('"{}" does not have the same nested structure after one'
|
||||
' iteration.\n\n{}'.format(name, e))
|
||||
if invariant is not None:
|
||||
try:
|
||||
nest.assert_same_structure(init, invariant, expand_composites=False)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError('"{}" does not have the same nested structure as its'
|
||||
' corresponding shape invariant.\n\n{}'.format(name, e))
|
||||
|
||||
nest.map_structure(
|
||||
functools.partial(_verify_single_loop_var, name, check_shapes),
|
||||
init_loop_var, first_iter_var)
|
||||
functools.partial(_verify_single_loop_var, name, check_shapes), init,
|
||||
entry, exit_, invariant)
|
||||
|
||||
|
||||
def _verify_single_cond_var(name, body_var, orelse_var):
|
||||
@ -425,6 +448,8 @@ def _tf_ragged_for_stmt(iter_,
|
||||
else:
|
||||
n = iter_.row_lengths()[0]
|
||||
|
||||
opts['maximum_iterations'] = n
|
||||
|
||||
def while_body(iterate_index, *loop_vars):
|
||||
"""Main loop body."""
|
||||
iterate = iter_[iterate_index]
|
||||
@ -566,7 +591,7 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
|
||||
# Note: this verification duplicates that perfrmed in tf_while_stmt,
|
||||
# but needs to be done earlier to prevent the tf.cond inside while_body
|
||||
# from blowing up first.
|
||||
_verify_tf_loop_vars(loop_vars, new_vars,
|
||||
_verify_tf_loop_vars(init_vars, loop_vars, new_vars,
|
||||
basic_symbol_names + composite_symbol_names, opts)
|
||||
return new_vars
|
||||
|
||||
@ -653,20 +678,26 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
|
||||
|
||||
# TODO(mdan): Simplify this - following it is extremely difficult.
|
||||
|
||||
init_state = get_state()
|
||||
aug_init_vars = init_vars, init_state
|
||||
|
||||
def scan_body(aug_vars, iterate):
|
||||
"""The main loop body wrapper. Only calculates the stop condition."""
|
||||
loop_vars, state = aug_vars
|
||||
|
||||
def true_fn():
|
||||
"""Main path - stop condition is not set."""
|
||||
set_state(state)
|
||||
outputs = body(iterate, *loop_vars)
|
||||
new_vars = body(iterate, *loop_vars)
|
||||
new_state = get_state()
|
||||
_verify_tf_loop_vars(
|
||||
init_vars + init_state,
|
||||
loop_vars + state,
|
||||
outputs + state,
|
||||
new_vars + new_state,
|
||||
basic_symbol_names + composite_symbol_names,
|
||||
opts,
|
||||
check_shapes=False)
|
||||
return outputs, get_state()
|
||||
return new_vars, new_state
|
||||
|
||||
extra_cond = extra_test(*loop_vars)
|
||||
new_vars, new_state = control_flow_ops.cond(
|
||||
@ -690,11 +721,9 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
|
||||
del extra_cond
|
||||
return output_aug_vars, output_state
|
||||
|
||||
init_state = get_state()
|
||||
aug_vars = init_vars, init_state
|
||||
ds = _general_purpose_scan(ds, aug_vars, scan_body)
|
||||
ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
|
||||
ds = ds.apply(take_while_ops.take_while(take_while_predicate))
|
||||
final_aug_vars = ds.reduce(aug_vars, reduce_body)
|
||||
final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
|
||||
final_vars, final_state = final_aug_vars
|
||||
set_state(final_state)
|
||||
return final_vars
|
||||
@ -741,6 +770,7 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
|
||||
new_state = get_state()
|
||||
|
||||
_verify_tf_loop_vars(
|
||||
init_vars + init_state,
|
||||
loop_vars + state,
|
||||
new_vars + new_state,
|
||||
symbol_names,
|
||||
@ -824,11 +854,23 @@ def while_stmt(test,
|
||||
return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
|
||||
|
||||
|
||||
def _shape_invariants_mapping_to_positional_list(mapping, keys):
|
||||
# The keys are not expected to be hashable.
|
||||
mapping = {id(k): (k, v) for k, v in mapping}
|
||||
result = []
|
||||
for k in keys:
|
||||
map_key, map_val = mapping.get(id(k), (None, None))
|
||||
result.append(map_val if map_key is k else None)
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
||||
basic_symbol_names, composite_symbol_names, opts):
|
||||
"""Overload of while_stmt that stages a TF while_stmt."""
|
||||
_disallow_undefs_into_loop(*init_vars)
|
||||
|
||||
aug_init_vars = init_vars + get_state()
|
||||
|
||||
# TODO(mdan): Simplify this.
|
||||
loop_vars_slice = slice(len(init_vars))
|
||||
state_slice = slice(len(init_vars), None)
|
||||
@ -844,7 +886,7 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
||||
set_state(state)
|
||||
loop_vars = body(*aug_loop_vars[loop_vars_slice])
|
||||
new_state = loop_vars + get_state()
|
||||
_verify_tf_loop_vars(aug_loop_vars, new_state,
|
||||
_verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
|
||||
basic_symbol_names + composite_symbol_names, opts)
|
||||
|
||||
return new_state
|
||||
@ -853,7 +895,10 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
|
||||
# This enforces consistency across versions.
|
||||
opts['return_same_structure'] = True
|
||||
|
||||
aug_init_vars = init_vars + get_state()
|
||||
if 'shape_invariants' in opts:
|
||||
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
|
||||
opts['shape_invariants'], aug_init_vars)
|
||||
|
||||
final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
|
||||
aug_init_vars, **opts)
|
||||
final_state = final_aug_vars[state_slice]
|
||||
|
@ -503,13 +503,16 @@ def _shape_invariant_to_type_spec(var, shape):
|
||||
Returns:
|
||||
A `TypeSpec` for `var`, consistent with the given shape.
|
||||
"""
|
||||
if isinstance(shape, type_spec.TypeSpec):
|
||||
if shape is None:
|
||||
return type_spec.type_spec_from_value(var)
|
||||
elif isinstance(shape, type_spec.TypeSpec):
|
||||
if not shape.is_compatible_with(var):
|
||||
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
|
||||
return shape
|
||||
elif not isinstance(shape, tensor_shape.TensorShape):
|
||||
raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r"
|
||||
% shape)
|
||||
raise TypeError(
|
||||
"Expected shape to be a TypeSpec, TensorShape or None, got %r for"
|
||||
" value %r" % (shape, var))
|
||||
|
||||
if isinstance(var, ops.Tensor):
|
||||
return tensor_spec.TensorSpec(shape, var.dtype)
|
||||
|
@ -10,6 +10,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "set_loop_options"
|
||||
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "set_loop_options"
|
||||
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user