Add pfor converter for SelfAdjointEigV2.
PiperOrigin-RevId: 270816561
This commit is contained in:
parent
f65f08771e
commit
c9c2fcaf4c
tensorflow/python/ops/parallel_for
@ -655,6 +655,14 @@ class LinalgTest(PForTestCase):
|
||||
|
||||
self._test_loop_fn(loop_fn, 2)
|
||||
|
||||
def test_self_adjoint_eig(self):
|
||||
z = random_ops.random_normal([2, 3, 3])
|
||||
x = z + array_ops.matrix_transpose(z) # Ensure self-adjoint.
|
||||
|
||||
def loop_fn(i):
|
||||
return linalg_ops.self_adjoint_eig(array_ops.gather(x, i))
|
||||
|
||||
self._test_loop_fn(loop_fn, 2, loop_fn_dtypes=[dtypes.float32] * 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -2784,6 +2784,12 @@ def _convert_matrix_triangular_solve(pfor_input):
|
||||
return wrap(output, True)
|
||||
|
||||
|
||||
@RegisterPFor("SelfAdjointEigV2")
|
||||
def _convert_self_adjoint_eig(pfor_input):
|
||||
t = pfor_input.stacked_input(0)
|
||||
return [wrap(x, True) for x in linalg_ops.self_adjoint_eig(t)]
|
||||
|
||||
|
||||
# logging_ops
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user