From 5643e94437952349bc40109270b2447f460e3d3b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 25 Oct 2019 16:10:15 -0700 Subject: [PATCH] Implement Log determinant of matrix op in XLA. PiperOrigin-RevId: 276779923 Change-Id: Ia1e0c9d8bc405874febae6f89eae51aa83009065 --- tensorflow/compiler/xla/client/lib/BUILD | 40 +++++++++ tensorflow/compiler/xla/client/lib/logdet.cc | 45 ++++++++++ tensorflow/compiler/xla/client/lib/logdet.h | 29 +++++++ .../compiler/xla/client/lib/logdet_test.cc | 82 +++++++++++++++++++ 4 files changed, 196 insertions(+) create mode 100644 tensorflow/compiler/xla/client/lib/logdet.cc create mode 100644 tensorflow/compiler/xla/client/lib/logdet.h create mode 100644 tensorflow/compiler/xla/client/lib/logdet_test.cc diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f15c02106b8..99bccdb3bb8 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -542,3 +542,43 @@ xla_test( "//tensorflow/stream_executor/lib", ], ) + +cc_library( + name = "logdet", + srcs = ["logdet.cc"], + hdrs = ["logdet.h"], + deps = [ + ":arithmetic", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "logdet_test", + srcs = ["logdet_test.cc"], + tags = ["optonly"], + deps = [ + ":logdet", + ":matrix", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array3d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/logdet.cc b/tensorflow/compiler/xla/client/lib/logdet.cc new file mode 100644 index 00000000000..8f37c393922 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/logdet.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/logdet.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +// let G = root(A) be the Cholesky root of the matrix A +// log(det(A)) = 2*sum(log(vecdiag(G))) +XlaOp LogDet(XlaOp a) { + XlaOp cholesky = Cholesky(a, /*bool lower=*/true); + + return ScalarLike(a, 2) * + Einsum(Log(cholesky), "...aa->...", xla::PrecisionConfig::HIGHEST); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/logdet.h b/tensorflow/compiler/xla/client/lib/logdet.h new file mode 100644 index 00000000000..96e598a6475 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/logdet.h @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_ + +#include "tensorflow/compiler/xla/client/xla_builder.h" + +namespace xla { + +// For matrix a with shape [..., n, n], return log(det(a)) with shape[...]. +// Only hermitian positive definite matrices are supported. +XlaOp LogDet(XlaOp a); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LOGDET_H_ diff --git a/tensorflow/compiler/xla/client/lib/logdet_test.cc b/tensorflow/compiler/xla/client/lib/logdet_test.cc new file mode 100644 index 00000000000..54af41f77f6 --- /dev/null +++ b/tensorflow/compiler/xla/client/lib/logdet_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/lib/logdet.h" + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array3d.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace { + +using LogDetTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(LogDetTest, Simple) { + xla::XlaBuilder builder(TestName()); + + xla::Array2D a_vals({ + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }); + + float expected = 14.1601f; + + xla::XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + xla::LogDet(a); + + ComputeAndCompareR0(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); +} + +XLA_TEST_F(LogDetTest, SimpleBatched) { + xla::XlaBuilder builder(TestName()); + + xla::Array3D a_vals({ + { + {4, 6, 8, 10}, + {6, 45, 54, 63}, + {8, 54, 146, 166}, + {10, 63, 166, 310}, + }, + { + {16, 24, 8, 12}, + {24, 61, 82, 48}, + {8, 82, 456, 106}, + {12, 48, 106, 62}, + }, + }); + + std::vector expected = {14.1601, 14.3092}; + + xla::XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + xla::LogDet(a); + + ComputeAndCompareR1(&builder, expected, {a_data.get()}, + xla::ErrorSpec(1e-4)); +} + +} // namespace