Will fix https://github.com/google/jax/issues/775 when XLA is updated. PiperOrigin-RevId: 254010359
231 lines
6.9 KiB
C++
231 lines
6.9 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include <limits>
|
|
#include <memory>
|
|
#include <numeric>
|
|
#include <vector>
|
|
|
|
#include "tensorflow/compiler/xla/array2d.h"
|
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
#include "tensorflow/compiler/xla/statusor.h"
|
|
#include "tensorflow/compiler/xla/test.h"
|
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
using CholeskyTest = ClientLibraryTestBase;
|
|
|
|
XLA_TEST_F(CholeskyTest, NonPSDInput) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
Array2D<float> a_vals({
|
|
{1, 1, 1},
|
|
{1, 1, 1},
|
|
{1, 1, 1},
|
|
});
|
|
|
|
XlaOp a;
|
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
|
Cholesky(a, /*lower=*/true);
|
|
|
|
float nan = std::numeric_limits<float>::quiet_NaN();
|
|
Array2D<float> expected({
|
|
{nan, nan, nan},
|
|
{nan, nan, nan},
|
|
{nan, nan, nan},
|
|
});
|
|
|
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
XLA_TEST_F(CholeskyTest, Lower) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
float nan = std::numeric_limits<float>::quiet_NaN();
|
|
Array2D<float> a_vals({
|
|
{4, nan, nan, nan},
|
|
{6, 45, nan, nan},
|
|
{8, 54, 146, nan},
|
|
{10, 63, 166, 310},
|
|
});
|
|
|
|
XlaOp a;
|
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
|
LowerTriangle(Cholesky(a, /*lower=*/true));
|
|
|
|
Array2D<float> expected({
|
|
{2, 0, 0, 0},
|
|
{3, 6, 0, 0},
|
|
{4, 7, 9, 0},
|
|
{5, 8, 10, 11},
|
|
});
|
|
|
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
XLA_TEST_F(CholeskyTest, Upper) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
float nan = std::numeric_limits<float>::quiet_NaN();
|
|
Array2D<float> a_vals({
|
|
{4, 6, 8, 10},
|
|
{nan, 45, 54, 63},
|
|
{nan, nan, 146, 166},
|
|
{nan, nan, nan, 310},
|
|
});
|
|
|
|
XlaOp a;
|
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
|
UpperTriangle(Cholesky(a, /*lower=*/false));
|
|
|
|
Array2D<float> expected({
|
|
{2, 3, 4, 5},
|
|
{0, 6, 7, 8},
|
|
{0, 0, 9, 10},
|
|
{0, 0, 0, 11},
|
|
});
|
|
|
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
XLA_TEST_F(CholeskyTest, Simple2) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
Array2D<float> a_vals({
|
|
{16, 24, 8, 12},
|
|
{24, 61, 82, 48},
|
|
{8, 82, 456, 106},
|
|
{12, 48, 106, 62},
|
|
});
|
|
|
|
XlaOp a;
|
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
|
LowerTriangle(Cholesky(a, /*lower=*/true));
|
|
|
|
Array2D<float> expected({{4, 0, 0, 0}, //
|
|
{6, 5, 0, 0}, //
|
|
{2, 14, 16, 0}, //
|
|
{3, 6, 1, 4}});
|
|
|
|
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
XLA_TEST_F(CholeskyTest, SimpleBatched) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
Array3D<float> a_vals({
|
|
{
|
|
{4, 6, 8, 10},
|
|
{6, 45, 54, 63},
|
|
{8, 54, 146, 166},
|
|
{10, 63, 166, 310},
|
|
},
|
|
{
|
|
{16, 24, 8, 12},
|
|
{24, 61, 82, 48},
|
|
{8, 82, 456, 106},
|
|
{12, 48, 106, 62},
|
|
},
|
|
});
|
|
|
|
XlaOp a;
|
|
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
|
LowerTriangle(Cholesky(a, /*lower=*/true));
|
|
|
|
Array3D<float> expected({
|
|
{
|
|
{2, 0, 0, 0},
|
|
{3, 6, 0, 0},
|
|
{4, 7, 9, 0},
|
|
{5, 8, 10, 11},
|
|
},
|
|
{{4, 0, 0, 0}, //
|
|
{6, 5, 0, 0}, //
|
|
{2, 14, 16, 0}, //
|
|
{3, 6, 1, 4}},
|
|
});
|
|
|
|
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
using CholeskyTestCase = std::tuple<int64, int64, bool>;
|
|
|
|
class RandomCholeskyTest
|
|
: public ClientLibraryTestBase,
|
|
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
|
|
|
XLA_TEST_P(RandomCholeskyTest, Random) {
|
|
XlaBuilder builder(TestName());
|
|
|
|
auto test_params = GetParam();
|
|
std::vector<int64> dimensions = {std::get<0>(test_params),
|
|
std::get<1>(test_params),
|
|
std::get<1>(test_params)};
|
|
bool lower = std::get<2>(test_params);
|
|
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
auto literal, LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
|
|
|
auto input = Parameter(&builder, 0, shape, "input");
|
|
// Form a random positive definite matrix.
|
|
auto matrix =
|
|
BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST);
|
|
|
|
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
|
|
|
|
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
|
XlaOp verification;
|
|
if (lower) {
|
|
verification = BatchDot(cholesky, TransposeInMinorDims(cholesky),
|
|
PrecisionConfig::HIGHEST);
|
|
} else {
|
|
verification = BatchDot(TransposeInMinorDims(cholesky), cholesky,
|
|
PrecisionConfig::HIGHEST);
|
|
}
|
|
auto delta = matrix - verification;
|
|
Reduce(delta * delta, ConstantR0<float>(&builder, 0.0),
|
|
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
|
|
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
|
|
ErrorSpec(1e-4, 1e-4));
|
|
}
|
|
|
|
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
|
::testing::Values(CholeskyTestCase{1, 1, true},
|
|
CholeskyTestCase{1, 2, true},
|
|
CholeskyTestCase{1, 50, true},
|
|
CholeskyTestCase{1, 50, false},
|
|
CholeskyTestCase{10, 5, true},
|
|
CholeskyTestCase{5, 10, false},
|
|
CholeskyTestCase{2, 20, true}));
|
|
|
|
} // namespace
|
|
} // namespace xla
|