parallel_for: add mechanism for performing reductions across pfor iterations.
PiperOrigin-RevId: 237331345
This commit is contained in:
parent
87c56c8fea
commit
94367ef18f
@ -13,10 +13,14 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""for_loop and pfor ops."""
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -27,7 +31,10 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.parallel_for.pfor import PFor
|
||||
from tensorflow.python.ops.parallel_for.pfor import PForConfig
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
|
||||
@ -98,6 +105,9 @@ def _flatten_first_two_dims(x):
|
||||
return array_ops.reshape(x, new_shape)
|
||||
|
||||
|
||||
PFOR_CONFIG_ARG = "pfor_config"
|
||||
|
||||
|
||||
def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
|
||||
|
||||
@ -127,10 +137,11 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
|
||||
Args:
|
||||
loop_fn: A function that takes an int32 scalar tf.Tensor object representing
|
||||
the iteration number, and returns a possibly nested structure of Tensor or
|
||||
Operation objects. Note that if setting `parallel_iterations` argument to
|
||||
something other than None, `loop_fn` may be called more than once during
|
||||
graph construction. So it may need to avoid mutating global state.
|
||||
the iteration number, and optionally a keyword argument `pfor_config` set
|
||||
to a PForConfig object. It returns a possibly nested structure of Tensor
|
||||
or Operation objects. Note that if setting `parallel_iterations` argument
|
||||
to something other than None, `loop_fn` may be called more than once
|
||||
during graph construction. So it may need to avoid mutating global state.
|
||||
iters: Number of iterations for which to run loop_fn.
|
||||
parallel_iterations: A knob to control how many iterations are vectorized
|
||||
and dispatched in parallel. The default value of None corresponds to
|
||||
@ -151,11 +162,37 @@ def pfor(loop_fn, iters, parallel_iterations=None):
|
||||
return f()
|
||||
|
||||
|
||||
def _pfor_impl(loop_fn, iters, parallel_iterations=None):
|
||||
def _loop_fn_has_config(loop_fn):
|
||||
"""Test if `loop_fn` has a `pfor_config` argument."""
|
||||
if tf_inspect.isfunction(loop_fn):
|
||||
argspec = tf_inspect.getargspec(loop_fn)
|
||||
return PFOR_CONFIG_ARG in argspec.args
|
||||
elif isinstance(loop_fn, functools.partial):
|
||||
fn = loop_fn.func
|
||||
argspec = tf_inspect.getargspec(fn)
|
||||
return (PFOR_CONFIG_ARG in argspec.args and
|
||||
PFOR_CONFIG_ARG not in loop_fn.keywords)
|
||||
else:
|
||||
loop_class = tf_decorator.unwrap(loop_fn)[1]
|
||||
if not hasattr(loop_class, "__call__"):
|
||||
raise ValueError("loop_fn object did not have a __call__ method")
|
||||
argspec = tf_inspect.getargspec(loop_class.__call__)
|
||||
return PFOR_CONFIG_ARG in argspec.args
|
||||
|
||||
|
||||
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
|
||||
"""Implementation of pfor."""
|
||||
loop_fn_has_config = _loop_fn_has_config(loop_fn)
|
||||
existing_ops = set(ops.get_default_graph().get_operations())
|
||||
with ops.name_scope("loop_body"):
|
||||
loop_var = array_ops.placeholder(dtypes.int32, shape=[])
|
||||
if loop_fn_has_config:
|
||||
if pfor_config is None:
|
||||
pfor_config = PForConfig()
|
||||
pfor_config._set_iters(iters) # pylint: disable=protected-access
|
||||
loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
|
||||
else:
|
||||
assert pfor_config is None
|
||||
loop_fn_outputs = loop_fn(loop_var)
|
||||
new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
|
||||
iters = ops.convert_to_tensor(iters)
|
||||
@ -169,18 +206,22 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
|
||||
parallel_iterations = None
|
||||
if parallel_iterations is None:
|
||||
with ops.name_scope("pfor"):
|
||||
converter = PFor(loop_var, iters, new_ops)
|
||||
converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config)
|
||||
outputs = []
|
||||
for loop_fn_output in nest.flatten(loop_fn_outputs):
|
||||
outputs.append(converter.convert(loop_fn_output))
|
||||
return nest.pack_sequence_as(loop_fn_outputs, outputs)
|
||||
else:
|
||||
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
|
||||
raise ValueError("Setting parallel_iterations currently unsupported if"
|
||||
" reductions across iterations are performed.")
|
||||
num_tiled_iterations = iters // parallel_iterations
|
||||
num_remaining_iterations = iters % parallel_iterations
|
||||
# TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
|
||||
# a tf.function and extract the graph from there to vectorize it.
|
||||
with ops.name_scope("pfor_untiled"):
|
||||
converter = PFor(loop_var, num_remaining_iterations, new_ops)
|
||||
converter = PFor(loop_var, num_remaining_iterations, new_ops,
|
||||
pfor_config=pfor_config)
|
||||
remaining_outputs = []
|
||||
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
|
||||
for loop_fn_output in flattened_loop_fn_outputs:
|
||||
@ -193,10 +234,14 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
|
||||
def tiled_loop_body(j):
|
||||
offset = j * parallel_iterations + num_remaining_iterations
|
||||
|
||||
def tiled_loop_fn(i):
|
||||
def tiled_loop_fn(i, pfor_config=None):
|
||||
if loop_fn_has_config:
|
||||
return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config))
|
||||
else:
|
||||
return nest.flatten(loop_fn(i + offset))
|
||||
|
||||
return pfor(tiled_loop_fn, parallel_iterations)
|
||||
return _pfor_impl(
|
||||
tiled_loop_fn, parallel_iterations, pfor_config=pfor_config)
|
||||
|
||||
tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
|
||||
num_tiled_iterations, parallel_iterations=1)
|
||||
@ -213,7 +258,3 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
|
||||
else:
|
||||
outputs = tiled_outputs
|
||||
return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -13,11 +13,13 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for pfor and for_loop."""
|
||||
# pylint: disable=g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
@ -100,6 +102,90 @@ class PForTest(PForTestCase):
|
||||
pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ReductionTest(PForTestCase):
|
||||
|
||||
def test_reduce_concat(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
vectorized_value = pfor_config.reduce_concat(x_i)
|
||||
mean_value = math_ops.reduce_mean(vectorized_value, axis=0)
|
||||
return x_i - mean_value
|
||||
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 8)
|
||||
ans = x - math_ops.reduce_mean(x, axis=0)
|
||||
output_val, ans_val = self.evaluate([output, ans])
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
def test_reduce_mean(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return x_i - pfor_config.reduce_mean(x_i)
|
||||
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 8)
|
||||
ans = x - math_ops.reduce_mean(x, axis=0)
|
||||
output_val, ans_val = self.evaluate([output, ans])
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
def test_reduce_sum(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return x_i - pfor_config.reduce_sum(x_i)
|
||||
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 8)
|
||||
ans = x - math_ops.reduce_sum(x, axis=0)
|
||||
output_val, ans_val = self.evaluate([output, ans])
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
def test_reduce_class(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
class LoopFn(object):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return x_i - pfor_config.reduce_mean(x_i)
|
||||
|
||||
output = pfor_control_flow_ops.pfor(LoopFn(), 8)
|
||||
ans = x - math_ops.reduce_mean(x, axis=0)
|
||||
output_val, ans_val = self.evaluate([output, ans])
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
def test_reduce_functools_partial(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
def fn(i, pfor_config, dummy=None):
|
||||
del dummy
|
||||
x_i = array_ops.gather(x, i)
|
||||
return x_i - pfor_config.reduce_mean(x_i)
|
||||
|
||||
loop_fn = functools.partial(fn, dummy=1)
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 8)
|
||||
ans = x - math_ops.reduce_mean(x, axis=0)
|
||||
output_val, ans_val = self.evaluate([output, ans])
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
def test_parallel_iterations(self):
|
||||
x = random_ops.random_uniform([8, 3])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return pfor_config.reduce_sum(x_i)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "parallel_iterations currently unsupported"):
|
||||
pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class BitwiseTest(PForTestCase):
|
||||
|
||||
@ -965,6 +1051,26 @@ class Benchmarks(test.Benchmark):
|
||||
self._run(pfor_outputs, 100, name="pfor_rnn")
|
||||
self._run(tf_outputs, 100, name="tf_rnn")
|
||||
|
||||
def benchmark_reduction(self):
|
||||
n = 1024
|
||||
with ops.Graph().as_default():
|
||||
x = random_ops.random_uniform([n, n])
|
||||
w = random_ops.random_uniform([n, n])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return math_ops.reduce_sum(
|
||||
math_ops.matmul(pfor_config.reduce_concat(x_i), w))
|
||||
|
||||
# Note that output_reduction will be tiled, so there may be some minor
|
||||
# overheads compared to output_no_reduction.
|
||||
output_reduction = pfor_control_flow_ops.pfor(loop_fn, n)
|
||||
output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w))
|
||||
# Benchmark to test that reduction does not add overhead and its output is
|
||||
# treated as loop invariant.
|
||||
self._run(output_reduction, 30, name="matmul_reduction")
|
||||
self._run(output_no_reduction, 30, name="matmul_no_reduction")
|
||||
|
||||
|
||||
class SparseTest(PForTestCase):
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Compiled parallel-for loop."""
|
||||
# pylint: disable=missing-docstring
|
||||
# pylint: disable=missing-docstring,g-direct-tensorflow-import
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -94,13 +95,15 @@ def _is_stateful_pfor_op(op):
|
||||
class WhileOp(object):
|
||||
"""Object for storing state for converting the outputs of a while_loop."""
|
||||
|
||||
def __init__(self, exit_node, pfor_ops):
|
||||
def __init__(self, exit_node, pfor_ops, pfor_config):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
exit_node: A tensor output from the while_loop.
|
||||
pfor_ops: list of ops inside the current pfor loop.
|
||||
pfor_config: PForConfig object used while constructing loop body.
|
||||
"""
|
||||
self._pfor_config = pfor_config
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set([x._id for x in pfor_ops])
|
||||
assert isinstance(exit_node, ops.Tensor)
|
||||
@ -281,7 +284,8 @@ class WhileOp(object):
|
||||
loop_len,
|
||||
pfor_ops=self._pfor_ops,
|
||||
all_indices=indices,
|
||||
all_indices_partitioned=cond_stacked)
|
||||
all_indices_partitioned=cond_stacked,
|
||||
pfor_config=self._pfor_config)
|
||||
# Map all inputs of Enter nodes in self._direct_enters to their converted
|
||||
# values.
|
||||
for enter in self._direct_enters:
|
||||
@ -903,6 +907,86 @@ def _fallback_converter(pfor_input):
|
||||
return tuple([wrap(ta.concat(), True) for ta in ta_list])
|
||||
|
||||
|
||||
class PForConfig(object):
|
||||
"""A configuration object used to communicate with loop body function."""
|
||||
|
||||
def __init__(self):
|
||||
# This may be set to the number of iterations.
|
||||
self._maybe_iters = None
|
||||
# Map from output placeholder to the unvectorized tensor.
|
||||
self._reduce_concat_map = {}
|
||||
# Reverse map of `self._reduce_concat_map`.
|
||||
self._reverse_reduce_concat_map = {}
|
||||
|
||||
def _has_reductions(self):
|
||||
"""True if some reductions where performed by loop body."""
|
||||
return len(self._reduce_concat_map)
|
||||
|
||||
def _set_iters(self, iters):
|
||||
"""Set number of pfor iterations."""
|
||||
self._maybe_iters = iters
|
||||
|
||||
# TODO(agarwal): handle reductions inside control flow constructs.
|
||||
def reduce_concat(self, x):
|
||||
"""Performs a concat reduction on `x` across pfor iterations.
|
||||
|
||||
Note that this currently may not work inside a control flow construct.
|
||||
Args:
|
||||
x: an unvectorized Tensor.
|
||||
|
||||
Returns:
|
||||
A Tensor that has rank one higher than `x`. The value is the vectorized
|
||||
version of `x`, i.e. stacking the value of `x` across different pfor
|
||||
iterations.
|
||||
"""
|
||||
assert not context.executing_eagerly()
|
||||
assert isinstance(x, ops.Tensor)
|
||||
if x not in self._reduce_concat_map:
|
||||
out_shape = tensor_shape.TensorShape([self._maybe_iters]).concatenate(
|
||||
x.shape)
|
||||
with ops.control_dependencies([x]):
|
||||
# Control dependency to make sure out is converted after x.
|
||||
out = array_ops.placeholder(x.dtype, out_shape)
|
||||
self._reduce_concat_map[out] = x
|
||||
self._reverse_reduce_concat_map[x] = out
|
||||
return out
|
||||
else:
|
||||
return self._reverse_reduce_concat_map[x]
|
||||
|
||||
def reduce_mean(self, x):
|
||||
"""Performs a mean reduction on `x` across pfor iterations.
|
||||
|
||||
Note that this currently may not work inside a control flow construct.
|
||||
Args:
|
||||
x: an unvectorized Tensor.
|
||||
|
||||
Returns:
|
||||
A Tensor that has same rank as `x`. The value is the mean of the values
|
||||
of `x` across the pfor iterations.
|
||||
"""
|
||||
y = self.reduce_concat(x)
|
||||
return math_ops.reduce_mean(y, axis=0)
|
||||
|
||||
def reduce_sum(self, x):
|
||||
"""Performs a sum reduction on `x` across pfor iterations.
|
||||
|
||||
Note that this currently may not work inside a control flow construct.
|
||||
Args:
|
||||
x: an unvectorized Tensor.
|
||||
|
||||
Returns:
|
||||
A Tensor that has same rank as `x`. The value is the sum of the values
|
||||
of `x` across the pfor iterations.
|
||||
"""
|
||||
y = self.reduce_concat(x)
|
||||
return math_ops.reduce_sum(y, axis=0)
|
||||
|
||||
def _lookup_reduction(self, pl):
|
||||
"""Lookups Placeholder `pl` in the reduction map."""
|
||||
assert isinstance(pl, ops.Tensor)
|
||||
return self._reduce_concat_map.get(pl, None)
|
||||
|
||||
|
||||
class PFor(object):
|
||||
"""Implementation of rewrite of parallel-for loops.
|
||||
|
||||
@ -941,7 +1025,8 @@ class PFor(object):
|
||||
loop_len,
|
||||
pfor_ops,
|
||||
all_indices=None,
|
||||
all_indices_partitioned=False):
|
||||
all_indices_partitioned=False,
|
||||
pfor_config=None):
|
||||
"""Creates an object to rewrite a parallel-for loop.
|
||||
|
||||
Args:
|
||||
@ -958,6 +1043,7 @@ class PFor(object):
|
||||
all_indices_partitioned: If True, this object is being constructed from a
|
||||
control flow construct where not all the pfor iterations are guaranteed
|
||||
to be active.
|
||||
pfor_config: PForConfig object used while constructing the loop body.
|
||||
"""
|
||||
assert isinstance(loop_var, ops.Tensor)
|
||||
assert loop_var.op.type == "Placeholder"
|
||||
@ -976,6 +1062,7 @@ class PFor(object):
|
||||
self._conversion_map[loop_var] = wrap(self.all_indices, True)
|
||||
self._pfor_ops = set(pfor_ops)
|
||||
self._pfor_op_ids = set([x._id for x in pfor_ops])
|
||||
self._pfor_config = pfor_config
|
||||
|
||||
def op_is_inside_loop(self, op):
|
||||
"""True if op was created inside the pfor loop body."""
|
||||
@ -1113,7 +1200,8 @@ class PFor(object):
|
||||
|
||||
is_while_loop = y_op.type == "Exit"
|
||||
if is_while_loop:
|
||||
while_op = WhileOp(y, pfor_ops=self._pfor_ops)
|
||||
while_op = WhileOp(
|
||||
y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
|
||||
is_inside_loop = while_op.is_inside_loop
|
||||
# If all nodes in the while_loop graph were created inside the pfor, we
|
||||
# treat the whole loop subgraph as a single op (y_op) and try to convert
|
||||
@ -1185,8 +1273,30 @@ class PFor(object):
|
||||
control_dependencies = [] if is_while_loop else converted_control_ops
|
||||
with ops.control_dependencies(control_dependencies), ops.name_scope(
|
||||
y_op.name + "/pfor/"):
|
||||
# Op is a placeholder for a reduction.
|
||||
if (self._pfor_config is not None and
|
||||
self._pfor_config._lookup_reduction(y) is not None):
|
||||
# Handle reductions. Map the placeholder to the unvectorized input
|
||||
# that is being reduced.
|
||||
reduction_input = self._pfor_config._lookup_reduction(y)
|
||||
assert isinstance(reduction_input, ops.Tensor), reduction_input
|
||||
# Tensor being reduced should already be converted due to a control
|
||||
# dependency on the created placeholder.
|
||||
# Note that in cases where reduction_input is in an outer context, one
|
||||
# needs to locate the corresponding Enter node and use that to lookup
|
||||
# the conversion.
|
||||
# TODO(agarwal): handle reductions inside control flow constructs.
|
||||
assert reduction_input in self._conversion_map, (
|
||||
"Unable to handle reduction of %s, possibly as it was used "
|
||||
"inside a control flow construct. Note that reductions across "
|
||||
"pfor iterations are currently not supported inside control flow "
|
||||
"constructs." % reduction_input)
|
||||
output = self._conversion_map[reduction_input]
|
||||
# If original input is not stacked, we tile it. Also we always mark
|
||||
# output as unstacked.
|
||||
new_outputs = [wrap(self._unwrap_or_tile(output), False)]
|
||||
# None of the inputs and control inputs were converted.
|
||||
if (not is_inside_loop or
|
||||
elif (not is_inside_loop or
|
||||
(not is_stateful and not some_input_converted and
|
||||
not some_control_input_converted)):
|
||||
if y == y_op:
|
||||
|
Loading…
Reference in New Issue
Block a user