Add explicit block_size to TriangularSolveExpander constructor
This small patch allows passing block_size explicitly, removing hardcoded value of 128. Provide test for triangular solve expander using different block_size values
This commit is contained in:
		
							parent
							
								
									e0b19f6ef2
								
							
						
					
					
						commit
						93baf2ca35
					
				@ -1808,6 +1808,27 @@ cc_library(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
    name = "triangular_solve_expander_test",
 | 
			
		||||
    size = "medium",
 | 
			
		||||
    srcs = ["triangular_solve_expander_test.cc"],
 | 
			
		||||
    shard_count = 3,
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        ":triangular_solve_expander",
 | 
			
		||||
        "//tensorflow/compiler/jit:xla_cpu_jit",
 | 
			
		||||
        "//tensorflow/compiler/xla:literal",
 | 
			
		||||
        "//tensorflow/compiler/xla:reference_util",
 | 
			
		||||
        "//tensorflow/compiler/xla:test",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:hlo_test_base",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:verified_hlo_module",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/memory",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "cholesky_expander",
 | 
			
		||||
    srcs = ["cholesky_expander.cc"],
 | 
			
		||||
 | 
			
		||||
@ -454,6 +454,9 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
 | 
			
		||||
    : block_size_(block_size) {}
 | 
			
		||||
 | 
			
		||||
bool TriangularSolveExpander::InstructionMatchesPattern(
 | 
			
		||||
    HloInstruction* instruction) {
 | 
			
		||||
  return instruction->opcode() == HloOpcode::kTriangularSolve;
 | 
			
		||||
@ -496,7 +499,7 @@ StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
 | 
			
		||||
 | 
			
		||||
    BuildTriangularSolve(a, b, options.left_side(), options.lower(),
 | 
			
		||||
                         transpose_a, conjugate_a, options.unit_diagonal(),
 | 
			
		||||
                         /*block_size=*/128,
 | 
			
		||||
                         /*block_size=*/block_size_,
 | 
			
		||||
                         /*precision=*/PrecisionConfig::HIGHEST);
 | 
			
		||||
    TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,8 @@ namespace xla {
 | 
			
		||||
 | 
			
		||||
class TriangularSolveExpander : public OpExpanderPass {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit TriangularSolveExpander(int64 block_size = 128);
 | 
			
		||||
 | 
			
		||||
  absl::string_view name() const override {
 | 
			
		||||
    return "triangular_solve_expander";
 | 
			
		||||
  }
 | 
			
		||||
@ -34,6 +36,8 @@ class TriangularSolveExpander : public OpExpanderPass {
 | 
			
		||||
      HloInstruction* instruction) override;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Block size for BuildTriangularSolve
 | 
			
		||||
  const int64 block_size_;
 | 
			
		||||
  // Mapping from op signatures to existing computations.
 | 
			
		||||
  absl::flat_hash_map<string, HloComputation*> computation_cache_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,108 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/literal.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/reference_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/test.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/types.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status_test_util.h"
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
class TriangularExpanderTest : public HloTestBase,
 | 
			
		||||
                               public ::testing::WithParamInterface<int32> {};
 | 
			
		||||
 | 
			
		||||
TEST_P(TriangularExpanderTest, TestBlockSize) {
 | 
			
		||||
  auto block_size = GetParam();
 | 
			
		||||
  std::string hlo_string = R"(
 | 
			
		||||
    HloModule TensorFlowTriangularSolve
 | 
			
		||||
 | 
			
		||||
    ENTRY main {
 | 
			
		||||
      a = f32[256,256]{1,0} parameter(0)
 | 
			
		||||
      b = f32[256,192]{1,0} parameter(1)
 | 
			
		||||
      ROOT triangular-solve = f32[256,192]{1,0} triangular-solve(a, b),
 | 
			
		||||
                                    left_side=true, unit_diagonal=true,
 | 
			
		||||
                                    lower=true, transpose_a=NO_TRANSPOSE
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(auto module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(hlo_string));
 | 
			
		||||
 | 
			
		||||
  {
 | 
			
		||||
    TriangularSolveExpander triangular_solve_expander(block_size);
 | 
			
		||||
 | 
			
		||||
    TF_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
        bool result, RunHloPass(&triangular_solve_expander, module.get()));
 | 
			
		||||
    EXPECT_TRUE(result);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // To test triangular solver expander we generate simple bidiagonal matrix:
 | 
			
		||||
  // Solve a * x = b.
 | 
			
		||||
  // Check that shape is still valid.
 | 
			
		||||
  // Use reference matrix multiplication to test validity of result.
 | 
			
		||||
 | 
			
		||||
  Array2D<float> a(256, 256);
 | 
			
		||||
  for (int64 row = 0; row < a.dim(0); ++row) {
 | 
			
		||||
    a(row, row) = 1;
 | 
			
		||||
    if (row > 0) {
 | 
			
		||||
      a(row, row - 1) = 0.01;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Array2D<float> b(256, 192);
 | 
			
		||||
  const float kMax = (b.dim(0) * b.dim(1) + 1);
 | 
			
		||||
  for (int64 row = 0; row < b.dim(0); ++row) {
 | 
			
		||||
    for (int64 col = 0; col < b.dim(1); ++col) {
 | 
			
		||||
      b(row, col) = (row + col + 1) / kMax;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  auto la = LiteralUtil::CreateR2FromArray2D(a);
 | 
			
		||||
  auto lb = LiteralUtil::CreateR2FromArray2D(b);
 | 
			
		||||
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(Literal lx, Execute(std::move(module), {&la, &lb}));
 | 
			
		||||
 | 
			
		||||
  auto x_shape = lx.shape();
 | 
			
		||||
  EXPECT_EQ(x_shape.dimensions_size(), 2);
 | 
			
		||||
  EXPECT_EQ(x_shape.dimensions(0), b.dim(0));
 | 
			
		||||
  EXPECT_EQ(x_shape.dimensions(1), b.dim(1));
 | 
			
		||||
 | 
			
		||||
  Array2D<float> x(x_shape.dimensions(0), x_shape.dimensions(1));
 | 
			
		||||
  x.SetValues(lx.data<float>());
 | 
			
		||||
 | 
			
		||||
  auto ref_b = ReferenceUtil::MatmulArray2D(a, x);
 | 
			
		||||
  auto ref_lb = LiteralUtil::CreateR2FromArray2D(*ref_b);
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      LiteralTestUtil::NearOrEqual(ref_lb, lb, ErrorSpec{0.001, 0.001}));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// block_size test limits based on the following considerations:
 | 
			
		||||
// - test at least twice the range of original value
 | 
			
		||||
// - try to test odd values unaligned with matrix dims
 | 
			
		||||
// - full 1-256 range test takes too long to run
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_CASE_P(TriangularExpanderTestInstances, TriangularExpanderTest,
 | 
			
		||||
                        ::testing::Range(2, 256, 7));
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user