diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b700961795c..5dc792d3288 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -300,6 +300,7 @@ xla_test( "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/tests:client_library_test_base", diff --git a/tensorflow/compiler/xla/client/lib/qr_test.cc b/tensorflow/compiler/xla/client/lib/qr_test.cc index 0f50144766d..f71c63d6fd9 100644 --- a/tensorflow/compiler/xla/client/lib/qr_test.cc +++ b/tensorflow/compiler/xla/client/lib/qr_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/tensor_float_32_utils.h" @@ -134,4 +135,24 @@ XLA_TEST_F(QrTest, SimpleBatched) { xla::ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(QrTest, SubnormalComplex) { + tensorflow::enable_tensor_float_32_execution(false); + + // Verifies that we don't get NaNs in the case that the norm of a complex + // number would be denormal but its imaginary value is not exactly 0. + xla::Array2D a_vals({ + {xla::complex64(4e-20, 5e-23), 6, 80}, + {0, 45, 54}, + {0, 54, 146}, + }); + + xla::XlaBuilder builder(TestName()); + xla::XlaOp a, q, r; + auto a_data = CreateParameter(a_vals, 0, "a", &builder, &a); + xla::QrExplicit(a, /*full_matrices=*/true, q, r); + xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST); + ComputeAndCompare(&builder, a_vals, {a_data.get()}, + xla::ErrorSpec(1e-4, 1e-4)); +} + } // namespace diff --git a/tensorflow/compiler/xla/service/qr_expander.cc b/tensorflow/compiler/xla/service/qr_expander.cc index 42dcebb20b3..9146e014fd1 100644 --- a/tensorflow/compiler/xla/service/qr_expander.cc +++ b/tensorflow/compiler/xla/service/qr_expander.cc @@ -47,6 +47,33 @@ std::vector ConcatVectors(absl::Span xs, return output; } +// Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow. +// e.g. for 3 arguments: +// def norm(x, y, z): +// xabs = np.abs(x) +// yabs = np.abs(y) +// zabs = np.abs(z) +// w = np.maximum(np.maximum(xabs, yabs), zabs) +// if w == 0: +// return 0 +// else: +// return w * np.sqrt((xabs / w)**2 + (yabs / w) ** 2 + (zabs / w) ** 2) +XlaOp Norm(std::vector xs) { + CHECK(!xs.empty()); + XlaOp w; + for (size_t i = 0; i < xs.size(); ++i) { + xs[i] = Abs(xs[i]); + w = i == 0 ? xs[i] : xla::Max(w, xs[i]); + } + + XlaOp out; + for (size_t i = 0; i < xs.size(); ++i) { + XlaOp t = Square(xs[i] / w); + out = i == 0 ? t : xla::Add(out, t); + } + return Select(Eq(w, ZerosLike(w)), ZerosLike(w), w * Sqrt(out)); +} + // Computes a Householder reflection of the form: // H = I - tau v v.T. // such that @@ -102,15 +129,13 @@ Status House(XlaOp x, XlaOp k, absl::Span batch_dims, XlaOp sigma_is_zero; if (primitive_util::IsComplexType(type)) { // sigma = np.dot(x[k+1:], np.conj(x[k+1:])) - // TODO(phawkins): this calculation may be numerically unstable. auto x_squared = Real(x_after_k * Conj(x_after_k)); auto sigma = Reduce(x_squared, ScalarLike(x_squared, 0.0), CreateScalarAddComputation( primitive_util::ComplexComponentType(type), builder), {minor_dim}); - // mu = np.sqrt(x[k]*np.con(x[k]) + sigma) - auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma); + auto mu = Norm({Real(alpha), Imag(alpha), Sqrt(sigma)}); sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0)); sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0))); @@ -122,11 +147,9 @@ Status House(XlaOp x, XlaOp k, absl::Span batch_dims, *tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta); } else { // sigma = np.dot(x[k+1:], x[k+1:]) - // TODO(phawkins): this calculation may be numerically unstable. auto sigma = Reduce(x_after_k * x_after_k, zero, CreateScalarAddComputation(type, builder), {minor_dim}); - // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = Sqrt(Square(alpha) + sigma); + auto mu = Norm({alpha, Sqrt(sigma)}); sigma_is_zero = Eq(sigma, zero); XlaOp one = ScalarLike(x, 1.0);