pfor: add support for vectorizing "While" and "StatelessWhile" ops.

PiperOrigin-RevId: 309305421
Change-Id: I95caf455fb519fd7a9d736814d624089363cb7b8
This commit is contained in:
A. Unique TensorFlower 2020-04-30 14:56:16 -07:00 committed by TensorFlower Gardener
parent e0701fcb76
commit 679da1ca0e
4 changed files with 605 additions and 17 deletions

View File

@ -114,6 +114,7 @@ cuda_py_test(
":test_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_v2_toggles",
"//tensorflow/python:gradients",
"//tensorflow/python:logging_ops",
"//tensorflow/python:parsing_ops",
@ -138,6 +139,8 @@ cuda_py_test(
":test_util",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_v2_toggles",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python/compiler/xla",

View File

@ -36,11 +36,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients as gradient_ops
@ -79,7 +81,7 @@ class PForTest(PForTestCase):
x_i = array_ops.gather(x, i)
return nn.top_k(x_i)
with self.assertRaisesRegexp(ValueError, "No converter defined"):
with self.assertRaisesRegexp(ValueError, "No pfor vectorization"):
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
@ -987,8 +989,9 @@ class WhileV1Test(PForTestCase):
def loop_fn(_):
return control_flow_ops.while_loop(
lambda j, x: j < 4, lambda j, x:
(j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
lambda j, x: j < 4,
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
[0, 0.])[0]
self._test_loop_fn(loop_fn, 3)
@ -996,8 +999,9 @@ class WhileV1Test(PForTestCase):
def test_while_unstacked_condition(self):
def loop_fn(i):
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
(j + 1, x + i), [0, 0])
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + i), [0, 0])
self._test_loop_fn(loop_fn, 3)
@ -1011,8 +1015,8 @@ class WhileV1Test(PForTestCase):
lengths_i = array_ops.gather(lengths, i)
_, total = control_flow_ops.while_loop(
lambda j, _: j < lengths_i, lambda j, t:
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
lambda j, _: j < lengths_i,
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
return total
self._test_loop_fn(loop_fn, 3)
@ -1202,6 +1206,143 @@ def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
return pfor_output, tf_output
@test_util.run_all_in_graph_and_eager_modes
class WhileV2Test(PForTestCase):
def setUp(self):
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.enable_control_flow_v2()
super(WhileV2Test, self).setUp()
def tearDown(self):
if not self._enabled:
control_flow_v2_toggles.disable_control_flow_v2()
super(WhileV2Test, self).tearDown()
def test_while_outside_loop(self):
def _f():
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
def loop_fn(i):
return _f() + i
self._test_loop_fn(loop_fn, 3)
def test_invariant_while(self):
def loop_fn(_):
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
self._test_loop_fn(loop_fn, 3)
def test_invariant_while_with_control_dependency(self):
def loop_fn(i):
with ops.control_dependencies([i]):
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
[0])
self._test_loop_fn(loop_fn, 3)
def test_while_with_stateful_ops(self):
def loop_fn(_):
j, _ = control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
[0, 0.])
return j
self._test_loop_fn(loop_fn, 3)
def test_while_with_variable(self):
v = resource_variable_ops.ResourceVariable(5.)
def loop_fn(_):
_, output = control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + v), [0, 0.])
return output
self._test_loop_fn(loop_fn, 3)
def test_while_unstacked_condition(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + i), [0, 0])
self._test_loop_fn(loop_fn, 3)
def test_while(self):
x = random_ops.random_uniform([3, 5])
lengths = constant_op.constant([4, 0, 2])
def loop_fn(i):
x_i = array_ops.gather(x, i)
lengths_i = array_ops.gather(lengths, i)
return control_flow_ops.while_loop(
lambda j, _: j < lengths_i,
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
self._test_loop_fn(loop_fn, 3)
def test_while_change_input_invariance(self):
# This tests cases where a loop invariant input to while has loop dependent
# operations applied to it inside the while body.
# It also test inputs that are passed through.
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, *_: j < i,
lambda j, x, y, z, w: (j + 1, x + i, y + x, z, w),
[0,
constant_op.constant(0),
constant_op.constant(1),
i,
constant_op.constant(2)])
self._test_loop_fn(loop_fn, 3)
def test_while_shape_invariants(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, *_: j < 4,
lambda j, x, y: (j + 1, x + i, y + 1),
[0, constant_op.constant([0, 1]), constant_op.constant([2, 3])],
shape_invariants=[None,
tensor_shape.TensorShape([2]),
tensor_shape.TensorShape([2])])
self._test_loop_fn(loop_fn, 3)
def test_while_jacobian(self):
# Note that we wrap the code below in a tf.function since we don't want the
# while_loop call to be evaluated eagerly using a python loop.
@def_function.function
def _f(x, y, use_pfor):
# out = x @ y @ y @ y @ y, where @ is matmul operator.
_, out = control_flow_ops.while_loop(
lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
[0, x])
def loop_fn(i):
out_i = array_ops.gather(out, i, axis=1)
grad = gradient_ops.gradients(out_i, x)
return array_ops.reshape(grad[0], [-1])
if use_pfor:
return pfor_control_flow_ops.pfor(loop_fn, iters=3)
else:
return pfor_control_flow_ops.for_loop(loop_fn, iters=3,
loop_fn_dtypes=out.dtype)
x = constant_op.constant(np.random.uniform(size=(1, 3)))
y = constant_op.constant(np.random.uniform(size=(3, 3)))
self.assertAllClose(_f(x, y, True), _f(x, y, False))
@test_util.run_all_in_graph_and_eager_modes
@test_util.with_control_flow_v2
class StatelessIfTest(PForTestCase):
@ -1383,8 +1524,9 @@ class Benchmarks(test.Benchmark):
with ops.Graph().as_default():
def loop_fn(i):
_, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
(t + 1, x + i), [0, 0])
_, s = control_flow_ops.while_loop(lambda t, x: t < i,
lambda t, x: (t + 1, x + i),
[0, 0])
return s
iters = 50

