BUGFIX: Properly set input_output_dtype on Circulant.inverse().
PiperOrigin-RevId: 290189697 Change-Id: I3acc7eb170cd0f14acc3f9729489d3b07afd0e70
This commit is contained in:
parent
80f0540bc8
commit
18645e7a3c
@ -112,7 +112,8 @@ def _inverse_circulant(circulant_operator):
|
||||
is_non_singular=circulant_operator.is_non_singular,
|
||||
is_self_adjoint=circulant_operator.is_self_adjoint,
|
||||
is_positive_definite=circulant_operator.is_positive_definite,
|
||||
is_square=True)
|
||||
is_square=True,
|
||||
input_output_dtype=circulant_operator.dtype)
|
||||
|
||||
|
||||
@linear_operator_algebra.RegisterInverse(
|
||||
|
||||
@ -88,12 +88,14 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
dtypes.complex128: 1e-12
|
||||
}
|
||||
|
||||
def assertAC(self, x, y):
|
||||
def assertAC(self, x, y, check_dtype=False):
|
||||
"""Derived classes can set _atol, _rtol to get different tolerance."""
|
||||
dtype = dtypes.as_dtype(x.dtype)
|
||||
atol = self._atol[dtype]
|
||||
rtol = self._rtol[dtype]
|
||||
self.assertAllClose(x, y, atol=atol, rtol=rtol)
|
||||
if check_dtype:
|
||||
self.assertDTypeEqual(x, y.dtype)
|
||||
|
||||
@staticmethod
|
||||
def adjoint_options():
|
||||
@ -565,7 +567,7 @@ def _test_inverse(use_placeholder, shapes_info, dtype):
|
||||
shapes_info, dtype, use_placeholder=use_placeholder)
|
||||
op_inverse_v, mat_inverse_v = sess.run([
|
||||
operator.inverse().to_dense(), linalg.inv(mat)])
|
||||
self.assertAC(op_inverse_v, mat_inverse_v)
|
||||
self.assertAC(op_inverse_v, mat_inverse_v, check_dtype=True)
|
||||
return test_inverse
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user