[XLA/GPU] [NFC] Minor refactoring of tiling emitter: use self-descriptive names instead of x-loc and y-loc

PiperOrigin-RevId: 294344795
Change-Id: Ia16093aad12b7dfb8ac11ab76580ac9ade98bfb2
This commit is contained in:
George Karpenkov 2020-02-10 17:43:19 -08:00 committed by TensorFlower Gardener
parent 61d19090c8
commit ae6efa0b93
4 changed files with 58 additions and 64 deletions

View File

@ -1878,8 +1878,7 @@ bool MayPreventVectorization(const HloInstruction& hlo) {
} // namespace
Status IrEmitterUnnested::EmitTargetElementLoop(
const HloInstruction& hlo,
const llvm_ir::ElementGenerator& element_generator) {
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
int unroll_factor = 1;
if (!MayPreventVectorization(hlo)) {
unroll_factor = ComputeMaxUnrollFactor(&hlo);
@ -1888,7 +1887,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
&hlo, /*implements_whole_instruction=*/true, unroll_factor);
Status emit_status =
EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
EmitTargetElementLoopInThunk(hlo, body_emitter, kernel_thunk.get());
thunk_sequence_->emplace_back(std::move(kernel_thunk));
return emit_status;
@ -1912,8 +1911,8 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
void IrEmitterUnnested::EmitTile(
const KernelMappingScheme& mapping_scheme,
const IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
llvm::Value* tile_height, llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
llvm::Type* index_ty = tile_width->getType();
auto constant = [&](int64 val) {
@ -1924,8 +1923,8 @@ void IrEmitterUnnested::EmitTile(
int64 tile_size_x = mapping_scheme.GetTileSizeX();
int64 x_num_steps = tile_size_x / num_threads_x;
llvm::Value* start_offset_x =
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
llvm::Value* start_offset_x = GetStartOffsetX(
mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
// Using dilated mapping scheme, each thread steps with a stride of number
// of threads.
@ -1962,22 +1961,25 @@ void IrEmitterUnnested::EmitTile(
loop_name + "_y_in_tile",
/*start=*/constant(0),
/*end=*/
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
num_threads_y),
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
llvm::Value* y_loc =
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
llvm::Value* y_loc = b_.CreateAdd(
thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
for (int64 j = 0; j < x_num_steps; j++) {
llvm::Value* x_loc =
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
IrArray::Index source_idx_x =
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
auto emit_element = [&] {
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
};
if (!x_tile_fits) {
ksl->If(loop_name + "_x_in_tile",
b_.CreateICmpULT(x_loc, tile_width),
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
b_.CreateICmpULT(x_loc, tile_width), emit_element);
} else {
emit_elem_function(source_idx_x, y_loc, x_loc, j);
emit_element();
}
}
});
@ -1989,14 +1991,11 @@ void IrEmitterUnnested::EmitTile(
// index: The index for the first output element in the normalized tensor. The
// normalized tensor is the resulting tensor after collapsing contiguous
// dimensions that play the same role in the transpose.
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
// mapping_scheme: Kernel mapping scheme specifying the tiling
void IrEmitterUnnested::EmitTileElementForCopy(
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/,
absl::Span<llvm::Value* const> param_shmem_buffers) {
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
// TODO(jlebar): Add AA metadata to this load.
llvm::Instruction* load_from_shmem_buffer =
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
@ -2027,13 +2026,10 @@ static IrArray::Index GetUnnormalizedIndex(
// is the resulting tensor after collapsing contiguous dimensions that play
// the same role in the transpose.
// kernel_info: Other information to support the kernel code generation.
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
void IrEmitterUnnested::EmitTileElementForFusion(
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/,
absl::Span<llvm::Value* const> param_shmem_buffers) {
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
@ -2545,8 +2541,7 @@ IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
}();
auto emit_tile = [&](const IrArray::Index& tile) {
tile_element_generator(thread_id_info.thread_id_y,
thread_id_info.thread_id_x, tile, "output",
tile_element_generator(thread_id_info, tile, "output",
output_tile_bounds[1], output_tile_bounds[2], &ksl);
};
@ -2667,16 +2662,16 @@ void IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* x_loc, int64 x_iter_num) {
if (hlo->opcode() == HloOpcode::kCopy) {
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
x_iter_num, param_shmem_buffers);
param_shmem_buffers);
} else {
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
x_iter_num, param_shmem_buffers);
param_shmem_buffers);
}
};
TileElementGenerator tile_generator =
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
// If shared memory transpose is needed, wait for all threads to reach
@ -2689,11 +2684,11 @@ void IrEmitterUnnested::EmitHlo021Tile(
Permute({0, 2, 1}, index.dims()), index.GetType());
// Copy input parameter values to shared memory buffers:
// tile[y, x] = input[index]
// tile[thread_id_y, thread_id_x] = input[index]
// Note that tile_width and tile_height are flipped here because we
// are reading a transposed tile.
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
tile_width, tile_height,
EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
thread_id_info, tile_width, tile_height,
[&](const IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 /*x_iter_num*/) {
for (int64 id : tiled_param_ids) {
@ -2717,8 +2712,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
EmitSyncThreads();
}
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
tile_width, element_generator);
EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
tile_height, tile_width, element_generator);
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
// If a tile block contains multiple tiles and shared memory buffers are
@ -3133,11 +3128,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
mapping_scheme, index_ty,
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
y, x, tile_height, tile_width, emit_reduction_tile);
thread_id_info, tile_height, tile_width, emit_reduction_tile);
});
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
reduce_instructions, reduction_output_shape_indices,

