[XLA] Fix interface contract of xla::LogDet.
xla::LogDet was previously mixing the sign of the determinant with the sign of the logarithm. These semantics don't make sense. Create two methods: xla::SLogDet() which returns a (sign, determinant) pair, and xla::LogDet() which returns the log-determinant for matrices with positive determinant, and NaN otherwise. The semantics of xla::LogDet() match torch.logdet, and PyTorch/XLA is the only user of this function at the moment. PiperOrigin-RevId: 360508928 Change-Id: Iebad4920635f6602fc7d43eadde2ba496a2b176b
This commit is contained in:
parent
7f8f4fe39b
commit
00d31f1d50
@ -34,10 +34,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// log(det(A)) = sum(log(vecdiag(QR(A).r))), since R is triangular and Q is
|
||||
// orthonormal
|
||||
XlaOp LogDet(XlaOp a) {
|
||||
return a.builder()->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
SignAndLogDet SLogDet(XlaOp a) {
|
||||
StatusOr<SignAndLogDet> result = [&]() -> StatusOr<SignAndLogDet> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a));
|
||||
auto qr = Qr(a);
|
||||
|
||||
@ -63,8 +61,20 @@ XlaOp LogDet(XlaOp a) {
|
||||
One(a.builder(), a_shape.element_type()),
|
||||
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
|
||||
{a_shape.rank() - 2});
|
||||
return sign_diag * log_abs_det * sign_taus;
|
||||
});
|
||||
return SignAndLogDet{sign_diag * sign_taus, log_abs_det};
|
||||
}();
|
||||
if (!result.ok()) {
|
||||
XlaOp error = a.builder()->ReportError(result.status());
|
||||
return SignAndLogDet{error, error};
|
||||
}
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
XlaOp LogDet(XlaOp a) {
|
||||
SignAndLogDet slogdet = SLogDet(a);
|
||||
return Select(
|
||||
Ge(slogdet.sign, ZerosLike(slogdet.sign)), slogdet.logdet,
|
||||
FullLike(slogdet.logdet, std::numeric_limits<float>::quiet_NaN()));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -20,8 +20,16 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// For matrix a with shape [..., n, n], return log(det(a)) with shape[...].
|
||||
// Only hermitian positive definite matrices are supported.
|
||||
// Computes the sign and logarithm of the absolute value of the determinant
|
||||
// of a batch of square matrices with shape [..., n, n].
|
||||
struct SignAndLogDet {
|
||||
XlaOp sign; // Either 1, 0, or -1, depending on the determinant's sign.
|
||||
XlaOp logdet; // log(abs(det(a)).
|
||||
};
|
||||
SignAndLogDet SLogDet(XlaOp a);
|
||||
|
||||
// For a batch of matrices with shape [..., n, n], return log(det(a)).
|
||||
// Returns NaN if a matrix has a negative determinant.
|
||||
XlaOp LogDet(XlaOp a);
|
||||
|
||||
} // namespace xla
|
||||
|
@ -41,14 +41,17 @@ XLA_TEST_F(LogDetTest, Simple) {
|
||||
{10, 63, 166, 310},
|
||||
});
|
||||
|
||||
float expected = 14.1601f;
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::LogDet(a);
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
xla::SignAndLogDet slogdet = xla::SLogDet(a);
|
||||
xla::XlaOp logdet = xla::LogDet(a);
|
||||
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
|
||||
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
|
||||
xla::LiteralUtil::CreateR0<float>(1.f),
|
||||
xla::LiteralUtil::CreateR0<float>(14.1601f),
|
||||
xla::LiteralUtil::CreateR0<float>(14.1601f));
|
||||
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(LogDetTest, SimpleTriangle) {
|
||||
@ -61,14 +64,18 @@ XLA_TEST_F(LogDetTest, SimpleTriangle) {
|
||||
{4, 6, 8, 320},
|
||||
});
|
||||
|
||||
float expected = 15.9131355f;
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::LogDet(a);
|
||||
xla::SignAndLogDet slogdet = xla::SLogDet(a);
|
||||
xla::XlaOp logdet = xla::LogDet(a);
|
||||
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
|
||||
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
|
||||
xla::LiteralUtil::CreateR0<float>(1.f),
|
||||
xla::LiteralUtil::CreateR0<float>(15.9131355f),
|
||||
xla::LiteralUtil::CreateR0<float>(15.9131355f));
|
||||
|
||||
ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(LogDetTest, SimpleBatched) {
|
||||
@ -87,19 +94,29 @@ XLA_TEST_F(LogDetTest, SimpleBatched) {
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
},
|
||||
{{2, 2, 3, 4}, {4, 5, 6, 7}, {7, 8, 9, 8}, {10, 11, 12, 13}},
|
||||
{{0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}},
|
||||
});
|
||||
|
||||
std::vector<float> expected = {14.1601, 14.3092};
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::LogDet(a);
|
||||
xla::SignAndLogDet slogdet = xla::SLogDet(a);
|
||||
xla::XlaOp logdet = xla::LogDet(a);
|
||||
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
|
||||
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
|
||||
xla::LiteralUtil::CreateR1<float>({1.f, 1.f, -1.f, 0.f}),
|
||||
xla::LiteralUtil::CreateR1<float>(
|
||||
{14.1601f, 14.3092f, 2.4849f,
|
||||
-std::numeric_limits<float>::infinity()}),
|
||||
xla::LiteralUtil::CreateR1<float>(
|
||||
{14.1601f, 14.3092f, std::numeric_limits<float>::quiet_NaN(),
|
||||
-std::numeric_limits<float>::infinity()}));
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(LogDetTest, LargerMatricesBatched) {
|
||||
XLA_TEST_F(LogDetTest, LogdetOfLargerMatricesBatched) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array<float> a_vals = {
|
||||
@ -127,13 +144,18 @@ XLA_TEST_F(LogDetTest, LargerMatricesBatched) {
|
||||
{-3.5759, -1.5619, 2.4410, 1.3046, 4.2678, 7.3587, -4.0935},
|
||||
{-1.1187, 0.9150, -1.8253, 0.0390, -2.5684, -4.0778, 4.1447}}};
|
||||
|
||||
std::vector<float> expected = {8.93788053, 6.77846303, 7.4852403};
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateParameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::LogDet(a);
|
||||
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
xla::SignAndLogDet slogdet = xla::SLogDet(a);
|
||||
xla::XlaOp logdet = xla::LogDet(a);
|
||||
xla::Tuple(&builder, {slogdet.sign, slogdet.logdet, logdet});
|
||||
xla::Literal expected = xla::LiteralUtil::MakeTupleOwned(
|
||||
xla::LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f}),
|
||||
xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}),
|
||||
xla::LiteralUtil::CreateR1<float>({8.93788053, 6.77846303, 7.4852403}));
|
||||
|
||||
ComputeAndCompareLiteral(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user