From 18645e7a3c63f539d3c7746792e4554ec2a7b6cf Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Thu, 16 Jan 2020 18:46:46 -0800 Subject: [PATCH] BUGFIX: Properly set input_output_dtype on Circulant.inverse(). PiperOrigin-RevId: 290189697 Change-Id: I3acc7eb170cd0f14acc3f9729489d3b07afd0e70 --- tensorflow/python/ops/linalg/inverse_registrations.py | 3 ++- tensorflow/python/ops/linalg/linear_operator_test_util.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/linalg/inverse_registrations.py b/tensorflow/python/ops/linalg/inverse_registrations.py index 009b2236ffb..00f2c074943 100644 --- a/tensorflow/python/ops/linalg/inverse_registrations.py +++ b/tensorflow/python/ops/linalg/inverse_registrations.py @@ -112,7 +112,8 @@ def _inverse_circulant(circulant_operator): is_non_singular=circulant_operator.is_non_singular, is_self_adjoint=circulant_operator.is_self_adjoint, is_positive_definite=circulant_operator.is_positive_definite, - is_square=True) + is_square=True, + input_output_dtype=circulant_operator.dtype) @linear_operator_algebra.RegisterInverse( diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py index dc13039ffd3..cbdbe5b3eee 100644 --- a/tensorflow/python/ops/linalg/linear_operator_test_util.py +++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py @@ -88,12 +88,14 @@ class LinearOperatorDerivedClassTest(test.TestCase): dtypes.complex128: 1e-12 } - def assertAC(self, x, y): + def assertAC(self, x, y, check_dtype=False): """Derived classes can set _atol, _rtol to get different tolerance.""" dtype = dtypes.as_dtype(x.dtype) atol = self._atol[dtype] rtol = self._rtol[dtype] self.assertAllClose(x, y, atol=atol, rtol=rtol) + if check_dtype: + self.assertDTypeEqual(x, y.dtype) @staticmethod def adjoint_options(): @@ -565,7 +567,7 @@ def _test_inverse(use_placeholder, shapes_info, dtype): shapes_info, dtype, use_placeholder=use_placeholder) op_inverse_v, mat_inverse_v = sess.run([ operator.inverse().to_dense(), linalg.inv(mat)]) - self.assertAC(op_inverse_v, mat_inverse_v) + self.assertAC(op_inverse_v, mat_inverse_v, check_dtype=True) return test_inverse