Add pfor converter for Squeeze.
PiperOrigin-RevId: 243937790
This commit is contained in:
parent
0500369bf5
commit
c16cf12009
@ -190,6 +190,17 @@ class ArrayTest(PForTestCase):
|
|||||||
|
|
||||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 4)
|
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 4)
|
||||||
|
|
||||||
|
def test_squeeze(self):
|
||||||
|
x = random_ops.random_uniform([5, 1, 2, 1])
|
||||||
|
|
||||||
|
def loop_fn(i):
|
||||||
|
x1 = array_ops.gather(x, i)
|
||||||
|
return (array_ops.squeeze(x1, axis=0),
|
||||||
|
array_ops.squeeze(x1, axis=-1),
|
||||||
|
array_ops.squeeze(x1))
|
||||||
|
|
||||||
|
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3)
|
||||||
|
|
||||||
def test_transpose(self):
|
def test_transpose(self):
|
||||||
x = random_ops.random_uniform([3, 2, 3, 4])
|
x = random_ops.random_uniform([3, 2, 3, 4])
|
||||||
|
|
||||||
|
@ -1746,6 +1746,14 @@ def _convert_split_v(pfor_input):
|
|||||||
return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
|
return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)]
|
||||||
|
|
||||||
|
|
||||||
|
@RegisterPFor("Squeeze")
|
||||||
|
def _convert_squeeze(pfor_input):
|
||||||
|
t = pfor_input.stacked_input(0)
|
||||||
|
squeeze_dims = pfor_input.get_attr("squeeze_dims")
|
||||||
|
squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims]
|
||||||
|
return wrap(array_ops.squeeze(t, axis=squeeze_dims), True)
|
||||||
|
|
||||||
|
|
||||||
@RegisterPFor("Transpose")
|
@RegisterPFor("Transpose")
|
||||||
def _convert_transpose(pfor_input):
|
def _convert_transpose(pfor_input):
|
||||||
t = pfor_input.stacked_input(0)
|
t = pfor_input.stacked_input(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user