parallel_for: add mechanism for performing reductions across pfor iterations.

PiperOrigin-RevId: 237331345
This commit is contained in:
A. Unique TensorFlower 2019-03-07 14:56:42 -08:00 committed by TensorFlower Gardener
parent 87c56c8fea
commit 94367ef18f
3 changed files with 280 additions and 23 deletions

View File

@ -13,10 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""for_loop and pfor ops.""" """for_loop and pfor ops."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -27,7 +31,10 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.parallel_for.pfor import PFor from tensorflow.python.ops.parallel_for.pfor import PFor
from tensorflow.python.ops.parallel_for.pfor import PForConfig
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
@ -98,6 +105,9 @@ def _flatten_first_two_dims(x):
return array_ops.reshape(x, new_shape) return array_ops.reshape(x, new_shape)
PFOR_CONFIG_ARG = "pfor_config"
def pfor(loop_fn, iters, parallel_iterations=None): def pfor(loop_fn, iters, parallel_iterations=None):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs. """Equivalent to running `loop_fn` `iters` times and stacking the outputs.
@ -127,10 +137,11 @@ def pfor(loop_fn, iters, parallel_iterations=None):
Args: Args:
loop_fn: A function that takes an int32 scalar tf.Tensor object representing loop_fn: A function that takes an int32 scalar tf.Tensor object representing
the iteration number, and returns a possibly nested structure of Tensor or the iteration number, and optionally a keyword argument `pfor_config` set
Operation objects. Note that if setting `parallel_iterations` argument to to a PForConfig object. It returns a possibly nested structure of Tensor
something other than None, `loop_fn` may be called more than once during or Operation objects. Note that if setting `parallel_iterations` argument
graph construction. So it may need to avoid mutating global state. to something other than None, `loop_fn` may be called more than once
during graph construction. So it may need to avoid mutating global state.
iters: Number of iterations for which to run loop_fn. iters: Number of iterations for which to run loop_fn.
parallel_iterations: A knob to control how many iterations are vectorized parallel_iterations: A knob to control how many iterations are vectorized
and dispatched in parallel. The default value of None corresponds to and dispatched in parallel. The default value of None corresponds to
@ -151,11 +162,37 @@ def pfor(loop_fn, iters, parallel_iterations=None):
return f() return f()
def _pfor_impl(loop_fn, iters, parallel_iterations=None): def _loop_fn_has_config(loop_fn):
"""Test if `loop_fn` has a `pfor_config` argument."""
if tf_inspect.isfunction(loop_fn):
argspec = tf_inspect.getargspec(loop_fn)
return PFOR_CONFIG_ARG in argspec.args
elif isinstance(loop_fn, functools.partial):
fn = loop_fn.func
argspec = tf_inspect.getargspec(fn)
return (PFOR_CONFIG_ARG in argspec.args and
PFOR_CONFIG_ARG not in loop_fn.keywords)
else:
loop_class = tf_decorator.unwrap(loop_fn)[1]
if not hasattr(loop_class, "__call__"):
raise ValueError("loop_fn object did not have a __call__ method")
argspec = tf_inspect.getargspec(loop_class.__call__)
return PFOR_CONFIG_ARG in argspec.args
def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None):
"""Implementation of pfor.""" """Implementation of pfor."""
loop_fn_has_config = _loop_fn_has_config(loop_fn)
existing_ops = set(ops.get_default_graph().get_operations()) existing_ops = set(ops.get_default_graph().get_operations())
with ops.name_scope("loop_body"): with ops.name_scope("loop_body"):
loop_var = array_ops.placeholder(dtypes.int32, shape=[]) loop_var = array_ops.placeholder(dtypes.int32, shape=[])
if loop_fn_has_config:
if pfor_config is None:
pfor_config = PForConfig()
pfor_config._set_iters(iters) # pylint: disable=protected-access
loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config})
else:
assert pfor_config is None
loop_fn_outputs = loop_fn(loop_var) loop_fn_outputs = loop_fn(loop_var)
new_ops = set(ops.get_default_graph().get_operations()) - existing_ops new_ops = set(ops.get_default_graph().get_operations()) - existing_ops
iters = ops.convert_to_tensor(iters) iters = ops.convert_to_tensor(iters)
@ -169,18 +206,22 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
parallel_iterations = None parallel_iterations = None
if parallel_iterations is None: if parallel_iterations is None:
with ops.name_scope("pfor"): with ops.name_scope("pfor"):
converter = PFor(loop_var, iters, new_ops) converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config)
outputs = [] outputs = []
for loop_fn_output in nest.flatten(loop_fn_outputs): for loop_fn_output in nest.flatten(loop_fn_outputs):
outputs.append(converter.convert(loop_fn_output)) outputs.append(converter.convert(loop_fn_output))
return nest.pack_sequence_as(loop_fn_outputs, outputs) return nest.pack_sequence_as(loop_fn_outputs, outputs)
else: else:
if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access
raise ValueError("Setting parallel_iterations currently unsupported if"
" reductions across iterations are performed.")
num_tiled_iterations = iters // parallel_iterations num_tiled_iterations = iters // parallel_iterations
num_remaining_iterations = iters % parallel_iterations num_remaining_iterations = iters % parallel_iterations
# TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside
# a tf.function and extract the graph from there to vectorize it. # a tf.function and extract the graph from there to vectorize it.
with ops.name_scope("pfor_untiled"): with ops.name_scope("pfor_untiled"):
converter = PFor(loop_var, num_remaining_iterations, new_ops) converter = PFor(loop_var, num_remaining_iterations, new_ops,
pfor_config=pfor_config)
remaining_outputs = [] remaining_outputs = []
flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs)
for loop_fn_output in flattened_loop_fn_outputs: for loop_fn_output in flattened_loop_fn_outputs:
@ -193,10 +234,14 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
def tiled_loop_body(j): def tiled_loop_body(j):
offset = j * parallel_iterations + num_remaining_iterations offset = j * parallel_iterations + num_remaining_iterations
def tiled_loop_fn(i): def tiled_loop_fn(i, pfor_config=None):
if loop_fn_has_config:
return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config))
else:
return nest.flatten(loop_fn(i + offset)) return nest.flatten(loop_fn(i + offset))
return pfor(tiled_loop_fn, parallel_iterations) return _pfor_impl(
tiled_loop_fn, parallel_iterations, pfor_config=pfor_config)
tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes,
num_tiled_iterations, parallel_iterations=1) num_tiled_iterations, parallel_iterations=1)
@ -213,7 +258,3 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None):
else: else:
outputs = tiled_outputs outputs = tiled_outputs
return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs)) return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs))

View File

@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for pfor and for_loop.""" """Tests for pfor and for_loop."""
# pylint: disable=g-direct-tensorflow-import
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import time import time
from absl import flags from absl import flags
@ -100,6 +102,90 @@ class PForTest(PForTestCase):
pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1) pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1)
@test_util.run_all_in_graph_and_eager_modes
class ReductionTest(PForTestCase):
def test_reduce_concat(self):
x = random_ops.random_uniform([8, 3])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
vectorized_value = pfor_config.reduce_concat(x_i)
mean_value = math_ops.reduce_mean(vectorized_value, axis=0)
return x_i - mean_value
output = pfor_control_flow_ops.pfor(loop_fn, 8)
ans = x - math_ops.reduce_mean(x, axis=0)
output_val, ans_val = self.evaluate([output, ans])
self.assertAllClose(ans_val, output_val)
def test_reduce_mean(self):
x = random_ops.random_uniform([8, 3])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
return x_i - pfor_config.reduce_mean(x_i)
output = pfor_control_flow_ops.pfor(loop_fn, 8)
ans = x - math_ops.reduce_mean(x, axis=0)
output_val, ans_val = self.evaluate([output, ans])
self.assertAllClose(ans_val, output_val)
def test_reduce_sum(self):
x = random_ops.random_uniform([8, 3])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
return x_i - pfor_config.reduce_sum(x_i)
output = pfor_control_flow_ops.pfor(loop_fn, 8)
ans = x - math_ops.reduce_sum(x, axis=0)
output_val, ans_val = self.evaluate([output, ans])
self.assertAllClose(ans_val, output_val)
def test_reduce_class(self):
x = random_ops.random_uniform([8, 3])
class LoopFn(object):
def __init__(self):
pass
def __call__(self, i, pfor_config):
x_i = array_ops.gather(x, i)
return x_i - pfor_config.reduce_mean(x_i)
output = pfor_control_flow_ops.pfor(LoopFn(), 8)
ans = x - math_ops.reduce_mean(x, axis=0)
output_val, ans_val = self.evaluate([output, ans])
self.assertAllClose(ans_val, output_val)
def test_reduce_functools_partial(self):
x = random_ops.random_uniform([8, 3])
def fn(i, pfor_config, dummy=None):
del dummy
x_i = array_ops.gather(x, i)
return x_i - pfor_config.reduce_mean(x_i)
loop_fn = functools.partial(fn, dummy=1)
output = pfor_control_flow_ops.pfor(loop_fn, 8)
ans = x - math_ops.reduce_mean(x, axis=0)
output_val, ans_val = self.evaluate([output, ans])
self.assertAllClose(ans_val, output_val)
def test_parallel_iterations(self):
x = random_ops.random_uniform([8, 3])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
return pfor_config.reduce_sum(x_i)
with self.assertRaisesRegexp(
ValueError, "parallel_iterations currently unsupported"):
pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2)
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
class BitwiseTest(PForTestCase): class BitwiseTest(PForTestCase):
@ -965,6 +1051,26 @@ class Benchmarks(test.Benchmark):
self._run(pfor_outputs, 100, name="pfor_rnn") self._run(pfor_outputs, 100, name="pfor_rnn")
self._run(tf_outputs, 100, name="tf_rnn") self._run(tf_outputs, 100, name="tf_rnn")
def benchmark_reduction(self):
n = 1024
with ops.Graph().as_default():
x = random_ops.random_uniform([n, n])
w = random_ops.random_uniform([n, n])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
return math_ops.reduce_sum(
math_ops.matmul(pfor_config.reduce_concat(x_i), w))
# Note that output_reduction will be tiled, so there may be some minor
# overheads compared to output_no_reduction.
output_reduction = pfor_control_flow_ops.pfor(loop_fn, n)
output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w))
# Benchmark to test that reduction does not add overhead and its output is
# treated as loop invariant.
self._run(output_reduction, 30, name="matmul_reduction")
self._run(output_no_reduction, 30, name="matmul_no_reduction")
class SparseTest(PForTestCase): class SparseTest(PForTestCase):

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Compiled parallel-for loop.""" """Compiled parallel-for loop."""
# pylint: disable=missing-docstring # pylint: disable=missing-docstring,g-direct-tensorflow-import
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -21,6 +21,7 @@ from __future__ import print_function
import collections import collections
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -94,13 +95,15 @@ def _is_stateful_pfor_op(op):
class WhileOp(object): class WhileOp(object):
"""Object for storing state for converting the outputs of a while_loop.""" """Object for storing state for converting the outputs of a while_loop."""
def __init__(self, exit_node, pfor_ops): def __init__(self, exit_node, pfor_ops, pfor_config):
"""Initializer. """Initializer.
Args: Args:
exit_node: A tensor output from the while_loop. exit_node: A tensor output from the while_loop.
pfor_ops: list of ops inside the current pfor loop. pfor_ops: list of ops inside the current pfor loop.
pfor_config: PForConfig object used while constructing loop body.
""" """
self._pfor_config = pfor_config
self._pfor_ops = set(pfor_ops) self._pfor_ops = set(pfor_ops)
self._pfor_op_ids = set([x._id for x in pfor_ops]) self._pfor_op_ids = set([x._id for x in pfor_ops])
assert isinstance(exit_node, ops.Tensor) assert isinstance(exit_node, ops.Tensor)
@ -281,7 +284,8 @@ class WhileOp(object):
loop_len, loop_len,
pfor_ops=self._pfor_ops, pfor_ops=self._pfor_ops,
all_indices=indices, all_indices=indices,
all_indices_partitioned=cond_stacked) all_indices_partitioned=cond_stacked,
pfor_config=self._pfor_config)
# Map all inputs of Enter nodes in self._direct_enters to their converted # Map all inputs of Enter nodes in self._direct_enters to their converted
# values. # values.
for enter in self._direct_enters: for enter in self._direct_enters:
@ -903,6 +907,86 @@ def _fallback_converter(pfor_input):
return tuple([wrap(ta.concat(), True) for ta in ta_list]) return tuple([wrap(ta.concat(), True) for ta in ta_list])
class PForConfig(object):
"""A configuration object used to communicate with loop body function."""
def __init__(self):
# This may be set to the number of iterations.
self._maybe_iters = None
# Map from output placeholder to the unvectorized tensor.
self._reduce_concat_map = {}
# Reverse map of `self._reduce_concat_map`.
self._reverse_reduce_concat_map = {}
def _has_reductions(self):
"""True if some reductions where performed by loop body."""
return len(self._reduce_concat_map)
def _set_iters(self, iters):
"""Set number of pfor iterations."""
self._maybe_iters = iters
# TODO(agarwal): handle reductions inside control flow constructs.
def reduce_concat(self, x):
"""Performs a concat reduction on `x` across pfor iterations.
Note that this currently may not work inside a control flow construct.
Args:
x: an unvectorized Tensor.
Returns:
A Tensor that has rank one higher than `x`. The value is the vectorized
version of `x`, i.e. stacking the value of `x` across different pfor
iterations.
"""
assert not context.executing_eagerly()
assert isinstance(x, ops.Tensor)
if x not in self._reduce_concat_map:
out_shape = tensor_shape.TensorShape([self._maybe_iters]).concatenate(
x.shape)
with ops.control_dependencies([x]):
# Control dependency to make sure out is converted after x.
out = array_ops.placeholder(x.dtype, out_shape)
self._reduce_concat_map[out] = x
self._reverse_reduce_concat_map[x] = out
return out
else:
return self._reverse_reduce_concat_map[x]
def reduce_mean(self, x):
"""Performs a mean reduction on `x` across pfor iterations.
Note that this currently may not work inside a control flow construct.
Args:
x: an unvectorized Tensor.
Returns:
A Tensor that has same rank as `x`. The value is the mean of the values
of `x` across the pfor iterations.
"""
y = self.reduce_concat(x)
return math_ops.reduce_mean(y, axis=0)
def reduce_sum(self, x):
"""Performs a sum reduction on `x` across pfor iterations.
Note that this currently may not work inside a control flow construct.
Args:
x: an unvectorized Tensor.
Returns:
A Tensor that has same rank as `x`. The value is the sum of the values
of `x` across the pfor iterations.
"""
y = self.reduce_concat(x)
return math_ops.reduce_sum(y, axis=0)
def _lookup_reduction(self, pl):
"""Lookups Placeholder `pl` in the reduction map."""
assert isinstance(pl, ops.Tensor)
return self._reduce_concat_map.get(pl, None)
class PFor(object): class PFor(object):
"""Implementation of rewrite of parallel-for loops. """Implementation of rewrite of parallel-for loops.
@ -941,7 +1025,8 @@ class PFor(object):
loop_len, loop_len,
pfor_ops, pfor_ops,
all_indices=None, all_indices=None,
all_indices_partitioned=False): all_indices_partitioned=False,
pfor_config=None):
"""Creates an object to rewrite a parallel-for loop. """Creates an object to rewrite a parallel-for loop.
Args: Args:
@ -958,6 +1043,7 @@ class PFor(object):
all_indices_partitioned: If True, this object is being constructed from a all_indices_partitioned: If True, this object is being constructed from a
control flow construct where not all the pfor iterations are guaranteed control flow construct where not all the pfor iterations are guaranteed
to be active. to be active.
pfor_config: PForConfig object used while constructing the loop body.
""" """
assert isinstance(loop_var, ops.Tensor) assert isinstance(loop_var, ops.Tensor)
assert loop_var.op.type == "Placeholder" assert loop_var.op.type == "Placeholder"
@ -976,6 +1062,7 @@ class PFor(object):
self._conversion_map[loop_var] = wrap(self.all_indices, True) self._conversion_map[loop_var] = wrap(self.all_indices, True)
self._pfor_ops = set(pfor_ops) self._pfor_ops = set(pfor_ops)
self._pfor_op_ids = set([x._id for x in pfor_ops]) self._pfor_op_ids = set([x._id for x in pfor_ops])
self._pfor_config = pfor_config
def op_is_inside_loop(self, op): def op_is_inside_loop(self, op):
"""True if op was created inside the pfor loop body.""" """True if op was created inside the pfor loop body."""
@ -1113,7 +1200,8 @@ class PFor(object):
is_while_loop = y_op.type == "Exit" is_while_loop = y_op.type == "Exit"
if is_while_loop: if is_while_loop:
while_op = WhileOp(y, pfor_ops=self._pfor_ops) while_op = WhileOp(
y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config)
is_inside_loop = while_op.is_inside_loop is_inside_loop = while_op.is_inside_loop
# If all nodes in the while_loop graph were created inside the pfor, we # If all nodes in the while_loop graph were created inside the pfor, we
# treat the whole loop subgraph as a single op (y_op) and try to convert # treat the whole loop subgraph as a single op (y_op) and try to convert
@ -1185,8 +1273,30 @@ class PFor(object):
control_dependencies = [] if is_while_loop else converted_control_ops control_dependencies = [] if is_while_loop else converted_control_ops
with ops.control_dependencies(control_dependencies), ops.name_scope( with ops.control_dependencies(control_dependencies), ops.name_scope(
y_op.name + "/pfor/"): y_op.name + "/pfor/"):
# Op is a placeholder for a reduction.
if (self._pfor_config is not None and
self._pfor_config._lookup_reduction(y) is not None):
# Handle reductions. Map the placeholder to the unvectorized input
# that is being reduced.
reduction_input = self._pfor_config._lookup_reduction(y)
assert isinstance(reduction_input, ops.Tensor), reduction_input
# Tensor being reduced should already be converted due to a control
# dependency on the created placeholder.
# Note that in cases where reduction_input is in an outer context, one
# needs to locate the corresponding Enter node and use that to lookup
# the conversion.
# TODO(agarwal): handle reductions inside control flow constructs.
assert reduction_input in self._conversion_map, (
"Unable to handle reduction of %s, possibly as it was used "
"inside a control flow construct. Note that reductions across "
"pfor iterations are currently not supported inside control flow "
"constructs." % reduction_input)
output = self._conversion_map[reduction_input]
# If original input is not stacked, we tile it. Also we always mark
# output as unstacked.
new_outputs = [wrap(self._unwrap_or_tile(output), False)]
# None of the inputs and control inputs were converted. # None of the inputs and control inputs were converted.
if (not is_inside_loop or elif (not is_inside_loop or
(not is_stateful and not some_input_converted and (not is_stateful and not some_input_converted and
not some_control_input_converted)): not some_control_input_converted)):
if y == y_op: if y == y_op: