From 93baf2ca3508524e1dedbe24bd38a025d6b9eaf8 Mon Sep 17 00:00:00 2001 From: Vladimir Menshakov Date: Tue, 19 May 2020 16:09:17 +0100 Subject: [PATCH] Add explicit block_size to TriangularSolveExpander constructor This small patch allows passing block_size explicitly, removing hardcoded value of 128. Provide test for triangular solve expander using different block_size values --- tensorflow/compiler/xla/service/BUILD | 21 ++++ .../xla/service/triangular_solve_expander.cc | 5 +- .../xla/service/triangular_solve_expander.h | 4 + .../service/triangular_solve_expander_test.cc | 108 ++++++++++++++++++ 4 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 tensorflow/compiler/xla/service/triangular_solve_expander_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 126b62a8eb2..f5e267b874c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1808,6 +1808,27 @@ cc_library( ], ) +tf_cc_test( + name = "triangular_solve_expander_test", + size = "medium", + srcs = ["triangular_solve_expander_test.cc"], + shard_count = 3, + deps = [ + ":hlo", + ":triangular_solve_expander", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:verified_hlo_module", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "cholesky_expander", srcs = ["cholesky_expander.cc"], diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index cc483c310e8..d54eb9e78c3 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -454,6 +454,9 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, } // namespace +TriangularSolveExpander::TriangularSolveExpander(int64 block_size) + : block_size_(block_size) {} + bool TriangularSolveExpander::InstructionMatchesPattern( HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kTriangularSolve; @@ -496,7 +499,7 @@ StatusOr TriangularSolveExpander::ExpandInstruction( BuildTriangularSolve(a, b, options.left_side(), options.lower(), transpose_a, conjugate_a, options.unit_diagonal(), - /*block_size=*/128, + /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h index be2374ef8c8..362e8557229 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.h +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -23,6 +23,8 @@ namespace xla { class TriangularSolveExpander : public OpExpanderPass { public: + explicit TriangularSolveExpander(int64 block_size = 128); + absl::string_view name() const override { return "triangular_solve_expander"; } @@ -34,6 +36,8 @@ class TriangularSolveExpander : public OpExpanderPass { HloInstruction* instruction) override; private: + // Block size for BuildTriangularSolve + const int64 block_size_; // Mapping from op signatures to existing computations. absl::flat_hash_map computation_cache_; }; diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc b/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc new file mode 100644 index 00000000000..6cc95aba5d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class TriangularExpanderTest : public HloTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(TriangularExpanderTest, TestBlockSize) { + auto block_size = GetParam(); + std::string hlo_string = R"( + HloModule TensorFlowTriangularSolve + + ENTRY main { + a = f32[256,256]{1,0} parameter(0) + b = f32[256,192]{1,0} parameter(1) + ROOT triangular-solve = f32[256,192]{1,0} triangular-solve(a, b), + left_side=true, unit_diagonal=true, + lower=true, transpose_a=NO_TRANSPOSE + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + TriangularSolveExpander triangular_solve_expander(block_size); + + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&triangular_solve_expander, module.get())); + EXPECT_TRUE(result); + } + + // To test triangular solver expander we generate simple bidiagonal matrix: + // Solve a * x = b. + // Check that shape is still valid. + // Use reference matrix multiplication to test validity of result. + + Array2D a(256, 256); + for (int64 row = 0; row < a.dim(0); ++row) { + a(row, row) = 1; + if (row > 0) { + a(row, row - 1) = 0.01; + } + } + + Array2D b(256, 192); + const float kMax = (b.dim(0) * b.dim(1) + 1); + for (int64 row = 0; row < b.dim(0); ++row) { + for (int64 col = 0; col < b.dim(1); ++col) { + b(row, col) = (row + col + 1) / kMax; + } + } + auto la = LiteralUtil::CreateR2FromArray2D(a); + auto lb = LiteralUtil::CreateR2FromArray2D(b); + + TF_ASSERT_OK_AND_ASSIGN(Literal lx, Execute(std::move(module), {&la, &lb})); + + auto x_shape = lx.shape(); + EXPECT_EQ(x_shape.dimensions_size(), 2); + EXPECT_EQ(x_shape.dimensions(0), b.dim(0)); + EXPECT_EQ(x_shape.dimensions(1), b.dim(1)); + + Array2D x(x_shape.dimensions(0), x_shape.dimensions(1)); + x.SetValues(lx.data()); + + auto ref_b = ReferenceUtil::MatmulArray2D(a, x); + auto ref_lb = LiteralUtil::CreateR2FromArray2D(*ref_b); + + EXPECT_TRUE( + LiteralTestUtil::NearOrEqual(ref_lb, lb, ErrorSpec{0.001, 0.001})); +} + +// block_size test limits based on the following considerations: +// - test at least twice the range of original value +// - try to test odd values unaligned with matrix dims +// - full 1-256 range test takes too long to run + +INSTANTIATE_TEST_CASE_P(TriangularExpanderTestInstances, TriangularExpanderTest, + ::testing::Range(2, 256, 7)); + +} // namespace +} // namespace xla