diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py index 7f0c0f5b992..883f28cb05d 100644 --- a/tensorflow/python/ops/parallel_for/array_test.py +++ b/tensorflow/python/ops/parallel_for/array_test.py @@ -190,6 +190,17 @@ class ArrayTest(PForTestCase): self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 4) + def test_squeeze(self): + x = random_ops.random_uniform([5, 1, 2, 1]) + + def loop_fn(i): + x1 = array_ops.gather(x, i) + return (array_ops.squeeze(x1, axis=0), + array_ops.squeeze(x1, axis=-1), + array_ops.squeeze(x1)) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 3) + def test_transpose(self): x = random_ops.random_uniform([3, 2, 3, 4]) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 1d421f344e9..d5eb10bf19f 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1746,6 +1746,14 @@ def _convert_split_v(pfor_input): return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] +@RegisterPFor("Squeeze") +def _convert_squeeze(pfor_input): + t = pfor_input.stacked_input(0) + squeeze_dims = pfor_input.get_attr("squeeze_dims") + squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] + return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) + + @RegisterPFor("Transpose") def _convert_transpose(pfor_input): t = pfor_input.stacked_input(0)