From 34b86e53d7faaa62a9b62946a3a0a6d65c517eba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Apr 2020 04:31:10 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 304157828 Change-Id: I72bd4c0fd81c0e42c3fe202c4237a443e6bdb135 --- .../xla/service/gpu/ir_emitter_unnested.cc | 198 +++--------- .../xla/service/gpu/kernel_mapping_scheme.h | 55 +--- .../compiler/xla/service/gpu/tests/BUILD | 27 -- .../gpu/tests/reduction_vectorization_test.cc | 299 ------------------ 4 files changed, 64 insertions(+), 515 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 089f604fad4..528a847b3ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -106,11 +106,6 @@ const auto kDimY = KernelMappingScheme::DimY; const auto kDimZ = KernelMappingScheme::DimZ; const auto kDimTot = KernelMappingScheme::DimTot; -const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX; -const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX; -const auto kLinearStridedIndexingX = - KernelMappingScheme::LinearStridedIndexingX; - // If a dimensions is smaller than this, untiled transposition may be more // efficient. const int64 kMinDimensionToTransposeTiled = 16; @@ -1868,8 +1863,9 @@ namespace { bool MayPreventVectorization(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kFusion) { return absl::c_any_of(hlo.fused_instructions_computation()->instructions(), - [&](const HloInstruction* instr) { + [](const HloInstruction* instr) { switch (instr->opcode()) { + case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kSort: case HloOpcode::kDot: @@ -1896,10 +1892,6 @@ bool MayPreventVectorization(const HloInstruction& hlo) { default: return false; } - } else if (hlo.opcode() == HloOpcode::kReduce) { - // TODO(nouiz): check if the to_apply() attribute contains instruction - // that break LLVM vectorization. - return false; } return true; } @@ -1928,59 +1920,13 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme, llvm::Value* thread_id_x, llvm::Type* index_ty, llvm::IRBuilder<>* b) { - auto constant = [&](int64 val) { - return llvm::ConstantInt::get(index_ty, val); - }; - if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) { + if (mapping_scheme.DilatedX()) { return thread_id_x; - } else if (mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) { - return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize())); } - CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX); int64 x_num_steps = mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX(); - return b->CreateMul(thread_id_x, constant(x_num_steps)); -} - -// Calls `emit_elem_function()` `x_num_steps` times. If -// `vector_size`==1, then each element index passed to -// `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1, -// then it must be a multiple of `x_num_steps`. In that case, it -// triggers a different indexing order that is vectorizable by -// LLVM. It generates many groups of calls to `emit_elem_function`. Each -// group is separated by `step_x` elements. Inside a group, elements -// are consecutive. If `check_x_tile_bounds` is true, then it will check -// if the element index is in bound compared to `tile_width` before -// calling `emit_elem_function`. -static void UnrollInnerTileLoop( - bool check_x_tile_bounds, int64 x_num_steps, int64 step_x, - int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl, - llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width, - const IrArray::Index& source_idx, llvm::IRBuilder<>* b, - const IrEmitterUnnested::EmitElementFunction* emit_elem_function) { - llvm::Type* index_ty = tile_width->getType(); - auto constant = [&](int64 val) { - return llvm::ConstantInt::get(index_ty, val); - }; - for (int64 j = 0; j < x_num_steps / vector_size; j++) { - for (int64 i = 0; i < vector_size; i++) { - int64 linear_index = j * vector_size + i; - llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i), - start_offset_x, "x_loc"); - IrArray::Index source_idx_x = - source_idx.AddOffsetToDim(y_loc, kDimY, b) - .AddOffsetToDim(constant(j * step_x * vector_size + i), kDimX, b); - auto emit_element = [&] { - return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index); - }; - if (check_x_tile_bounds) { - ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width), - emit_element); - } else { - emit_element(); - } - } - } + return b->CreateMul(thread_id_x, + llvm::ConstantInt::get(index_ty, x_num_steps)); } void IrEmitterUnnested::EmitTile( @@ -2005,9 +1951,7 @@ void IrEmitterUnnested::EmitTile( // of threads. // Otherwise, the stride is one, but we multiply each offset by the limit of // number of steps which can be made. - int64 step_x = - mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x; - int64 vector_size = mapping_scheme.GetVectorSize(); + int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1; IrArray::Index source_idx = tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_); @@ -2043,29 +1987,21 @@ void IrEmitterUnnested::EmitTile( llvm::Value* y_loc = b_.CreateAdd(thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y)); - auto unrollInnerTileLoop = [&](bool check_x_tile_bounds) { - return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, - step_x, vector_size, loop_name, ksl, - start_offset_x, y_loc, tile_width, - source_idx, &b_, &emit_elem_function); - }; - - // Only take this path when we unroll in a way vectorizable by - // LLVM. Special case when the tile doesn't fit completely for even - // row size. For odd row size every other row isn't aligned to the - // vectorized size, so it can't be vectorized by LLVM. - if (!x_tile_fits && - mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) { - ksl->If( - loop_name + "_is_full_tile", - // For the last block, tile_width will be the number of - // elements left. - b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()), - tile_width), - [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/false); }, - [&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/true); }); - } else { - unrollInnerTileLoop(/*check_x_tile_bounds=*/!x_tile_fits); + for (int64 j = 0; j < x_num_steps; j++) { + llvm::Value* x_loc = + b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc"); + IrArray::Index source_idx_x = + source_idx.AddOffsetToDim(y_loc, kDimY, &b_) + .AddOffsetToDim(constant(j * step_x), kDimX, &b_); + auto emit_element = [&] { + return emit_elem_function(source_idx_x, y_loc, x_loc, j); + }; + if (!x_tile_fits) { + ksl->If(loop_name + "_x_in_tile", + b_.CreateICmpULT(x_loc, tile_width), emit_element); + } else { + emit_element(); + } } }); } @@ -2099,19 +2035,6 @@ static IrArray::Index GetUnnormalizedIndex( const Shape& unnormalized_shape, llvm::IRBuilder<>* b_, const KernelMappingScheme& kernel_mapping_scheme) { DCHECK_EQ(normalized_shape_index.size(), 3); - // If the normalization only add a new dimensions of size 1, - // generate simpler indexing. LLVM doesn't always simplify the more - // complicated indexing and this prevents it from vectorizing some - // cases. We do this only for major_to_minor memory layout. - if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() && - unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] && - unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] && - unnormalized_shape.layout().minor_to_major(1) == 0) { - DCHECK_EQ(normalized_shape_index.dims()[0], 1); - auto multidim = normalized_shape_index.multidim(); - return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape, - normalized_shape_index.GetType()); - } llvm::Value* linear = normalized_shape_index.Linearize( kernel_mapping_scheme.GetDimsInElems(), b_); return IrArray::Index(linear, unnormalized_shape, b_); @@ -2154,6 +2077,21 @@ void IrEmitterUnnested::EmitTileElementForFusion( } } +// Gets the number of partial results accumulated by a single thread performing +// reduction. +static int GetNumberOfPartialResults( + const ReductionCodegenInfo& reduction_info) { + const KernelMappingScheme& mapping_scheme = + reduction_info.GetKernelMappingScheme(); + if (reduction_info.IsRowReduction()) { + return 1; + } + int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2; + CHECK_EQ(num_partial_results, + (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX())); + return num_partial_results; +} + void IrEmitterUnnested::EmitPrologueForReduction( HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info, absl::Span reduce_instructions, @@ -2180,7 +2118,7 @@ void IrEmitterUnnested::EmitPrologueForReduction( llvm::AllocaInst* reduction_input_address = Alloca(element_type); reduction_input_addresses->push_back(reduction_input_address); - int num_partial_results = reduction_info->GetNumPartialResults(); + int num_partial_results = GetNumberOfPartialResults(*reduction_info); AddressVector* partial_result_addresses = reduction_info->GetMutablePartialResultAddresses(); llvm::AllocaInst* partial_result_address = @@ -2332,7 +2270,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction( absl::Span partial_result_addresses = reduction_info.GetPartialResultAddresses(); - int num_partial_results = reduction_info.GetNumPartialResults(); + int num_partial_results = GetNumberOfPartialResults(reduction_info); // Emit an atomic operation that accumulates the partial reduction to the // output element. For row reduction, this is only for lane 0 due to the @@ -2546,7 +2484,7 @@ void IrEmitterUnnested::EmitTileElementForReduction( // GetElementPointer with array types. This enables the vectorization of // the computation for different partial results. Use this index if // 'num_partial_results > 1'. - int num_partial_results = reduction_info.GetNumPartialResults(); + int num_partial_results = GetNumberOfPartialResults(reduction_info); auto index_without_linear = IrArray::Index( input_index.multidim(), reduction_operand_shape, input_index.GetType()); @@ -2732,8 +2670,7 @@ void IrEmitterUnnested::EmitHlo021Tile( /*tile_sizes=*/{1, kWarpSize, kWarpSize}, /*num_threads_y=*/kNumRows, /*num_threads_x=*/kWarpSize, - /*indexing_order=*/kLinearIndexingX, - /*vector_size=*/1); + /*is_dilated_x=*/false); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); llvm::Type* index_type = @@ -3174,6 +3111,15 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( std::array reduction_tiling = GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits, &ir_emitter_context_->device_description()); + bool dilated_x = + reduction_dimensions.is_row_reduction || + !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape, + reduction_dimensions.dimensions[2]); + + if (!dilated_x && !reduction_dimensions.is_row_reduction) { + // Vectorized loads: a single thread reduces two adjacent columns. + reduction_tiling[2] *= 2; + } int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize; int64 num_threads_x = [&] { @@ -3187,54 +3133,12 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( return kWarpSize; }(); - bool tile_fit = reduction_dimensions.dimensions[kDimX] % - (reduction_tiling[2] * num_threads_x) == - 0; - - int cc_major = 0, cc_minor = 0; - ir_emitter_context_->device_description().cuda_compute_capability(&cc_major, - &cc_minor); - - int num_partial_results = 1; - KernelMappingScheme::IndexingOrder indexing_order = [&]() { - if (reduction_dimensions.is_row_reduction && - // P100, only try to vectorize+coales memory access when the - // tile size fits exactly and dtypes <= 32 bits - ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) || - // On V100, only try to vectorize+coales memory access for - // rows of even size. For odd row sizes, every other row - // isn't aligned, so it can't be vectorized. - (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) { - return kLinearStridedIndexingX; - } else if (!reduction_dimensions.is_row_reduction && - IsUnrollingColumnReductionBeneficial( - unnested_hlo, input_shape, - reduction_dimensions.dimensions[2])) { - num_partial_results = 2; - reduction_tiling[2] *= num_partial_results; - return kLinearIndexingX; - } else { - return kStridedIndexingX; - } - }(); - - int vector_size = 1; - if (indexing_order == kLinearStridedIndexingX) { - if (reduction_dimensions.dimensions[2] % 2 == 0 && - // Assuming XLA will perform the unrolling and LLVM will vectorize, - // disable the unroll for the cases that LLVM doesn't vectorize. - !MayPreventVectorization(*unnested_hlo)) { - vector_size = 2; - } else { - indexing_order = kStridedIndexingX; - } - } KernelMappingScheme mapping_scheme( reduction_dimensions.dimensions, {reduction_tiling[0], reduction_tiling[1] * num_threads_y, reduction_tiling[2] * num_threads_x}, - num_threads_y, num_threads_x, indexing_order, vector_size); - return ReductionCodegenInfo(mapping_scheme, num_partial_results, + num_threads_y, num_threads_x, dilated_x); + return ReductionCodegenInfo(mapping_scheme, reduction_dimensions.is_row_reduction); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h index cd690c910c2..eeab8d4dc80 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h @@ -76,33 +76,19 @@ namespace gpu { class KernelMappingScheme { public: enum { DimZ = 0, DimY, DimX, DimTot }; - enum IndexingOrder { - // Thread reads consecutive elements. - LinearIndexingX, - // Thread reads strided elements while keeping memory coalescing. - StridedIndexingX, - // Thread reads a few consecutive elements then take a strided - // step. This can trigger vectorized reads and keep memory - // coalescing. - LinearStridedIndexingX - }; - KernelMappingScheme(absl::Span dims_in_elems, absl::Span tile_sizes, int64 num_threads_y, - int64 num_threads_x, IndexingOrder indexing_order, - int vector_size) + int64 num_threads_x, bool is_dilated_x) : dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]}, tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]}, num_threads_x_(num_threads_x), num_threads_y_(num_threads_y), - indexing_order_(indexing_order), - vector_size_(vector_size) { + dilated_x_(is_dilated_x) { CHECK_EQ(tile_sizes[1] % num_threads_y_, 0); CHECK_EQ(tile_sizes[2] % num_threads_x_, 0); VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ","); - if (indexing_order != LinearIndexingX) { - // StridedIndexingX, and LinearStridedIndexingX - // is for the purpose of vectorization, which requires + if (!dilated_x_) { + // dilated_x_=false is for the purpose of vectorization, which requires // GetTileSizeFor(DimX) to be a multiplier of num_threads_x_. CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0); } @@ -132,8 +118,7 @@ class KernelMappingScheme { return GetNumThreadsX() * GetNumThreadsY(); } - IndexingOrder GetIndexingOrder() const { return indexing_order_; } - int GetVectorSize() const { return vector_size_; } + bool DilatedX() const { return dilated_x_; } private: // The number of elements in each dimension. @@ -148,17 +133,12 @@ class KernelMappingScheme { // Number of threads used to process elements in the Y direction of a tile. const int64 num_threads_y_; - // When num_threads_x threads process a total of tile_size_x - // elements in the X dimension of a tile, each threads process - // n=tile_size_x/num_threads_x elements. - // indexing_order defines which tile's elements each thread reads. - const IndexingOrder indexing_order_; - - // vector_size_ only supported for row reduction and must be a divisor - // of tile_sizes_[2]/num_threads_x. Interesting values are 2 and 4 - // to trigger vectorized loads on GPUs while keeping memory - // coalescing. - const int vector_size_; + // When num_threads_x threads process a total of tile_size_x elements in the + // X dimension of a tile, each threads process n=tile_size_x/num_threads_x + // elements. When dilated_x=false, the n elements processed by a thread are + // contiguous. On the other hand, when dilated_x=true the n elements are + // dilated by a factor of num_threads_x. + const bool dilated_x_; }; // Information to support the code generation for a tiled reduction kernel. @@ -166,15 +146,8 @@ using AddressVector = absl::InlinedVector; class ReductionCodegenInfo { public: explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme, - int num_partial_results, bool is_row_reduction) - : mapping_scheme_(mapping_scheme), - num_partial_results_(num_partial_results), - is_row_reduction_(is_row_reduction) { - if (num_partial_results > 1) { - CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() / - mapping_scheme.GetNumThreadsX())); - } - } + bool is_row_reduction) + : mapping_scheme_(mapping_scheme), is_row_reduction_(is_row_reduction) {} const KernelMappingScheme& GetKernelMappingScheme() const { return mapping_scheme_; @@ -210,7 +183,6 @@ class ReductionCodegenInfo { return reduction_input_addresses_; } - int GetNumPartialResults() const { return num_partial_results_; } bool IsRowReduction() const { return is_row_reduction_; } // Gets a pointer to a mutable shared cache used by reduction. @@ -229,7 +201,6 @@ class ReductionCodegenInfo { const KernelMappingScheme mapping_scheme_; AddressVector partial_result_addresses_; AddressVector reduction_input_addresses_; - int num_partial_results_; bool is_row_reduction_; }; diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index e04dba418d9..1fd51c78988 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -164,33 +164,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "reduction_vectorization_test", - srcs = [ - "reduction_vectorization_test.cc", - ], - tags = tf_cuda_tests_tags() + ["no_rocm"], - deps = [ - ":gpu_codegen_test", - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/gpu:gemm_rewriter", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/tests:filecheck", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/memory", - ], -) - tf_cc_test( name = "reduction_dimension_grouper_test", srcs = [ diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc deleted file mode 100644 index 5f27df08b90..00000000000 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_vectorization_test.cc +++ /dev/null @@ -1,299 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tests/filecheck.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -class ReductionVectorizationTest : public GpuCodegenTest {}; - -TEST_F(ReductionVectorizationTest, Power2) { - const char* hlo_text = R"( -HloModule ReducePower2 - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -ENTRY %main { - %param_0 = f32[5,131072] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %reduce.8 = f32[5] reduce(f32[5,131072] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - se::StreamExecutor* executor = backend().default_stream_executor(); - int cc_major = 0, cc_minor = 0; - executor->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - string expected_ptx; - if (cc_major >= 6) { - expected_ptx = R"( -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -)"; - } else { - expected_ptx = R"( -CHECK-NOT: ld.global.nc.v2.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -)"; - } - CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(ReductionVectorizationTest, TileFit) { - const char* hlo_text = R"( -HloModule ReduceTileFit - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -ENTRY %main { - %param_0 = f32[5,122880] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %reduce.8 = f32[5] reduce(f32[5,122880] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - se::StreamExecutor* executor = backend().default_stream_executor(); - int cc_major = 0, cc_minor = 0; - executor->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - string expected_ptx; - if (cc_major >= 6) { - expected_ptx = R"( -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -)"; - } else { - expected_ptx = R"( -CHECK-NOT: ld.global.nc.v2.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -)"; - } - CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(ReductionVectorizationTest, EvenColumns) { - const char* hlo_text = R"( -HloModule ReducePower2 - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -ENTRY %main { - %param_0 = f32[5,131070] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %reduce.8 = f32[5] reduce(f32[5,131070] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - se::StreamExecutor* executor = backend().default_stream_executor(); - int cc_major = 0, cc_minor = 0; - executor->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - string expected_ptx; - if (cc_major >= 7) { - expected_ptx = R"( -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK-NOT: ld.global.nc.v2.f32 -// TODO: Make this a vectorized load -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -)"; - } else { - expected_ptx = R"( -CHECK-NOT: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -)"; - } - CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(ReductionVectorizationTest, DisableOddColumns) { - const char* hlo_text = R"( -HloModule ReduceTileFit - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(%x, %y) -} - -ENTRY %main { - %param_0 = f32[5,131071] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %reduce.8 = f32[5] reduce(f32[5,131071] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_ -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(optimized_module), - R"( -CHECK-NOT: ld.global.nc.v2.f32 -CHECK-NOT: ld.global.nc.v4.f32 -CHECK-NOT: ld.global.nc.u64 -CHECK-NOT: ld.global.u64 -)"); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(ReductionVectorizationTest, Exp) { - const char* hlo_text = R"( -HloModule DisableSin - -%add_float { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) -} - -ENTRY %main { - %arg0.1 = f32[5,131072] parameter(0) - %sine = f32[5,131072] exponential(f32[5,131072] %arg0.1) - %constant.0 = f32[] constant(0) - ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - se::StreamExecutor* executor = backend().default_stream_executor(); - int cc_major = 0, cc_minor = 0; - executor->GetDeviceDescription().cuda_compute_capability(&cc_major, - &cc_minor); - string expected_ptx; - if (cc_major >= 6) { - expected_ptx = R"( -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -CHECK: ld.global.nc.v2.f32 -)"; - } else { - expected_ptx = R"( -CHECK-NOT: ld.global.nc.v2.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -CHECK: ld.global.nc.f32 -)"; - } - CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(ReductionVectorizationTest, DisableSin) { - const char* hlo_text = R"( -HloModule DisableSin - -%add_float { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) -} - -ENTRY %main { - %arg0.1 = f32[5,131072] parameter(0) - %sine = f32[5,131072] sine(f32[5,131072] %arg0.1) - %constant.0 = f32[] constant(0) - ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(optimized_module), - R"( -CHECK-NOT: ld.global.nc.v2.f32 -CHECK-NOT: ld.global.nc.v4.f32 -CHECK-NOT: ld.global.nc.u64 -CHECK-NOT: ld.global.u64 -)"); - - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -} // namespace -} // namespace gpu -} // namespace xla