View File

@ -1473,14 +1473,19 @@ class PFor(object):
else:
converter = _pfor_converter_registry.get(y_op.type, None)
if converter is None:
if self._fallback_to_while_loop:
has_variant_outputs = any(x.dtype == dtypes.variant for x in
y_op.outputs)
if self._fallback_to_while_loop and not has_variant_outputs:
converter = _fallback_converter
else:
raise ValueError("No converter defined for %s\n%s\ninputs: %s. "
"\nEither add a converter or "
"enable fallback_to_while_loop "
"option to pfor, which may run slower" %
(y_op.type, y_op, converted_inputs))
message = ("No pfor vectorization defined for %s\n"
"%s\n"
"inputs: %s. " %
(y_op.type, y_op, converted_inputs))
if not self._fallback_to_while_loop:
message += ("Consider enabling the fallback_to_while_loop "
"option to pfor, which may run slower.")
raise ValueError(message)
# TODO(rachelim): Handle the case where some inputs are sparsely
# stacked. We should only call the converter if it supports handling
# those inputs.
@ -3727,9 +3732,12 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs):
return stacked_outputs
# TODO(agarwal): Currently the converted code aggressively tiles loop variant
# outputs from the then/else branches. Instead, it could do so only if at least
# one of the branch outputs is loop variant.
@RegisterPFor("StatelessIf")
@RegisterPFor("If")
def _convert_stateless_if(pfor_input):
def _convert_if(pfor_input):
cond, cond_stacked, _ = pfor_input.input(0)
inputs = pfor_input.inputs[1:]
then_branch = pfor_input.get_attr("then_branch")
@ -3780,6 +3788,322 @@ def _convert_stateless_if(pfor_input):
return [wrap(t, True) for t in outputs]
class WhileV2(object):
"""Object for vectorizing V2 while_loop op."""
def __init__(self, pfor_input):
self._pfor_input = pfor_input
self._pfor = pfor_input.pfor
cond_func_name = pfor_input.get_attr("cond").name
self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
cond_func_name))
body_func_name = pfor_input.get_attr("body").name
self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
body_func_name))
if self._cond_func is None or self._body_func is None:
raise ValueError("Error extracting cond and body functions for op %s." % (
self._pfor_input.op))
# Indices of inputs that are passed unchanged through the while loop body.
# Typically these are tensors captured from outside the body context.
self._body_pass_through_indices = set()
for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
self._body_func.graph.outputs)):
if id(inp) == id(out):
self._body_pass_through_indices.add(i)
self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
def _output_shapes(self):
# Calculate output shape for vectorized loop. This will be used as
# shape_invariant. Merges shape inference outputs with the `output_shapes`
# attribute of the op.
output_shapes = [out.shape for out in self._pfor_input.op.outputs]
shapes = self._pfor_input.get_attr("output_shapes")
if not shapes:
shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
else:
shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
for i, shape in enumerate(shapes):
shape = shape.merge_with(output_shapes[i])
if self._pfor_input.input(i).is_stacked:
shape = tensor_shape.TensorShape([None]).concatenate(shape)
output_shapes[i] = shape
assert len(output_shapes) == self._pfor_input.num_inputs
return output_shapes
def _init_values(self):
"""Create arguments passed to converted while_loop."""
loop_len = self._pfor.loop_len_vector[0]
inputs = []
# TensorArrays for outputs of converted while loop
output_tas = []
with ops.name_scope("while_init"):
for inp in self._pfor_input.inputs:
inputs.append(inp.t)
output_tas.append(tensor_array_ops.TensorArray(inp.t.dtype, loop_len))
# See documentation for __call__ for the structure of init_values.
return [True, self._pfor.all_indices] + inputs + output_tas
def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
"""Handles case when condition is pfor loop invariant."""
# Note that all iterations end together. So we don't need to partition the
# inputs.
not_all_done = array_ops.reshape(conditions, [])
return not_all_done, indices, inputs, output_tas
def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
output_tas):
"""Handles case when condition is pfor loop dependent."""
# Compute if all iterations are done.
not_all_done = math_ops.reduce_any(conditions)
conditions_int = math_ops.cast(conditions, dtypes.int32)
# Partition the indices.
done_indices, new_indices = data_flow_ops.dynamic_partition(
indices, conditions_int, 2)
new_inputs = []
new_output_tas = []
for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
pass_through = i in self._body_pass_through_indices
# Partition the inputs.
if stacked:
done_inp, new_inp = data_flow_ops.dynamic_partition(
inp, conditions_int, 2)
else:
if not pass_through:
done_inp = _stack(inp, [array_ops.size(done_indices)]).t
new_inp = inp
new_inputs.append(new_inp)
out_ta = output_tas[i]
if not pass_through:
# Note that done_indices can be empty. done_inp should also be empty
# in that case.
out_ta = out_ta.scatter(done_indices, done_inp)
new_output_tas.append(out_ta)
assert len(new_output_tas) == len(output_tas)
assert len(new_inputs) == len(inputs)
return not_all_done, new_indices, new_inputs, new_output_tas
def _process_body(self, inputs_stacked, new_indices, cond_stacked,
new_inputs, not_all_done):
"""Convert the body function."""
# This is used to store the indices of inputs to the while op that need to
# be stacked. This stacking may be needed in cases where the input to the
# while_loop is loop_invariant but the corresponding output is not.
mismatching_stacked_indices = []
def true_fn():
"""Converts the body function for all but last iteration."""
wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
zip(new_inputs, inputs_stacked)]
# Note the iterative process below to figure out loop invariance.
# Here we iterate on vectorization process till a fixed point. The issue
# is that the while body can take pfor loop invariant inputs but return
# loop variant outputs. For any loop variant output, the corresponding
# input has to be then made loop variant (since subsequent while
# iterations will need to see loop variant values).
# However once we make a new input loop variant, we might make other
# outputs loop variant. Hence we need to iterate till we get fixed point.
while True:
body_pfor = PFor(
loop_var=self._pfor.loop_var,
loop_len=array_ops.size(new_indices),
pfor_ops=self._body_func.graph.get_operations(),
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
all_indices=new_indices,
all_indices_partitioned=(self._pfor.all_indices_partitioned or
cond_stacked),
pfor_config=self._pfor.pfor_config)
stacking_mismatch = False
outputs = _convert_function_call(self._body_func,
body_pfor,
wrapped_inputs)
for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
if out.is_stacked != inp.is_stacked:
stacking_mismatch = True
mismatching_stacked_indices.append(i)
wrapped_inputs[i] = _stack(inp.t, [array_ops.size(new_indices)])
if not stacking_mismatch:
if mismatching_stacked_indices:
# We needed to stack some inputs. This code will be abandoned and
# should not get executed. Hence we simply return `new_inputs` to
# make sure the graph construction code completes.
with ops.control_dependencies([
control_flow_ops.Assert(
False, ["pfor ERROR: this branch should never execute"])]):
return [array_ops.identity(x) for x in new_inputs]
else:
return [out.t for out in outputs]
# If all are done, we simply return `new_inputs`. Else we need to run the
# body function.
return control_flow_ops.cond(
not_all_done,
true_fn,
lambda: list(new_inputs)), mismatching_stacked_indices
def __call__(self):
"""Converter for the V2 while_loop.
The conversion of a while_loop is another while_loop.
The arguments to this converted while_loop are as follows:
not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
are done.
indices: int32 1-D Tensor storing the id of the pfor iterations that are not
done.
args: Remaining arguments. These can be divided into 2 categories:
- The first set of arguments correspond one-to-one to the inputs to the
unvectorized while_loop.
- The second set are TensorArrays, corresponding one-to-one to each output
of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
elements, i.e. the number of pfor iterations. At the end, the i'th
element of each TensorArray will contain the output computed by the i'th
iteration of pfor. Note that elements can be written into these tensors
arrays in any order, depending on when the corresponding pfor iteration
is done.
In each iteration, the while_loop body recomputes the condition for all
active pfor iterations to see which of them are now done. It then partitions
all the inputs and passes them along to the converted body. Values for all
the iterations that are done are written to TensorArrays indexed by the pfor
iteration number. When all iterations are done, the TensorArrays are stacked
to get the final value.
Returns:
List of converted outputs.
"""
output_shapes = self._output_shapes()
# Note that we use these lists as a hack since we need the `body` to compute
# these values during construction of the while_loop graph.
cond_is_stacked = [None]
indices_to_stack = []
def cond(not_all_done, *_):
return not_all_done
def body(not_all_done, indices, *args):
# See documentation for __call__ for the structure of *args.
num_inputs = self._pfor_input.num_inputs
inputs = args[:num_inputs]
output_tas = args[num_inputs:]
inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
assert len(inputs) >= len(output_tas)
assert len(inputs) == len(inputs_stacked)
# Convert condition
with ops.name_scope("while_cond"):
# Note that we set all_indices_partitioned to True here. At this point
# we don't know if indices will be partitioned. Hence we use the
# conservative value.
cond_pfor = PFor(
loop_var=self._pfor.loop_var,
loop_len=array_ops.size(indices),
pfor_ops=self._cond_func.graph.get_operations(),
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
all_indices=indices,
all_indices_partitioned=True,
pfor_config=self._pfor.pfor_config)
wrapped_inputs = [wrap(inp, stacked) for inp, stacked
in zip(inputs, inputs_stacked)]
conditions, cond_stacked, _ = _convert_function_call(
self._cond_func,
cond_pfor,
wrapped_inputs)[0]
cond_is_stacked[0] = cond_stacked
# Recompute the new condition, write outputs of done iterations, and
# partition the inputs if needed.
if not cond_stacked:
(not_all_done, new_indices, new_inputs,
new_output_tas) = self._process_cond_unstacked(conditions, indices,
inputs, output_tas)
else:
(not_all_done, new_indices, new_inputs,
new_output_tas) = self._process_cond_stacked(conditions, indices,
inputs, inputs_stacked,
output_tas)
# Convert body
with ops.name_scope("while_body"):
# Compute the outputs from the body.
new_outputs, mismatching_stacked_indices = self._process_body(
inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
indices_to_stack[:] = mismatching_stacked_indices
for i, new_output in enumerate(new_outputs):
new_output.set_shape(output_shapes[i])
new_args = ([not_all_done, new_indices] + new_outputs +
list(new_output_tas))
return tuple(new_args)
# Note that we run the code below in a function since we might abandon the
# generated code in cases where the conversion dictates that some inputs be
# further stacked. Hence we run the graph construction using
# `get_concrete_function` and avoid calling the constructed function if not
# needed.
@def_function.function
def while_fn():
# Create init_values that will be passed to the while_loop.
init_values = self._init_values()
ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
self._pfor_input.outputs]
shape_invariants = (
[tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
+ output_shapes + ta_shape_invariants)
while_outputs = control_flow_ops.while_loop(
cond, body, init_values,
shape_invariants=shape_invariants,
parallel_iterations=self._parallel_iterations)
if indices_to_stack:
# This function will be abandoned.
return while_outputs
else:
num_inputs = self._pfor_input.num_inputs
new_inputs = while_outputs[2:num_inputs+2]
output_tas = while_outputs[num_inputs+2:]
assert cond_is_stacked[0] is not None
outputs = []
for i, inp in enumerate(new_inputs):
if cond_is_stacked[0]:
if i in self._body_pass_through_indices:
outputs.append(init_values[i + 2])
else:
ta = output_tas[i]
outputs.append(ta.stack())
else:
outputs.append(inp)
return outputs
_ = while_fn.get_concrete_function()
if indices_to_stack:
# Need to abandon the current conversion, stack some inputs and restart.
self._pfor_input.stack_inputs(stack_indices=indices_to_stack)
# Note that this call will recurse at most one time. The first call will
# do the required stacking, based on the iterative procedure in
# _process_body, and the next invocation to __call__ should not need to do
# any more stacking.
# We invoke `self()` here as a way to discard any corrupted state.
return self()
else:
outputs = while_fn()
wrapped_outputs = []
for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
if i not in self._body_pass_through_indices and cond_is_stacked[0]:
wrapped_outputs.append(wrap(out, True))
else:
wrapped_outputs.append(wrap(out, inp.is_stacked))
return wrapped_outputs
@RegisterPFor("StatelessWhile")
@RegisterPFor("While")
def _convert_while(pfor_input):
converter = WhileV2(pfor_input)
return converter()
# spectral_ops

View File

@ -20,13 +20,17 @@ from __future__ import division
from __future__ import print_function
from tensorflow.compiler.tf2xla.python import xla as xla_ops
from tensorflow.python.compiler.xla import jit
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.platform import test
@ -118,5 +122,120 @@ class PForTest(PForTestCase):
self.assertAllClose(ans_val, output_val)
if __name__ == '__main__':
def _make_unstacked(cond, body, pfor_config):
def _cond(*args):
return math_ops.reduce_any(pfor_config.reduce_concat(args[0]))
def _body(*args):
not_done = args[0]
args = args[1:]
not_done = math_ops.logical_and(not_done, cond(*args))
outputs = body(*args)
return (not_done,) + tuple(
array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args))
return _cond, _body
@test_util.run_all_in_graph_and_eager_modes
class WhileV2Test(PForTestCase):
def setUp(self):
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.enable_control_flow_v2()
super(WhileV2Test, self).setUp()
def tearDown(self):
if not self._enabled:
control_flow_v2_toggles.disable_control_flow_v2()
super(WhileV2Test, self).tearDown()
def _test_loop_fn(self, loop_fn, iters, force_xla=False):
def f():
return pfor_control_flow_ops.pfor(loop_fn, iters)
@def_function.function
def jit_f():
with jit.experimental_jit_scope():
return f()
out = f()
jit_out = jit_f()
self.run_and_assert_equal(out, jit_out)
# TODO(agarwal): The following may complain about uncompilable nodes. Hence
# these are currently not enabled for all tests.
if force_xla:
out_exp_compile_f = def_function.function(experimental_compile=True)(f)()
self.run_and_assert_equal(out, out_exp_compile_f)
out_xla_compile_f = xla.compile(f, inputs=[])
self.run_and_assert_equal(out, out_xla_compile_f)
def test_stateless_while(self):
x = random_ops.random_uniform([3, 5])
lengths = constant_op.constant([4, 0, 2])
def loop_fn(i):
x_i = array_ops.gather(x, i)
lengths_i = array_ops.gather(lengths, i)
return control_flow_ops.while_loop(
lambda j, _: j < lengths_i,
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)),
[0, 0.])
self._test_loop_fn(loop_fn, 3)
def test_while_with_variable(self):
v = resource_variable_ops.ResourceVariable(5.)
def loop_fn(_):
_, output = control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + v),
[0, 0.])
return output
self._test_loop_fn(loop_fn, 3)
def test_while_unstacked_condition(self):
def loop_fn(i):
return control_flow_ops.while_loop(
lambda j, x: j < 4,
lambda j, x: (j + 1, x + i), [0, 0])
self._test_loop_fn(loop_fn, 3, force_xla=True)
def test_while_force_unstacked_condition(self):
# The while_loop in this setup is similar to the one in test_stateless_while
# whose condition is loop variant. However here we wrap the cond and body of
# the loop in a way that makes the while_loop condition pfor loop invariant.
# This allows xla compilation to work since the vectorized code no longer
# needs to perform dynamic partitioning of the inputs.
x = random_ops.random_uniform([3, 5])
lengths = constant_op.constant([4, 0, 2])
def loop_fn(i, pfor_config):
x_i = array_ops.gather(x, i)
lengths_i = array_ops.gather(lengths, i)
def _cond(j, _):
return j < lengths_i
def _body(j, t):
return (j + 1, t + array_ops.gather(x_i, j))
cond, body = _make_unstacked(_cond, _body, pfor_config)
return control_flow_ops.while_loop(
cond,
body,
[True, 0, 0.])
# b/155430349: Enabling forrce_xla=True triggers a CHECK in debug mode.
self._test_loop_fn(loop_fn, 3, force_xla=False)
if __name__ == "__main__":
test.main()