View File

@ -52,13 +52,25 @@ namespace gpu {
class IrEmitterUnnested : public IrEmitter,
private ThunkEmitter::EmissionContext {
public:
struct ThreadIdInfo {
// Raw thread id.
llvm::Value* thread_id;
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
llvm::Value* thread_id_x;
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
llvm::Value* thread_id_y;
// Lane id: `thread_id % kWarpSize`
llvm::Value* lane_id;
};
// A function object to generate code to process one element in a tile.
//
// hlo: the instruction for which the code is generated for.
// index: the index for the first output element of the current thread.
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
// kernel_info: Other information to support the kernel code generation.
// x_iter_num: When a thread process N elements in the X dimension, x_iter_num
// has a value of 0..N-1 to identify the element being process.
using EmitElementFunction = std::function<void(
@ -69,7 +81,7 @@ class IrEmitterUnnested : public IrEmitter,
// A function to generate the code to emit the entire tile.
using TileElementGenerator = std::function<void(
llvm::Value* y, llvm::Value* x, const llvm_ir::IrArray::Index& index,
const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
@ -291,26 +303,27 @@ class IrEmitterUnnested : public IrEmitter,
void EmitTile(
const KernelMappingScheme& mapping_scheme,
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
llvm::Value* thread_id_x, llvm::Value* tile_height,
llvm::Value* tile_width,
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
llvm::Value* tile_height, llvm::Value* tile_width,
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
// Emits code to process a tensor element in a tile for the given kCopy HLO
// that performs a 0-2-1 transpose.
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
void EmitTileElementForCopy(
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num,
absl::Span<llvm::Value* const> param_shmem_buffers);
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
// Emits code to process a tensor element in a tile for the given kLoop
// fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
void EmitTileElementForFusion(
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num,
absl::Span<llvm::Value* const> param_shmem_buffers);
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
// Emits code to process a tensor element in a tile for the given input hlo
// that is either a unnested kReduce or a kInput fusion.
@ -389,20 +402,6 @@ class IrEmitterUnnested : public IrEmitter,
// Sets the return value range to [0, threads_per_block).
llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty);
struct ThreadIdInfo {
// Raw thread id.
llvm::Value* thread_id;
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
llvm::Value* thread_id_x;
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
llvm::Value* thread_id_y;
// Lane id: `thread_id % kWarpSize`
llvm::Value* lane_id;
};
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
// id.
//

View File

@ -161,7 +161,7 @@ Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
// address-space-based AA in LLVM, it wouldn't help us much here.
return b_->CreateLoad(
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
tile_param_x_, tile_param_y_}),
thread_id_x_, thread_id_y_}),
"tiled_buffer");
}
}

View File

@ -60,13 +60,13 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
ElementalIrEmitter* elemental_emitter,
llvm::Value* tile_param_x = nullptr,
llvm::Value* tile_param_y = nullptr,
llvm::Value* thread_id_x = nullptr,
llvm::Value* thread_id_y = nullptr,
absl::Span<llvm::Value* const> param_shmem_buffers = {})
: operand_arrays_(),
operand_arrays_generator_(std::move(operand_arrays_generator)),
tile_param_x_(tile_param_x),
tile_param_y_(tile_param_y),
thread_id_x_(thread_id_x),
thread_id_y_(thread_id_y),
param_shmem_buffers_(param_shmem_buffers.begin(),
param_shmem_buffers.end()),
elemental_emitter_(elemental_emitter),
@ -121,10 +121,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
GeneratorForOperandIrArrays operand_arrays_generator_;
// The x coordinate within a tile.
llvm::Value* tile_param_x_;
llvm::Value* thread_id_x_;
// The y coordinate within a tile.
llvm::Value* tile_param_y_;
llvm::Value* thread_id_y_;
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
// if the parameter is not tiled.