parallel_for: add converter for SplitV
PiperOrigin-RevId: 217766896
This commit is contained in:
parent
ce126591bd
commit
aa7a7f3d91
@ -245,6 +245,16 @@ class ArrayTest(PForTest):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 5)
|
||||
|
||||
def test_split_v(self):
|
||||
x = random_ops.random_uniform([3, 6, 3])
|
||||
|
||||
def loop_fn(i):
|
||||
x1 = array_ops.gather(x, i)
|
||||
return (array_ops.split(x1, [2, 1, 3], axis=0),
|
||||
array_ops.split(x1, [3], axis=-1))
|
||||
|
||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 4)
|
||||
|
||||
def test_transpose(self):
|
||||
x = random_ops.random_uniform([3, 2, 3, 4])
|
||||
|
||||
|
@ -1616,6 +1616,15 @@ def _convert_split(pfor_input):
|
||||
return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)]
|
||||
|
||||
|
||||
@RegisterPFor("SplitV")
|
||||
def _convert_split_v(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
splits = pfor_input.unstacked_input(1)
|
||||
split_dim = pfor_input.unstacked_input(2)
|
||||
split_dim += math_ops.cast(split_dim >= 0, dtypes.int32)
|
||||
return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
|
||||
|
||||
|
||||
@RegisterPFor("Transpose")
|
||||
def _convert_transpose(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
|
Loading…
Reference in New Issue
Block a user