diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index 1826924b47e..171369b724a 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -309,6 +309,14 @@ class ArrayTest(PForTest): self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2) + def test_matrix_diag_part(self): + x = random_ops.random_uniform([3, 4, 2]) + + def loop_fn(i): + return array_ops.matrix_diag_part(array_ops.gather(x, i)) + + self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32]) + def test_strided_slice(self): x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2]) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 72441908ec2..e6f140a9410 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -1533,6 +1533,7 @@ def _convert_conv2d_backprop_filter(pfor_input): @RegisterPForWithArgs("Identity", array_ops.identity) @RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) +@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) def _convert_identity(pfor_input, op_type, op_func): del op_type return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)