[XLA] Fix interface contract of xla::LogDet.

xla::LogDet was previously mixing the sign of the determinant with the sign of the logarithm. These semantics don't make sense.

Create two methods: xla::SLogDet() which returns a (sign, determinant) pair, and xla::LogDet() which returns the log-determinant for matrices with positive determinant, and NaN otherwise.

The semantics of xla::LogDet() match torch.logdet, and PyTorch/XLA is the only user of this function at the moment.

PiperOrigin-RevId: 360508928
Change-Id: Iebad4920635f6602fc7d43eadde2ba496a2b176b
This commit is contained in:
Peter Hawkins 2021-03-02 13:59:28 -08:00 committed by TensorFlower Gardener
parent 7f8f4fe39b
commit 00d31f1d50
3 changed files with 70 additions and 30 deletions

View File

@ -34,10 +34,8 @@ limitations under the License.
namespace xla {
// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is
// orthonormal
XlaOp LogDet(XlaOp a) {
return a.builder()->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
SignAndLogDet SLogDet(XlaOp a) {
StatusOr<SignAndLogDet> result = [&]() -> StatusOr<SignAndLogDet> {
TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a));
auto qr = Qr(a);
@ -63,8 +61,20 @@ XlaOp LogDet(XlaOp a) {
One(a.builder(), a_shape.element_type()),
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
{a_shape.rank() - 2});
return sign_diag * log_abs_det * sign_taus;
});
return SignAndLogDet{sign_diag * sign_taus, log_abs_det};
}();
if (!result.ok()) {
XlaOp error = a.builder()->ReportError(result.status());
return SignAndLogDet{error, error};
}
return result.ValueOrDie();
}
XlaOp LogDet(XlaOp a) {
SignAndLogDet slogdet = SLogDet(a);
return Select(
Ge(slogdet.sign, ZerosLike(slogdet.sign)), slogdet.logdet,
FullLike(slogdet.logdet, std::numeric_limits<float>::quiet_NaN()));
}
} // namespace xla

View File

@ -20,8 +20,16 @@ limitations under the License.
namespace xla {
// For matrix a with shape [..., n, n], return log(det(a)) with shape[...].
// Only hermitian positive definite matrices are supported.
// Computes the sign and logarithm of the absolute value of the determinant
// of a batch of square matrices with shape [..., n, n].
struct SignAndLogDet {
XlaOp sign; // Either 1, 0, or -1, depending on the determinant's sign.
XlaOp logdet; // log(abs(det(a)).
};
SignAndLogDet SLogDet(XlaOp a);
// For a batch of matrices with shape [..., n, n], return log(det(a)).
// Returns NaN if a matrix has a negative determinant.
XlaOp LogDet(XlaOp a);
} // namespace xla

View File

@ -41,14 +41,17 @@ XLA_TEST_F(LogDetTest, Simple) {
{10, 63, 166, 310},
});
float expected = 14.1601f;
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
xla::SignAndLogDet slogdet = xla::SLogDet(a);
xla::XlaOp logdet = xla::LogDet(a);
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
xla::LiteralUtil::CreateR0<float>(1.f),
xla::LiteralUtil::CreateR0<float>(14.1601f),
xla::LiteralUtil::CreateR0<float>(14.1601f));
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
XLA_TEST_F(LogDetTest, SimpleTriangle) {
@ -61,14 +64,18 @@ XLA_TEST_F(LogDetTest, SimpleTriangle) {
{4, 6, 8, 320},
});
float expected = 15.9131355f;
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
xla::SignAndLogDet slogdet = xla::SLogDet(a);
xla::XlaOp logdet = xla::LogDet(a);
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
xla::LiteralUtil::CreateR0<float>(1.f),
xla::LiteralUtil::CreateR0<float>(15.9131355f),
xla::LiteralUtil::CreateR0<float>(15.9131355f));
ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
XLA_TEST_F(LogDetTest, SimpleBatched) {
@ -87,19 +94,29 @@ XLA_TEST_F(LogDetTest, SimpleBatched) {
{8, 82, 456, 106},
{12, 48, 106, 62},
},
{{2, 2, 3, 4}, {4, 5, 6, 7}, {7, 8, 9, 8}, {10, 11, 12, 13}},
{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
});
std::vector<float> expected = {14.1601, 14.3092};
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
xla::SignAndLogDet slogdet = xla::SLogDet(a);
xla::XlaOp logdet = xla::LogDet(a);
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
xla::LiteralUtil::CreateR1<float>({1.f, 1.f, -1.f, 0.f}),
xla::LiteralUtil::CreateR1<float>(
{14.1601f, 14.3092f, 2.4849f,
-std::numeric_limits<float>::infinity()}),
xla::LiteralUtil::CreateR1<float>(
{14.1601f, 14.3092f, std::numeric_limits<float>::quiet_NaN(),
-std::numeric_limits<float>::infinity()}));
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
XLA_TEST_F(LogDetTest, LargerMatricesBatched) {
XLA_TEST_F(LogDetTest, LogdetOfLargerMatricesBatched) {
xla::XlaBuilder builder(TestName());
xla::Array<float> a_vals = {
@ -127,13 +144,18 @@ XLA_TEST_F(LogDetTest, LargerMatricesBatched) {
{-3.5759, -1.5619, 2.4410, 1.3046, 4.2678, 7.3587, -4.0935},
{-1.1187, 0.9150, -1.8253, 0.0390, -2.5684, -4.0778, 4.1447}}};
std::vector<float> expected = {8.93788053, 6.77846303, 7.4852403};
xla::XlaOp a;
auto a_data = CreateParameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
xla::SignAndLogDet slogdet = xla::SLogDet(a);
xla::XlaOp logdet = xla::LogDet(a);
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
xla::LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f}),
xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}),
xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}));
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
} // namespace