pfor: support vectorizing nested control flow.

PiperOrigin-RevId: 315810423
Change-Id: Ida6ad51bfe5b9bbd865d93e8e169299b9b011eb3
This commit is contained in:
A. Unique TensorFlower 2020-06-10 18:35:22 -07:00 committed by TensorFlower Gardener
parent 18dc63baeb
commit 0e8cc390c2
2 changed files with 95 additions and 4 deletions

View File

@ -768,6 +768,16 @@ class LoggingTest(PForTestCase):
class TensorArrayTest(PForTestCase):
def setUp(self):
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.disable_control_flow_v2()
super(TensorArrayTest, self).setUp()
def tearDown(self):
if self._enabled:
control_flow_v2_toggles.enable_control_flow_v2()
super(TensorArrayTest, self).tearDown()
@test_util.run_v1_only("b/122612051")
def test_create_outside_and_read(self):
@ -1088,6 +1098,16 @@ class StackTest(PForTestCase):
# tf.cond.
class WhileV1Test(PForTestCase):
def setUp(self):
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.disable_control_flow_v2()
super(WhileV1Test, self).setUp()
def tearDown(self):
if self._enabled:
control_flow_v2_toggles.enable_control_flow_v2()
super(WhileV1Test, self).tearDown()
def test_while_outside_loop(self):
x = control_flow_ops.while_loop(lambda j: j < 4, lambda j: j + 1, [0])
@ -1474,6 +1494,65 @@ class WhileV2Test(PForTestCase):
self.assertAllClose(_f(x, y, True), _f(x, y, False))
@test_util.run_all_in_graph_and_eager_modes
class NestedControlFlowTest(PForTestCase):
def setUp(self):
self._enabled = control_flow_v2_toggles.control_flow_v2_enabled()
control_flow_v2_toggles.enable_control_flow_v2()
super(NestedControlFlowTest, self).setUp()
def tearDown(self):
if not self._enabled:
control_flow_v2_toggles.disable_control_flow_v2()
super(NestedControlFlowTest, self).tearDown()
def _cond(self, f=None, split=0):
if f is None:
f = lambda x, y: (x, y)
def _f(x, y):
return control_flow_ops.cond(y > split,
lambda: f(x, y),
lambda: (x + 1., y))
return _f
def _while(self, f=None):
if f is None:
f = lambda x, y: (x, y)
def _f(x, y):
return control_flow_ops.while_loop(
lambda j, _: j < y,
lambda j, t: (j + 1, t + array_ops.gather(f(x, y)[0], j)),
[0, x])[1], y
return _f
def _test_helper(self, f):
x = random_ops.random_uniform([5, 5])
y = constant_op.constant([4, -1, 2, -2, 2])
def loop_fn(i):
x_i = array_ops.gather(x, i)
y_i = array_ops.gather(y, i)
return f(x_i, y_i)
self._test_loop_fn(loop_fn, 5)
def test_cond_while(self):
self._test_helper(self._cond(self._while()))
def test_while_cond(self):
self._test_helper(self._while(self._cond()))
def test_while_while(self):
self._test_helper(self._while(self._while()))
def test_cond_cond(self):
self._test_helper(self._cond(self._cond()))
@test_util.run_all_in_graph_and_eager_modes
@test_util.with_control_flow_v2
class StatelessIfTest(PForTestCase):

View File

@ -4014,7 +4014,8 @@ def _convert_if(pfor_input):
# Compute indices for cond being True or False.
if pfor_input.pfor.all_indices_partitioned:
else_indices, then_indices = data_flow_ops.dynamic_partition(
array_ops.range(len(pfor_input.pfor.all_indices)), cond_int, 2)
math_ops.range(pfor_input.pfor.loop_len_vector[0]),
cond_int, 2)
else:
else_indices, then_indices = false_indices, true_indices
# Partition inputs
@ -4103,9 +4104,16 @@ class WhileV2(object):
with ops.name_scope("while_init"):
for inp in self._pfor_input.inputs:
inputs.append(inp.t)
output_tas.append(tensor_array_ops.TensorArray(inp.t.dtype, loop_len))
output_tas.append(tensor_array_ops.TensorArray(
inp.t.dtype,
size=loop_len,
dynamic_size=False,
infer_shape=True))
# See documentation for __call__ for the structure of init_values.
return [True, self._pfor.all_indices] + inputs + output_tas
indices = (
math_ops.range(self._pfor.loop_len_vector[0])
if self._pfor.all_indices_partitioned else self._pfor.all_indices)
return [True, indices] + inputs + output_tas
def _process_cond_unstacked(self, conditions, indices, inputs, output_tas):
"""Handles case when condition is pfor loop invariant."""
@ -4170,12 +4178,16 @@ class WhileV2(object):
# However once we make a new input loop variant, we might make other
# outputs loop variant. Hence we need to iterate till we get fixed point.
while True:
if self._pfor.all_indices_partitioned:
indices = array_ops.gather(self._pfor.all_indices, new_indices)
else:
indices = new_indices
body_pfor = PFor(
loop_var=self._pfor.loop_var,
loop_len=array_ops.size(new_indices),
pfor_ops=self._body_func.graph.get_operations(),
fallback_to_while_loop=self._pfor.fallback_to_while_loop,
all_indices=new_indices,
all_indices=indices,
all_indices_partitioned=(self._pfor.all_indices_partitioned or
cond_stacked),
pfor_config=self._pfor.pfor_config)