BUGFIX: Properly set input_output_dtype on Circulant.inverse().

PiperOrigin-RevId: 290189697
Change-Id: I3acc7eb170cd0f14acc3f9729489d3b07afd0e70
This commit is contained in:
Ian Langmore 2020-01-16 18:46:46 -08:00 committed by TensorFlower Gardener
parent 80f0540bc8
commit 18645e7a3c
2 changed files with 6 additions and 3 deletions

View File

@ -112,7 +112,8 @@ def _inverse_circulant(circulant_operator):
is_non_singular=circulant_operator.is_non_singular,
is_self_adjoint=circulant_operator.is_self_adjoint,
is_positive_definite=circulant_operator.is_positive_definite,
is_square=True)
is_square=True,
input_output_dtype=circulant_operator.dtype)
@linear_operator_algebra.RegisterInverse(

View File

@ -88,12 +88,14 @@ class LinearOperatorDerivedClassTest(test.TestCase):
dtypes.complex128: 1e-12
}
def assertAC(self, x, y):
def assertAC(self, x, y, check_dtype=False):
"""Derived classes can set _atol, _rtol to get different tolerance."""
dtype = dtypes.as_dtype(x.dtype)
atol = self._atol[dtype]
rtol = self._rtol[dtype]
self.assertAllClose(x, y, atol=atol, rtol=rtol)
if check_dtype:
self.assertDTypeEqual(x, y.dtype)
@staticmethod
def adjoint_options():
@ -565,7 +567,7 @@ def _test_inverse(use_placeholder, shapes_info, dtype):
shapes_info, dtype, use_placeholder=use_placeholder)
op_inverse_v, mat_inverse_v = sess.run([
operator.inverse().to_dense(), linalg.inv(mat)])
self.assertAC(op_inverse_v, mat_inverse_v)
self.assertAC(op_inverse_v, mat_inverse_v, check_dtype=True)
return test_inverse