pfor: Have pfor fallback to while_loop by default in cases where a conversion
was not supported. PiperOrigin-RevId: 306319253 Change-Id: I9bc1581c4b854f41df6270ea538ae6d88a704d45
This commit is contained in:
parent
8eaea75c6d
commit
867673d63b
@ -1593,27 +1593,6 @@ class JacobianTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
||||
g.jacobian(y, x, experimental_use_pfor=False)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def testPforException(self):
|
||||
var = variables.Variable([1.])
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def op(x):
|
||||
def grad(_):
|
||||
# Note that we perform a stateful operation here that will not be
|
||||
# compatible with parallel for construct.
|
||||
with ops.control_dependencies(
|
||||
[var.assign(random_ops.random_uniform([1]))]):
|
||||
return constant_op.constant(1.)
|
||||
return x, grad
|
||||
|
||||
with backprop.GradientTape() as g:
|
||||
x = constant_op.constant([1., 2.])
|
||||
g.watch(x)
|
||||
y = op(x)
|
||||
with self.assertRaisesRegexp(ValueError, 'No converter'):
|
||||
g.jacobian(y, x, experimental_use_pfor=True)
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
def test_parallel_iterations(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
@ -1723,26 +1702,6 @@ class BatchJacobianTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank at least 2'):
|
||||
g.batch_jacobian(y, x)
|
||||
|
||||
def testPforException(self):
|
||||
var = variables.Variable([1.])
|
||||
|
||||
@custom_gradient.custom_gradient
|
||||
def op(x):
|
||||
def grad(_):
|
||||
# Note that we perform a stateful operation here that will not be
|
||||
# compatible with parallel for construct.
|
||||
with ops.control_dependencies(
|
||||
[var.assign(random_ops.random_uniform([1]))]):
|
||||
return constant_op.constant(1.)
|
||||
return x, grad
|
||||
|
||||
with backprop.GradientTape() as g:
|
||||
x = constant_op.constant([[1.], [2.]])
|
||||
g.watch(x)
|
||||
y = op(x)
|
||||
with self.assertRaisesRegexp(ValueError, 'No converter'):
|
||||
g.batch_jacobian(y, x, experimental_use_pfor=True)
|
||||
|
||||
def test_parallel_iterations(self):
|
||||
with backprop.GradientTape(persistent=True) as g:
|
||||
x = constant_op.constant([[1., 2], [3, 4]])
|
||||
|
@ -210,7 +210,7 @@ class ArrayTest(PForTestCase):
|
||||
return array_ops.tile(x1, [i, 1])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"):
|
||||
pfor_control_flow_ops.pfor(loop_fn, 2)
|
||||
pfor_control_flow_ops.pfor(loop_fn, 2, fallback_to_while_loop=False)
|
||||
|
||||
def test_pack(self):
|
||||
x = random_ops.random_uniform([3, 2, 3])
|
||||
@ -447,6 +447,20 @@ class ArrayTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_strided_slice_loop_variant(self):
|
||||
x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return x_i[i:i+1, ...]
|
||||
|
||||
# Test the fallback to while loop for a ConversionNotImplementedError is
|
||||
# handled.
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
|
||||
# Without fallback, ValueError is thrown.
|
||||
with self.assertRaisesRegexp(ValueError, "expected to be loop invariant"):
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
|
||||
|
||||
def test_depth_to_space(self):
|
||||
x = random_ops.random_uniform([2, 3, 2, 2, 12])
|
||||
|
||||
|
@ -132,7 +132,7 @@ def _is_under_xla_context():
|
||||
return False
|
||||
|
||||
|
||||
def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
def pfor(loop_fn, iters, fallback_to_while_loop=True, parallel_iterations=None):
|
||||
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
|
||||
|
||||
`pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters`
|
||||
@ -167,6 +167,8 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
to something other than None, `loop_fn` may be called more than once
|
||||
during graph construction. So it may need to avoid mutating global state.
|
||||
iters: Number of iterations for which to run loop_fn.
|
||||
fallback_to_while_loop: If true, on failing to vectorize an operation, pfor
|
||||
fallbacks to using a tf.while_loop to dispatch the iterations.
|
||||
parallel_iterations: A knob to control how many iterations are vectorized
|
||||
and dispatched in parallel. The default value of None corresponds to
|
||||
vectorizing all the iterations. If `parallel_iterations` is smaller than
|
||||
@ -180,7 +182,10 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
ValueError: If parallel_iterations is not None and not an integer > 1.
|
||||
"""
|
||||
def f():
|
||||
return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
|
||||
return _pfor_impl(loop_fn,
|
||||
iters,
|
||||
fallback_to_while_loop=fallback_to_while_loop,
|
||||
parallel_iterations=parallel_iterations)
|
||||
# Note that we wrap into a tf.function if in eager execution mode or under
|
||||
# XLA compilation. The latter is so that we don't compile operations like
|
||||
# tf.placeholder that are created by the loop body.
|
||||
@ -219,7 +224,11 @@ def _loop_fn_has_config(loop_fn):
|
||||
return PFOR_CONFIG_ARG in argspec.args
|
||||
|
||||
|
||||
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
def _pfor_impl(loop_fn,
|
||||
iters,
|
||||
fallback_to_while_loop,
|
||||
parallel_iterations=None,
|
||||
pfor_config=None):
|
||||
"""Implementation of pfor."""
|
||||
assert not context.executing_eagerly()
|
||||
loop_fn_has_config = _loop_fn_has_config(loop_fn)
|
||||
@ -263,7 +272,9 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
parallel_iterations = None
|
||||
if parallel_iterations is None:
|
||||
with ops.name_scope("pfor"):
|
||||
converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config)
|
||||
converter = PFor(loop_var, iters, new_ops,
|
||||
fallback_to_while_loop=fallback_to_while_loop,
|
||||
pfor_config=pfor_config)
|
||||
outputs = []
|
||||
for loop_fn_output in nest.flatten(loop_fn_outputs):
|
||||
outputs.append(converter.convert(loop_fn_output))
|
||||
@ -278,6 +289,7 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
# a tf.function and extract the graph from there to vectorize it.
|
||||
with ops.name_scope("pfor_untiled"):
|
||||
converter = PFor(loop_var, num_remaining_iterations, new_ops,
|
||||
fallback_to_while_loop=fallback_to_while_loop,
|
||||
pfor_config=pfor_config)
|
||||
remaining_outputs = []
|
||||
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
|
||||
@ -298,7 +310,10 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
return nest.flatten(loop_fn(i + offset))
|
||||
|
||||
return _pfor_impl(
|
||||
tiled_loop_fn, parallel_iterations, pfor_config=pfor_config)
|
||||
tiled_loop_fn,
|
||||
parallel_iterations,
|
||||
fallback_to_while_loop=fallback_to_while_loop,
|
||||
pfor_config=pfor_config)
|
||||
|
||||
tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
|
||||
num_tiled_iterations, parallel_iterations=1)
|
||||
@ -318,7 +333,7 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
|
||||
|
||||
@tf_export("vectorized_map")
|
||||
def vectorized_map(fn, elems):
|
||||
def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
||||
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
|
||||
@ -340,8 +355,7 @@ def vectorized_map(fn, elems):
|
||||
- Stateful kernels may mostly not be supported since these often imply a
|
||||
data dependency. We do support a limited set of such stateful kernels
|
||||
though (like RandomFoo, Variable operations like reads, etc).
|
||||
- `fn` has limited support for control flow operations. `tf.cond` in
|
||||
particular is not supported.
|
||||
- `fn` has limited support for control flow operations.
|
||||
- `fn` should return nested structure of Tensors or Operations. However
|
||||
if an Operation is returned, it should have zero outputs.
|
||||
- The shape and dtype of any intermediate or output tensors in the
|
||||
@ -389,11 +403,21 @@ def vectorized_map(fn, elems):
|
||||
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
||||
be unpacked along their first dimension. The nested sequence of the
|
||||
resulting slices will be mapped over by `fn`.
|
||||
fallback_to_while_loop: If true, on failing to vectorize an operation,
|
||||
the unsupported op is wrapped in a tf.while_loop to execute the map
|
||||
iterations. Note that this fallback only happens for unsupported ops and
|
||||
other parts of `fn` are still vectorized. If false, on encountering an
|
||||
unsupported op, a ValueError is thrown. Note that the fallbacks can result
|
||||
in slowdowns since vectorization often yields speedup of one to two orders
|
||||
of magnitude.
|
||||
|
||||
Returns:
|
||||
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
|
||||
results of applying fn to tensors unpacked from elems along the first
|
||||
dimension, from first to last.
|
||||
|
||||
Raises:
|
||||
ValueError: If vectorization fails and fallback_to_while_loop is False.
|
||||
"""
|
||||
def loop_fn(i):
|
||||
gathered_elems = nest.map_structure(lambda x: array_ops.gather(x, i), elems)
|
||||
@ -404,4 +428,5 @@ def vectorized_map(fn, elems):
|
||||
batch_size = first_elem.shape.as_list()[0]
|
||||
if batch_size is None:
|
||||
batch_size = array_ops.shape(first_elem)[0]
|
||||
return pfor(loop_fn, batch_size)
|
||||
return pfor(loop_fn, batch_size,
|
||||
fallback_to_while_loop=fallback_to_while_loop)
|
||||
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
||||
import functools
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
@ -82,10 +81,8 @@ class PForTest(PForTestCase):
|
||||
return nn.top_k(x_i)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "No converter defined"):
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
flags.FLAGS.op_conversion_fallback_to_while_loop = True
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
flags.FLAGS.op_conversion_fallback_to_while_loop = False
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
|
||||
|
||||
def test_parallel_iterations(self):
|
||||
for parallel_iterations in [2, 3, 8, 10]:
|
||||
|
@ -65,10 +65,11 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import object_identity
|
||||
|
||||
|
||||
# TODO(agarwal): remove flag.
|
||||
flags.DEFINE_bool(
|
||||
"op_conversion_fallback_to_while_loop", False,
|
||||
"If true, falls back to using a while loop for ops for "
|
||||
"which a converter is not defined.")
|
||||
"op_conversion_fallback_to_while_loop", True,
|
||||
"DEPRECATED: Flag is ignored.")
|
||||
|
||||
|
||||
def _stack(t, length):
|
||||
@ -116,14 +117,17 @@ def _is_stateful_pfor_op(op):
|
||||
class WhileOp(object):
|
||||
"""Object for storing state for converting the outputs of a while_loop."""
|
||||
|
||||
def __init__(self, exit_node, pfor_ops, pfor_config):
|
||||
def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
exit_node: A tensor output from the while_loop.
|
||||
pfor_ops: list of ops inside the current pfor loop.
|
||||
fallback_to_while_loop: If True, fallback to while loop when conversion of
|
||||
an op is not supported
|
||||
pfor_config: PForConfig object used while constructing loop body.
|
||||
"""
|
||||
self._fallback_to_while_loop = fallback_to_while_loop
|
||||
self._pfor_config = pfor_config
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set(x._id for x in pfor_ops)
|
||||
@ -306,6 +310,7 @@ class WhileOp(object):
|
||||
pfor_ops=self._pfor_ops,
|
||||
all_indices=indices,
|
||||
all_indices_partitioned=cond_stacked,
|
||||
fallback_to_while_loop=self._fallback_to_while_loop,
|
||||
pfor_config=self._pfor_config)
|
||||
# Map all inputs of Enter nodes in self._direct_enters to their converted
|
||||
# values.
|
||||
@ -680,6 +685,10 @@ class WhileOp(object):
|
||||
return outputs
|
||||
|
||||
|
||||
class ConversionNotImplementedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _PforInput(object):
|
||||
"""Input object passed to registered pfor converters."""
|
||||
|
||||
@ -763,7 +772,7 @@ class _PforInput(object):
|
||||
input_name = "at index %d" % index
|
||||
else:
|
||||
input_name = "\"%s\"" % op_def.input_arg[index].name
|
||||
raise ValueError(
|
||||
raise ConversionNotImplementedError(
|
||||
"Input %s of op \"%s\" expected to be not loop invariant" %
|
||||
(input_name, op_type))
|
||||
return t
|
||||
@ -777,8 +786,9 @@ class _PforInput(object):
|
||||
input_name = "at index %d" % index
|
||||
else:
|
||||
input_name = "\"%s\"" % op_def.input_arg[index].name
|
||||
raise ValueError("Input %s of op \"%s\" expected to be loop invariant" %
|
||||
(input_name, op_type))
|
||||
raise ConversionNotImplementedError(
|
||||
"Input %s of op \"%s\" expected to be loop invariant" %
|
||||
(input_name, op_type))
|
||||
return t
|
||||
|
||||
@property
|
||||
@ -971,6 +981,8 @@ def _fallback_converter(pfor_input):
|
||||
attrs=pfor_input.op.node_def.attr).outputs
|
||||
|
||||
outputs = []
|
||||
# TODO(agarwal): Add tf.debugging asserts to check that the shapes across
|
||||
# the different iterations are the same.
|
||||
for out, ta in zip(op_outputs, ta_list):
|
||||
assert isinstance(out, ops.Tensor)
|
||||
outputs.append(ta.write(i, array_ops.expand_dims(out, 0)))
|
||||
@ -1144,6 +1156,7 @@ class PFor(object):
|
||||
loop_var,
|
||||
loop_len,
|
||||
pfor_ops,
|
||||
fallback_to_while_loop,
|
||||
all_indices=None,
|
||||
all_indices_partitioned=False,
|
||||
pfor_config=None):
|
||||
@ -1155,6 +1168,8 @@ class PFor(object):
|
||||
loop_len: A scalar or scalar Tensor representing the number of iterations
|
||||
the loop is run for.
|
||||
pfor_ops: List of all ops inside the loop body.
|
||||
fallback_to_while_loop: If True, on failure to vectorize an op, a while
|
||||
loop is used to sequentially execute that op.
|
||||
all_indices: If not None, an int32 vector with size `loop_len`
|
||||
representing the iteration ids that are still active. These values
|
||||
should be unique and sorted. However they may not be contiguous. This is
|
||||
@ -1182,6 +1197,7 @@ class PFor(object):
|
||||
self._conversion_map[loop_var] = wrap(self.all_indices, True)
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set(x._id for x in pfor_ops)
|
||||
self._fallback_to_while_loop = fallback_to_while_loop
|
||||
self._pfor_config = pfor_config
|
||||
|
||||
def op_is_inside_loop(self, op):
|
||||
@ -1350,7 +1366,9 @@ class PFor(object):
|
||||
is_while_loop = y_op.type == "Exit"
|
||||
if is_while_loop:
|
||||
while_op = WhileOp(
|
||||
y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
|
||||
y, pfor_ops=self._pfor_ops,
|
||||
fallback_to_while_loop=self.fallback_to_while_loop,
|
||||
pfor_config=self._pfor_config)
|
||||
is_inside_loop = while_op.is_inside_loop
|
||||
# If all nodes in the while_loop graph were created inside the pfor, we
|
||||
# treat the whole loop subgraph as a single op (y_op) and try to convert
|
||||
@ -1455,20 +1473,26 @@ class PFor(object):
|
||||
else:
|
||||
converter = _pfor_converter_registry.get(y_op.type, None)
|
||||
if converter is None:
|
||||
if flags.FLAGS.op_conversion_fallback_to_while_loop:
|
||||
if self._fallback_to_while_loop:
|
||||
converter = _fallback_converter
|
||||
else:
|
||||
raise ValueError("No converter defined for %s\n%s\ninputs: %s. "
|
||||
"\nEither add a converter or set "
|
||||
"--op_conversion_fallback_to_while_loop=True, "
|
||||
"which may run slower" %
|
||||
"\nEither add a converter or "
|
||||
"enable fallback_to_while_loop "
|
||||
"option to pfor, which may run slower" %
|
||||
(y_op.type, y_op, converted_inputs))
|
||||
# TODO(rachelim): Handle the case where some inputs are sparsely
|
||||
# stacked. We should only call the converter if it supports handling
|
||||
# those inputs.
|
||||
pfor_inputs = _PforInput(self, y_op, converted_inputs)
|
||||
try:
|
||||
new_outputs = converter(pfor_inputs)
|
||||
try:
|
||||
new_outputs = converter(pfor_inputs)
|
||||
except ConversionNotImplementedError as e:
|
||||
if self._fallback_to_while_loop:
|
||||
new_outputs = _fallback_converter(pfor_inputs)
|
||||
else:
|
||||
six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2])
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.error(
|
||||
"Got error while pfor was converting op %s"
|
||||
@ -1544,6 +1568,10 @@ class PFor(object):
|
||||
"""
|
||||
return self._all_indices_partitioned
|
||||
|
||||
@property
|
||||
def fallback_to_while_loop(self):
|
||||
return self._fallback_to_while_loop
|
||||
|
||||
|
||||
# The code below defines converters for different operations. Please see comment
|
||||
# for RegisterPFor to see how converters should be defined.
|
||||
@ -3657,6 +3685,7 @@ def _convert_partitioned_call(pfor_input):
|
||||
loop_var=pfor.loop_var,
|
||||
loop_len=pfor.loop_len_vector[0],
|
||||
pfor_ops=func.graph.get_operations(),
|
||||
fallback_to_while_loop=pfor.fallback_to_while_loop,
|
||||
all_indices=pfor.all_indices,
|
||||
all_indices_partitioned=pfor.all_indices_partitioned,
|
||||
pfor_config=pfor.pfor_config)
|
||||
@ -3684,6 +3713,7 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs):
|
||||
loop_var=pfor_input.pfor.loop_var,
|
||||
loop_len=array_ops.size(indices),
|
||||
pfor_ops=func.graph.get_operations(),
|
||||
fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop,
|
||||
all_indices=indices,
|
||||
all_indices_partitioned=partitioned,
|
||||
pfor_config=pfor_input.pfor.pfor_config)
|
||||
|
@ -53,10 +53,14 @@ class PForTestCase(test.TestCase):
|
||||
loop_fn,
|
||||
iters,
|
||||
parallel_iterations=None,
|
||||
fallback_to_while_loop=False,
|
||||
rtol=1e-4,
|
||||
atol=1e-5):
|
||||
t1 = pfor_control_flow_ops.pfor(loop_fn, iters=iters,
|
||||
parallel_iterations=parallel_iterations)
|
||||
t1 = pfor_control_flow_ops.pfor(
|
||||
loop_fn,
|
||||
iters=iters,
|
||||
fallback_to_while_loop=fallback_to_while_loop,
|
||||
parallel_iterations=parallel_iterations)
|
||||
loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
|
||||
t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
|
||||
parallel_iterations=parallel_iterations)
|
||||
|
@ -2490,7 +2490,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "vectorized_map"
|
||||
argspec: "args=[\'fn\', \'elems\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "verify_tensor_all_finite"
|
||||
|
@ -1130,7 +1130,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "vectorized_map"
|
||||
argspec: "args=[\'fn\', \'elems\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'fn\', \'elems\', \'fallback_to_while_loop\'], varargs=None, keywords=None, defaults=[\'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "where"
|
||||
|
Loading…
x
Reference in New Issue
Block a user