parallel_for: add converter for SplitV

PiperOrigin-RevId: 217766896
This commit is contained in:
A. Unique TensorFlower 2018-10-18 14:15:29 -07:00 committed by TensorFlower Gardener
parent ce126591bd
commit aa7a7f3d91
2 changed files with 19 additions and 0 deletions

View File

@ -245,6 +245,16 @@ class ArrayTest(PForTest):
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 5)
def test_split_v(self):
x = random_ops.random_uniform([3, 6, 3])
def loop_fn(i):
x1 = array_ops.gather(x, i)
return (array_ops.split(x1, [2, 1, 3], axis=0),
array_ops.split(x1, [3], axis=-1))
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 4)
def test_transpose(self):
x = random_ops.random_uniform([3, 2, 3, 4])

View File

@ -1616,6 +1616,15 @@ def _convert_split(pfor_input):
return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
@RegisterPFor("SplitV")
def _convert_split_v(pfor_input):
t = pfor_input.stacked_input(0)
splits = pfor_input.unstacked_input(1)
split_dim = pfor_input.unstacked_input(2)
split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
@RegisterPFor("Transpose")
def _convert_transpose(pfor_input):
t = pfor_input.stacked_input(0)