Implement Log determinant of matrix op in XLA.

PiperOrigin-RevId: 276779923
Change-Id: Ia1e0c9d8bc405874febae6f89eae51aa83009065
This commit is contained in:
A. Unique TensorFlower 2019-10-25 16:10:15 -07:00 committed by TensorFlower Gardener
parent 4a8c898cdf
commit 5643e94437
4 changed files with 196 additions and 0 deletions

View File

@ -542,3 +542,43 @@ xla_test(
"//tensorflow/stream_executor/lib",
],
)
cc_library(
name = "logdet",
srcs = ["logdet.cc"],
hdrs = ["logdet.h"],
deps = [
":arithmetic",
":constants",
":loops",
":math",
":matrix",
":slicing",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
xla_test(
name = "logdet_test",
srcs = ["logdet_test.cc"],
tags = ["optonly"],
deps = [
":logdet",
":matrix",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -0,0 +1,45 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/logdet.h"
#include <memory>
#include <vector>
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
// let G = root(A) be the Cholesky root of the matrix A
// log(det(A)) = 2*sum(log(vecdiag(G)))
XlaOp LogDet(XlaOp a) {
XlaOp cholesky = Cholesky(a, /*bool lower=*/true);
return ScalarLike(a, 2) *
Einsum(Log(cholesky), "...aa->...", xla::PrecisionConfig::HIGHEST);
}
} // namespace xla

View File

@ -0,0 +1,29 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace xla {
// For matrix a with shape [..., n, n], return log(det(a)) with shape[...].
// Only hermitian positive definite matrices are supported.
XlaOp LogDet(XlaOp a);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_

View File

@ -0,0 +1,82 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/logdet.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace {
using LogDetTest = xla::ClientLibraryTestBase;
XLA_TEST_F(LogDetTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
float expected = 14.1601f;
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
ComputeAndCompareR0<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
XLA_TEST_F(LogDetTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
},
});
std::vector<float> expected = {14.1601, 14.3092};
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::LogDet(a);
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4));
}
} // namespace