Parallel for: add converter for MatrixDiagPart.
PiperOrigin-RevId: 220829434
This commit is contained in:
parent
afc175f7b3
commit
b895a35c6a
@ -309,6 +309,14 @@ class ArrayTest(PForTest):
|
||||
|
||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32] * 2)
|
||||
|
||||
def test_matrix_diag_part(self):
|
||||
x = random_ops.random_uniform([3, 4, 2])
|
||||
|
||||
def loop_fn(i):
|
||||
return array_ops.matrix_diag_part(array_ops.gather(x, i))
|
||||
|
||||
self._test_loop_fn(loop_fn, 3, loop_fn_dtypes=[dtypes.float32])
|
||||
|
||||
def test_strided_slice(self):
|
||||
x = random_ops.random_uniform([3, 3, 4, 4, 2, 2, 2])
|
||||
|
||||
|
@ -1533,6 +1533,7 @@ def _convert_conv2d_backprop_filter(pfor_input):
|
||||
|
||||
@RegisterPForWithArgs("Identity", array_ops.identity)
|
||||
@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient)
|
||||
@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part)
|
||||
def _convert_identity(pfor_input, op_type, op_func):
|
||||
del op_type
|
||||
return wrap(op_func(*[x.t for x in pfor_input.inputs]), True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user