diff --git a/tensorflow/python/ops/parallel_for/BUILD b/tensorflow/python/ops/parallel_for/BUILD index 88ddf7a7ec8..3f75ec7e581 100644 --- a/tensorflow/python/ops/parallel_for/BUILD +++ b/tensorflow/python/ops/parallel_for/BUILD @@ -114,6 +114,7 @@ cuda_py_test( ":test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_v2_toggles", "//tensorflow/python:gradients", "//tensorflow/python:logging_ops", "//tensorflow/python:parsing_ops", @@ -138,6 +139,8 @@ cuda_py_test( ":test_util", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_v2_toggles", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python/compiler/xla", diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index cb84f4a16b0..11380b2dac2 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -36,11 +36,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradients as gradient_ops @@ -79,7 +81,7 @@ class PForTest(PForTestCase): x_i = array_ops.gather(x, i) return nn.top_k(x_i) - with self.assertRaisesRegexp(ValueError, "No converter defined"): + with self.assertRaisesRegexp(ValueError, "No pfor vectorization"): self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False) self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True) @@ -987,8 +989,9 @@ class WhileV1Test(PForTestCase): def loop_fn(_): return control_flow_ops.while_loop( - lambda j, x: j < 4, lambda j, x: - (j + 1, x + random_ops.random_uniform([])), [0, 0.])[0] + lambda j, x: j < 4, + lambda j, x: (j + 1, x + random_ops.random_uniform([])), + [0, 0.])[0] self._test_loop_fn(loop_fn, 3) @@ -996,8 +999,9 @@ class WhileV1Test(PForTestCase): def test_while_unstacked_condition(self): def loop_fn(i): - return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x: - (j + 1, x + i), [0, 0]) + return control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + i), [0, 0]) self._test_loop_fn(loop_fn, 3) @@ -1011,8 +1015,8 @@ class WhileV1Test(PForTestCase): lengths_i = array_ops.gather(lengths, i) _, total = control_flow_ops.while_loop( - lambda j, _: j < lengths_i, lambda j, t: - (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) + lambda j, _: j < lengths_i, + lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) return total self._test_loop_fn(loop_fn, 3) @@ -1202,6 +1206,143 @@ def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps): return pfor_output, tf_output +@test_util.run_all_in_graph_and_eager_modes +class WhileV2Test(PForTestCase): + + def setUp(self): + self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() + control_flow_v2_toggles.enable_control_flow_v2() + super(WhileV2Test, self).setUp() + + def tearDown(self): + if not self._enabled: + control_flow_v2_toggles.disable_control_flow_v2() + super(WhileV2Test, self).tearDown() + + def test_while_outside_loop(self): + def _f(): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + def loop_fn(i): + return _f() + i + + self._test_loop_fn(loop_fn, 3) + + def test_invariant_while(self): + + def loop_fn(_): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0]) + + self._test_loop_fn(loop_fn, 3) + + def test_invariant_while_with_control_dependency(self): + + def loop_fn(i): + with ops.control_dependencies([i]): + return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, + [0]) + + self._test_loop_fn(loop_fn, 3) + + def test_while_with_stateful_ops(self): + + def loop_fn(_): + j, _ = control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + random_ops.random_uniform([])), + [0, 0.]) + return j + + self._test_loop_fn(loop_fn, 3) + + def test_while_with_variable(self): + v = resource_variable_ops.ResourceVariable(5.) + + def loop_fn(_): + _, output = control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + v), [0, 0.]) + return output + + self._test_loop_fn(loop_fn, 3) + + def test_while_unstacked_condition(self): + + def loop_fn(i): + return control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + i), [0, 0]) + + self._test_loop_fn(loop_fn, 3) + + def test_while(self): + x = random_ops.random_uniform([3, 5]) + lengths = constant_op.constant([4, 0, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + lengths_i = array_ops.gather(lengths, i) + + return control_flow_ops.while_loop( + lambda j, _: j < lengths_i, + lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.]) + + self._test_loop_fn(loop_fn, 3) + + def test_while_change_input_invariance(self): + # This tests cases where a loop invariant input to while has loop dependent + # operations applied to it inside the while body. + # It also test inputs that are passed through. + def loop_fn(i): + return control_flow_ops.while_loop( + lambda j, *_: j < i, + lambda j, x, y, z, w: (j + 1, x + i, y + x, z, w), + [0, + constant_op.constant(0), + constant_op.constant(1), + i, + constant_op.constant(2)]) + + self._test_loop_fn(loop_fn, 3) + + def test_while_shape_invariants(self): + def loop_fn(i): + return control_flow_ops.while_loop( + lambda j, *_: j < 4, + lambda j, x, y: (j + 1, x + i, y + 1), + [0, constant_op.constant([0, 1]), constant_op.constant([2, 3])], + shape_invariants=[None, + tensor_shape.TensorShape([2]), + tensor_shape.TensorShape([2])]) + + self._test_loop_fn(loop_fn, 3) + + def test_while_jacobian(self): + # Note that we wrap the code below in a tf.function since we don't want the + # while_loop call to be evaluated eagerly using a python loop. + @def_function.function + def _f(x, y, use_pfor): + # out = x @ y @ y @ y @ y, where @ is matmul operator. + _, out = control_flow_ops.while_loop( + lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)), + [0, x]) + + def loop_fn(i): + out_i = array_ops.gather(out, i, axis=1) + grad = gradient_ops.gradients(out_i, x) + return array_ops.reshape(grad[0], [-1]) + + if use_pfor: + return pfor_control_flow_ops.pfor(loop_fn, iters=3) + else: + return pfor_control_flow_ops.for_loop(loop_fn, iters=3, + loop_fn_dtypes=out.dtype) + + x = constant_op.constant(np.random.uniform(size=(1, 3))) + y = constant_op.constant(np.random.uniform(size=(3, 3))) + self.assertAllClose(_f(x, y, True), _f(x, y, False)) + + @test_util.run_all_in_graph_and_eager_modes @test_util.with_control_flow_v2 class StatelessIfTest(PForTestCase): @@ -1383,8 +1524,9 @@ class Benchmarks(test.Benchmark): with ops.Graph().as_default(): def loop_fn(i): - _, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x: - (t + 1, x + i), [0, 0]) + _, s = control_flow_ops.while_loop(lambda t, x: t < i, + lambda t, x: (t + 1, x + i), + [0, 0]) return s iters = 50 diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 5c21620dc66..bece477e754 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1473,14 +1473,19 @@ class PFor(object): else: converter = _pfor_converter_registry.get(y_op.type, None) if converter is None: - if self._fallback_to_while_loop: + has_variant_outputs = any(x.dtype == dtypes.variant for x in + y_op.outputs) + if self._fallback_to_while_loop and not has_variant_outputs: converter = _fallback_converter else: - raise ValueError("No converter defined for %s\n%s\ninputs: %s. " - "\nEither add a converter or " - "enable fallback_to_while_loop " - "option to pfor, which may run slower" % - (y_op.type, y_op, converted_inputs)) + message = ("No pfor vectorization defined for %s\n" + "%s\n" + "inputs: %s. " % + (y_op.type, y_op, converted_inputs)) + if not self._fallback_to_while_loop: + message += ("Consider enabling the fallback_to_while_loop " + "option to pfor, which may run slower.") + raise ValueError(message) # TODO(rachelim): Handle the case where some inputs are sparsely # stacked. We should only call the converter if it supports handling # those inputs. @@ -3727,9 +3732,12 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs): return stacked_outputs +# TODO(agarwal): Currently the converted code aggressively tiles loop variant +# outputs from the then/else branches. Instead, it could do so only if at least +# one of the branch outputs is loop variant. @RegisterPFor("StatelessIf") @RegisterPFor("If") -def _convert_stateless_if(pfor_input): +def _convert_if(pfor_input): cond, cond_stacked, _ = pfor_input.input(0) inputs = pfor_input.inputs[1:] then_branch = pfor_input.get_attr("then_branch") @@ -3780,6 +3788,322 @@ def _convert_stateless_if(pfor_input): return [wrap(t, True) for t in outputs] +class WhileV2(object): + """Object for vectorizing V2 while_loop op.""" + + def __init__(self, pfor_input): + self._pfor_input = pfor_input + self._pfor = pfor_input.pfor + cond_func_name = pfor_input.get_attr("cond").name + self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( + cond_func_name)) + body_func_name = pfor_input.get_attr("body").name + self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( + body_func_name)) + if self._cond_func is None or self._body_func is None: + raise ValueError("Error extracting cond and body functions for op %s." % ( + self._pfor_input.op)) + # Indices of inputs that are passed unchanged through the while loop body. + # Typically these are tensors captured from outside the body context. + self._body_pass_through_indices = set() + for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, + self._body_func.graph.outputs)): + if id(inp) == id(out): + self._body_pass_through_indices.add(i) + self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") + + def _output_shapes(self): + # Calculate output shape for vectorized loop. This will be used as + # shape_invariant. Merges shape inference outputs with the `output_shapes` + # attribute of the op. + output_shapes = [out.shape for out in self._pfor_input.op.outputs] + shapes = self._pfor_input.get_attr("output_shapes") + if not shapes: + shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] + else: + shapes = [tensor_shape.TensorShape(shape) for shape in shapes] + for i, shape in enumerate(shapes): + shape = shape.merge_with(output_shapes[i]) + if self._pfor_input.input(i).is_stacked: + shape = tensor_shape.TensorShape([None]).concatenate(shape) + output_shapes[i] = shape + assert len(output_shapes) == self._pfor_input.num_inputs + return output_shapes + + def _init_values(self): + """Create arguments passed to converted while_loop.""" + loop_len = self._pfor.loop_len_vector[0] + inputs = [] + # TensorArrays for outputs of converted while loop + output_tas = [] + + with ops.name_scope("while_init"): + for inp in self._pfor_input.inputs: + inputs.append(inp.t) + output_tas.append(tensor_array_ops.TensorArray(inp.t.dtype, loop_len)) + # See documentation for __call__ for the structure of init_values. + return [True, self._pfor.all_indices] + inputs + output_tas + + def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): + """Handles case when condition is pfor loop invariant.""" + # Note that all iterations end together. So we don't need to partition the + # inputs. + not_all_done = array_ops.reshape(conditions, []) + return not_all_done, indices, inputs, output_tas + + def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, + output_tas): + """Handles case when condition is pfor loop dependent.""" + # Compute if all iterations are done. + not_all_done = math_ops.reduce_any(conditions) + conditions_int = math_ops.cast(conditions, dtypes.int32) + # Partition the indices. + done_indices, new_indices = data_flow_ops.dynamic_partition( + indices, conditions_int, 2) + + new_inputs = [] + new_output_tas = [] + for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): + pass_through = i in self._body_pass_through_indices + # Partition the inputs. + if stacked: + done_inp, new_inp = data_flow_ops.dynamic_partition( + inp, conditions_int, 2) + else: + if not pass_through: + done_inp = _stack(inp, [array_ops.size(done_indices)]).t + new_inp = inp + + new_inputs.append(new_inp) + out_ta = output_tas[i] + if not pass_through: + # Note that done_indices can be empty. done_inp should also be empty + # in that case. + out_ta = out_ta.scatter(done_indices, done_inp) + new_output_tas.append(out_ta) + + assert len(new_output_tas) == len(output_tas) + assert len(new_inputs) == len(inputs) + return not_all_done, new_indices, new_inputs, new_output_tas + + def _process_body(self, inputs_stacked, new_indices, cond_stacked, + new_inputs, not_all_done): + """Convert the body function.""" + # This is used to store the indices of inputs to the while op that need to + # be stacked. This stacking may be needed in cases where the input to the + # while_loop is loop_invariant but the corresponding output is not. + mismatching_stacked_indices = [] + + def true_fn(): + """Converts the body function for all but last iteration.""" + wrapped_inputs = [wrap(inp, stacked) for inp, stacked in + zip(new_inputs, inputs_stacked)] + # Note the iterative process below to figure out loop invariance. + # Here we iterate on vectorization process till a fixed point. The issue + # is that the while body can take pfor loop invariant inputs but return + # loop variant outputs. For any loop variant output, the corresponding + # input has to be then made loop variant (since subsequent while + # iterations will need to see loop variant values). + # However once we make a new input loop variant, we might make other + # outputs loop variant. Hence we need to iterate till we get fixed point. + while True: + body_pfor = PFor( + loop_var=self._pfor.loop_var, + loop_len=array_ops.size(new_indices), + pfor_ops=self._body_func.graph.get_operations(), + fallback_to_while_loop=self._pfor.fallback_to_while_loop, + all_indices=new_indices, + all_indices_partitioned=(self._pfor.all_indices_partitioned or + cond_stacked), + pfor_config=self._pfor.pfor_config) + stacking_mismatch = False + outputs = _convert_function_call(self._body_func, + body_pfor, + wrapped_inputs) + for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): + if out.is_stacked != inp.is_stacked: + stacking_mismatch = True + mismatching_stacked_indices.append(i) + wrapped_inputs[i] = _stack(inp.t, [array_ops.size(new_indices)]) + if not stacking_mismatch: + if mismatching_stacked_indices: + # We needed to stack some inputs. This code will be abandoned and + # should not get executed. Hence we simply return `new_inputs` to + # make sure the graph construction code completes. + with ops.control_dependencies([ + control_flow_ops.Assert( + False, ["pfor ERROR: this branch should never execute"])]): + return [array_ops.identity(x) for x in new_inputs] + else: + return [out.t for out in outputs] + + # If all are done, we simply return `new_inputs`. Else we need to run the + # body function. + return control_flow_ops.cond( + not_all_done, + true_fn, + lambda: list(new_inputs)), mismatching_stacked_indices + + def __call__(self): + """Converter for the V2 while_loop. + + The conversion of a while_loop is another while_loop. + + The arguments to this converted while_loop are as follows: + not_all_done: Boolean scalar Tensor indicating if all the pfor iterations + are done. + indices: int32 1-D Tensor storing the id of the pfor iterations that are not + done. + args: Remaining arguments. These can be divided into 2 categories: + - The first set of arguments correspond one-to-one to the inputs to the + unvectorized while_loop. + - The second set are TensorArrays, corresponding one-to-one to each output + of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` + elements, i.e. the number of pfor iterations. At the end, the i'th + element of each TensorArray will contain the output computed by the i'th + iteration of pfor. Note that elements can be written into these tensors + arrays in any order, depending on when the corresponding pfor iteration + is done. + In each iteration, the while_loop body recomputes the condition for all + active pfor iterations to see which of them are now done. It then partitions + all the inputs and passes them along to the converted body. Values for all + the iterations that are done are written to TensorArrays indexed by the pfor + iteration number. When all iterations are done, the TensorArrays are stacked + to get the final value. + + Returns: + List of converted outputs. + """ + output_shapes = self._output_shapes() + # Note that we use these lists as a hack since we need the `body` to compute + # these values during construction of the while_loop graph. + cond_is_stacked = [None] + indices_to_stack = [] + + def cond(not_all_done, *_): + return not_all_done + + def body(not_all_done, indices, *args): + # See documentation for __call__ for the structure of *args. + num_inputs = self._pfor_input.num_inputs + inputs = args[:num_inputs] + output_tas = args[num_inputs:] + inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] + assert len(inputs) >= len(output_tas) + assert len(inputs) == len(inputs_stacked) + # Convert condition + with ops.name_scope("while_cond"): + # Note that we set all_indices_partitioned to True here. At this point + # we don't know if indices will be partitioned. Hence we use the + # conservative value. + cond_pfor = PFor( + loop_var=self._pfor.loop_var, + loop_len=array_ops.size(indices), + pfor_ops=self._cond_func.graph.get_operations(), + fallback_to_while_loop=self._pfor.fallback_to_while_loop, + all_indices=indices, + all_indices_partitioned=True, + pfor_config=self._pfor.pfor_config) + + wrapped_inputs = [wrap(inp, stacked) for inp, stacked + in zip(inputs, inputs_stacked)] + conditions, cond_stacked, _ = _convert_function_call( + self._cond_func, + cond_pfor, + wrapped_inputs)[0] + cond_is_stacked[0] = cond_stacked + + # Recompute the new condition, write outputs of done iterations, and + # partition the inputs if needed. + if not cond_stacked: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_unstacked(conditions, indices, + inputs, output_tas) + else: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_stacked(conditions, indices, + inputs, inputs_stacked, + output_tas) + # Convert body + with ops.name_scope("while_body"): + # Compute the outputs from the body. + new_outputs, mismatching_stacked_indices = self._process_body( + inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) + + indices_to_stack[:] = mismatching_stacked_indices + for i, new_output in enumerate(new_outputs): + new_output.set_shape(output_shapes[i]) + new_args = ([not_all_done, new_indices] + new_outputs + + list(new_output_tas)) + return tuple(new_args) + + # Note that we run the code below in a function since we might abandon the + # generated code in cases where the conversion dictates that some inputs be + # further stacked. Hence we run the graph construction using + # `get_concrete_function` and avoid calling the constructed function if not + # needed. + @def_function.function + def while_fn(): + # Create init_values that will be passed to the while_loop. + init_values = self._init_values() + ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in + self._pfor_input.outputs] + shape_invariants = ( + [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] + + output_shapes + ta_shape_invariants) + + while_outputs = control_flow_ops.while_loop( + cond, body, init_values, + shape_invariants=shape_invariants, + parallel_iterations=self._parallel_iterations) + if indices_to_stack: + # This function will be abandoned. + return while_outputs + else: + num_inputs = self._pfor_input.num_inputs + new_inputs = while_outputs[2:num_inputs+2] + output_tas = while_outputs[num_inputs+2:] + assert cond_is_stacked[0] is not None + outputs = [] + for i, inp in enumerate(new_inputs): + if cond_is_stacked[0]: + if i in self._body_pass_through_indices: + outputs.append(init_values[i + 2]) + else: + ta = output_tas[i] + outputs.append(ta.stack()) + else: + outputs.append(inp) + return outputs + + _ = while_fn.get_concrete_function() + if indices_to_stack: + # Need to abandon the current conversion, stack some inputs and restart. + self._pfor_input.stack_inputs(stack_indices=indices_to_stack) + # Note that this call will recurse at most one time. The first call will + # do the required stacking, based on the iterative procedure in + # _process_body, and the next invocation to __call__ should not need to do + # any more stacking. + # We invoke `self()` here as a way to discard any corrupted state. + return self() + else: + outputs = while_fn() + wrapped_outputs = [] + for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): + if i not in self._body_pass_through_indices and cond_is_stacked[0]: + wrapped_outputs.append(wrap(out, True)) + else: + wrapped_outputs.append(wrap(out, inp.is_stacked)) + return wrapped_outputs + + +@RegisterPFor("StatelessWhile") +@RegisterPFor("While") +def _convert_while(pfor_input): + converter = WhileV2(pfor_input) + return converter() + + # spectral_ops diff --git a/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py index 9d0fac6db4c..b1762e2f55f 100644 --- a/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py @@ -20,13 +20,17 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.python import xla as xla_ops +from tensorflow.python.compiler.xla import jit from tensorflow.python.compiler.xla import xla from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops from tensorflow.python.ops.parallel_for.test_util import PForTestCase from tensorflow.python.platform import test @@ -118,5 +122,120 @@ class PForTest(PForTestCase): self.assertAllClose(ans_val, output_val) -if __name__ == '__main__': +def _make_unstacked(cond, body, pfor_config): + + def _cond(*args): + return math_ops.reduce_any(pfor_config.reduce_concat(args[0])) + + def _body(*args): + not_done = args[0] + args = args[1:] + not_done = math_ops.logical_and(not_done, cond(*args)) + outputs = body(*args) + return (not_done,) + tuple( + array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args)) + + return _cond, _body + + +@test_util.run_all_in_graph_and_eager_modes +class WhileV2Test(PForTestCase): + + def setUp(self): + self._enabled = control_flow_v2_toggles.control_flow_v2_enabled() + control_flow_v2_toggles.enable_control_flow_v2() + super(WhileV2Test, self).setUp() + + def tearDown(self): + if not self._enabled: + control_flow_v2_toggles.disable_control_flow_v2() + super(WhileV2Test, self).tearDown() + + def _test_loop_fn(self, loop_fn, iters, force_xla=False): + + def f(): + return pfor_control_flow_ops.pfor(loop_fn, iters) + + @def_function.function + def jit_f(): + with jit.experimental_jit_scope(): + return f() + + out = f() + jit_out = jit_f() + self.run_and_assert_equal(out, jit_out) + # TODO(agarwal): The following may complain about uncompilable nodes. Hence + # these are currently not enabled for all tests. + if force_xla: + out_exp_compile_f = def_function.function(experimental_compile=True)(f)() + self.run_and_assert_equal(out, out_exp_compile_f) + out_xla_compile_f = xla.compile(f, inputs=[]) + self.run_and_assert_equal(out, out_xla_compile_f) + + def test_stateless_while(self): + x = random_ops.random_uniform([3, 5]) + lengths = constant_op.constant([4, 0, 2]) + + def loop_fn(i): + x_i = array_ops.gather(x, i) + lengths_i = array_ops.gather(lengths, i) + + return control_flow_ops.while_loop( + lambda j, _: j < lengths_i, + lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), + [0, 0.]) + + self._test_loop_fn(loop_fn, 3) + + def test_while_with_variable(self): + v = resource_variable_ops.ResourceVariable(5.) + + def loop_fn(_): + _, output = control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + v), + [0, 0.]) + return output + + self._test_loop_fn(loop_fn, 3) + + def test_while_unstacked_condition(self): + + def loop_fn(i): + return control_flow_ops.while_loop( + lambda j, x: j < 4, + lambda j, x: (j + 1, x + i), [0, 0]) + + self._test_loop_fn(loop_fn, 3, force_xla=True) + + def test_while_force_unstacked_condition(self): + # The while_loop in this setup is similar to the one in test_stateless_while + # whose condition is loop variant. However here we wrap the cond and body of + # the loop in a way that makes the while_loop condition pfor loop invariant. + # This allows xla compilation to work since the vectorized code no longer + # needs to perform dynamic partitioning of the inputs. + x = random_ops.random_uniform([3, 5]) + lengths = constant_op.constant([4, 0, 2]) + + def loop_fn(i, pfor_config): + x_i = array_ops.gather(x, i) + lengths_i = array_ops.gather(lengths, i) + + def _cond(j, _): + return j < lengths_i + + def _body(j, t): + return (j + 1, t + array_ops.gather(x_i, j)) + + cond, body = _make_unstacked(_cond, _body, pfor_config) + return control_flow_ops.while_loop( + cond, + body, + [True, 0, 0.]) + + # b/155430349: Enabling forrce_xla=True triggers a CHECK in debug mode. + self._test_loop_fn(loop_fn, 3, force_xla=False) + + +if __name__ == "__main__": test.main()