pfor: Have pfor fallback to while_loop by default in cases where a conversion

was not supported.

PiperOrigin-RevId: 306319253
Change-Id: I9bc1581c4b854f41df6270ea538ae6d88a704d45
This commit is contained in:
A. Unique TensorFlower 2020-04-13 15:17:16 -07:00 committed by TensorFlower Gardener
parent 8eaea75c6d
commit 867673d63b
8 changed files with 102 additions and 73 deletions

View File

@ -1593,27 +1593,6 @@ class JacobianTest(test.TestCase):
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
g.jacobian(y, x, experimental_use_pfor=False)
@test_util.run_v1_only('b/120545219')
def testPforException(self):
var = variables.Variable([1.])
@custom_gradient.custom_gradient
def op(x):
def grad(_):
# Note that we perform a stateful operation here that will not be
# compatible with parallel for construct.
with ops.control_dependencies(
[var.assign(random_ops.random_uniform([1]))]):
return constant_op.constant(1.)
return x, grad
with backprop.GradientTape() as g:
x = constant_op.constant([1., 2.])
g.watch(x)
y = op(x)
with self.assertRaisesRegexp(ValueError, 'No converter'):
g.jacobian(y, x, experimental_use_pfor=True)
@test_util.run_v1_only('b/120545219')
def test_parallel_iterations(self):
with backprop.GradientTape(persistent=True) as g:
@ -1723,26 +1702,6 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
g.batch_jacobian(y, x)
def testPforException(self):
var = variables.Variable([1.])
@custom_gradient.custom_gradient
def op(x):
def grad(_):
# Note that we perform a stateful operation here that will not be
# compatible with parallel for construct.
with ops.control_dependencies(
[var.assign(random_ops.random_uniform([1]))]):
return constant_op.constant(1.)
return x, grad
with backprop.GradientTape() as g:
x = constant_op.constant([[1.], [2.]])
g.watch(x)
y = op(x)
with self.assertRaisesRegexp(ValueError, 'No converter'):
g.batch_jacobian(y, x, experimental_use_pfor=True)
def test_parallel_iterations(self):
with backprop.GradientTape(persistent=True) as g:
x = constant_op.constant([[1., 2], [3, 4]])

View File

@ -210,7 +210,7 @@ class ArrayTest(PForTestCase):
return array_ops.tile(x1, [i, 1])
with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"):
pfor_control_flow_ops.pfor(loop_fn, 2)
pfor_control_flow_ops.pfor(loop_fn, 2, fallback_to_while_loop=False)
def test_pack(self):
x = random_ops.random_uniform([3, 2, 3])
@ -447,6 +447,20 @@ class ArrayTest(PForTestCase):
self._test_loop_fn(loop_fn, 3)
def test_strided_slice_loop_variant(self):
x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
def loop_fn(i):
x_i = array_ops.gather(x, i)
return x_i[i:i+1, ...]
# Test the fallback to while loop for a ConversionNotImplementedError is
# handled.
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
# Without fallback, ValueError is thrown.
with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"):
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
def test_depth_to_space(self):
x = random_ops.random_uniform([2, 3, 2, 2, 12])

View File

