[XLA] Fix underflow in norm calculation in Householder reflection for complex inputs.
If a complex value's squared norm was denormal but had a non-zero imaginary part, the Householder reflection computation could yield NaNs. By using a more accurate norm, we can avoid the underflow in this case. PiperOrigin-RevId: 359180409 Change-Id: I2b6963800da551ab50b4e3e52a06cf92d75c0ee9
This commit is contained in:
parent
ea6f79938f
commit
dc4d330cfe
@ -300,6 +300,7 @@ xla_test(
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
@ -134,4 +135,24 @@ XLA_TEST_F(QrTest, SimpleBatched) {
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(QrTest, SubnormalComplex) {
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
|
||||
// Verifies that we don't get NaNs in the case that the norm of a complex
|
||||
// number would be denormal but its imaginary value is not exactly 0.
|
||||
xla::Array2D<xla::complex64> a_vals({
|
||||
{xla::complex64(4e-20, 5e-23), 6, 80},
|
||||
{0, 45, 54},
|
||||
{0, 54, 146},
|
||||
});
|
||||
|
||||
xla::XlaBuilder builder(TestName());
|
||||
xla::XlaOp a, q, r;
|
||||
auto a_data = CreateParameter<xla::complex64>(a_vals, 0, "a", &builder, &a);
|
||||
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
|
||||
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
|
||||
ComputeAndCompare<xla::complex64>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -47,6 +47,33 @@ std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
return output;
|
||||
}
|
||||
|
||||
// Computes sqrt(x^2 + y^2 + ...), avoiding overflow/underflow.
|
||||
// e.g. for 3 arguments:
|
||||
// def norm(x, y, z):
|
||||
// xabs = np.abs(x)
|
||||
// yabs = np.abs(y)
|
||||
// zabs = np.abs(z)
|
||||
// w = np.maximum(np.maximum(xabs, yabs), zabs)
|
||||
// if w == 0:
|
||||
// return 0
|
||||
// else:
|
||||
// return w * np.sqrt((xabs / w)**2 + (yabs / w) ** 2 + (zabs / w) ** 2)
|
||||
XlaOp Norm(std::vector<XlaOp> xs) {
|
||||
CHECK(!xs.empty());
|
||||
XlaOp w;
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
xs[i] = Abs(xs[i]);
|
||||
w = i == 0 ? xs[i] : xla::Max(w, xs[i]);
|
||||
}
|
||||
|
||||
XlaOp out;
|
||||
for (size_t i = 0; i < xs.size(); ++i) {
|
||||
XlaOp t = Square(xs[i] / w);
|
||||
out = i == 0 ? t : xla::Add(out, t);
|
||||
}
|
||||
return Select(Eq(w, ZerosLike(w)), ZerosLike(w), w * Sqrt(out));
|
||||
}
|
||||
|
||||
// Computes a Householder reflection of the form:
|
||||
// H = I - tau v v.T.
|
||||
// such that
|
||||
@ -102,15 +129,13 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
XlaOp sigma_is_zero;
|
||||
if (primitive_util::IsComplexType(type)) {
|
||||
// sigma = np.dot(x[k+1:], np.conj(x[k+1:]))
|
||||
// TODO(phawkins): this calculation may be numerically unstable.
|
||||
auto x_squared = Real(x_after_k * Conj(x_after_k));
|
||||
auto sigma =
|
||||
Reduce(x_squared, ScalarLike(x_squared, 0.0),
|
||||
CreateScalarAddComputation(
|
||||
primitive_util::ComplexComponentType(type), builder),
|
||||
{minor_dim});
|
||||
// mu = np.sqrt(x[k]*np.con(x[k]) + sigma)
|
||||
auto mu = Sqrt(Real(alpha * Conj(alpha)) + sigma);
|
||||
auto mu = Norm({Real(alpha), Imag(alpha), Sqrt(sigma)});
|
||||
|
||||
sigma_is_zero = Eq(sigma, ScalarLike(sigma, 0));
|
||||
sigma_is_zero = And(sigma_is_zero, Eq(Imag(alpha), ScalarLike(sigma, 0)));
|
||||
@ -122,11 +147,9 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
*tau = Complex((*beta - Real(alpha)) / *beta, -Imag(alpha) / *beta);
|
||||
} else {
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
// TODO(phawkins): this calculation may be numerically unstable.
|
||||
auto sigma = Reduce(x_after_k * x_after_k, zero,
|
||||
CreateScalarAddComputation(type, builder), {minor_dim});
|
||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||
auto mu = Sqrt(Square(alpha) + sigma);
|
||||
auto mu = Norm({alpha, Sqrt(sigma)});
|
||||
sigma_is_zero = Eq(sigma, zero);
|
||||
|
||||
XlaOp one = ScalarLike(x, 1.0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user