Pfor: Add vectorization for TensorListConcatV2.

PiperOrigin-RevId: 331560736
Change-Id: Idef0c9aef951bc05762c13a9820020113e319fea
This commit is contained in:
A. Unique TensorFlower 2020-09-14 09:17:19 -07:00 committed by TensorFlower Gardener
parent 8b7d5b4842
commit 021440d9d9
2 changed files with 75 additions and 0 deletions

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops import cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import image_ops
@ -910,6 +911,7 @@ class TensorArrayTest(PForTestCase):
self.assertAllClose(actual_grad, computed_grad)
@test_util.run_all_in_graph_and_eager_modes
class TensorListTest(PForTestCase):
def test_create_outside_and_write(self):
@ -1009,6 +1011,38 @@ class TensorListTest(PForTestCase):
self._test_loop_fn(loop_fn, 2)
def test_create_inside_and_concat(self):
def loop_fn(i):
handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
return gen_list_ops.tensor_list_concat_v2(
handle,
element_dtype=dtypes.int32,
element_shape=[2],
leading_dims=[])
output = pfor_control_flow_ops.pfor(loop_fn, 2)
self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
self.assertAllClose([[2, 2], [2, 2]], output[1])
def test_create_outside_and_concat(self):
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
def loop_fn(i):
handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
return gen_list_ops.tensor_list_concat_v2(
handle,
element_dtype=dtypes.int32,
element_shape=[2],
leading_dims=[])
output = pfor_control_flow_ops.pfor(loop_fn, 2)
self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
self.assertAllClose([[2, 2], [2, 2]], output[1])
def test_tensor_list_from_tensor(self):
t = random_ops.random_uniform([2, 3, 4])

View File

@ -46,6 +46,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gen_parsing_ops
@ -3665,6 +3666,46 @@ def _convert_tensor_array_set_item(pfor_input):
return wrap(_tile_variant(handle, pfor_input), True)
@RegisterPFor("TensorListConcatV2")
def _convert_tensor_list_concat_v2(pfor_input):
input_handle = pfor_input.stacked_input(0)
element_shape = pfor_input.unstacked_input(1)
leading_dims = pfor_input.unstacked_input(2)
element_dtype = pfor_input.get_attr("element_dtype")
handle = _untile_variant(input_handle)
length = list_ops.tensor_list_length(handle)
# Note that element_shape attribute can have incomplete shapes. This doesn't
# seem to work well when creating another list and then doing a concat on it.
# Hence we try to find the dynamic shape here.
element_shape = control_flow_ops.cond(
length > 0, lambda: array_ops.shape(
list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
# The code below creates a copy of the list with each elements' first two
# dimensions transposed.
new_element_shape = array_ops.concat(
[element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
# Create a new TensorList with elements transposed.
def _transpose_elem(i, h):
elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
elem = _transpose_first_two_dims(elem)
return i + 1, list_ops.tensor_list_set_item(h, i, elem)
new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
element_dtype)
new_handle = control_flow_ops.while_loop(lambda i, _: i < length,
_transpose_elem, [0, new_handle])[1]
output, lengths = gen_list_ops.tensor_list_concat_v2(
input_handle=new_handle,
element_dtype=element_dtype,
element_shape=new_element_shape,
leading_dims=leading_dims)
output = _transpose_first_two_dims(output)
return wrap(output, True), wrap(lengths, False)
@RegisterPFor("TensorListStack")
def _convert_tensor_list_stack(pfor_input):
handle = pfor_input.stacked_input(0)