[XLA] Fix underflow in norm calculation in Householder reflection for complex inputs.

If a complex value's squared norm was denormal but had a non-zero imaginary part, the Householder reflection computation could yield NaNs. By using a more accurate norm, we can avoid the underflow in this case.

PiperOrigin-RevId: 359180409
Change-Id: I2b6963800da551ab50b4e3e52a06cf92d75c0ee9
This commit is contained in:
Peter Hawkins 2021-02-23 18:20:13 -08:00 committed by TensorFlower Gardener
parent ea6f79938f
commit dc4d330cfe
3 changed files with 51 additions and 6 deletions

View File

@ -300,6 +300,7 @@ xla_test(
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
@ -134,4 +135,24 @@ XLA_TEST_F(QrTest, SimpleBatched) {
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(QrTest, SubnormalComplex) {
tensorflow::enable_tensor_float_32_execution(false);
// Verifies that we don't get NaNs in the case that the norm of a complex
// number would be denormal but its imaginary value is not exactly 0.
xla::Array2D<xla::complex64> a_vals({
{xla::complex64(4e-20, 5e-23), 6, 80},
{0, 45, 54},
{0, 54, 146},
});
xla::XlaBuilder builder(TestName());
xla::XlaOp a, q, r;
auto a_data = CreateParameter<xla::complex64>(a_vals, 0, "a", &builder, &a);
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompare<xla::complex64>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
} // namespace

View File

@ -47,6 +47,33 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
return output;
}
// Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow.
// e.g. for 3 arguments:
// def norm(x, y, z):
// xabs = np.abs(x)
// yabs = np.abs(y)
// zabs = np.abs(z)
// w = np.maximum(np.maximum(xabs, yabs), zabs)
// if w == 0:
// return 0
// else:
// return w * np.sqrt((xabs / w)**2 + (yabs / w) ** 2 + (zabs / w) ** 2)
XlaOp Norm(std::vector<XlaOp> xs) {
CHECK(!xs.empty());
XlaOp w;
for (size_t i = 0; i < xs.size(); ++i) {
xs[i] = Abs(xs[i]);
w = i == 0 ? xs[i] : xla::Max(w, xs[i]);
}
XlaOp out;
for (size_t i = 0; i < xs.size(); ++i) {
XlaOp t = Square(xs[i] / w);
out = i == 0 ? t : xla::Add(out, t);
}
return Select(Eq(w, ZerosLike(w)), ZerosLike(w), w * Sqrt(out));
}
// Computes a Householder reflection of the form:
// H = I - tau v v.T.
// such that
@ -102,15 +129,13 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
XlaOp sigma_is_zero;
if (primitive_util::IsComplexType(type)) {
// sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
// TODO(phawkins): this calculation may be numerically unstable.
auto x_squared = Real(x_after_k * Conj(x_after_k));
auto sigma =
Reduce(x_squared, ScalarLike(x_squared, 0.0),
CreateScalarAddComputation(
primitive_util::ComplexComponentType(type), builder),
{minor_dim});
// mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
auto mu = Norm({Real(alpha), Imag(alpha), Sqrt(sigma)});
sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
@ -122,11 +147,9 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
*tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
} else {
// sigma = np.dot(x[k+1:], x[k+1:])
// TODO(phawkins): this calculation may be numerically unstable.
auto sigma = Reduce(x_after_k * x_after_k, zero,
CreateScalarAddComputation(type, builder), {minor_dim});
// mu = np.sqrt(x[k]*x[k] + sigma)
auto mu = Sqrt(Square(alpha) + sigma);
auto mu = Norm({alpha, Sqrt(sigma)});
sigma_is_zero = Eq(sigma, zero);
XlaOp one = ScalarLike(x, 1.0);