[XLA/GPU] [NFC] Minor refactoring of tiling emitter: use self-descriptive names instead of x-loc and y-loc
PiperOrigin-RevId: 294344795 Change-Id: Ia16093aad12b7dfb8ac11ab76580ac9ade98bfb2
This commit is contained in:
parent
61d19090c8
commit
ae6efa0b93
@ -1878,8 +1878,7 @@ bool MayPreventVectorization(const HloInstruction& hlo) {
|
||||
} // namespace
|
||||
|
||||
Status IrEmitterUnnested::EmitTargetElementLoop(
|
||||
const HloInstruction& hlo,
|
||||
const llvm_ir::ElementGenerator& element_generator) {
|
||||
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
|
||||
int unroll_factor = 1;
|
||||
if (!MayPreventVectorization(hlo)) {
|
||||
unroll_factor = ComputeMaxUnrollFactor(&hlo);
|
||||
@ -1888,7 +1887,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
|
||||
std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
|
||||
&hlo, /*implements_whole_instruction=*/true, unroll_factor);
|
||||
Status emit_status =
|
||||
EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
|
||||
EmitTargetElementLoopInThunk(hlo, body_emitter, kernel_thunk.get());
|
||||
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||
|
||||
return emit_status;
|
||||
@ -1912,8 +1911,8 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
|
||||
void IrEmitterUnnested::EmitTile(
|
||||
const KernelMappingScheme& mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
||||
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
|
||||
llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||
llvm::Type* index_ty = tile_width->getType();
|
||||
auto constant = [&](int64 val) {
|
||||
@ -1924,8 +1923,8 @@ void IrEmitterUnnested::EmitTile(
|
||||
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
||||
|
||||
int64 x_num_steps = tile_size_x / num_threads_x;
|
||||
llvm::Value* start_offset_x =
|
||||
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
|
||||
llvm::Value* start_offset_x = GetStartOffsetX(
|
||||
mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
|
||||
|
||||
// Using dilated mapping scheme, each thread steps with a stride of number
|
||||
// of threads.
|
||||
@ -1962,22 +1961,25 @@ void IrEmitterUnnested::EmitTile(
|
||||
loop_name + "_y_in_tile",
|
||||
/*start=*/constant(0),
|
||||
/*end=*/
|
||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
|
||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
|
||||
num_threads_y),
|
||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
||||
llvm::Value* y_loc =
|
||||
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
||||
llvm::Value* y_loc = b_.CreateAdd(
|
||||
thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
||||
for (int64 j = 0; j < x_num_steps; j++) {
|
||||
llvm::Value* x_loc =
|
||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||
IrArray::Index source_idx_x =
|
||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||
auto emit_element = [&] {
|
||||
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
};
|
||||
if (!x_tile_fits) {
|
||||
ksl->If(loop_name + "_x_in_tile",
|
||||
b_.CreateICmpULT(x_loc, tile_width),
|
||||
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
|
||||
b_.CreateICmpULT(x_loc, tile_width), emit_element);
|
||||
} else {
|
||||
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||
emit_element();
|
||||
}
|
||||
}
|
||||
});
|
||||
@ -1989,14 +1991,11 @@ void IrEmitterUnnested::EmitTile(
|
||||
// index: The index for the first output element in the normalized tensor. The
|
||||
// normalized tensor is the resulting tensor after collapsing contiguous
|
||||
// dimensions that play the same role in the transpose.
|
||||
// y_loc: The y coordinate within a tile.
|
||||
// x_loc: The x coordinate within a tile.
|
||||
// mapping_scheme: Kernel mapping scheme specifying the tiling
|
||||
void IrEmitterUnnested::EmitTileElementForCopy(
|
||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||
// TODO(jlebar): Add AA metadata to this load.
|
||||
llvm::Instruction* load_from_shmem_buffer =
|
||||
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
|
||||
@ -2027,13 +2026,10 @@ static IrArray::Index GetUnnormalizedIndex(
|
||||
// is the resulting tensor after collapsing contiguous dimensions that play
|
||||
// the same role in the transpose.
|
||||
// kernel_info: Other information to support the kernel code generation.
|
||||
// y_loc: The y coordinate within a tile.
|
||||
// x_loc: The x coordinate within a tile.
|
||||
void IrEmitterUnnested::EmitTileElementForFusion(
|
||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
|
||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||
GetNestedComputer());
|
||||
@ -2545,8 +2541,7 @@ IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
|
||||
}();
|
||||
|
||||
auto emit_tile = [&](const IrArray::Index& tile) {
|
||||
tile_element_generator(thread_id_info.thread_id_y,
|
||||
thread_id_info.thread_id_x, tile, "output",
|
||||
tile_element_generator(thread_id_info, tile, "output",
|
||||
output_tile_bounds[1], output_tile_bounds[2], &ksl);
|
||||
};
|
||||
|
||||
@ -2667,16 +2662,16 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
if (hlo->opcode() == HloOpcode::kCopy) {
|
||||
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||
x_iter_num, param_shmem_buffers);
|
||||
param_shmem_buffers);
|
||||
} else {
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
||||
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||
x_iter_num, param_shmem_buffers);
|
||||
param_shmem_buffers);
|
||||
}
|
||||
};
|
||||
|
||||
TileElementGenerator tile_generator =
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
// If shared memory transpose is needed, wait for all threads to reach
|
||||
@ -2689,11 +2684,11 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
Permute({0, 2, 1}, index.dims()), index.GetType());
|
||||
|
||||
// Copy input parameter values to shared memory buffers:
|
||||
// tile[y, x] = input[index]
|
||||
// tile[thread_id_y, thread_id_x] = input[index]
|
||||
// Note that tile_width and tile_height are flipped here because we
|
||||
// are reading a transposed tile.
|
||||
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
|
||||
tile_width, tile_height,
|
||||
EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
|
||||
thread_id_info, tile_width, tile_height,
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||
for (int64 id : tiled_param_ids) {
|
||||
@ -2717,8 +2712,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
EmitSyncThreads();
|
||||
}
|
||||
|
||||
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
|
||||
tile_width, element_generator);
|
||||
EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
|
||||
tile_height, tile_width, element_generator);
|
||||
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
||||
|
||||
// If a tile block contains multiple tiles and shared memory buffers are
|
||||
@ -3133,11 +3128,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
|
||||
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
|
||||
mapping_scheme, index_ty,
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||
y, x, tile_height, tile_width, emit_reduction_tile);
|
||||
thread_id_info, tile_height, tile_width, emit_reduction_tile);
|
||||
});
|
||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||
reduce_instructions, reduction_output_shape_indices,
|
||||
|
@ -52,13 +52,25 @@ namespace gpu {
|
||||
class IrEmitterUnnested : public IrEmitter,
|
||||
private ThunkEmitter::EmissionContext {
|
||||
public:
|
||||
struct ThreadIdInfo {
|
||||
// Raw thread id.
|
||||
llvm::Value* thread_id;
|
||||
|
||||
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
|
||||
llvm::Value* thread_id_x;
|
||||
|
||||
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
|
||||
llvm::Value* thread_id_y;
|
||||
|
||||
// Lane id: `thread_id % kWarpSize`
|
||||
llvm::Value* lane_id;
|
||||
};
|
||||
|
||||
// A function object to generate code to process one element in a tile.
|
||||
//
|
||||
// hlo: the instruction for which the code is generated for.
|
||||
// index: the index for the first output element of the current thread.
|
||||
// y_loc: The y coordinate within a tile.
|
||||
// x_loc: The x coordinate within a tile.
|
||||
// kernel_info: Other information to support the kernel code generation.
|
||||
// x_iter_num: When a thread process N elements in the X dimension, x_iter_num
|
||||
// has a value of 0..N-1 to identify the element being process.
|
||||
using EmitElementFunction = std::function<void(
|
||||
@ -69,7 +81,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
|
||||
// A function to generate the code to emit the entire tile.
|
||||
using TileElementGenerator = std::function<void(
|
||||
llvm::Value* y, llvm::Value* x, const llvm_ir::IrArray::Index& index,
|
||||
const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
|
||||
|
||||
@ -291,26 +303,27 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
void EmitTile(
|
||||
const KernelMappingScheme& mapping_scheme,
|
||||
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
||||
llvm::Value* thread_id_x, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width,
|
||||
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
|
||||
llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||
// that performs a 0-2-1 transpose.
|
||||
// y_loc: The y coordinate within a tile.
|
||||
// x_loc: The x coordinate within a tile.
|
||||
void EmitTileElementForCopy(
|
||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kLoop
|
||||
// fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
|
||||
// y_loc: The y coordinate within a tile.
|
||||
// x_loc: The x coordinate within a tile.
|
||||
void EmitTileElementForFusion(
|
||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given input hlo
|
||||
// that is either a unnested kReduce or a kInput fusion.
|
||||
@ -389,20 +402,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Sets the return value range to [0, threads_per_block).
|
||||
llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty);
|
||||
|
||||
struct ThreadIdInfo {
|
||||
// Raw thread id.
|
||||
llvm::Value* thread_id;
|
||||
|
||||
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
|
||||
llvm::Value* thread_id_x;
|
||||
|
||||
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
|
||||
llvm::Value* thread_id_y;
|
||||
|
||||
// Lane id: `thread_id % kWarpSize`
|
||||
llvm::Value* lane_id;
|
||||
};
|
||||
|
||||
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
|
||||
// id.
|
||||
//
|
||||
|
@ -161,7 +161,7 @@ Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
|
||||
// address-space-based AA in LLVM, it wouldn't help us much here.
|
||||
return b_->CreateLoad(
|
||||
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
|
||||
tile_param_x_, tile_param_y_}),
|
||||
thread_id_x_, thread_id_y_}),
|
||||
"tiled_buffer");
|
||||
}
|
||||
}
|
||||
|
@ -60,13 +60,13 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
|
||||
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
|
||||
ElementalIrEmitter* elemental_emitter,
|
||||
llvm::Value* tile_param_x = nullptr,
|
||||
llvm::Value* tile_param_y = nullptr,
|
||||
llvm::Value* thread_id_x = nullptr,
|
||||
llvm::Value* thread_id_y = nullptr,
|
||||
absl::Span<llvm::Value* const> param_shmem_buffers = {})
|
||||
: operand_arrays_(),
|
||||
operand_arrays_generator_(std::move(operand_arrays_generator)),
|
||||
tile_param_x_(tile_param_x),
|
||||
tile_param_y_(tile_param_y),
|
||||
thread_id_x_(thread_id_x),
|
||||
thread_id_y_(thread_id_y),
|
||||
param_shmem_buffers_(param_shmem_buffers.begin(),
|
||||
param_shmem_buffers.end()),
|
||||
elemental_emitter_(elemental_emitter),
|
||||
@ -121,10 +121,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
||||
GeneratorForOperandIrArrays operand_arrays_generator_;
|
||||
|
||||
// The x coordinate within a tile.
|
||||
llvm::Value* tile_param_x_;
|
||||
llvm::Value* thread_id_x_;
|
||||
|
||||
// The y coordinate within a tile.
|
||||
llvm::Value* tile_param_y_;
|
||||
llvm::Value* thread_id_y_;
|
||||
|
||||
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
|
||||
// if the parameter is not tiled.
|
||||
|
Loading…
Reference in New Issue
Block a user