pfor: add support for vectorizing "While" and "StatelessWhile" ops.
PiperOrigin-RevId: 309305421 Change-Id: I95caf455fb519fd7a9d736814d624089363cb7b8
This commit is contained in:
parent
e0701fcb76
commit
679da1ca0e
@ -114,6 +114,7 @@ cuda_py_test(
|
||||
":test_util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:logging_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
@ -138,6 +139,8 @@ cuda_py_test(
|
||||
":test_util",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python/compiler/xla",
|
||||
|
@ -36,11 +36,13 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import bitwise_ops
|
||||
from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradients as gradient_ops
|
||||
@ -79,7 +81,7 @@ class PForTest(PForTestCase):
|
||||
x_i = array_ops.gather(x, i)
|
||||
return nn.top_k(x_i)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "No converter defined"):
|
||||
with self.assertRaisesRegexp(ValueError, "No pfor vectorization"):
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=False)
|
||||
self._test_loop_fn(loop_fn, 3, fallback_to_while_loop=True)
|
||||
|
||||
@ -987,8 +989,9 @@ class WhileV1Test(PForTestCase):
|
||||
|
||||
def loop_fn(_):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + random_ops.random_uniform([])), [0, 0.])[0]
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
|
||||
[0, 0.])[0]
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -996,8 +999,9 @@ class WhileV1Test(PForTestCase):
|
||||
def test_while_unstacked_condition(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(lambda j, x: j < 4, lambda j, x:
|
||||
(j + 1, x + i), [0, 0])
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + i), [0, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
@ -1011,8 +1015,8 @@ class WhileV1Test(PForTestCase):
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
_, total = control_flow_ops.while_loop(
|
||||
lambda j, _: j < lengths_i, lambda j, t:
|
||||
(j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
lambda j, _: j < lengths_i,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
return total
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
@ -1202,6 +1206,143 @@ def create_dynamic_lstm(cell_fn, batch_size, state_size, max_steps):
|
||||
return pfor_output, tf_output
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WhileV2Test(PForTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
super(WhileV2Test, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
if not self._enabled:
|
||||
control_flow_v2_toggles.disable_control_flow_v2()
|
||||
super(WhileV2Test, self).tearDown()
|
||||
|
||||
def test_while_outside_loop(self):
|
||||
def _f():
|
||||
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
|
||||
|
||||
def loop_fn(i):
|
||||
return _f() + i
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_invariant_while(self):
|
||||
|
||||
def loop_fn(_):
|
||||
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_invariant_while_with_control_dependency(self):
|
||||
|
||||
def loop_fn(i):
|
||||
with ops.control_dependencies([i]):
|
||||
return control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1,
|
||||
[0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_with_stateful_ops(self):
|
||||
|
||||
def loop_fn(_):
|
||||
j, _ = control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + random_ops.random_uniform([])),
|
||||
[0, 0.])
|
||||
return j
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_with_variable(self):
|
||||
v = resource_variable_ops.ResourceVariable(5.)
|
||||
|
||||
def loop_fn(_):
|
||||
_, output = control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + v), [0, 0.])
|
||||
return output
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_unstacked_condition(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + i), [0, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while(self):
|
||||
x = random_ops.random_uniform([3, 5])
|
||||
lengths = constant_op.constant([4, 0, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, _: j < lengths_i,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)), [0, 0.])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_change_input_invariance(self):
|
||||
# This tests cases where a loop invariant input to while has loop dependent
|
||||
# operations applied to it inside the while body.
|
||||
# It also test inputs that are passed through.
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, *_: j < i,
|
||||
lambda j, x, y, z, w: (j + 1, x + i, y + x, z, w),
|
||||
[0,
|
||||
constant_op.constant(0),
|
||||
constant_op.constant(1),
|
||||
i,
|
||||
constant_op.constant(2)])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_shape_invariants(self):
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, *_: j < 4,
|
||||
lambda j, x, y: (j + 1, x + i, y + 1),
|
||||
[0, constant_op.constant([0, 1]), constant_op.constant([2, 3])],
|
||||
shape_invariants=[None,
|
||||
tensor_shape.TensorShape([2]),
|
||||
tensor_shape.TensorShape([2])])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_jacobian(self):
|
||||
# Note that we wrap the code below in a tf.function since we don't want the
|
||||
# while_loop call to be evaluated eagerly using a python loop.
|
||||
@def_function.function
|
||||
def _f(x, y, use_pfor):
|
||||
# out = x @ y @ y @ y @ y, where @ is matmul operator.
|
||||
_, out = control_flow_ops.while_loop(
|
||||
lambda i, _: i < 4, lambda i, out: (i + 1, math_ops.matmul(out, y)),
|
||||
[0, x])
|
||||
|
||||
def loop_fn(i):
|
||||
out_i = array_ops.gather(out, i, axis=1)
|
||||
grad = gradient_ops.gradients(out_i, x)
|
||||
return array_ops.reshape(grad[0], [-1])
|
||||
|
||||
if use_pfor:
|
||||
return pfor_control_flow_ops.pfor(loop_fn, iters=3)
|
||||
else:
|
||||
return pfor_control_flow_ops.for_loop(loop_fn, iters=3,
|
||||
loop_fn_dtypes=out.dtype)
|
||||
|
||||
x = constant_op.constant(np.random.uniform(size=(1, 3)))
|
||||
y = constant_op.constant(np.random.uniform(size=(3, 3)))
|
||||
self.assertAllClose(_f(x, y, True), _f(x, y, False))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.with_control_flow_v2
|
||||
class StatelessIfTest(PForTestCase):
|
||||
@ -1383,8 +1524,9 @@ class Benchmarks(test.Benchmark):
|
||||
with ops.Graph().as_default():
|
||||
|
||||
def loop_fn(i):
|
||||
_, s = control_flow_ops.while_loop(lambda t, x: t < i, lambda t, x:
|
||||
(t + 1, x + i), [0, 0])
|
||||
_, s = control_flow_ops.while_loop(lambda t, x: t < i,
|
||||
lambda t, x: (t + 1, x + i),
|
||||
[0, 0])
|
||||
return s
|
||||
|
||||
iters = 50
|
||||
|
@ -1473,14 +1473,19 @@ class PFor(object):
|
||||
else:
|
||||
converter = _pfor_converter_registry.get(y_op.type, None)
|
||||
if converter is None:
|
||||
if self._fallback_to_while_loop:
|
||||
has_variant_outputs = any(x.dtype == dtypes.variant for x in
|
||||
y_op.outputs)
|
||||
if self._fallback_to_while_loop and not has_variant_outputs:
|
||||
converter = _fallback_converter
|
||||
else:
|
||||
raise ValueError("No converter defined for %s\n%s\ninputs: %s. "
|
||||
"\nEither add a converter or "
|
||||
"enable fallback_to_while_loop "
|
||||
"option to pfor, which may run slower" %
|
||||
(y_op.type, y_op, converted_inputs))
|
||||
message = ("No pfor vectorization defined for %s\n"
|
||||
"%s\n"
|
||||
"inputs: %s. " %
|
||||
(y_op.type, y_op, converted_inputs))
|
||||
if not self._fallback_to_while_loop:
|
||||
message += ("Consider enabling the fallback_to_while_loop "
|
||||
"option to pfor, which may run slower.")
|
||||
raise ValueError(message)
|
||||
# TODO(rachelim): Handle the case where some inputs are sparsely
|
||||
# stacked. We should only call the converter if it supports handling
|
||||
# those inputs.
|
||||
@ -3727,9 +3732,12 @@ def _outputs_for_branch(func_name, indices, pfor_input, inputs):
|
||||
return stacked_outputs
|
||||
|
||||
|
||||
# TODO(agarwal): Currently the converted code aggressively tiles loop variant
|
||||
# outputs from the then/else branches. Instead, it could do so only if at least
|
||||
# one of the branch outputs is loop variant.
|
||||
@RegisterPFor("StatelessIf")
|
||||
@RegisterPFor("If")
|
||||
def _convert_stateless_if(pfor_input):
|
||||
def _convert_if(pfor_input):
|
||||
cond, cond_stacked, _ = pfor_input.input(0)
|
||||
inputs = pfor_input.inputs[1:]
|
||||
then_branch = pfor_input.get_attr("then_branch")
|
||||
@ -3780,6 +3788,322 @@ def _convert_stateless_if(pfor_input):
|
||||
return [wrap(t, True) for t in outputs]
|
||||
|
||||
|
||||
class WhileV2(object):
|
||||
"""Object for vectorizing V2 while_loop op."""
|
||||
|
||||
def __init__(self, pfor_input):
|
||||
self._pfor_input = pfor_input
|
||||
self._pfor = pfor_input.pfor
|
||||
cond_func_name = pfor_input.get_attr("cond").name
|
||||
self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes(
|
||||
cond_func_name))
|
||||
body_func_name = pfor_input.get_attr("body").name
|
||||
self._body_func = pfor_input.op.graph._get_function(compat.as_bytes(
|
||||
body_func_name))
|
||||
if self._cond_func is None or self._body_func is None:
|
||||
raise ValueError("Error extracting cond and body functions for op %s." % (
|
||||
self._pfor_input.op))
|
||||
# Indices of inputs that are passed unchanged through the while loop body.
|
||||
# Typically these are tensors captured from outside the body context.
|
||||
self._body_pass_through_indices = set()
|
||||
for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs,
|
||||
self._body_func.graph.outputs)):
|
||||
if id(inp) == id(out):
|
||||
self._body_pass_through_indices.add(i)
|
||||
self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations")
|
||||
|
||||
def _output_shapes(self):
|
||||
# Calculate output shape for vectorized loop. This will be used as
|
||||
# shape_invariant. Merges shape inference outputs with the `output_shapes`
|
||||
# attribute of the op.
|
||||
output_shapes = [out.shape for out in self._pfor_input.op.outputs]
|
||||
shapes = self._pfor_input.get_attr("output_shapes")
|
||||
if not shapes:
|
||||
shapes = [tensor_shape.TensorShape(None) for _ in output_shapes]
|
||||
else:
|
||||
shapes = [tensor_shape.TensorShape(shape) for shape in shapes]
|
||||
for i, shape in enumerate(shapes):
|
||||
shape = shape.merge_with(output_shapes[i])
|
||||
if self._pfor_input.input(i).is_stacked:
|
||||
shape = tensor_shape.TensorShape([None]).concatenate(shape)
|
||||
output_shapes[i] = shape
|
||||
assert len(output_shapes) == self._pfor_input.num_inputs
|
||||
return output_shapes
|
||||
|
||||
def _init_values(self):
|
||||
"""Create arguments passed to converted while_loop."""
|
||||
loop_len = self._pfor.loop_len_vector[0]
|
||||
inputs = []
|
||||
# TensorArrays for outputs of converted while loop
|
||||
output_tas = []
|
||||
|
||||
with ops.name_scope("while_init"):
|
||||
for inp in self._pfor_input.inputs:
|
||||
inputs.append(inp.t)
|
||||
output_tas.append(tensor_array_ops.TensorArray(inp.t.dtype, loop_len))
|
||||
# See documentation for __call__ for the structure of init_values.
|
||||
return [True, self._pfor.all_indices] + inputs + output_tas
|
||||
|
||||
def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
|
||||
"""Handles case when condition is pfor loop invariant."""
|
||||
# Note that all iterations end together. So we don't need to partition the
|
||||
# inputs.
|
||||
not_all_done = array_ops.reshape(conditions, [])
|
||||
return not_all_done, indices, inputs, output_tas
|
||||
|
||||
def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked,
|
||||
output_tas):
|
||||
"""Handles case when condition is pfor loop dependent."""
|
||||
# Compute if all iterations are done.
|
||||
not_all_done = math_ops.reduce_any(conditions)
|
||||
conditions_int = math_ops.cast(conditions, dtypes.int32)
|
||||
# Partition the indices.
|
||||
done_indices, new_indices = data_flow_ops.dynamic_partition(
|
||||
indices, conditions_int, 2)
|
||||
|
||||
new_inputs = []
|
||||
new_output_tas = []
|
||||
for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)):
|
||||
pass_through = i in self._body_pass_through_indices
|
||||
# Partition the inputs.
|
||||
if stacked:
|
||||
done_inp, new_inp = data_flow_ops.dynamic_partition(
|
||||
inp, conditions_int, 2)
|
||||
else:
|
||||
if not pass_through:
|
||||
done_inp = _stack(inp, [array_ops.size(done_indices)]).t
|
||||
new_inp = inp
|
||||
|
||||
new_inputs.append(new_inp)
|
||||
out_ta = output_tas[i]
|
||||
if not pass_through:
|
||||
# Note that done_indices can be empty. done_inp should also be empty
|
||||
# in that case.
|
||||
out_ta = out_ta.scatter(done_indices, done_inp)
|
||||
new_output_tas.append(out_ta)
|
||||
|
||||
assert len(new_output_tas) == len(output_tas)
|
||||
assert len(new_inputs) == len(inputs)
|
||||
return not_all_done, new_indices, new_inputs, new_output_tas
|
||||
|
||||
def _process_body(self, inputs_stacked, new_indices, cond_stacked,
|
||||
new_inputs, not_all_done):
|
||||
"""Convert the body function."""
|
||||
# This is used to store the indices of inputs to the while op that need to
|
||||
# be stacked. This stacking may be needed in cases where the input to the
|
||||
# while_loop is loop_invariant but the corresponding output is not.
|
||||
mismatching_stacked_indices = []
|
||||
|
||||
def true_fn():
|
||||
"""Converts the body function for all but last iteration."""
|
||||
wrapped_inputs = [wrap(inp, stacked) for inp, stacked in
|
||||
zip(new_inputs, inputs_stacked)]
|
||||
# Note the iterative process below to figure out loop invariance.
|
||||
# Here we iterate on vectorization process till a fixed point. The issue
|
||||
# is that the while body can take pfor loop invariant inputs but return
|
||||
# loop variant outputs. For any loop variant output, the corresponding
|
||||
# input has to be then made loop variant (since subsequent while
|
||||
# iterations will need to see loop variant values).
|
||||
# However once we make a new input loop variant, we might make other
|
||||
# outputs loop variant. Hence we need to iterate till we get fixed point.
|
||||
while True:
|
||||
body_pfor = PFor(
|
||||
loop_var=self._pfor.loop_var,
|
||||
loop_len=array_ops.size(new_indices),
|
||||
pfor_ops=self._body_func.graph.get_operations(),
|
||||
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
|
||||
all_indices=new_indices,
|
||||
all_indices_partitioned=(self._pfor.all_indices_partitioned or
|
||||
cond_stacked),
|
||||
pfor_config=self._pfor.pfor_config)
|
||||
stacking_mismatch = False
|
||||
outputs = _convert_function_call(self._body_func,
|
||||
body_pfor,
|
||||
wrapped_inputs)
|
||||
for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)):
|
||||
if out.is_stacked != inp.is_stacked:
|
||||
stacking_mismatch = True
|
||||
mismatching_stacked_indices.append(i)
|
||||
wrapped_inputs[i] = _stack(inp.t, [array_ops.size(new_indices)])
|
||||
if not stacking_mismatch:
|
||||
if mismatching_stacked_indices:
|
||||
# We needed to stack some inputs. This code will be abandoned and
|
||||
# should not get executed. Hence we simply return `new_inputs` to
|
||||
# make sure the graph construction code completes.
|
||||
with ops.control_dependencies([
|
||||
control_flow_ops.Assert(
|
||||
False, ["pfor ERROR: this branch should never execute"])]):
|
||||
return [array_ops.identity(x) for x in new_inputs]
|
||||
else:
|
||||
return [out.t for out in outputs]
|
||||
|
||||
# If all are done, we simply return `new_inputs`. Else we need to run the
|
||||
# body function.
|
||||
return control_flow_ops.cond(
|
||||
not_all_done,
|
||||
true_fn,
|
||||
lambda: list(new_inputs)), mismatching_stacked_indices
|
||||
|
||||
def __call__(self):
|
||||
"""Converter for the V2 while_loop.
|
||||
|
||||
The conversion of a while_loop is another while_loop.
|
||||
|
||||
The arguments to this converted while_loop are as follows:
|
||||
not_all_done: Boolean scalar Tensor indicating if all the pfor iterations
|
||||
are done.
|
||||
indices: int32 1-D Tensor storing the id of the pfor iterations that are not
|
||||
done.
|
||||
args: Remaining arguments. These can be divided into 2 categories:
|
||||
- The first set of arguments correspond one-to-one to the inputs to the
|
||||
unvectorized while_loop.
|
||||
- The second set are TensorArrays, corresponding one-to-one to each output
|
||||
of the unvectorized while_loop. Each TensorArray has `PFor.loop_len`
|
||||
elements, i.e. the number of pfor iterations. At the end, the i'th
|
||||
element of each TensorArray will contain the output computed by the i'th
|
||||
iteration of pfor. Note that elements can be written into these tensors
|
||||
arrays in any order, depending on when the corresponding pfor iteration
|
||||
is done.
|
||||
In each iteration, the while_loop body recomputes the condition for all
|
||||
active pfor iterations to see which of them are now done. It then partitions
|
||||
all the inputs and passes them along to the converted body. Values for all
|
||||
the iterations that are done are written to TensorArrays indexed by the pfor
|
||||
iteration number. When all iterations are done, the TensorArrays are stacked
|
||||
to get the final value.
|
||||
|
||||
Returns:
|
||||
List of converted outputs.
|
||||
"""
|
||||
output_shapes = self._output_shapes()
|
||||
# Note that we use these lists as a hack since we need the `body` to compute
|
||||
# these values during construction of the while_loop graph.
|
||||
cond_is_stacked = [None]
|
||||
indices_to_stack = []
|
||||
|
||||
def cond(not_all_done, *_):
|
||||
return not_all_done
|
||||
|
||||
def body(not_all_done, indices, *args):
|
||||
# See documentation for __call__ for the structure of *args.
|
||||
num_inputs = self._pfor_input.num_inputs
|
||||
inputs = args[:num_inputs]
|
||||
output_tas = args[num_inputs:]
|
||||
inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs]
|
||||
assert len(inputs) >= len(output_tas)
|
||||
assert len(inputs) == len(inputs_stacked)
|
||||
# Convert condition
|
||||
with ops.name_scope("while_cond"):
|
||||
# Note that we set all_indices_partitioned to True here. At this point
|
||||
# we don't know if indices will be partitioned. Hence we use the
|
||||
# conservative value.
|
||||
cond_pfor = PFor(
|
||||
loop_var=self._pfor.loop_var,
|
||||
loop_len=array_ops.size(indices),
|
||||
pfor_ops=self._cond_func.graph.get_operations(),
|
||||
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
|
||||
all_indices=indices,
|
||||
all_indices_partitioned=True,
|
||||
pfor_config=self._pfor.pfor_config)
|
||||
|
||||
wrapped_inputs = [wrap(inp, stacked) for inp, stacked
|
||||
in zip(inputs, inputs_stacked)]
|
||||
conditions, cond_stacked, _ = _convert_function_call(
|
||||
self._cond_func,
|
||||
cond_pfor,
|
||||
wrapped_inputs)[0]
|
||||
cond_is_stacked[0] = cond_stacked
|
||||
|
||||
# Recompute the new condition, write outputs of done iterations, and
|
||||
# partition the inputs if needed.
|
||||
if not cond_stacked:
|
||||
(not_all_done, new_indices, new_inputs,
|
||||
new_output_tas) = self._process_cond_unstacked(conditions, indices,
|
||||
inputs, output_tas)
|
||||
else:
|
||||
(not_all_done, new_indices, new_inputs,
|
||||
new_output_tas) = self._process_cond_stacked(conditions, indices,
|
||||
inputs, inputs_stacked,
|
||||
output_tas)
|
||||
# Convert body
|
||||
with ops.name_scope("while_body"):
|
||||
# Compute the outputs from the body.
|
||||
new_outputs, mismatching_stacked_indices = self._process_body(
|
||||
inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done)
|
||||
|
||||
indices_to_stack[:] = mismatching_stacked_indices
|
||||
for i, new_output in enumerate(new_outputs):
|
||||
new_output.set_shape(output_shapes[i])
|
||||
new_args = ([not_all_done, new_indices] + new_outputs +
|
||||
list(new_output_tas))
|
||||
return tuple(new_args)
|
||||
|
||||
# Note that we run the code below in a function since we might abandon the
|
||||
# generated code in cases where the conversion dictates that some inputs be
|
||||
# further stacked. Hence we run the graph construction using
|
||||
# `get_concrete_function` and avoid calling the constructed function if not
|
||||
# needed.
|
||||
@def_function.function
|
||||
def while_fn():
|
||||
# Create init_values that will be passed to the while_loop.
|
||||
init_values = self._init_values()
|
||||
ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in
|
||||
self._pfor_input.outputs]
|
||||
shape_invariants = (
|
||||
[tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])]
|
||||
+ output_shapes + ta_shape_invariants)
|
||||
|
||||
while_outputs = control_flow_ops.while_loop(
|
||||
cond, body, init_values,
|
||||
shape_invariants=shape_invariants,
|
||||
parallel_iterations=self._parallel_iterations)
|
||||
if indices_to_stack:
|
||||
# This function will be abandoned.
|
||||
return while_outputs
|
||||
else:
|
||||
num_inputs = self._pfor_input.num_inputs
|
||||
new_inputs = while_outputs[2:num_inputs+2]
|
||||
output_tas = while_outputs[num_inputs+2:]
|
||||
assert cond_is_stacked[0] is not None
|
||||
outputs = []
|
||||
for i, inp in enumerate(new_inputs):
|
||||
if cond_is_stacked[0]:
|
||||
if i in self._body_pass_through_indices:
|
||||
outputs.append(init_values[i + 2])
|
||||
else:
|
||||
ta = output_tas[i]
|
||||
outputs.append(ta.stack())
|
||||
else:
|
||||
outputs.append(inp)
|
||||
return outputs
|
||||
|
||||
_ = while_fn.get_concrete_function()
|
||||
if indices_to_stack:
|
||||
# Need to abandon the current conversion, stack some inputs and restart.
|
||||
self._pfor_input.stack_inputs(stack_indices=indices_to_stack)
|
||||
# Note that this call will recurse at most one time. The first call will
|
||||
# do the required stacking, based on the iterative procedure in
|
||||
# _process_body, and the next invocation to __call__ should not need to do
|
||||
# any more stacking.
|
||||
# We invoke `self()` here as a way to discard any corrupted state.
|
||||
return self()
|
||||
else:
|
||||
outputs = while_fn()
|
||||
wrapped_outputs = []
|
||||
for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)):
|
||||
if i not in self._body_pass_through_indices and cond_is_stacked[0]:
|
||||
wrapped_outputs.append(wrap(out, True))
|
||||
else:
|
||||
wrapped_outputs.append(wrap(out, inp.is_stacked))
|
||||
return wrapped_outputs
|
||||
|
||||
|
||||
@RegisterPFor("StatelessWhile")
|
||||
@RegisterPFor("While")
|
||||
def _convert_while(pfor_input):
|
||||
converter = WhileV2(pfor_input)
|
||||
return converter()
|
||||
|
||||
|
||||
# spectral_ops
|
||||
|
||||
|
||||
|
@ -20,13 +20,17 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.compiler.tf2xla.python import xla as xla_ops
|
||||
from tensorflow.python.compiler.xla import jit
|
||||
from tensorflow.python.compiler.xla import xla
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
|
||||
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
||||
from tensorflow.python.platform import test
|
||||
@ -118,5 +122,120 @@ class PForTest(PForTestCase):
|
||||
self.assertAllClose(ans_val, output_val)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def _make_unstacked(cond, body, pfor_config):
|
||||
|
||||
def _cond(*args):
|
||||
return math_ops.reduce_any(pfor_config.reduce_concat(args[0]))
|
||||
|
||||
def _body(*args):
|
||||
not_done = args[0]
|
||||
args = args[1:]
|
||||
not_done = math_ops.logical_and(not_done, cond(*args))
|
||||
outputs = body(*args)
|
||||
return (not_done,) + tuple(
|
||||
array_ops.where_v2(not_done, x, y) for x, y in zip(outputs, args))
|
||||
|
||||
return _cond, _body
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class WhileV2Test(PForTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
super(WhileV2Test, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
if not self._enabled:
|
||||
control_flow_v2_toggles.disable_control_flow_v2()
|
||||
super(WhileV2Test, self).tearDown()
|
||||
|
||||
def _test_loop_fn(self, loop_fn, iters, force_xla=False):
|
||||
|
||||
def f():
|
||||
return pfor_control_flow_ops.pfor(loop_fn, iters)
|
||||
|
||||
@def_function.function
|
||||
def jit_f():
|
||||
with jit.experimental_jit_scope():
|
||||
return f()
|
||||
|
||||
out = f()
|
||||
jit_out = jit_f()
|
||||
self.run_and_assert_equal(out, jit_out)
|
||||
# TODO(agarwal): The following may complain about uncompilable nodes. Hence
|
||||
# these are currently not enabled for all tests.
|
||||
if force_xla:
|
||||
out_exp_compile_f = def_function.function(experimental_compile=True)(f)()
|
||||
self.run_and_assert_equal(out, out_exp_compile_f)
|
||||
out_xla_compile_f = xla.compile(f, inputs=[])
|
||||
self.run_and_assert_equal(out, out_xla_compile_f)
|
||||
|
||||
def test_stateless_while(self):
|
||||
x = random_ops.random_uniform([3, 5])
|
||||
lengths = constant_op.constant([4, 0, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, _: j < lengths_i,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(x_i, j)),
|
||||
[0, 0.])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_with_variable(self):
|
||||
v = resource_variable_ops.ResourceVariable(5.)
|
||||
|
||||
def loop_fn(_):
|
||||
_, output = control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + v),
|
||||
[0, 0.])
|
||||
return output
|
||||
|
||||
self._test_loop_fn(loop_fn, 3)
|
||||
|
||||
def test_while_unstacked_condition(self):
|
||||
|
||||
def loop_fn(i):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, x: j < 4,
|
||||
lambda j, x: (j + 1, x + i), [0, 0])
|
||||
|
||||
self._test_loop_fn(loop_fn, 3, force_xla=True)
|
||||
|
||||
def test_while_force_unstacked_condition(self):
|
||||
# The while_loop in this setup is similar to the one in test_stateless_while
|
||||
# whose condition is loop variant. However here we wrap the cond and body of
|
||||
# the loop in a way that makes the while_loop condition pfor loop invariant.
|
||||
# This allows xla compilation to work since the vectorized code no longer
|
||||
# needs to perform dynamic partitioning of the inputs.
|
||||
x = random_ops.random_uniform([3, 5])
|
||||
lengths = constant_op.constant([4, 0, 2])
|
||||
|
||||
def loop_fn(i, pfor_config):
|
||||
x_i = array_ops.gather(x, i)
|
||||
lengths_i = array_ops.gather(lengths, i)
|
||||
|
||||
def _cond(j, _):
|
||||
return j < lengths_i
|
||||
|
||||
def _body(j, t):
|
||||
return (j + 1, t + array_ops.gather(x_i, j))
|
||||
|
||||
cond, body = _make_unstacked(_cond, _body, pfor_config)
|
||||
return control_flow_ops.while_loop(
|
||||
cond,
|
||||
body,
|
||||
[True, 0, 0.])
|
||||
|
||||
# b/155430349: Enabling forrce_xla=True triggers a CHECK in debug mode.
|
||||
self._test_loop_fn(loop_fn, 3, force_xla=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user