[XLA/GPU] [NFC] Minor refactoring of tiling emitter: use self-descriptive names instead of x-loc and y-loc
PiperOrigin-RevId: 294344795 Change-Id: Ia16093aad12b7dfb8ac11ab76580ac9ade98bfb2
This commit is contained in:
parent
61d19090c8
commit
ae6efa0b93
@ -1878,8 +1878,7 @@ bool MayPreventVectorization(const HloInstruction& hlo) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status IrEmitterUnnested::EmitTargetElementLoop(
|
Status IrEmitterUnnested::EmitTargetElementLoop(
|
||||||
const HloInstruction& hlo,
|
const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
|
||||||
const llvm_ir::ElementGenerator& element_generator) {
|
|
||||||
int unroll_factor = 1;
|
int unroll_factor = 1;
|
||||||
if (!MayPreventVectorization(hlo)) {
|
if (!MayPreventVectorization(hlo)) {
|
||||||
unroll_factor = ComputeMaxUnrollFactor(&hlo);
|
unroll_factor = ComputeMaxUnrollFactor(&hlo);
|
||||||
@ -1888,7 +1887,7 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
|
|||||||
std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
|
std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
|
||||||
&hlo, /*implements_whole_instruction=*/true, unroll_factor);
|
&hlo, /*implements_whole_instruction=*/true, unroll_factor);
|
||||||
Status emit_status =
|
Status emit_status =
|
||||||
EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
|
EmitTargetElementLoopInThunk(hlo, body_emitter, kernel_thunk.get());
|
||||||
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||||
|
|
||||||
return emit_status;
|
return emit_status;
|
||||||
@ -1912,8 +1911,8 @@ static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
|
|||||||
void IrEmitterUnnested::EmitTile(
|
void IrEmitterUnnested::EmitTile(
|
||||||
const KernelMappingScheme& mapping_scheme,
|
const KernelMappingScheme& mapping_scheme,
|
||||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
|
||||||
llvm::Value* thread_id_x, llvm::Value* tile_height, llvm::Value* tile_width,
|
llvm::Value* tile_height, llvm::Value* tile_width,
|
||||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||||
llvm::Type* index_ty = tile_width->getType();
|
llvm::Type* index_ty = tile_width->getType();
|
||||||
auto constant = [&](int64 val) {
|
auto constant = [&](int64 val) {
|
||||||
@ -1924,8 +1923,8 @@ void IrEmitterUnnested::EmitTile(
|
|||||||
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
int64 tile_size_x = mapping_scheme.GetTileSizeX();
|
||||||
|
|
||||||
int64 x_num_steps = tile_size_x / num_threads_x;
|
int64 x_num_steps = tile_size_x / num_threads_x;
|
||||||
llvm::Value* start_offset_x =
|
llvm::Value* start_offset_x = GetStartOffsetX(
|
||||||
GetStartOffsetX(mapping_scheme, thread_id_x, index_ty, &b_);
|
mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
|
||||||
|
|
||||||
// Using dilated mapping scheme, each thread steps with a stride of number
|
// Using dilated mapping scheme, each thread steps with a stride of number
|
||||||
// of threads.
|
// of threads.
|
||||||
@ -1962,22 +1961,25 @@ void IrEmitterUnnested::EmitTile(
|
|||||||
loop_name + "_y_in_tile",
|
loop_name + "_y_in_tile",
|
||||||
/*start=*/constant(0),
|
/*start=*/constant(0),
|
||||||
/*end=*/
|
/*end=*/
|
||||||
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_y), num_threads_y),
|
ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
|
||||||
|
num_threads_y),
|
||||||
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
/*step=*/constant(1), [&](llvm::Value* y_indvar) {
|
||||||
llvm::Value* y_loc =
|
llvm::Value* y_loc = b_.CreateAdd(
|
||||||
b_.CreateAdd(thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
|
||||||
for (int64 j = 0; j < x_num_steps; j++) {
|
for (int64 j = 0; j < x_num_steps; j++) {
|
||||||
llvm::Value* x_loc =
|
llvm::Value* x_loc =
|
||||||
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
b_.CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
|
||||||
IrArray::Index source_idx_x =
|
IrArray::Index source_idx_x =
|
||||||
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
source_idx.AddOffsetToDim(y_loc, kDimY, &b_)
|
||||||
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
.AddOffsetToDim(constant(j * step_x), kDimX, &b_);
|
||||||
|
auto emit_element = [&] {
|
||||||
|
return emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
||||||
|
};
|
||||||
if (!x_tile_fits) {
|
if (!x_tile_fits) {
|
||||||
ksl->If(loop_name + "_x_in_tile",
|
ksl->If(loop_name + "_x_in_tile",
|
||||||
b_.CreateICmpULT(x_loc, tile_width),
|
b_.CreateICmpULT(x_loc, tile_width), emit_element);
|
||||||
[&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
|
|
||||||
} else {
|
} else {
|
||||||
emit_elem_function(source_idx_x, y_loc, x_loc, j);
|
emit_element();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -1989,14 +1991,11 @@ void IrEmitterUnnested::EmitTile(
|
|||||||
// index: The index for the first output element in the normalized tensor. The
|
// index: The index for the first output element in the normalized tensor. The
|
||||||
// normalized tensor is the resulting tensor after collapsing contiguous
|
// normalized tensor is the resulting tensor after collapsing contiguous
|
||||||
// dimensions that play the same role in the transpose.
|
// dimensions that play the same role in the transpose.
|
||||||
// y_loc: The y coordinate within a tile.
|
|
||||||
// x_loc: The x coordinate within a tile.
|
|
||||||
// mapping_scheme: Kernel mapping scheme specifying the tiling
|
// mapping_scheme: Kernel mapping scheme specifying the tiling
|
||||||
void IrEmitterUnnested::EmitTileElementForCopy(
|
void IrEmitterUnnested::EmitTileElementForCopy(
|
||||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 /*x_iter_num*/,
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||||
absl::Span<llvm::Value* const> param_shmem_buffers) {
|
|
||||||
// TODO(jlebar): Add AA metadata to this load.
|
// TODO(jlebar): Add AA metadata to this load.
|
||||||
llvm::Instruction* load_from_shmem_buffer =
|
llvm::Instruction* load_from_shmem_buffer =
|
||||||
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
|
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
|
||||||
@ -2027,13 +2026,10 @@ static IrArray::Index GetUnnormalizedIndex(
|
|||||||
// is the resulting tensor after collapsing contiguous dimensions that play
|
// is the resulting tensor after collapsing contiguous dimensions that play
|
||||||
// the same role in the transpose.
|
// the same role in the transpose.
|
||||||
// kernel_info: Other information to support the kernel code generation.
|
// kernel_info: Other information to support the kernel code generation.
|
||||||
// y_loc: The y coordinate within a tile.
|
|
||||||
// x_loc: The x coordinate within a tile.
|
|
||||||
void IrEmitterUnnested::EmitTileElementForFusion(
|
void IrEmitterUnnested::EmitTileElementForFusion(
|
||||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 /*x_iter_num*/,
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||||
absl::Span<llvm::Value* const> param_shmem_buffers) {
|
|
||||||
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
|
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
|
||||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||||
GetNestedComputer());
|
GetNestedComputer());
|
||||||
@ -2545,8 +2541,7 @@ IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
|
|||||||
}();
|
}();
|
||||||
|
|
||||||
auto emit_tile = [&](const IrArray::Index& tile) {
|
auto emit_tile = [&](const IrArray::Index& tile) {
|
||||||
tile_element_generator(thread_id_info.thread_id_y,
|
tile_element_generator(thread_id_info, tile, "output",
|
||||||
thread_id_info.thread_id_x, tile, "output",
|
|
||||||
output_tile_bounds[1], output_tile_bounds[2], &ksl);
|
output_tile_bounds[1], output_tile_bounds[2], &ksl);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -2667,16 +2662,16 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
llvm::Value* x_loc, int64 x_iter_num) {
|
llvm::Value* x_loc, int64 x_iter_num) {
|
||||||
if (hlo->opcode() == HloOpcode::kCopy) {
|
if (hlo->opcode() == HloOpcode::kCopy) {
|
||||||
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
|
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||||
x_iter_num, param_shmem_buffers);
|
param_shmem_buffers);
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
||||||
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
|
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||||
x_iter_num, param_shmem_buffers);
|
param_shmem_buffers);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TileElementGenerator tile_generator =
|
TileElementGenerator tile_generator =
|
||||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
|
||||||
const string& loop_name, llvm::Value* tile_height,
|
const string& loop_name, llvm::Value* tile_height,
|
||||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||||
// If shared memory transpose is needed, wait for all threads to reach
|
// If shared memory transpose is needed, wait for all threads to reach
|
||||||
@ -2689,11 +2684,11 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
Permute({0, 2, 1}, index.dims()), index.GetType());
|
Permute({0, 2, 1}, index.dims()), index.GetType());
|
||||||
|
|
||||||
// Copy input parameter values to shared memory buffers:
|
// Copy input parameter values to shared memory buffers:
|
||||||
// tile[y, x] = input[index]
|
// tile[thread_id_y, thread_id_x] = input[index]
|
||||||
// Note that tile_width and tile_height are flipped here because we
|
// Note that tile_width and tile_height are flipped here because we
|
||||||
// are reading a transposed tile.
|
// are reading a transposed tile.
|
||||||
EmitTile(mapping_scheme, input_tile_origin, "input", ksl, y, x,
|
EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
|
||||||
tile_width, tile_height,
|
thread_id_info, tile_width, tile_height,
|
||||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||||
for (int64 id : tiled_param_ids) {
|
for (int64 id : tiled_param_ids) {
|
||||||
@ -2717,8 +2712,8 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
EmitSyncThreads();
|
EmitSyncThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
EmitTile(mapping_scheme, index, loop_name, ksl, y, x, tile_height,
|
EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
|
||||||
tile_width, element_generator);
|
tile_height, tile_width, element_generator);
|
||||||
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
|
||||||
|
|
||||||
// If a tile block contains multiple tiles and shared memory buffers are
|
// If a tile block contains multiple tiles and shared memory buffers are
|
||||||
@ -3133,11 +3128,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
|||||||
|
|
||||||
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
|
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
|
||||||
mapping_scheme, index_ty,
|
mapping_scheme, index_ty,
|
||||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
|
||||||
const string& loop_name, llvm::Value* tile_height,
|
const string& loop_name, llvm::Value* tile_height,
|
||||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||||
y, x, tile_height, tile_width, emit_reduction_tile);
|
thread_id_info, tile_height, tile_width, emit_reduction_tile);
|
||||||
});
|
});
|
||||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||||
reduce_instructions, reduction_output_shape_indices,
|
reduce_instructions, reduction_output_shape_indices,
|
||||||
|
@ -52,13 +52,25 @@ namespace gpu {
|
|||||||
class IrEmitterUnnested : public IrEmitter,
|
class IrEmitterUnnested : public IrEmitter,
|
||||||
private ThunkEmitter::EmissionContext {
|
private ThunkEmitter::EmissionContext {
|
||||||
public:
|
public:
|
||||||
|
struct ThreadIdInfo {
|
||||||
|
// Raw thread id.
|
||||||
|
llvm::Value* thread_id;
|
||||||
|
|
||||||
|
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
|
||||||
|
llvm::Value* thread_id_x;
|
||||||
|
|
||||||
|
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
|
||||||
|
llvm::Value* thread_id_y;
|
||||||
|
|
||||||
|
// Lane id: `thread_id % kWarpSize`
|
||||||
|
llvm::Value* lane_id;
|
||||||
|
};
|
||||||
|
|
||||||
// A function object to generate code to process one element in a tile.
|
// A function object to generate code to process one element in a tile.
|
||||||
//
|
//
|
||||||
// hlo: the instruction for which the code is generated for.
|
|
||||||
// index: the index for the first output element of the current thread.
|
// index: the index for the first output element of the current thread.
|
||||||
// y_loc: The y coordinate within a tile.
|
// y_loc: The y coordinate within a tile.
|
||||||
// x_loc: The x coordinate within a tile.
|
// x_loc: The x coordinate within a tile.
|
||||||
// kernel_info: Other information to support the kernel code generation.
|
|
||||||
// x_iter_num: When a thread process N elements in the X dimension, x_iter_num
|
// x_iter_num: When a thread process N elements in the X dimension, x_iter_num
|
||||||
// has a value of 0..N-1 to identify the element being process.
|
// has a value of 0..N-1 to identify the element being process.
|
||||||
using EmitElementFunction = std::function<void(
|
using EmitElementFunction = std::function<void(
|
||||||
@ -69,7 +81,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
|
|
||||||
// A function to generate the code to emit the entire tile.
|
// A function to generate the code to emit the entire tile.
|
||||||
using TileElementGenerator = std::function<void(
|
using TileElementGenerator = std::function<void(
|
||||||
llvm::Value* y, llvm::Value* x, const llvm_ir::IrArray::Index& index,
|
const ThreadIdInfo& thread_id_info, const llvm_ir::IrArray::Index& index,
|
||||||
const string& loop_name, llvm::Value* tile_height,
|
const string& loop_name, llvm::Value* tile_height,
|
||||||
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
|
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
|
||||||
|
|
||||||
@ -291,26 +303,27 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
void EmitTile(
|
void EmitTile(
|
||||||
const KernelMappingScheme& mapping_scheme,
|
const KernelMappingScheme& mapping_scheme,
|
||||||
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
|
const llvm_ir::IrArray::Index& tile_origin_index, const string& loop_name,
|
||||||
KernelSupportLibrary* ksl, llvm::Value* thread_id_y,
|
KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
|
||||||
llvm::Value* thread_id_x, llvm::Value* tile_height,
|
llvm::Value* tile_height, llvm::Value* tile_width,
|
||||||
llvm::Value* tile_width,
|
|
||||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
|
const IrEmitterUnnested::EmitElementFunction& emit_elem_function);
|
||||||
|
|
||||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||||
// that performs a 0-2-1 transpose.
|
// that performs a 0-2-1 transpose.
|
||||||
|
// y_loc: The y coordinate within a tile.
|
||||||
|
// x_loc: The x coordinate within a tile.
|
||||||
void EmitTileElementForCopy(
|
void EmitTileElementForCopy(
|
||||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 x_iter_num,
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||||
absl::Span<llvm::Value* const> param_shmem_buffers);
|
|
||||||
|
|
||||||
// Emits code to process a tensor element in a tile for the given kLoop
|
// Emits code to process a tensor element in a tile for the given kLoop
|
||||||
// fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
|
// fusion HLO containing parameters that are 0-2-1 transpose of its outputs.
|
||||||
|
// y_loc: The y coordinate within a tile.
|
||||||
|
// x_loc: The x coordinate within a tile.
|
||||||
void EmitTileElementForFusion(
|
void EmitTileElementForFusion(
|
||||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 x_iter_num,
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||||
absl::Span<llvm::Value* const> param_shmem_buffers);
|
|
||||||
|
|
||||||
// Emits code to process a tensor element in a tile for the given input hlo
|
// Emits code to process a tensor element in a tile for the given input hlo
|
||||||
// that is either a unnested kReduce or a kInput fusion.
|
// that is either a unnested kReduce or a kInput fusion.
|
||||||
@ -389,20 +402,6 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// Sets the return value range to [0, threads_per_block).
|
// Sets the return value range to [0, threads_per_block).
|
||||||
llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty);
|
llvm::Value* EmitThreadId(int64 threads_per_block, llvm::Type* index_ty);
|
||||||
|
|
||||||
struct ThreadIdInfo {
|
|
||||||
// Raw thread id.
|
|
||||||
llvm::Value* thread_id;
|
|
||||||
|
|
||||||
// X-coordinate calculated from thread id: `thread_id % num_threads_x`
|
|
||||||
llvm::Value* thread_id_x;
|
|
||||||
|
|
||||||
// Y-coordinate calculated from thread id: `thread_id / num_threads_x`
|
|
||||||
llvm::Value* thread_id_y;
|
|
||||||
|
|
||||||
// Lane id: `thread_id % kWarpSize`
|
|
||||||
llvm::Value* lane_id;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
|
// Emits the LLVM values for thread_id, thread_id.x, thread_id.y and lane
|
||||||
// id.
|
// id.
|
||||||
//
|
//
|
||||||
|
@ -161,7 +161,7 @@ Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
|
|||||||
// address-space-based AA in LLVM, it wouldn't help us much here.
|
// address-space-based AA in LLVM, it wouldn't help us much here.
|
||||||
return b_->CreateLoad(
|
return b_->CreateLoad(
|
||||||
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
|
b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
|
||||||
tile_param_x_, tile_param_y_}),
|
thread_id_x_, thread_id_y_}),
|
||||||
"tiled_buffer");
|
"tiled_buffer");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,13 +60,13 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
|||||||
|
|
||||||
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
|
FusedIrEmitter(GeneratorForOperandIrArrays operand_arrays_generator,
|
||||||
ElementalIrEmitter* elemental_emitter,
|
ElementalIrEmitter* elemental_emitter,
|
||||||
llvm::Value* tile_param_x = nullptr,
|
llvm::Value* thread_id_x = nullptr,
|
||||||
llvm::Value* tile_param_y = nullptr,
|
llvm::Value* thread_id_y = nullptr,
|
||||||
absl::Span<llvm::Value* const> param_shmem_buffers = {})
|
absl::Span<llvm::Value* const> param_shmem_buffers = {})
|
||||||
: operand_arrays_(),
|
: operand_arrays_(),
|
||||||
operand_arrays_generator_(std::move(operand_arrays_generator)),
|
operand_arrays_generator_(std::move(operand_arrays_generator)),
|
||||||
tile_param_x_(tile_param_x),
|
thread_id_x_(thread_id_x),
|
||||||
tile_param_y_(tile_param_y),
|
thread_id_y_(thread_id_y),
|
||||||
param_shmem_buffers_(param_shmem_buffers.begin(),
|
param_shmem_buffers_(param_shmem_buffers.begin(),
|
||||||
param_shmem_buffers.end()),
|
param_shmem_buffers.end()),
|
||||||
elemental_emitter_(elemental_emitter),
|
elemental_emitter_(elemental_emitter),
|
||||||
@ -121,10 +121,10 @@ class FusedIrEmitter : public ConstDfsHloVisitorWithDefault {
|
|||||||
GeneratorForOperandIrArrays operand_arrays_generator_;
|
GeneratorForOperandIrArrays operand_arrays_generator_;
|
||||||
|
|
||||||
// The x coordinate within a tile.
|
// The x coordinate within a tile.
|
||||||
llvm::Value* tile_param_x_;
|
llvm::Value* thread_id_x_;
|
||||||
|
|
||||||
// The y coordinate within a tile.
|
// The y coordinate within a tile.
|
||||||
llvm::Value* tile_param_y_;
|
llvm::Value* thread_id_y_;
|
||||||
|
|
||||||
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
|
// Param_buffers_[i] stores the tile buffer for the ith parameter or nullptr
|
||||||
// if the parameter is not tiled.
|
// if the parameter is not tiled.
|
||||||
|
Loading…
Reference in New Issue
Block a user