diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index 8f652e9c509..83bf86a5635 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -13,10 +13,14 @@ # limitations under the License. # ============================================================================== """for_loop and pfor ops.""" +# pylint: disable=g-direct-tensorflow-import + from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes @@ -27,7 +31,10 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.parallel_for.pfor import PFor +from tensorflow.python.ops.parallel_for.pfor import PForConfig from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): @@ -98,6 +105,9 @@ def _flatten_first_two_dims(x): return array_ops.reshape(x, new_shape) +PFOR_CONFIG_ARG = "pfor_config" + + def pfor(loop_fn, iters, parallel_iterations=None): """Equivalent to running `loop_fn` `iters` times and stacking the outputs. @@ -127,10 +137,11 @@ def pfor(loop_fn, iters, parallel_iterations=None): Args: loop_fn: A function that takes an int32 scalar tf.Tensor object representing - the iteration number, and returns a possibly nested structure of Tensor or - Operation objects. Note that if setting `parallel_iterations` argument to - something other than None, `loop_fn` may be called more than once during - graph construction. So it may need to avoid mutating global state. + the iteration number, and optionally a keyword argument `pfor_config` set + to a PForConfig object. It returns a possibly nested structure of Tensor + or Operation objects. Note that if setting `parallel_iterations` argument + to something other than None, `loop_fn` may be called more than once + during graph construction. So it may need to avoid mutating global state. iters: Number of iterations for which to run loop_fn. parallel_iterations: A knob to control how many iterations are vectorized and dispatched in parallel. The default value of None corresponds to @@ -151,12 +162,38 @@ def pfor(loop_fn, iters, parallel_iterations=None): return f() -def _pfor_impl(loop_fn, iters, parallel_iterations=None): +def _loop_fn_has_config(loop_fn): + """Test if `loop_fn` has a `pfor_config` argument.""" + if tf_inspect.isfunction(loop_fn): + argspec = tf_inspect.getargspec(loop_fn) + return PFOR_CONFIG_ARG in argspec.args + elif isinstance(loop_fn, functools.partial): + fn = loop_fn.func + argspec = tf_inspect.getargspec(fn) + return (PFOR_CONFIG_ARG in argspec.args and + PFOR_CONFIG_ARG not in loop_fn.keywords) + else: + loop_class = tf_decorator.unwrap(loop_fn)[1] + if not hasattr(loop_class, "__call__"): + raise ValueError("loop_fn object did not have a __call__ method") + argspec = tf_inspect.getargspec(loop_class.__call__) + return PFOR_CONFIG_ARG in argspec.args + + +def _pfor_impl(loop_fn, iters, parallel_iterations=None, pfor_config=None): """Implementation of pfor.""" + loop_fn_has_config = _loop_fn_has_config(loop_fn) existing_ops = set(ops.get_default_graph().get_operations()) with ops.name_scope("loop_body"): loop_var = array_ops.placeholder(dtypes.int32, shape=[]) - loop_fn_outputs = loop_fn(loop_var) + if loop_fn_has_config: + if pfor_config is None: + pfor_config = PForConfig() + pfor_config._set_iters(iters) # pylint: disable=protected-access + loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) + else: + assert pfor_config is None + loop_fn_outputs = loop_fn(loop_var) new_ops = set(ops.get_default_graph().get_operations()) - existing_ops iters = ops.convert_to_tensor(iters) if parallel_iterations is not None: @@ -169,18 +206,22 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None): parallel_iterations = None if parallel_iterations is None: with ops.name_scope("pfor"): - converter = PFor(loop_var, iters, new_ops) + converter = PFor(loop_var, iters, new_ops, pfor_config=pfor_config) outputs = [] for loop_fn_output in nest.flatten(loop_fn_outputs): outputs.append(converter.convert(loop_fn_output)) return nest.pack_sequence_as(loop_fn_outputs, outputs) else: + if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access + raise ValueError("Setting parallel_iterations currently unsupported if" + " reductions across iterations are performed.") num_tiled_iterations = iters // parallel_iterations num_remaining_iterations = iters % parallel_iterations # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside # a tf.function and extract the graph from there to vectorize it. with ops.name_scope("pfor_untiled"): - converter = PFor(loop_var, num_remaining_iterations, new_ops) + converter = PFor(loop_var, num_remaining_iterations, new_ops, + pfor_config=pfor_config) remaining_outputs = [] flattened_loop_fn_outputs = nest.flatten(loop_fn_outputs) for loop_fn_output in flattened_loop_fn_outputs: @@ -193,10 +234,14 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None): def tiled_loop_body(j): offset = j * parallel_iterations + num_remaining_iterations - def tiled_loop_fn(i): - return nest.flatten(loop_fn(i + offset)) + def tiled_loop_fn(i, pfor_config=None): + if loop_fn_has_config: + return nest.flatten(loop_fn(i + offset, pfor_config=pfor_config)) + else: + return nest.flatten(loop_fn(i + offset)) - return pfor(tiled_loop_fn, parallel_iterations) + return _pfor_impl( + tiled_loop_fn, parallel_iterations, pfor_config=pfor_config) tiled_outputs = for_loop(tiled_loop_body, loop_fn_dtypes, num_tiled_iterations, parallel_iterations=1) @@ -213,7 +258,3 @@ def _pfor_impl(loop_fn, iters, parallel_iterations=None): else: outputs = tiled_outputs return nest.pack_sequence_as(loop_fn_outputs, nest.flatten(outputs)) - - - - diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 1b1b336bd0e..ef877c35446 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================== """Tests for pfor and for_loop.""" +# pylint: disable=g-direct-tensorflow-import from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import time from absl import flags @@ -100,6 +102,90 @@ class PForTest(PForTestCase): pfor_control_flow_ops.pfor(lambda i: 1, 8, parallel_iterations=1) +@test_util.run_all_in_graph_and_eager_modes +class ReductionTest(PForTestCase): + + def test_reduce_concat(self): + x = random_ops.random_uniform([8, 3]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + vectorized_value = pfor_config.reduce_concat(x_i) + mean_value = math_ops.reduce_mean(vectorized_value, axis=0) + return x_i - mean_value + + output = pfor_control_flow_ops.pfor(loop_fn, 8) + ans = x - math_ops.reduce_mean(x, axis=0) + output_val, ans_val = self.evaluate([output, ans]) + self.assertAllClose(ans_val, output_val) + + def test_reduce_mean(self): + x = random_ops.random_uniform([8, 3]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + return x_i - pfor_config.reduce_mean(x_i) + + output = pfor_control_flow_ops.pfor(loop_fn, 8) + ans = x - math_ops.reduce_mean(x, axis=0) + output_val, ans_val = self.evaluate([output, ans]) + self.assertAllClose(ans_val, output_val) + + def test_reduce_sum(self): + x = random_ops.random_uniform([8, 3]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + return x_i - pfor_config.reduce_sum(x_i) + + output = pfor_control_flow_ops.pfor(loop_fn, 8) + ans = x - math_ops.reduce_sum(x, axis=0) + output_val, ans_val = self.evaluate([output, ans]) + self.assertAllClose(ans_val, output_val) + + def test_reduce_class(self): + x = random_ops.random_uniform([8, 3]) + + class LoopFn(object): + + def __init__(self): + pass + + def __call__(self, i, pfor_config): + x_i = array_ops.gather(x, i) + return x_i - pfor_config.reduce_mean(x_i) + + output = pfor_control_flow_ops.pfor(LoopFn(), 8) + ans = x - math_ops.reduce_mean(x, axis=0) + output_val, ans_val = self.evaluate([output, ans]) + self.assertAllClose(ans_val, output_val) + + def test_reduce_functools_partial(self): + x = random_ops.random_uniform([8, 3]) + + def fn(i, pfor_config, dummy=None): + del dummy + x_i = array_ops.gather(x, i) + return x_i - pfor_config.reduce_mean(x_i) + + loop_fn = functools.partial(fn, dummy=1) + output = pfor_control_flow_ops.pfor(loop_fn, 8) + ans = x - math_ops.reduce_mean(x, axis=0) + output_val, ans_val = self.evaluate([output, ans]) + self.assertAllClose(ans_val, output_val) + + def test_parallel_iterations(self): + x = random_ops.random_uniform([8, 3]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + return pfor_config.reduce_sum(x_i) + + with self.assertRaisesRegexp( + ValueError, "parallel_iterations currently unsupported"): + pfor_control_flow_ops.pfor(loop_fn, 8, parallel_iterations=2) + + @test_util.run_all_in_graph_and_eager_modes class BitwiseTest(PForTestCase): @@ -965,6 +1051,26 @@ class Benchmarks(test.Benchmark): self._run(pfor_outputs, 100, name="pfor_rnn") self._run(tf_outputs, 100, name="tf_rnn") + def benchmark_reduction(self): + n = 1024 + with ops.Graph().as_default(): + x = random_ops.random_uniform([n, n]) + w = random_ops.random_uniform([n, n]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + return math_ops.reduce_sum( + math_ops.matmul(pfor_config.reduce_concat(x_i), w)) + + # Note that output_reduction will be tiled, so there may be some minor + # overheads compared to output_no_reduction. + output_reduction = pfor_control_flow_ops.pfor(loop_fn, n) + output_no_reduction = math_ops.reduce_sum(math_ops.matmul(x, w)) + # Benchmark to test that reduction does not add overhead and its output is + # treated as loop invariant. + self._run(output_reduction, 30, name="matmul_reduction") + self._run(output_no_reduction, 30, name="matmul_no_reduction") + class SparseTest(PForTestCase): diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index b9f7a0ffca5..d6ba04917ae 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Compiled parallel-for loop.""" -# pylint: disable=missing-docstring +# pylint: disable=missing-docstring,g-direct-tensorflow-import from __future__ import absolute_import from __future__ import division @@ -21,6 +21,7 @@ from __future__ import print_function import collections +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -94,13 +95,15 @@ def _is_stateful_pfor_op(op): class WhileOp(object): """Object for storing state for converting the outputs of a while_loop.""" - def __init__(self, exit_node, pfor_ops): + def __init__(self, exit_node, pfor_ops, pfor_config): """Initializer. Args: exit_node: A tensor output from the while_loop. pfor_ops: list of ops inside the current pfor loop. + pfor_config: PForConfig object used while constructing loop body. """ + self._pfor_config = pfor_config self._pfor_ops = set(pfor_ops) self._pfor_op_ids = set([x._id for x in pfor_ops]) assert isinstance(exit_node, ops.Tensor) @@ -281,7 +284,8 @@ class WhileOp(object): loop_len, pfor_ops=self._pfor_ops, all_indices=indices, - all_indices_partitioned=cond_stacked) + all_indices_partitioned=cond_stacked, + pfor_config=self._pfor_config) # Map all inputs of Enter nodes in self._direct_enters to their converted # values. for enter in self._direct_enters: @@ -903,6 +907,86 @@ def _fallback_converter(pfor_input): return tuple([wrap(ta.concat(), True) for ta in ta_list]) +class PForConfig(object): + """A configuration object used to communicate with loop body function.""" + + def __init__(self): + # This may be set to the number of iterations. + self._maybe_iters = None + # Map from output placeholder to the unvectorized tensor. + self._reduce_concat_map = {} + # Reverse map of `self._reduce_concat_map`. + self._reverse_reduce_concat_map = {} + + def _has_reductions(self): + """True if some reductions where performed by loop body.""" + return len(self._reduce_concat_map) + + def _set_iters(self, iters): + """Set number of pfor iterations.""" + self._maybe_iters = iters + + # TODO(agarwal): handle reductions inside control flow constructs. + def reduce_concat(self, x): + """Performs a concat reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has rank one higher than `x`. The value is the vectorized + version of `x`, i.e. stacking the value of `x` across different pfor + iterations. + """ + assert not context.executing_eagerly() + assert isinstance(x, ops.Tensor) + if x not in self._reduce_concat_map: + out_shape = tensor_shape.TensorShape([self._maybe_iters]).concatenate( + x.shape) + with ops.control_dependencies([x]): + # Control dependency to make sure out is converted after x. + out = array_ops.placeholder(x.dtype, out_shape) + self._reduce_concat_map[out] = x + self._reverse_reduce_concat_map[x] = out + return out + else: + return self._reverse_reduce_concat_map[x] + + def reduce_mean(self, x): + """Performs a mean reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has same rank as `x`. The value is the mean of the values + of `x` across the pfor iterations. + """ + y = self.reduce_concat(x) + return math_ops.reduce_mean(y, axis=0) + + def reduce_sum(self, x): + """Performs a sum reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has same rank as `x`. The value is the sum of the values + of `x` across the pfor iterations. + """ + y = self.reduce_concat(x) + return math_ops.reduce_sum(y, axis=0) + + def _lookup_reduction(self, pl): + """Lookups Placeholder `pl` in the reduction map.""" + assert isinstance(pl, ops.Tensor) + return self._reduce_concat_map.get(pl, None) + + class PFor(object): """Implementation of rewrite of parallel-for loops. @@ -941,7 +1025,8 @@ class PFor(object): loop_len, pfor_ops, all_indices=None, - all_indices_partitioned=False): + all_indices_partitioned=False, + pfor_config=None): """Creates an object to rewrite a parallel-for loop. Args: @@ -958,6 +1043,7 @@ class PFor(object): all_indices_partitioned: If True, this object is being constructed from a control flow construct where not all the pfor iterations are guaranteed to be active. + pfor_config: PForConfig object used while constructing the loop body. """ assert isinstance(loop_var, ops.Tensor) assert loop_var.op.type == "Placeholder" @@ -976,6 +1062,7 @@ class PFor(object): self._conversion_map[loop_var] = wrap(self.all_indices, True) self._pfor_ops = set(pfor_ops) self._pfor_op_ids = set([x._id for x in pfor_ops]) + self._pfor_config = pfor_config def op_is_inside_loop(self, op): """True if op was created inside the pfor loop body.""" @@ -1113,7 +1200,8 @@ class PFor(object): is_while_loop = y_op.type == "Exit" if is_while_loop: - while_op = WhileOp(y, pfor_ops=self._pfor_ops) + while_op = WhileOp( + y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config) is_inside_loop = while_op.is_inside_loop # If all nodes in the while_loop graph were created inside the pfor, we # treat the whole loop subgraph as a single op (y_op) and try to convert @@ -1185,10 +1273,32 @@ class PFor(object): control_dependencies = [] if is_while_loop else converted_control_ops with ops.control_dependencies(control_dependencies), ops.name_scope( y_op.name + "/pfor/"): + # Op is a placeholder for a reduction. + if (self._pfor_config is not None and + self._pfor_config._lookup_reduction(y) is not None): + # Handle reductions. Map the placeholder to the unvectorized input + # that is being reduced. + reduction_input = self._pfor_config._lookup_reduction(y) + assert isinstance(reduction_input, ops.Tensor), reduction_input + # Tensor being reduced should already be converted due to a control + # dependency on the created placeholder. + # Note that in cases where reduction_input is in an outer context, one + # needs to locate the corresponding Enter node and use that to lookup + # the conversion. + # TODO(agarwal): handle reductions inside control flow constructs. + assert reduction_input in self._conversion_map, ( + "Unable to handle reduction of %s, possibly as it was used " + "inside a control flow construct. Note that reductions across " + "pfor iterations are currently not supported inside control flow " + "constructs." % reduction_input) + output = self._conversion_map[reduction_input] + # If original input is not stacked, we tile it. Also we always mark + # output as unstacked. + new_outputs = [wrap(self._unwrap_or_tile(output), False)] # None of the inputs and control inputs were converted. - if (not is_inside_loop or - (not is_stateful and not some_input_converted and - not some_control_input_converted)): + elif (not is_inside_loop or + (not is_stateful and not some_input_converted and + not some_control_input_converted)): if y == y_op: assert not isinstance(y_op, WhileOp) new_outputs = y_op