pfor: support vectorizing nested control flow.
PiperOrigin-RevId: 315810423 Change-Id: Ida6ad51bfe5b9bbd865d93e8e169299b9b011eb3
This commit is contained in:
parent
18dc63baeb
commit
0e8cc390c2
@ -768,6 +768,16 @@ class LoggingTest(PForTestCase):
|
||||
|
||||
class TensorArrayTest(PForTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.disable_control_flow_v2()
|
||||
super(TensorArrayTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
if self._enabled:
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
super(TensorArrayTest, self).tearDown()
|
||||
|
||||
@test_util.run_v1_only("b/122612051")
|
||||
def test_create_outside_and_read(self):
|
||||
|
||||
@ -1088,6 +1098,16 @@ class StackTest(PForTestCase):
|
||||
# tf.cond.
|
||||
class WhileV1Test(PForTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.disable_control_flow_v2()
|
||||
super(WhileV1Test, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
if self._enabled:
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
super(WhileV1Test, self).tearDown()
|
||||
|
||||
def test_while_outside_loop(self):
|
||||
|
||||
x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
|
||||
@ -1474,6 +1494,65 @@ class WhileV2Test(PForTestCase):
|
||||
self.assertAllClose(_f(x, y, True), _f(x, y, False))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class NestedControlFlowTest(PForTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
super(NestedControlFlowTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
if not self._enabled:
|
||||
control_flow_v2_toggles.disable_control_flow_v2()
|
||||
super(NestedControlFlowTest, self).tearDown()
|
||||
|
||||
def _cond(self, f=None, split=0):
|
||||
if f is None:
|
||||
f = lambda x, y: (x, y)
|
||||
|
||||
def _f(x, y):
|
||||
return control_flow_ops.cond(y > split,
|
||||
lambda: f(x, y),
|
||||
lambda: (x + 1., y))
|
||||
return _f
|
||||
|
||||
def _while(self, f=None):
|
||||
if f is None:
|
||||
f = lambda x, y: (x, y)
|
||||
|
||||
def _f(x, y):
|
||||
return control_flow_ops.while_loop(
|
||||
lambda j, _: j < y,
|
||||
lambda j, t: (j + 1, t + array_ops.gather(f(x, y)[0], j)),
|
||||
[0, x])[1], y
|
||||
|
||||
return _f
|
||||
|
||||
def _test_helper(self, f):
|
||||
x = random_ops.random_uniform([5, 5])
|
||||
y = constant_op.constant([4, -1, 2, -2, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
x_i = array_ops.gather(x, i)
|
||||
y_i = array_ops.gather(y, i)
|
||||
return f(x_i, y_i)
|
||||
|
||||
self._test_loop_fn(loop_fn, 5)
|
||||
|
||||
def test_cond_while(self):
|
||||
self._test_helper(self._cond(self._while()))
|
||||
|
||||
def test_while_cond(self):
|
||||
self._test_helper(self._while(self._cond()))
|
||||
|
||||
def test_while_while(self):
|
||||
self._test_helper(self._while(self._while()))
|
||||
|
||||
def test_cond_cond(self):
|
||||
self._test_helper(self._cond(self._cond()))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.with_control_flow_v2
|
||||
class StatelessIfTest(PForTestCase):
|
||||
|
||||
@ -4014,7 +4014,8 @@ def _convert_if(pfor_input):
|
||||
# Compute indices for cond being True or False.
|
||||
if pfor_input.pfor.all_indices_partitioned:
|
||||
else_indices, then_indices = data_flow_ops.dynamic_partition(
|
||||
array_ops.range(len(pfor_input.pfor.all_indices)), cond_int, 2)
|
||||
math_ops.range(pfor_input.pfor.loop_len_vector[0]),
|
||||
cond_int, 2)
|
||||
else:
|
||||
else_indices, then_indices = false_indices, true_indices
|
||||
# Partition inputs
|
||||
@ -4103,9 +4104,16 @@ class WhileV2(object):
|
||||
with ops.name_scope("while_init"):
|
||||
for inp in self._pfor_input.inputs:
|
||||
inputs.append(inp.t)
|
||||
output_tas.append(tensor_array_ops.TensorArray(inp.t.dtype, loop_len))
|
||||
output_tas.append(tensor_array_ops.TensorArray(
|
||||
inp.t.dtype,
|
||||
size=loop_len,
|
||||
dynamic_size=False,
|
||||
infer_shape=True))
|
||||
# See documentation for __call__ for the structure of init_values.
|
||||
return [True, self._pfor.all_indices] + inputs + output_tas
|
||||
indices = (
|
||||
math_ops.range(self._pfor.loop_len_vector[0])
|
||||
if self._pfor.all_indices_partitioned else self._pfor.all_indices)
|
||||
return [True, indices] + inputs + output_tas
|
||||
|
||||
def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
|
||||
"""Handles case when condition is pfor loop invariant."""
|
||||
@ -4170,12 +4178,16 @@ class WhileV2(object):
|
||||
# However once we make a new input loop variant, we might make other
|
||||
# outputs loop variant. Hence we need to iterate till we get fixed point.
|
||||
while True:
|
||||
if self._pfor.all_indices_partitioned:
|
||||
indices = array_ops.gather(self._pfor.all_indices, new_indices)
|
||||
else:
|
||||
indices = new_indices
|
||||
body_pfor = PFor(
|
||||
loop_var=self._pfor.loop_var,
|
||||
loop_len=array_ops.size(new_indices),
|
||||
pfor_ops=self._body_func.graph.get_operations(),
|
||||
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
|
||||
all_indices=new_indices,
|
||||
all_indices=indices,
|
||||
all_indices_partitioned=(self._pfor.all_indices_partitioned or
|
||||
cond_stacked),
|
||||
pfor_config=self._pfor.pfor_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user