Roll forward with a fix
PiperOrigin-RevId: 304156244 Change-Id: Ib14f8613aa5de72de6f2f8117d98b73d6ed5e297
This commit is contained in:
parent
271da41dfa
commit
f7b584639f
|
@ -106,6 +106,11 @@ const auto kDimY = KernelMappingScheme::DimY;
|
|||
const auto kDimZ = KernelMappingScheme::DimZ;
|
||||
const auto kDimTot = KernelMappingScheme::DimTot;
|
||||
|
||||
const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
|
||||
const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
|
||||
const auto kLinearStridedIndexingX =
|
||||
KernelMappingScheme::LinearStridedIndexingX;
|
||||
|
||||
// If a dimensions is smaller than this, untiled transposition may be more
|
||||
// efficient.
|
||||
const int64 kMinDimensionToTransposeTiled = 16;
|
||||
|
@ -1863,9 +1868,8 @@ namespace {
|
|||
bool MayPreventVectorization(const HloInstruction& hlo) {
|
||||
if (hlo.opcode() == HloOpcode::kFusion) {
|
||||
return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
|
||||
[](const HloInstruction* instr) {
|
||||
[&](const HloInstruction* instr) {
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kDot:
|
||||
|
@ -1892,6 +1896,10 @@ bool MayPreventVectorization(const HloInstruction& hlo) {
|
|||
default:
|
||||
return false;
|
||||
}
|
||||
} else if (hlo.opcode() == HloOpcode::kReduce) {
|
||||
// TODO(nouiz): check if the to_apply() attribute contains instruction
|
||||
// that break LLVM vectorization.
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -1920,13 +1928,59 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
|
|||
llvm::Value* thread_id_x,
|
||||
llvm::Type* index_ty,
|
||||
llvm::IRBuilder<>* b) {
|
||||
if (mapping_scheme.DilatedX()) {
|
||||
auto constant = [&](int64 val) {
|
||||
return llvm::ConstantInt::get(index_ty, val);
|
||||
};
|
||||
if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
|
||||
return thread_id_x;
|
||||
} else if (mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
|
||||
return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
|
||||
}
|
||||
CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
|
||||
int64 x_num_steps =
|
||||
mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
|
||||
return b->CreateMul(thread_id_x,
|
||||
llvm::ConstantInt::get(index_ty, x_num_steps));
|
||||
return b->CreateMul(thread_id_x, constant(x_num_steps));
|
||||
}
|
||||
|
||||
// Calls `emit_elem_function()` `x_num_steps` times. If
|
||||
// `vector_size`==1, then each element index passed to
|
||||
// `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
|
||||
// then it must be a multiple of `x_num_steps`. In that case, it
|
||||
// triggers a different indexing order that is vectorizable by
|
||||
// LLVM. It generates many groups of calls to `emit_elem_function`. Each
|
||||
// group is separated by `step_x` elements. Inside a group, elements
|
||||
// are consecutive. If `check_x_tile_bounds` is true, then it will check
|
||||
// if the element index is in bound compared to `tile_width` before
|
||||
// calling `emit_elem_function`.
|
||||
static void UnrollInnerTileLoop(
|
||||
bool check_x_tile_bounds, int64 x_num_steps, int64 step_x,
|
||||
int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl,
|
||||
llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
|
||||
const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
|
||||
const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
|
||||
llvm::Type* index_ty = tile_width->getType();
|
||||
auto constant = [&](int64 val) {
|
||||
return llvm::ConstantInt::get(index_ty, val);
|
||||
};
|
||||
for (int64 j = 0; j < x_num_steps / vector_size; j++) {
|
||||
for (int64 i = 0; i < vector_size; i++) {
|
||||
int64 linear_index = j * vector_size + i;
|
||||
llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
|
||||
start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, b)
|
||||
.AddOffsetToDim(constant(j * step_x * vector_size + i), kDimX, b);
|
||||
auto emit_element = [&] {
|
||||
return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
|
||||
};
|
||||
if (check_x_tile_bounds) {
|
||||
ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
|
||||
emit_element);
|
||||
} else {
|
||||
emit_element();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void IrEmitterUnnested::EmitTile(
|
||||
|
@ -1951,7 +2005,9 @@ void IrEmitterUnnested::EmitTile(
|
|||
// of threads.
|
||||
// Otherwise, the stride is one, but we multiply each offset by the limit of
|
||||
// number of steps which can be made.
|
||||
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
|
||||
int64 step_x =
|
||||
mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
|
||||
int64 vector_size = mapping_scheme.GetVectorSize();
|
||||
|
||||
IrArray::Index source_idx =
|
||||
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
|
||||
|
@ -1987,21 +2043,29 @@ void IrEmitterUnnested::EmitTile(
|
|||
llvm::Value* y_loc =
|
||||
b_.CreateAdd(thread_id_info.thread_id_y,
|
||||
b_.CreateMul(y_indvar, num_threads_y));
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||
auto emit_element = [&] {
|
||||
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
};
|
||||
if (!x_tile_fits) {
|
||||
ksl->If(loop_name + "_x_in_tile",
|
||||
b_.CreateICmpULT(x_loc, tile_width), emit_element);
|
||||
} else {
|
||||
emit_element();
|
||||
}
|
||||
auto unrollInnerTileLoop = [&](bool check_x_tile_bounds) {
|
||||
return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps,
|
||||
step_x, vector_size, loop_name, ksl,
|
||||
start_offset_x, y_loc, tile_width,
|
||||
source_idx, &b_, &emit_elem_function);
|
||||
};
|
||||
|
||||
// Only take this path when we unroll in a way vectorizable by
|
||||
// LLVM. Special case when the tile doesn't fit completely for even
|
||||
// row size. For odd row size every other row isn't aligned to the
|
||||
// vectorized size, so it can't be vectorized by LLVM.
|
||||
if (!x_tile_fits &&
|
||||
mapping_scheme.GetIndexingOrder() == kLinearStridedIndexingX) {
|
||||
ksl->If(
|
||||
loop_name + "_is_full_tile",
|
||||
// For the last block, tile_width will be the number of
|
||||
// elements left.
|
||||
b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
|
||||
tile_width),
|
||||
[&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/false); },
|
||||
[&] { unrollInnerTileLoop(/*check_x_tile_bounds=*/true); });
|
||||
} else {
|
||||
unrollInnerTileLoop(/*check_x_tile_bounds=*/!x_tile_fits);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -2035,6 +2099,19 @@ static IrArray::Index GetUnnormalizedIndex(
|
|||
const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
|
||||
const KernelMappingScheme& kernel_mapping_scheme) {
|
||||
DCHECK_EQ(normalized_shape_index.size(), 3);
|
||||
// If the normalization only add a new dimensions of size 1,
|
||||
// generate simpler indexing. LLVM doesn't always simplify the more
|
||||
// complicated indexing and this prevents it from vectorizing some
|
||||
// cases. We do this only for major_to_minor memory layout.
|
||||
if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
|
||||
unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
|
||||
unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
|
||||
unnormalized_shape.layout().minor_to_major(1) == 0) {
|
||||
DCHECK_EQ(normalized_shape_index.dims()[0], 1);
|
||||
auto multidim = normalized_shape_index.multidim();
|
||||
return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
|
||||
normalized_shape_index.GetType());
|
||||
}
|
||||
llvm::Value* linear = normalized_shape_index.Linearize(
|
||||
kernel_mapping_scheme.GetDimsInElems(), b_);
|
||||
return IrArray::Index(linear, unnormalized_shape, b_);
|
||||
|
@ -2077,21 +2154,6 @@ void IrEmitterUnnested::EmitTileElementForFusion(
|
|||
}
|
||||
}
|
||||
|
||||
// Gets the number of partial results accumulated by a single thread performing
|
||||
// reduction.
|
||||
static int GetNumberOfPartialResults(
|
||||
const ReductionCodegenInfo& reduction_info) {
|
||||
const KernelMappingScheme& mapping_scheme =
|
||||
reduction_info.GetKernelMappingScheme();
|
||||
if (reduction_info.IsRowReduction()) {
|
||||
return 1;
|
||||
}
|
||||
int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2;
|
||||
CHECK_EQ(num_partial_results,
|
||||
(mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX()));
|
||||
return num_partial_results;
|
||||
}
|
||||
|
||||
void IrEmitterUnnested::EmitPrologueForReduction(
|
||||
HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info,
|
||||
absl::Span<HloInstruction* const> reduce_instructions,
|
||||
|
@ -2118,7 +2180,7 @@ void IrEmitterUnnested::EmitPrologueForReduction(
|
|||
llvm::AllocaInst* reduction_input_address = Alloca(element_type);
|
||||
reduction_input_addresses->push_back(reduction_input_address);
|
||||
|
||||
int num_partial_results = GetNumberOfPartialResults(*reduction_info);
|
||||
int num_partial_results = reduction_info->GetNumPartialResults();
|
||||
AddressVector* partial_result_addresses =
|
||||
reduction_info->GetMutablePartialResultAddresses();
|
||||
llvm::AllocaInst* partial_result_address =
|
||||
|
@ -2270,7 +2332,7 @@ void IrEmitterUnnested::EmitEpilogueForReduction(
|
|||
absl::Span<llvm::AllocaInst* const> partial_result_addresses =
|
||||
reduction_info.GetPartialResultAddresses();
|
||||
|
||||
int num_partial_results = GetNumberOfPartialResults(reduction_info);
|
||||
int num_partial_results = reduction_info.GetNumPartialResults();
|
||||
|
||||
// Emit an atomic operation that accumulates the partial reduction to the
|
||||
// output element. For row reduction, this is only for lane 0 due to the
|
||||
|
@ -2484,7 +2546,7 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
|||
// GetElementPointer with array types. This enables the vectorization of
|
||||
// the computation for different partial results. Use this index if
|
||||
// 'num_partial_results > 1'.
|
||||
int num_partial_results = GetNumberOfPartialResults(reduction_info);
|
||||
int num_partial_results = reduction_info.GetNumPartialResults();
|
||||
auto index_without_linear = IrArray::Index(
|
||||
input_index.multidim(), reduction_operand_shape, input_index.GetType());
|
||||
|
||||
|
@ -2670,7 +2732,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||
/*tile_sizes=*/{1, kWarpSize, kWarpSize},
|
||||
/*num_threads_y=*/kNumRows,
|
||||
/*num_threads_x=*/kWarpSize,
|
||||
/*is_dilated_x=*/false);
|
||||
/*indexing_order=*/kLinearIndexingX,
|
||||
/*vector_size=*/1);
|
||||
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
|
||||
mapping_scheme.GetThreadsPerBlock());
|
||||
llvm::Type* index_type =
|
||||
|
@ -3111,15 +3174,6 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
|
|||
std::array<int64, 3> reduction_tiling =
|
||||
GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
|
||||
&ir_emitter_context_->device_description());
|
||||
bool dilated_x =
|
||||
reduction_dimensions.is_row_reduction ||
|
||||
!IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
|
||||
reduction_dimensions.dimensions[2]);
|
||||
|
||||
if (!dilated_x && !reduction_dimensions.is_row_reduction) {
|
||||
// Vectorized loads: a single thread reduces two adjacent columns.
|
||||
reduction_tiling[2] *= 2;
|
||||
}
|
||||
|
||||
int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
|
||||
int64 num_threads_x = [&] {
|
||||
|
@ -3133,12 +3187,54 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
|
|||
return kWarpSize;
|
||||
}();
|
||||
|
||||
bool tile_fit = reduction_dimensions.dimensions[kDimX] %
|
||||
(reduction_tiling[2] * num_threads_x) ==
|
||||
0;
|
||||
|
||||
int cc_major = 0, cc_minor = 0;
|
||||
ir_emitter_context_->device_description().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
|
||||
int num_partial_results = 1;
|
||||
KernelMappingScheme::IndexingOrder indexing_order = [&]() {
|
||||
if (reduction_dimensions.is_row_reduction &&
|
||||
// P100, only try to vectorize+coales memory access when the
|
||||
// tile size fits exactly and dtypes <= 32 bits
|
||||
((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
|
||||
// On V100, only try to vectorize+coales memory access for
|
||||
// rows of even size. For odd row sizes, every other row
|
||||
// isn't aligned, so it can't be vectorized.
|
||||
(cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
|
||||
return kLinearStridedIndexingX;
|
||||
} else if (!reduction_dimensions.is_row_reduction &&
|
||||
IsUnrollingColumnReductionBeneficial(
|
||||
unnested_hlo, input_shape,
|
||||
reduction_dimensions.dimensions[2])) {
|
||||
num_partial_results = 2;
|
||||
reduction_tiling[2] *= num_partial_results;
|
||||
return kLinearIndexingX;
|
||||
} else {
|
||||
return kStridedIndexingX;
|
||||
}
|
||||
}();
|
||||
|
||||
int vector_size = 1;
|
||||
if (indexing_order == kLinearStridedIndexingX) {
|
||||
if (reduction_dimensions.dimensions[2] % 2 == 0 &&
|
||||
// Assuming XLA will perform the unrolling and LLVM will vectorize,
|
||||
// disable the unroll for the cases that LLVM doesn't vectorize.
|
||||
!MayPreventVectorization(*unnested_hlo)) {
|
||||
vector_size = 2;
|
||||
} else {
|
||||
indexing_order = kStridedIndexingX;
|
||||
}
|
||||
}
|
||||
KernelMappingScheme mapping_scheme(
|
||||
reduction_dimensions.dimensions,
|
||||
{reduction_tiling[0], reduction_tiling[1] * num_threads_y,
|
||||
reduction_tiling[2] * num_threads_x},
|
||||
num_threads_y, num_threads_x, dilated_x);
|
||||
return ReductionCodegenInfo(mapping_scheme,
|
||||
num_threads_y, num_threads_x, indexing_order, vector_size);
|
||||
return ReductionCodegenInfo(mapping_scheme, num_partial_results,
|
||||
reduction_dimensions.is_row_reduction);
|
||||
}
|
||||
|
||||
|
|
|
@ -76,19 +76,33 @@ namespace gpu {
|
|||
class KernelMappingScheme {
|
||||
public:
|
||||
enum { DimZ = 0, DimY, DimX, DimTot };
|
||||
enum IndexingOrder {
|
||||
// Thread reads consecutive elements.
|
||||
LinearIndexingX,
|
||||
// Thread reads strided elements while keeping memory coalescing.
|
||||
StridedIndexingX,
|
||||
// Thread reads a few consecutive elements then take a strided
|
||||
// step. This can trigger vectorized reads and keep memory
|
||||
// coalescing.
|
||||
LinearStridedIndexingX
|
||||
};
|
||||
|
||||
KernelMappingScheme(absl::Span<const int64> dims_in_elems,
|
||||
absl::Span<const int64> tile_sizes, int64 num_threads_y,
|
||||
int64 num_threads_x, bool is_dilated_x)
|
||||
int64 num_threads_x, IndexingOrder indexing_order,
|
||||
int vector_size)
|
||||
: dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]},
|
||||
tile_sizes_{tile_sizes[0], tile_sizes[1], tile_sizes[2]},
|
||||
num_threads_x_(num_threads_x),
|
||||
num_threads_y_(num_threads_y),
|
||||
dilated_x_(is_dilated_x) {
|
||||
indexing_order_(indexing_order),
|
||||
vector_size_(vector_size) {
|
||||
CHECK_EQ(tile_sizes[1] % num_threads_y_, 0);
|
||||
CHECK_EQ(tile_sizes[2] % num_threads_x_, 0);
|
||||
VLOG(10) << "dims_in_elems_ = " << absl::StrJoin(dims_in_elems_, ",");
|
||||
if (!dilated_x_) {
|
||||
// dilated_x_=false is for the purpose of vectorization, which requires
|
||||
if (indexing_order != LinearIndexingX) {
|
||||
// StridedIndexingX, and LinearStridedIndexingX
|
||||
// is for the purpose of vectorization, which requires
|
||||
// GetTileSizeFor(DimX) to be a multiplier of num_threads_x_.
|
||||
CHECK_EQ(GetTileSizeFor(DimX) % num_threads_x_, 0);
|
||||
}
|
||||
|
@ -118,7 +132,8 @@ class KernelMappingScheme {
|
|||
return GetNumThreadsX() * GetNumThreadsY();
|
||||
}
|
||||
|
||||
bool DilatedX() const { return dilated_x_; }
|
||||
IndexingOrder GetIndexingOrder() const { return indexing_order_; }
|
||||
int GetVectorSize() const { return vector_size_; }
|
||||
|
||||
private:
|
||||
// The number of elements in each dimension.
|
||||
|
@ -133,12 +148,17 @@ class KernelMappingScheme {
|
|||
// Number of threads used to process elements in the Y direction of a tile.
|
||||
const int64 num_threads_y_;
|
||||
|
||||
// When num_threads_x threads process a total of tile_size_x elements in the
|
||||
// X dimension of a tile, each threads process n=tile_size_x/num_threads_x
|
||||
// elements. When dilated_x=false, the n elements processed by a thread are
|
||||
// contiguous. On the other hand, when dilated_x=true the n elements are
|
||||
// dilated by a factor of num_threads_x.
|
||||
const bool dilated_x_;
|
||||
// When num_threads_x threads process a total of tile_size_x
|
||||
// elements in the X dimension of a tile, each threads process
|
||||
// n=tile_size_x/num_threads_x elements.
|
||||
// indexing_order defines which tile's elements each thread reads.
|
||||
const IndexingOrder indexing_order_;
|
||||
|
||||
// vector_size_ only supported for row reduction and must be a divisor
|
||||
// of tile_sizes_[2]/num_threads_x. Interesting values are 2 and 4
|
||||
// to trigger vectorized loads on GPUs while keeping memory
|
||||
// coalescing.
|
||||
const int vector_size_;
|
||||
};
|
||||
|
||||
// Information to support the code generation for a tiled reduction kernel.
|
||||
|
@ -146,8 +166,15 @@ using AddressVector = absl::InlinedVector<llvm::AllocaInst*, 1>;
|
|||
class ReductionCodegenInfo {
|
||||
public:
|
||||
explicit ReductionCodegenInfo(KernelMappingScheme mapping_scheme,
|
||||
bool is_row_reduction)
|
||||
: mapping_scheme_(mapping_scheme), is_row_reduction_(is_row_reduction) {}
|
||||
int num_partial_results, bool is_row_reduction)
|
||||
: mapping_scheme_(mapping_scheme),
|
||||
num_partial_results_(num_partial_results),
|
||||
is_row_reduction_(is_row_reduction) {
|
||||
if (num_partial_results > 1) {
|
||||
CHECK_EQ(num_partial_results, (mapping_scheme.GetTileSizeX() /
|
||||
mapping_scheme.GetNumThreadsX()));
|
||||
}
|
||||
}
|
||||
|
||||
const KernelMappingScheme& GetKernelMappingScheme() const {
|
||||
return mapping_scheme_;
|
||||
|
@ -183,6 +210,7 @@ class ReductionCodegenInfo {
|
|||
return reduction_input_addresses_;
|
||||
}
|
||||
|
||||
int GetNumPartialResults() const { return num_partial_results_; }
|
||||
bool IsRowReduction() const { return is_row_reduction_; }
|
||||
|
||||
// Gets a pointer to a mutable shared cache used by reduction.
|
||||
|
@ -201,6 +229,7 @@ class ReductionCodegenInfo {
|
|||
const KernelMappingScheme mapping_scheme_;
|
||||
AddressVector partial_result_addresses_;
|
||||
AddressVector reduction_input_addresses_;
|
||||
int num_partial_results_;
|
||||
bool is_row_reduction_;
|
||||
};
|
||||
|
||||
|
|
|
@ -164,6 +164,33 @@ tf_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "reduction_vectorization_test",
|
||||
srcs = [
|
||||
"reduction_vectorization_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:gemm_rewriter",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "reduction_dimension_grouper_test",
|
||||
srcs = [
|
||||
|
|
|
@ -0,0 +1,299 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
class ReductionVectorizationTest : public GpuCodegenTest {};
|
||||
|
||||
TEST_F(ReductionVectorizationTest, Power2) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReducePower2
|
||||
|
||||
%max_ {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%param_0 = f32[5,131072] parameter(0)
|
||||
%constant.3 = f32[] constant(0)
|
||||
ROOT %reduce.8 = f32[5] reduce(f32[5,131072] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
se::StreamExecutor* executor = backend().default_stream_executor();
|
||||
int cc_major = 0, cc_minor = 0;
|
||||
executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
string expected_ptx;
|
||||
if (cc_major >= 6) {
|
||||
expected_ptx = R"(
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
)";
|
||||
} else {
|
||||
expected_ptx = R"(
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)";
|
||||
}
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ReductionVectorizationTest, TileFit) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReduceTileFit
|
||||
|
||||
%max_ {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%param_0 = f32[5,122880] parameter(0)
|
||||
%constant.3 = f32[] constant(0)
|
||||
ROOT %reduce.8 = f32[5] reduce(f32[5,122880] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
se::StreamExecutor* executor = backend().default_stream_executor();
|
||||
int cc_major = 0, cc_minor = 0;
|
||||
executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
string expected_ptx;
|
||||
if (cc_major >= 6) {
|
||||
expected_ptx = R"(
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
)";
|
||||
} else {
|
||||
expected_ptx = R"(
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)";
|
||||
}
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ReductionVectorizationTest, EvenColumns) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReducePower2
|
||||
|
||||
%max_ {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%param_0 = f32[5,131070] parameter(0)
|
||||
%constant.3 = f32[] constant(0)
|
||||
ROOT %reduce.8 = f32[5] reduce(f32[5,131070] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
se::StreamExecutor* executor = backend().default_stream_executor();
|
||||
int cc_major = 0, cc_minor = 0;
|
||||
executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
string expected_ptx;
|
||||
if (cc_major >= 7) {
|
||||
expected_ptx = R"(
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
// TODO: Make this a vectorized load
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)";
|
||||
} else {
|
||||
expected_ptx = R"(
|
||||
CHECK-NOT: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)";
|
||||
}
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ReductionVectorizationTest, DisableOddColumns) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ReduceTileFit
|
||||
|
||||
%max_ {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %maximum.7 = f32[] maximum(%x, %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%param_0 = f32[5,131071] parameter(0)
|
||||
%constant.3 = f32[] constant(0)
|
||||
ROOT %reduce.8 = f32[5] reduce(f32[5,131071] %param_0, f32[] %constant.3), dimensions={1}, to_apply=%max_
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module),
|
||||
R"(
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
CHECK-NOT: ld.global.nc.v4.f32
|
||||
CHECK-NOT: ld.global.nc.u64
|
||||
CHECK-NOT: ld.global.u64
|
||||
)");
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ReductionVectorizationTest, Exp) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule DisableSin
|
||||
|
||||
%add_float {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%arg0.1 = f32[5,131072] parameter(0)
|
||||
%sine = f32[5,131072] exponential(f32[5,131072] %arg0.1)
|
||||
%constant.0 = f32[] constant(0)
|
||||
ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
se::StreamExecutor* executor = backend().default_stream_executor();
|
||||
int cc_major = 0, cc_minor = 0;
|
||||
executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
string expected_ptx;
|
||||
if (cc_major >= 6) {
|
||||
expected_ptx = R"(
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.v2.f32
|
||||
)";
|
||||
} else {
|
||||
expected_ptx = R"(
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
CHECK: ld.global.nc.f32
|
||||
)";
|
||||
}
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx);
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ReductionVectorizationTest, DisableSin) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule DisableSin
|
||||
|
||||
%add_float {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add.17 = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %main {
|
||||
%arg0.1 = f32[5,131072] parameter(0)
|
||||
%sine = f32[5,131072] sine(f32[5,131072] %arg0.1)
|
||||
%constant.0 = f32[] constant(0)
|
||||
ROOT %reduce.18 = f32[5] reduce(f32[5,131072] %sine, f32[] %constant.0), dimensions={1}, to_apply=%add_float
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> optimized_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
CompileAndOptionallyVerifyPtx(std::move(optimized_module),
|
||||
R"(
|
||||
CHECK-NOT: ld.global.nc.v2.f32
|
||||
CHECK-NOT: ld.global.nc.v4.f32
|
||||
CHECK-NOT: ld.global.nc.u64
|
||||
CHECK-NOT: ld.global.u64
|
||||
)");
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
Loading…
Reference in New Issue