Merge pull request #39690 from whoozle:triangular-solve-expander-block-size
PiperOrigin-RevId: 321351533 Change-Id: Ic824da029bf209f66daf427881bfeb0392925a8c
This commit is contained in:
commit
53b57424ba
@ -1871,6 +1871,27 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "triangular_solve_expander_test",
|
||||
size = "medium",
|
||||
srcs = ["triangular_solve_expander_test.cc"],
|
||||
shard_count = 3,
|
||||
deps = [
|
||||
":hlo",
|
||||
":triangular_solve_expander",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:reference_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:verified_hlo_module",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky_expander",
|
||||
srcs = ["cholesky_expander.cc"],
|
||||
|
@ -454,6 +454,9 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
|
||||
} // namespace
|
||||
|
||||
TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
|
||||
: block_size_(block_size) {}
|
||||
|
||||
bool TriangularSolveExpander::InstructionMatchesPattern(
|
||||
HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kTriangularSolve;
|
||||
@ -496,7 +499,7 @@ StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
|
||||
|
||||
BuildTriangularSolve(a, b, options.left_side(), options.lower(),
|
||||
transpose_a, conjugate_a, options.unit_diagonal(),
|
||||
/*block_size=*/128,
|
||||
/*block_size=*/block_size_,
|
||||
/*precision=*/PrecisionConfig::HIGHEST);
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
|
||||
|
||||
|
@ -23,6 +23,8 @@ namespace xla {
|
||||
|
||||
class TriangularSolveExpander : public OpExpanderPass {
|
||||
public:
|
||||
explicit TriangularSolveExpander(int64 block_size = 128);
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "triangular_solve_expander";
|
||||
}
|
||||
@ -34,6 +36,8 @@ class TriangularSolveExpander : public OpExpanderPass {
|
||||
HloInstruction* instruction) override;
|
||||
|
||||
private:
|
||||
// Block size for BuildTriangularSolve
|
||||
const int64 block_size_;
|
||||
// Mapping from op signatures to existing computations.
|
||||
absl::flat_hash_map<string, HloComputation*> computation_cache_;
|
||||
};
|
||||
|
@ -0,0 +1,108 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/reference_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class TriangularExpanderTest : public HloTestBase,
|
||||
public ::testing::WithParamInterface<int32> {};
|
||||
|
||||
TEST_P(TriangularExpanderTest, TestBlockSize) {
|
||||
auto block_size = GetParam();
|
||||
std::string hlo_string = R"(
|
||||
HloModule TensorFlowTriangularSolve
|
||||
|
||||
ENTRY main {
|
||||
a = f32[256,256]{1,0} parameter(0)
|
||||
b = f32[256,192]{1,0} parameter(1)
|
||||
ROOT triangular-solve = f32[256,192]{1,0} triangular-solve(a, b),
|
||||
left_side=true, unit_diagonal=true,
|
||||
lower=true, transpose_a=NO_TRANSPOSE
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
{
|
||||
TriangularSolveExpander triangular_solve_expander(block_size);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool result, RunHloPass(&triangular_solve_expander, module.get()));
|
||||
EXPECT_TRUE(result);
|
||||
}
|
||||
|
||||
// To test triangular solver expander we generate simple bidiagonal matrix:
|
||||
// Solve a * x = b.
|
||||
// Check that shape is still valid.
|
||||
// Use reference matrix multiplication to test validity of result.
|
||||
|
||||
Array2D<float> a(256, 256);
|
||||
for (int64 row = 0; row < a.dim(0); ++row) {
|
||||
a(row, row) = 1;
|
||||
if (row > 0) {
|
||||
a(row, row - 1) = 0.01;
|
||||
}
|
||||
}
|
||||
|
||||
Array2D<float> b(256, 192);
|
||||
const float kMax = static_cast<float>(b.dim(0) * b.dim(1) + 1);
|
||||
for (int64 row = 0; row < b.dim(0); ++row) {
|
||||
for (int64 col = 0; col < b.dim(1); ++col) {
|
||||
b(row, col) = static_cast<float>(row + col + 1) / kMax;
|
||||
}
|
||||
}
|
||||
auto la = LiteralUtil::CreateR2FromArray2D(a);
|
||||
auto lb = LiteralUtil::CreateR2FromArray2D(b);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal lx, Execute(std::move(module), {&la, &lb}));
|
||||
|
||||
auto x_shape = lx.shape();
|
||||
EXPECT_EQ(x_shape.dimensions_size(), 2);
|
||||
EXPECT_EQ(x_shape.dimensions(0), b.dim(0));
|
||||
EXPECT_EQ(x_shape.dimensions(1), b.dim(1));
|
||||
|
||||
Array2D<float> x(x_shape.dimensions(0), x_shape.dimensions(1));
|
||||
x.SetValues(lx.data<float>());
|
||||
|
||||
auto ref_b = ReferenceUtil::MatmulArray2D(a, x);
|
||||
auto ref_lb = LiteralUtil::CreateR2FromArray2D(*ref_b);
|
||||
|
||||
EXPECT_TRUE(
|
||||
LiteralTestUtil::NearOrEqual(ref_lb, lb, ErrorSpec{0.001, 0.001}));
|
||||
}
|
||||
|
||||
// block_size test limits based on the following considerations:
|
||||
// - test at least twice the range of original value
|
||||
// - try to test odd values unaligned with matrix dims
|
||||
// - full 1-256 range test takes too long to run
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(TriangularExpanderTestInstances, TriangularExpanderTest,
|
||||
::testing::Range(2, 256, 7));
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user