Merge pull request #36384 from nouiz:xla_reduction_branch
PiperOrigin-RevId: 293563326 Change-Id: I8ee92fc645e3c443c6f2ff04a9db89a33208df5d
This commit is contained in:
commit
af49800eda
@ -1908,23 +1908,47 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
|
|||||||
llvm::ConstantInt::get(index_ty, x_num_steps));
|
llvm::ConstantInt::get(index_ty, x_num_steps));
|
||||||
}
|
}
|
||||||
|
|
||||||
void IrEmitterUnnested::EmitTile(
|
// Emits code to process up to
|
||||||
|
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
|
||||||
|
// given `emit_elem_function` is the function to emit code to process one
|
||||||
|
// element, `y` and `x` are the intra-tile coordinates for the first element
|
||||||
|
// to process, and `index` is the index for the origin of the tile. Information
|
||||||
|
// about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
|
||||||
|
// bounds check to ensure that each processed element is within the boundary
|
||||||
|
// defined by `tile_width` and `tile_height`.
|
||||||
|
//
|
||||||
|
// Pseudocode:
|
||||||
|
//
|
||||||
|
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
|
||||||
|
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
|
||||||
|
// if (dilated) {
|
||||||
|
// x_loc = x + j * num_threads_x;
|
||||||
|
// } else {
|
||||||
|
// x_loc = x * (tile_size_x / num_threads_x) + j;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// if (x_loc < tile_width) {
|
||||||
|
// emit_elem_function(y + y_loc, x_loc);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
static void EmitTile(
|
||||||
const KernelMappingScheme& mapping_scheme,
|
const KernelMappingScheme& mapping_scheme,
|
||||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
|
||||||
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
|
llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||||
llvm::Type* index_ty = tile_width->getType();
|
llvm::Type* index_ty = tile_width->getType();
|
||||||
auto constant = [&](int64 val) {
|
auto constant = [&](int64 val) {
|
||||||
return llvm::ConstantInt::get(index_ty, val);
|
return llvm::ConstantInt::get(index_ty, val);
|
||||||
};
|
};
|
||||||
int64 num_threads_x = mapping_scheme.GetNumThreadsX();
|
int64 num_threads_x = mapping_scheme.GetNumThreadsX();
|
||||||
llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
|
int64 num_threads_y = mapping_scheme.GetNumThreadsY();
|
||||||
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
||||||
|
|
||||||
int64 x_num_steps = tile_size_x / num_threads_x;
|
int64 x_num_steps = tile_size_x / num_threads_x;
|
||||||
llvm::Value* start_offset_x =
|
llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
|
||||||
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
|
|
||||||
|
|
||||||
// Using dilated mapping scheme, each thread steps with a stride of number
|
// Using dilated mapping scheme, each thread steps with a stride of number
|
||||||
// of threads.
|
// of threads.
|
||||||
@ -1933,42 +1957,33 @@ void IrEmitterUnnested::EmitTile(
|
|||||||
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
|
int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
|
||||||
|
|
||||||
IrArray::Index source_idx =
|
IrArray::Index source_idx =
|
||||||
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
|
tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
|
||||||
|
|
||||||
auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) {
|
// True when all threads will always execute all instructions.
|
||||||
return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b);
|
// So we do not need to emit condition.
|
||||||
};
|
bool always_full_tile = mapping_scheme.GetDimsInElems()[2] % tile_size_x == 0;
|
||||||
|
|
||||||
// The outer loop below is simply doing:
|
ksl->For(loop_name + "_y_in_tile",
|
||||||
//
|
/*start=*/y,
|
||||||
// for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y)
|
/*end=*/tile_height,
|
||||||
//
|
/*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
|
||||||
//
|
IrArray::Index source_idx_y =
|
||||||
// However, in order to avoid an LLVM optimization triggering the ptxas bug,
|
source_idx.AddOffsetToDim(y_loc, kDimY, b);
|
||||||
// we write this loop in a convoluted way:
|
|
||||||
//
|
|
||||||
// y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
|
|
||||||
// for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
|
|
||||||
// y_loc = thread_id_y + y_indvar * num_threads_y
|
|
||||||
//
|
|
||||||
// TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
|
|
||||||
// workaround.
|
|
||||||
ksl->For(
|
|
||||||
loop_name + "_y_in_tile",
|
|
||||||
/*start=*/constant(0),
|
|
||||||
/*end=*/
|
|
||||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
|
|
||||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
|
||||||
llvm::Value* y_loc =
|
|
||||||
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
|
||||||
for (int64 j = 0; j < x_num_steps; j++) {
|
for (int64 j = 0; j < x_num_steps; j++) {
|
||||||
llvm::Value* x_loc =
|
llvm::Value* x_loc =
|
||||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||||
IrArray::Index source_idx_x =
|
IrArray::Index source_idx_x =
|
||||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
|
||||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
// The if-statement below always evaluates to true for the blocks
|
||||||
ksl->If(loop_name + "_x_in_tile", b_.CreateICmpULT(x_loc, tile_width),
|
// where the entire processed tile fits within the input buffer.
|
||||||
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
|
if (!always_full_tile) {
|
||||||
|
ksl->If(loop_name + "_x_in_tile",
|
||||||
|
b->CreateICmpULT(x_loc, tile_width), [&] {
|
||||||
|
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -2608,7 +2623,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
// tile[y, x] = input[index]
|
// tile[y, x] = input[index]
|
||||||
// Note that tile_width and tile_height are flipped here because we
|
// Note that tile_width and tile_height are flipped here because we
|
||||||
// are reading a transposed tile.
|
// are reading a transposed tile.
|
||||||
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
|
xla::gpu::EmitTile(
|
||||||
|
mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
|
||||||
tile_width, tile_height,
|
tile_width, tile_height,
|
||||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||||
@ -2617,8 +2633,7 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
param_in_reduced_shape_arrays[id];
|
param_in_reduced_shape_arrays[id];
|
||||||
|
|
||||||
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||||
llvm::Value* zero =
|
llvm::Value* zero = llvm::ConstantInt::get(index_type, 0);
|
||||||
llvm::ConstantInt::get(index_type, 0);
|
|
||||||
// TODO(jlebar): Add AA metadata to this store. Tile
|
// TODO(jlebar): Add AA metadata to this store. Tile
|
||||||
// buffers are global variables, so LLVM can't infer much
|
// buffers are global variables, so LLVM can't infer much
|
||||||
// about it.
|
// about it.
|
||||||
@ -2633,8 +2648,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
EmitSyncThreads();
|
EmitSyncThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
|
xla::gpu::EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x,
|
||||||
tile_width, element_generator);
|
tile_height, tile_width, element_generator);
|
||||||
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
||||||
|
|
||||||
// If a tile block contains multiple tiles and shared memory buffers are
|
// If a tile block contains multiple tiles and shared memory buffers are
|
||||||
@ -3059,8 +3074,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
|||||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||||
const string& loop_name, llvm::Value* tile_height,
|
const string& loop_name, llvm::Value* tile_height,
|
||||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
xla::gpu::EmitTile(reduction_info.GetKernelMappingScheme(), index,
|
||||||
y, x, tile_height, tile_width, emit_reduction_tile);
|
loop_name, ksl, &b_, y, x, tile_height, tile_width,
|
||||||
|
emit_reduction_tile);
|
||||||
});
|
});
|
||||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||||
reduce_instructions, reduction_output_shape_indices,
|
reduce_instructions, reduction_output_shape_indices,
|
||||||
|
|||||||
@ -235,39 +235,6 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
|
const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
|
||||||
const TileElementGenerator& tile_element_generator);
|
const TileElementGenerator& tile_element_generator);
|
||||||
|
|
||||||
// Emits code to process up to
|
|
||||||
// (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
|
|
||||||
// given `emit_elem_function` is the function to emit code to process one
|
|
||||||
// element, `thread_id_y` and `thread_id_x` are the intra-tile coordinates for
|
|
||||||
// the first element to process, and `index` is the index for the origin of
|
|
||||||
// the tile. Information about tile_size_x/y and num_threads_x/y are stored in
|
|
||||||
// `mapping_scheme`. Emits bounds check to ensure that each processed element
|
|
||||||
// is within the boundary defined by `tile_width` and `tile_height`.
|
|
||||||
//
|
|
||||||
// Pseudocode:
|
|
||||||
//
|
|
||||||
// for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
|
|
||||||
// for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
|
|
||||||
// if (dilated) {
|
|
||||||
// x_loc = x + j * num_threads_x;
|
|
||||||
// } else {
|
|
||||||
// x_loc = x * (tile_size_x / num_threads_x) + j;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// if (x_loc < tile_width) {
|
|
||||||
// emit_elem_function(y + y_loc, x_loc);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
void EmitTile(
|
|
||||||
const KernelMappingScheme& mapping_scheme,
|
|
||||||
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
|
|
||||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
|
||||||
llvm::Value* thread_id_x, llvm::Value* tile_height,
|
|
||||||
llvm::Value* tile_width,
|
|
||||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
|
|
||||||
|
|
||||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||||
// that performs a 0-2-1 transpose.
|
// that performs a 0-2-1 transpose.
|
||||||
void EmitTileElementForCopy(
|
void EmitTileElementForCopy(
|
||||||
|
|||||||
@ -60,11 +60,15 @@ Status KernelSupportLibrary::IfWithStatus(
|
|||||||
absl::string_view name, llvm::Value* condition,
|
absl::string_view name, llvm::Value* condition,
|
||||||
const std::function<Status()>& true_block_generator,
|
const std::function<Status()>& true_block_generator,
|
||||||
const std::function<Status()>& false_block_generator) {
|
const std::function<Status()>& false_block_generator) {
|
||||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_);
|
llvm_ir::LlvmIfData if_data =
|
||||||
|
llvm_ir::EmitIfThenElse(condition, name, b_,
|
||||||
|
/*emit_else=*/false_block_generator != nullptr);
|
||||||
b_->SetInsertPoint(&if_data.true_block->back());
|
b_->SetInsertPoint(&if_data.true_block->back());
|
||||||
TF_RETURN_IF_ERROR(true_block_generator());
|
TF_RETURN_IF_ERROR(true_block_generator());
|
||||||
|
if (false_block_generator != nullptr) {
|
||||||
b_->SetInsertPoint(&if_data.false_block->back());
|
b_->SetInsertPoint(&if_data.false_block->back());
|
||||||
TF_RETURN_IF_ERROR(false_block_generator());
|
TF_RETURN_IF_ERROR(false_block_generator());
|
||||||
|
}
|
||||||
llvm_ir::SetToLastInsertPoint(if_data.after_block, b_);
|
llvm_ir::SetToLastInsertPoint(if_data.after_block, b_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -203,12 +203,11 @@ class KernelSupportLibrary {
|
|||||||
// `true_block_generator()`;
|
// `true_block_generator()`;
|
||||||
// else
|
// else
|
||||||
// `false_block_generator()`;
|
// `false_block_generator()`;
|
||||||
|
// The else is skipped if false_block_generator is null.
|
||||||
Status IfWithStatus(
|
Status IfWithStatus(
|
||||||
absl::string_view name, llvm::Value* condition,
|
absl::string_view name, llvm::Value* condition,
|
||||||
const std::function<Status()>& true_block_generator,
|
const std::function<Status()>& true_block_generator,
|
||||||
const std::function<Status()>& false_block_generator = []() -> Status {
|
const std::function<Status()>& false_block_generator = nullptr);
|
||||||
return Status::OK();
|
|
||||||
});
|
|
||||||
|
|
||||||
Status IfWithStatus(
|
Status IfWithStatus(
|
||||||
llvm::Value* condition,
|
llvm::Value* condition,
|
||||||
@ -220,16 +219,16 @@ class KernelSupportLibrary {
|
|||||||
false_block_generator);
|
false_block_generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
void If(
|
void If(llvm::Value* condition,
|
||||||
llvm::Value* condition, const std::function<void()>& true_block_generator,
|
const std::function<void()>& true_block_generator,
|
||||||
const std::function<void()>& false_block_generator = []() {}) {
|
const std::function<void()>& false_block_generator = nullptr) {
|
||||||
If("", condition, true_block_generator, false_block_generator);
|
If("", condition, true_block_generator, false_block_generator);
|
||||||
}
|
}
|
||||||
|
|
||||||
void If(
|
void If(absl::string_view name, llvm::Value* condition,
|
||||||
absl::string_view name, llvm::Value* condition,
|
|
||||||
const std::function<void()>& true_block_generator,
|
const std::function<void()>& true_block_generator,
|
||||||
const std::function<void()>& false_block_generator = []() {}) {
|
const std::function<void()>& false_block_generator = nullptr) {
|
||||||
|
if (false_block_generator != nullptr) {
|
||||||
TF_CHECK_OK(IfWithStatus(
|
TF_CHECK_OK(IfWithStatus(
|
||||||
name, condition,
|
name, condition,
|
||||||
[&]() {
|
[&]() {
|
||||||
@ -240,6 +239,12 @@ class KernelSupportLibrary {
|
|||||||
false_block_generator();
|
false_block_generator();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
|
} else {
|
||||||
|
TF_CHECK_OK(IfWithStatus(name, condition, [&]() {
|
||||||
|
true_block_generator();
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
using ArgumentVector = absl::Span<llvm::Value* const>;
|
using ArgumentVector = absl::Span<llvm::Value* const>;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user