@ -132,7 +132,7 @@ def _is_under_xla_context():
return False
def pfor(loop_fn, iters, parallel_iterations=None):
def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
`pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
@ -167,6 +167,8 @@ def pfor(loop_fn, iters, parallel_iterations=None):
to something other than None, `loop_fn` may be called more than once
during graph construction. So it may need to avoid mutating global state.
iters: Number of iterations for which to run loop_fn.
fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
fallbacks to using a tf.while_loop to dispatch the iterations.
parallel_iterations: A knob to control how many iterations are vectorized
and dispatched in parallel. The default value of None corresponds to
vectorizing all the iterations. If `parallel_iterations` is smaller than
@ -180,7 +182,10 @@ def pfor(loop_fn, iters, parallel_iterations=None):
ValueError: If parallel_iterations is not None and not an integer > 1.
"""
def f():
return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
return _pfor_impl(loop_fn,
iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations)
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
@ -219,7 +224,11 @@ def _loop_fn_has_config(loop_fn):
return PFOR_CONFIG_ARG in argspec.args
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
def _pfor_impl(loop_fn,
iters,
fallback_to_while_loop,
parallel_iterations=None,
pfor_config=None):
"""Implementation of pfor."""
assert not context.executing_eagerly()
loop_fn_has_config = _loop_fn_has_config(loop_fn)
@ -263,7 +272,9 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
parallel_iterations = None
if parallel_iterations is None:
with ops.name_scope("pfor"):
converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config)
converter = PFor(loop_var, iters, new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs):
outputs.append(converter.convert(loop_fn_output))
@ -278,6 +289,7 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
# a tf.function and extract the graph from there to vectorize it.
with ops.name_scope("pfor_untiled"):
converter = PFor(loop_var, num_remaining_iterations, new_ops,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
remaining_outputs = []
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
@ -298,7 +310,10 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
return nest.flatten(loop_fn(i + offset))
return _pfor_impl(
tiled_loop_fn, parallel_iterations, pfor_config=pfor_config)
tiled_loop_fn,
parallel_iterations,
fallback_to_while_loop=fallback_to_while_loop,
pfor_config=pfor_config)
tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
num_tiled_iterations, parallel_iterations=1)
@ -318,7 +333,7 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
@tf_export("vectorized_map")
def vectorized_map(fn, elems):
def vectorized_map(fn, elems, fallback_to_while_loop=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
@ -340,8 +355,7 @@ def vectorized_map(fn, elems):
- Stateful kernels may mostly not be supported since these often imply a
data dependency. We do support a limited set of such stateful kernels
though (like RandomFoo, Variable operations like reads, etc).
- `fn` has limited support for control flow operations. `tf.cond` in
particular is not supported.
- `fn` has limited support for control flow operations.
- `fn` should return nested structure of Tensors or Operations. However
if an Operation is returned, it should have zero outputs.
- The shape and dtype of any intermediate or output tensors in the
@ -389,11 +403,21 @@ def vectorized_map(fn, elems):
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be mapped over by `fn`.
fallback_to_while_loop: If true, on failing to vectorize an operation,
the unsupported op is wrapped in a tf.while_loop to execute the map
iterations. Note that this fallback only happens for unsupported ops and
other parts of `fn` are still vectorized. If false, on encountering an
unsupported op, a ValueError is thrown. Note that the fallbacks can result
in slowdowns since vectorization often yields speedup of one to two orders
of magnitude.
Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
results of applying fn to tensors unpacked from elems along the first
dimension, from first to last.
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: array_ops.gather(x, i), elems)
@ -404,4 +428,5 @@ def vectorized_map(fn, elems):
batch_size = first_elem.shape.as_list()[0]
if batch_size is None:
batch_size = array_ops.shape(first_elem)[0]
return pfor(loop_fn, batch_size)
return pfor(loop_fn, batch_size,
fallback_to_while_loop=fallback_to_while_loop)

View File

@ -22,7 +22,6 @@ from __future__ import print_function
import functools
import time
from absl import flags
from absl.testing import parameterized
import numpy as np
@ -82,10 +81,8 @@ class PForTest(PForTestCase):
return nn.top_k(x_i)
with self.assertRaisesRegexp(ValueError, "No converter defined"):
self._test_loop_fn(loop_fn, 3)
flags.FLAGS.op_conversion_fallback_to_while_loop = True
self._test_loop_fn(loop_fn, 3)
flags.FLAGS.op_conversion_fallback_to_while_loop = False
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
def test_parallel_iterations(self):
for parallel_iterations in [2, 3, 8, 10]:

View File

@ -65,10 +65,11 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import object_identity
# TODO(agarwal): remove flag.
flags.DEFINE_bool(
"op_conversion_fallback_to_while_loop", False,
"If true, falls back to using a while loop for ops for "
"which a converter is not defined.")
"op_conversion_fallback_to_while_loop", True,
"DEPRECATED: Flag is ignored.")
def _stack(t, length):
@ -116,14 +117,17 @@ def _is_stateful_pfor_op(op):
class WhileOp(object):
"""Object for storing state for converting the outputs of a while_loop."""
def __init__(self, exit_node, pfor_ops, pfor_config):
def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
"""Initializer.
Args:
exit_node: A tensor output from the while_loop.
pfor_ops: list of ops inside the current pfor loop.
fallback_to_while_loop: If True, fallback to while loop when conversion of
an op is not supported
pfor_config: PForConfig object used while constructing loop body.
"""
self._fallback_to_while_loop = fallback_to_while_loop
self._pfor_config = pfor_config
self._pfor_ops = set(pfor_ops)
self._pfor_op_ids = set(x._id for x in pfor_ops)
@ -306,6 +310,7 @@ class WhileOp(object):
pfor_ops=self._pfor_ops,
all_indices=indices,
all_indices_partitioned=cond_stacked,
fallback_to_while_loop=self._fallback_to_while_loop,
pfor_config=self._pfor_config)
# Map all inputs of Enter nodes in self._direct_enters to their converted
# values.
@ -680,6 +685,10 @@ class WhileOp(object):
return outputs
class ConversionNotImplementedError(Exception):
pass
class _PforInput(object):
"""Input object passed to registered pfor converters."""
@ -763,7 +772,7 @@ class _PforInput(object):
input_name = "at index %d" % index
else:
input_name = "\"%s\"" % op_def.input_arg[index].name
raise ValueError(
raise ConversionNotImplementedError(
"Input %s of op \"%s\" expected to be not loop invariant" %
(input_name, op_type))
return t
@ -777,8 +786,9 @@ class _PforInput(object):
input_name = "at index %d" % index
else:
input_name = "\"%s\"" % op_def.input_arg[index].name
raise ValueError("Input %s of op \"%s\" expected to be loop invariant" %
(input_name, op_type))
raise ConversionNotImplementedError(
"Input %s of op \"%s\" expected to be loop invariant" %
(input_name, op_type))
return t
@property
@ -971,6 +981,8 @@ def _fallback_converter(pfor_input):
attrs=pfor_input.op.node_def.attr).outputs
outputs = []
# TODO(agarwal): Add tf.debugging asserts to check that the shapes across
# the different iterations are the same.
for out, ta in zip(op_outputs, ta_list):
assert isinstance(out, ops.Tensor)
outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
@ -1144,6 +1156,7 @@ class PFor(object):
loop_var,
loop_len,
pfor_ops,
fallback_to_while_loop,
all_indices=None,
all_indices_partitioned=False,
pfor_config=None):
@ -1155,6 +1168,8 @@ class PFor(object):
loop_len: A scalar or scalar Tensor representing the number of iterations
the loop is run for.
pfor_ops: List of all ops inside the loop body.
fallback_to_while_loop: If True, on failure to vectorize an op, a while
loop is used to sequentially execute that op.
all_indices: If not None, an int32 vector with size `loop_len`
representing the iteration ids that are still active. These values
should be unique and sorted. However they may not be contiguous. This is
@ -1182,6 +1197,7 @@ class PFor(object):
self._conversion_map[loop_var] = wrap(self.all_indices, True)
self._pfor_ops = set(pfor_ops)
self._pfor_op_ids = set(x._id for x in pfor_ops)
self._fallback_to_while_loop = fallback_to_while_loop
self._pfor_config = pfor_config
def op_is_inside_loop(self, op):
@ -1350,7 +1366,9 @@ class PFor(object):
is_while_loop = y_op.type == "Exit"
if is_while_loop:
while_op = WhileOp(
y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
y, pfor_ops=self._pfor_ops,
fallback_to_while_loop=self.fallback_to_while_loop,
pfor_config=self._pfor_config)
is_inside_loop = while_op.is_inside_loop
# If all nodes in the while_loop graph were created inside the pfor, we
# treat the whole loop subgraph as a single op (y_op) and try to convert
@ -1455,20 +1473,26 @@ class PFor(object):
else:
converter = _pfor_converter_registry.get(y_op.type, None)
if converter is None:
if flags.FLAGS.op_conversion_fallback_to_while_loop:
if self._fallback_to_while_loop:
converter = _fallback_converter
else:
raise ValueError("No converter defined for %s\n%s\ninputs: %s. "
"\nEither add a converter or set "
"--op_conversion_fallback_to_while_loop=True, "
"which may run slower" %
"\nEither add a converter or "
"enable fallback_to_while_loop "
"option to pfor, which may run slower" %
(y_op.type, y_op, converted_inputs))
# TODO(rachelim): Handle the case where some inputs are sparsely
# stacked. We should only call the converter if it supports handling
# those inputs.
pfor_inputs = _PforInput(self, y_op, converted_inputs)
try:
new_outputs = converter(pfor_inputs)
try:
new_outputs = converter(pfor_inputs)
except ConversionNotImplementedError as e:
if self._fallback_to_while_loop:
new_outputs = _fallback_converter(pfor_inputs)
else:
six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2])
except Exception as e: # pylint: disable=broad-except
logging.error(
"Got error while pfor was converting op %s"
@ -1544,6 +1568,10 @@ class PFor(object):
"""
return self._all_indices_partitioned
@property
def fallback_to_while_loop(self):
return self._fallback_to_while_loop
# The code below defines converters for different operations. Please see comment
# for RegisterPFor to see how converters should be defined.
@ -3657,6 +3685,7 @@ def _convert_partitioned_call(pfor_input):
loop_var=pfor.loop_var,
loop_len=pfor.loop_len_vector[0],
pfor_ops=func.graph.get_operations(),
fallback_to_while_loop=pfor.fallback_to_while_loop,
all_indices=pfor.all_indices,
all_indices_partitioned=pfor.all_indices_partitioned,
pfor_config=pfor.pfor_config)
@ -3684,6 +3713,7 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs):
loop_var=pfor_input.pfor.loop_var,
loop_len=array_ops.size(indices),
pfor_ops=func.graph.get_operations(),
fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
all_indices=indices,
all_indices_partitioned=partitioned,
pfor_config=pfor_input.pfor.pfor_config)

View File

@ -53,10 +53,14 @@ class PForTestCase(test.TestCase):
loop_fn,
iters,
parallel_iterations=None,
fallback_to_while_loop=False,
rtol=1e-4,
atol=1e-5):
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
parallel_iterations=parallel_iterations)
t1 = pfor_control_flow_ops.pfor(
loop_fn,
iters=iters,
fallback_to_while_loop=fallback_to_while_loop,
parallel_iterations=parallel_iterations)
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
parallel_iterations=parallel_iterations)

View File

@ -2490,7 +2490,7 @@ tf_module {
}
member_method {
name: "vectorized_map"
argspec: "args=[\'fn\', \'elems\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "verify_tensor_all_finite"

View File

@ -1130,7 +1130,7 @@ tf_module {
}
member_method {
name: "vectorized_map"
argspec: "args=[\'fn\', \'elems\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
}
member_method {
name: "where"