Pfor: Add vectorization for TensorListConcatV2.
PiperOrigin-RevId: 331560736 Change-Id: Idef0c9aef951bc05762c13a9820020113e319fea
This commit is contained in:
parent
8b7d5b4842
commit
021440d9d9
@ -44,6 +44,7 @@ from tensorflow.python.ops import cond_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gradients as gradient_ops
|
||||
from tensorflow.python.ops import image_ops
|
||||
@ -910,6 +911,7 @@ class TensorArrayTest(PForTestCase):
|
||||
self.assertAllClose(actual_grad, computed_grad)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TensorListTest(PForTestCase):
|
||||
|
||||
def test_create_outside_and_write(self):
|
||||
@ -1009,6 +1011,38 @@ class TensorListTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_create_inside_and_concat(self):
|
||||
|
||||
def loop_fn(i):
|
||||
handle = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||
handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=handle)
|
||||
handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
|
||||
return gen_list_ops.tensor_list_concat_v2(
|
||||
handle,
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=[2],
|
||||
leading_dims=[])
|
||||
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 2)
|
||||
self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
|
||||
self.assertAllClose([[2, 2], [2, 2]], output[1])
|
||||
|
||||
def test_create_outside_and_concat(self):
|
||||
h = list_ops.tensor_list_reserve([2], 2, dtypes.int32)
|
||||
|
||||
def loop_fn(i):
|
||||
handle = list_ops.tensor_list_scatter([[i, 2]], [0], input_handle=h)
|
||||
handle = list_ops.tensor_list_scatter([[1, 2]], [1], input_handle=handle)
|
||||
return gen_list_ops.tensor_list_concat_v2(
|
||||
handle,
|
||||
element_dtype=dtypes.int32,
|
||||
element_shape=[2],
|
||||
leading_dims=[])
|
||||
|
||||
output = pfor_control_flow_ops.pfor(loop_fn, 2)
|
||||
self.assertAllClose([[0, 2, 1, 2], [1, 2, 1, 2]], output[0])
|
||||
self.assertAllClose([[2, 2], [2, 2]], output[1])
|
||||
|
||||
def test_tensor_list_from_tensor(self):
|
||||
t = random_ops.random_uniform([2, 3, 4])
|
||||
|
||||
|
@ -46,6 +46,7 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gen_image_ops
|
||||
from tensorflow.python.ops import gen_linalg_ops
|
||||
from tensorflow.python.ops import gen_list_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import gen_parsing_ops
|
||||
@ -3665,6 +3666,46 @@ def _convert_tensor_array_set_item(pfor_input):
|
||||
return wrap(_tile_variant(handle, pfor_input), True)
|
||||
|
||||
|
||||
@RegisterPFor("TensorListConcatV2")
|
||||
def _convert_tensor_list_concat_v2(pfor_input):
|
||||
input_handle = pfor_input.stacked_input(0)
|
||||
element_shape = pfor_input.unstacked_input(1)
|
||||
leading_dims = pfor_input.unstacked_input(2)
|
||||
element_dtype = pfor_input.get_attr("element_dtype")
|
||||
|
||||
handle = _untile_variant(input_handle)
|
||||
length = list_ops.tensor_list_length(handle)
|
||||
# Note that element_shape attribute can have incomplete shapes. This doesn't
|
||||
# seem to work well when creating another list and then doing a concat on it.
|
||||
# Hence we try to find the dynamic shape here.
|
||||
element_shape = control_flow_ops.cond(
|
||||
length > 0, lambda: array_ops.shape(
|
||||
list_ops.tensor_list_get_item(handle, 0, element_dtype, None)),
|
||||
lambda: constant_op.constant([0, 0], dtype=dtypes.int32))
|
||||
# The code below creates a copy of the list with each elements' first two
|
||||
# dimensions transposed.
|
||||
new_element_shape = array_ops.concat(
|
||||
[element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0)
|
||||
|
||||
# Create a new TensorList with elements transposed.
|
||||
def _transpose_elem(i, h):
|
||||
elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None)
|
||||
elem = _transpose_first_two_dims(elem)
|
||||
return i + 1, list_ops.tensor_list_set_item(h, i, elem)
|
||||
|
||||
new_handle = list_ops.tensor_list_reserve(new_element_shape, length,
|
||||
element_dtype)
|
||||
new_handle = control_flow_ops.while_loop(lambda i, _: i < length,
|
||||
_transpose_elem, [0, new_handle])[1]
|
||||
output, lengths = gen_list_ops.tensor_list_concat_v2(
|
||||
input_handle=new_handle,
|
||||
element_dtype=element_dtype,
|
||||
element_shape=new_element_shape,
|
||||
leading_dims=leading_dims)
|
||||
output = _transpose_first_two_dims(output)
|
||||
return wrap(output, True), wrap(lengths, False)
|
||||
|
||||
|
||||
@RegisterPFor("TensorListStack")
|
||||
def _convert_tensor_list_stack(pfor_input):
|
||||
handle = pfor_input.stacked_input(0)
|
||||
|
Loading…
Reference in New Issue
Block a user