[XLA GPU] [NFC] Simplify IrEmitterUnnested::EmitKernel function
The EmitKernel function is very complex, and a large amount of the complexity is brought by the machinery required for the 021 shared memory transposition. However, 021 transposition is only used by the EmitHlo021Tile user, and not by the reduction emitter. This CL achieves considerate logic simplification by moving the required machinery into the callback passed by EmitHlo021Tile, thus making EmitKernel simpler. PiperOrigin-RevId: 259452012
This commit is contained in:
parent
745b24b21d
commit
df6ba21e45
@ -98,10 +98,6 @@ namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
using llvm_ir::KernelMappingScheme;
|
||||
using EmitElementFunction =
|
||||
std::function<void(const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num)>;
|
||||
|
||||
namespace {
|
||||
|
||||
using absl::InlinedVector;
|
||||
@ -2200,41 +2196,6 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
|
||||
return emit_status;
|
||||
}
|
||||
|
||||
std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
|
||||
const HloInstruction& hlo) {
|
||||
std::vector<IrArray> param_arrays;
|
||||
param_arrays.reserve(hlo.operands().size());
|
||||
for (const HloInstruction* param : hlo.operands()) {
|
||||
param_arrays.push_back(GetIrArray(*param, hlo));
|
||||
}
|
||||
return param_arrays;
|
||||
}
|
||||
|
||||
int IrEmitterUnnested::ConstructInputReducedShapeAndCastInputIrArrayToShape(
|
||||
const HloInstruction& hlo, const std::vector<IrArray>& param_arrays,
|
||||
const std::vector<llvm::Value*>& param_buffers,
|
||||
absl::Span<const int64> reduced_output_dims,
|
||||
std::vector<Shape>* param_reduced_shapes,
|
||||
std::vector<IrArray>* param_in_reduced_shape_arrays) {
|
||||
int64 num_params = hlo.operands().size();
|
||||
param_in_reduced_shape_arrays->reserve(num_params);
|
||||
param_reduced_shapes->reserve(num_params);
|
||||
for (int64 id = 0; id < num_params; ++id) {
|
||||
if (param_buffers[id] == nullptr) {
|
||||
param_reduced_shapes->push_back(Shape());
|
||||
param_in_reduced_shape_arrays->push_back(IrArray());
|
||||
continue;
|
||||
}
|
||||
const HloInstruction* param = hlo.operand(id);
|
||||
param_reduced_shapes->push_back(ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
param->shape().element_type(),
|
||||
Permute({0, 2, 1}, reduced_output_dims)));
|
||||
param_in_reduced_shape_arrays->push_back(
|
||||
param_arrays[id].CastToShape((*param_reduced_shapes)[id], &b_));
|
||||
}
|
||||
return num_params;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
|
||||
@ -2254,12 +2215,12 @@ std::tuple<llvm::Value*, int64> GetStartOffsetAndStepForX(
|
||||
return std::make_tuple(start_offset_x, step_x);
|
||||
}
|
||||
|
||||
void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index,
|
||||
const string& loop_name, KernelSupportLibrary* ksl,
|
||||
llvm::IRBuilder<>* builder, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Type* index_ty,
|
||||
const EmitElementFunction& emit_elem_function) {
|
||||
void EmitFullElementalTile(
|
||||
const KernelMappingScheme* mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Type* index_ty,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||
int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
|
||||
int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
|
||||
int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
|
||||
@ -2292,14 +2253,13 @@ void EmitFullElementalTile(const KernelMappingScheme* mapping_scheme,
|
||||
});
|
||||
}
|
||||
|
||||
void EmitPartialElementalTile(const KernelMappingScheme* mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index,
|
||||
const string& loop_name,
|
||||
KernelSupportLibrary* ksl,
|
||||
llvm::IRBuilder<>* builder, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, llvm::Type* index_ty,
|
||||
const EmitElementFunction& emit_elem_function) {
|
||||
void EmitPartialElementalTile(
|
||||
const KernelMappingScheme* mapping_scheme,
|
||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
llvm::Type* index_ty,
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||
int64 num_threads_x = mapping_scheme->GetNumberOfThreadsForDimensionX();
|
||||
int64 num_threads_y = mapping_scheme->GetNumberOfThreadsForDimensionY();
|
||||
int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
|
||||
@ -2361,7 +2321,7 @@ void EmitTiledElementalCodeWithBoundsCheck(
|
||||
const IrArray::Index& tile_origin_index, const string& loop_name,
|
||||
KernelSupportLibrary* ksl, llvm::IRBuilder<>* builder, llvm::Value* y,
|
||||
llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
const EmitElementFunction& emit_elem_function) {
|
||||
const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
|
||||
int64 tile_size_x = mapping_scheme->GetTileSizeForDimensionX();
|
||||
int64 tile_size_y = mapping_scheme->GetTileSizeForDimensionY();
|
||||
llvm::Type* index_ty = tile_width->getType();
|
||||
@ -2938,10 +2898,10 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
}
|
||||
|
||||
// Emits a kernel for the hlo instruction using the given tiling scheme.
|
||||
void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
|
||||
KernelCodegenInfo* kernel_info,
|
||||
void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
|
||||
KernelSupportLibrary* ksl,
|
||||
llvm::Type* index_ty) {
|
||||
llvm::Type* index_ty,
|
||||
TileGenerator emit_one_tile) {
|
||||
KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
|
||||
absl::Span<const int64> dims_in_tile = mapping_scheme->GetDimensionsInTiles();
|
||||
absl::Span<const int64> dims_in_block =
|
||||
@ -2986,8 +2946,6 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
|
||||
|
||||
absl::Span<const int64> reduced_dims =
|
||||
mapping_scheme->GetDimensionsInElements();
|
||||
const bool block_contains_multi_tiles =
|
||||
mapping_scheme->GetNumberOfTilesInOneBlock() > 1;
|
||||
|
||||
// Emit the tile with a given tile_index, by calculating the tight bounds for
|
||||
// each dimension of the tile and then calling emit_one_tile.
|
||||
@ -3008,7 +2966,7 @@ void IrEmitterUnnested::EmitBlock(const TileGenerator& emit_one_tile,
|
||||
|
||||
IrArray::Index tile_origin =
|
||||
mapping_scheme->GetElementIndexForTileOrigin(tile_index);
|
||||
emit_one_tile(tile_origin, output_tile_bounds, block_contains_multi_tiles);
|
||||
emit_one_tile(tile_origin, output_tile_bounds);
|
||||
};
|
||||
|
||||
const IrArray::Index starting_block =
|
||||
@ -3051,40 +3009,17 @@ LaunchDimensions IrEmitterUnnested::EmitKernel(
|
||||
const KernelCodeGenerator& kernel_generator,
|
||||
KernelCodegenInfo* kernel_info) {
|
||||
KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
|
||||
|
||||
std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*unnested_hlo);
|
||||
int64 num_params = param_arrays.size();
|
||||
// Allocate shared memory buffers to store the tiled inputs.
|
||||
std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
|
||||
for (int64 id : tiled_param_ids) {
|
||||
const HloInstruction* param = unnested_hlo->operand(id);
|
||||
param_shmem_buffers[id] =
|
||||
mapping_scheme->GetSharedMemoryBufferForElementType(
|
||||
llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(),
|
||||
module_),
|
||||
IrName(unnested_hlo, StrCat("tile", id)));
|
||||
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
|
||||
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
|
||||
}
|
||||
|
||||
auto reduction_info = dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
|
||||
bool is_column_reduction =
|
||||
(reduction_info && !reduction_info->IsRowReduction());
|
||||
|
||||
LaunchDimensions launch_dimensions(mapping_scheme->GetNumberOfBlocks(),
|
||||
mapping_scheme->GetThreadsPerBlock());
|
||||
|
||||
// TODO(b/110211620): Enable int32 index type for column reduction.
|
||||
auto reduction_info = dynamic_cast<const ReductionCodegenInfo*>(kernel_info);
|
||||
llvm::Type* index_ty =
|
||||
is_column_reduction
|
||||
(reduction_info && !reduction_info->IsRowReduction())
|
||||
? b_.getInt64Ty()
|
||||
: GetIndexTypeForKernel(unnested_hlo,
|
||||
launch_dimensions.launch_bound(), &b_);
|
||||
|
||||
auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
|
||||
return llvm::ConstantInt::get(index_ty, c);
|
||||
};
|
||||
|
||||
// For multioutput fusion, one thread needs to output a tuple with pointers to
|
||||
// all the individual outputs. We could do this at any point in the kernel,
|
||||
// but we do it at the beginning in the hopes of reducing register pressure,
|
||||
@ -3097,17 +3032,6 @@ LaunchDimensions IrEmitterUnnested::EmitKernel(
|
||||
});
|
||||
}
|
||||
|
||||
// For each tiled parameter, cast its input IrArray to the corresponding
|
||||
// reduced shape and keep the reduced shape live during IR emission.
|
||||
std::vector<IrArray> param_in_reduced_shape_arrays;
|
||||
std::vector<Shape> param_reduced_shapes;
|
||||
absl::Span<const int64> reduced_dims =
|
||||
mapping_scheme->GetDimensionsInElements();
|
||||
int num_shapes = ConstructInputReducedShapeAndCastInputIrArrayToShape(
|
||||
*unnested_hlo, param_arrays, param_shmem_buffers, reduced_dims,
|
||||
¶m_reduced_shapes, ¶m_in_reduced_shape_arrays);
|
||||
DCHECK_EQ(num_shapes, num_params);
|
||||
|
||||
// Calculate the starting element coordinate within a tile for the current
|
||||
// thread, (y, x) from thread_id.
|
||||
llvm::Value* x;
|
||||
@ -3118,81 +3042,21 @@ LaunchDimensions IrEmitterUnnested::EmitKernel(
|
||||
mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
|
||||
: nullptr);
|
||||
kernel_info->SetIndexType(index_ty);
|
||||
|
||||
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
|
||||
// Curry a few parameters to EmitTiledElementalCodeWithBoundsCheck.
|
||||
auto emit_tiled_elemental_code_with_bounds_check =
|
||||
[&](const IrArray::Index& index, const string& loop_name,
|
||||
llvm::Value* tile_height, llvm::Value* tile_width,
|
||||
const EmitElementFunction& emit_elem_function) {
|
||||
EmitTiledElementalCodeWithBoundsCheck(mapping_scheme, index, loop_name,
|
||||
&ksl, &b_, y, x, tile_height,
|
||||
tile_width, emit_elem_function);
|
||||
};
|
||||
|
||||
auto emit_one_tile = [&](const IrArray::Index& output_tile_origin,
|
||||
absl::Span<llvm::Value* const> output_tile_bounds,
|
||||
bool block_contains_multi_tiles) {
|
||||
// Calculate the input tile origin from the output tile origin.
|
||||
const IrArray::Index input_tile_origin(
|
||||
Permute({0, 2, 1}, output_tile_origin.multidim()),
|
||||
Permute({0, 2, 1}, output_tile_origin.dims()),
|
||||
output_tile_origin.GetType());
|
||||
|
||||
// If shared memory transpose is needed, wait for all threads to reach this
|
||||
// point, lest we copy a value from tile to output before the other thread
|
||||
// copies it from input to tile. This is `__syncthreads` in CUDA.
|
||||
if (!tiled_param_ids.empty()) {
|
||||
// Copy input parameter values to shared memory buffers:
|
||||
// tile[y, x] = input[index]
|
||||
// Note that tile_width and tile_height are flipped here because we are
|
||||
// reading a transposed tile.
|
||||
emit_tiled_elemental_code_with_bounds_check(
|
||||
input_tile_origin, "input", output_tile_bounds[2],
|
||||
output_tile_bounds[1],
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||
for (int64 id : tiled_param_ids) {
|
||||
IrArray& input_in_logical_shape =
|
||||
param_in_reduced_shape_arrays[id];
|
||||
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||
// TODO(jlebar): Add AA metadata to this store. Tile buffers are
|
||||
// global variables, so LLVM can't infer much about it.
|
||||
Store(input_in_logical_shape.EmitReadArrayElement(
|
||||
index, &b_, "input_element"),
|
||||
GEP(shmem_buffer, {index_typed_constant(0), y_loc, x_loc}));
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for all threads to reach this point using `__syncthreads` in CUDA.
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
}
|
||||
|
||||
llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
|
||||
kernel_info->SetTiledParamInfo(&tiled_param_info);
|
||||
|
||||
// Write to output[index] by emitting code like normal, except that values
|
||||
// for the tiled parameters are read from the shmem buffers.
|
||||
emit_tiled_elemental_code_with_bounds_check(
|
||||
output_tile_origin, "output", output_tile_bounds[1],
|
||||
output_tile_bounds[2],
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc,
|
||||
int64 x_iter_num) {
|
||||
kernel_generator.GetTileElementGenerator()(
|
||||
unnested_hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
|
||||
});
|
||||
|
||||
// If a tile block contains multiple tiles and shared memory buffers are
|
||||
// used, we need to wait for all threads to finish using the shared memory
|
||||
// buffer for the current tile before we move on to process the next tile
|
||||
// and overwrite the shared memory buffers.
|
||||
if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
}
|
||||
};
|
||||
|
||||
kernel_generator.GetBlockPrologueGenerator()(unnested_hlo, kernel_info);
|
||||
EmitBlock(std::move(emit_one_tile), kernel_info, &ksl, index_ty);
|
||||
EmitBlock(kernel_info, &ksl, index_ty,
|
||||
[&](const IrArray::Index& output_tile_origin,
|
||||
absl::Span<llvm::Value* const> output_tile_bounds) {
|
||||
std::vector<llvm::Value*> param_shmem_buffers(
|
||||
unnested_hlo->operand_count(), nullptr);
|
||||
llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers,
|
||||
y, x);
|
||||
kernel_info->SetTiledParamInfo(&tiled_param_info);
|
||||
kernel_generator.GetTileElementGenerator()(
|
||||
y, x, output_tile_origin, "output", output_tile_bounds[1],
|
||||
output_tile_bounds[2], &ksl);
|
||||
});
|
||||
kernel_generator.GetBlockEpilogueGenerator()(unnested_hlo, kernel_info);
|
||||
return launch_dimensions;
|
||||
}
|
||||
@ -3230,27 +3094,110 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
|
||||
/*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1},
|
||||
/*num_threads_y=*/kNumRows,
|
||||
/*num_threads_x=*/kWarpSize, &b_);
|
||||
TileElementGenerator element_generator;
|
||||
if (hlo->opcode() == HloOpcode::kCopy) {
|
||||
element_generator = [&](HloInstruction* hlo,
|
||||
const llvm_ir::IrArray::Index& index,
|
||||
const KernelCodegenInfo* kernel_info,
|
||||
llvm::Value* y_loc, llvm::Value* x_loc,
|
||||
int64 x_iter_num) {
|
||||
EmitTileElementForCopy(hlo, index, kernel_info, y_loc, x_loc, x_iter_num);
|
||||
};
|
||||
} else {
|
||||
DCHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
||||
element_generator =
|
||||
[&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
EmitTileElementForFusion(hlo, index, kernel_info, y_loc, x_loc,
|
||||
x_iter_num);
|
||||
};
|
||||
}
|
||||
KernelCodegenInfo kernel_info(&mapping_scheme);
|
||||
KernelCodeGenerator kernel_generator(std::move(element_generator));
|
||||
|
||||
std::vector<IrArray> param_arrays;
|
||||
|
||||
// For each tiled parameter, cast its input IrArray to the corresponding
|
||||
// reduced shape and keep the reduced shape live during IR emission.
|
||||
std::vector<IrArray> param_in_reduced_shape_arrays;
|
||||
std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
|
||||
|
||||
for (int64 id = 0; id < hlo->operand_count(); id++) {
|
||||
const HloInstruction* param = hlo->operand(id);
|
||||
param_arrays.push_back(GetIrArray(*param, *hlo));
|
||||
|
||||
if (absl::c_linear_search(tiled_param_ids, id)) {
|
||||
param_shmem_buffers[id] =
|
||||
mapping_scheme.GetSharedMemoryBufferForElementType(
|
||||
llvm_ir::PrimitiveTypeToIrType(param->shape().element_type(),
|
||||
module_),
|
||||
IrName(hlo, StrCat("tile", id)));
|
||||
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
|
||||
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
|
||||
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||
param->shape().element_type(),
|
||||
Permute({0, 2, 1}, reduced_output_dims));
|
||||
LOG(ERROR) << "Generated shape: " << reduced_shape.ToString(true);
|
||||
param_in_reduced_shape_arrays.push_back(
|
||||
param_arrays[id].CastToShape(reduced_shape, &b_));
|
||||
} else {
|
||||
param_in_reduced_shape_arrays.push_back(IrArray());
|
||||
}
|
||||
}
|
||||
|
||||
EmitElementFunction element_generator =
|
||||
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
if (hlo->opcode() == HloOpcode::kCopy) {
|
||||
EmitTileElementForCopy(hlo, index, &kernel_info, y_loc, x_loc,
|
||||
x_iter_num);
|
||||
} else {
|
||||
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
||||
EmitTileElementForFusion(hlo, index, &kernel_info, y_loc, x_loc,
|
||||
x_iter_num);
|
||||
}
|
||||
};
|
||||
|
||||
KernelCodeGenerator kernel_generator(
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
llvm_ir::TiledParameterInfo tiled_param_info(param_shmem_buffers, y, x);
|
||||
kernel_info.SetTiledParamInfo(&tiled_param_info);
|
||||
|
||||
// If shared memory transpose is needed, wait for all threads to reach
|
||||
// this point, lest we copy a value from tile to output before the other
|
||||
// thread copies it from input to tile. This is `__syncthreads` in CUDA.
|
||||
if (!tiled_param_ids.empty()) {
|
||||
// Calculate the input tile origin from the output tile origin.
|
||||
const IrArray::Index input_tile_origin(
|
||||
Permute({0, 2, 1}, index.multidim()),
|
||||
Permute({0, 2, 1}, index.dims()), index.GetType());
|
||||
|
||||
// Copy input parameter values to shared memory buffers:
|
||||
// tile[y, x] = input[index]
|
||||
// Note that tile_width and tile_height are flipped here because we
|
||||
// are reading a transposed tile.
|
||||
EmitTiledElementalCodeWithBoundsCheck(
|
||||
&mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
|
||||
tile_width, tile_height,
|
||||
[&](const IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||
for (int64 id : tiled_param_ids) {
|
||||
IrArray& input_in_logical_shape =
|
||||
param_in_reduced_shape_arrays[id];
|
||||
|
||||
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||
llvm::Value* zero =
|
||||
llvm::ConstantInt::get(kernel_info.GetIndexType(), 0);
|
||||
// TODO(jlebar): Add AA metadata to this store. Tile buffers
|
||||
// are global variables, so LLVM can't infer much about it.
|
||||
Store(input_in_logical_shape.EmitReadArrayElement(
|
||||
index, &b_, "input_element"),
|
||||
GEP(shmem_buffer, {zero, y_loc, x_loc}));
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for all threads to reach this point using `__syncthreads` in
|
||||
// CUDA.
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
}
|
||||
|
||||
EmitTiledElementalCodeWithBoundsCheck(&mapping_scheme, index, loop_name,
|
||||
ksl, &b_, y, x, tile_height,
|
||||
tile_width, element_generator);
|
||||
bool block_contains_multi_tiles =
|
||||
mapping_scheme.GetNumberOfTilesInOneBlock() > 1;
|
||||
|
||||
// If a tile block contains multiple tiles and shared memory buffers are
|
||||
// used, we need to wait for all threads to finish using the shared
|
||||
// memory buffer for the current tile before we move on to process the
|
||||
// next tile and overwrite the shared memory buffers.
|
||||
if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
|
||||
}
|
||||
});
|
||||
return EmitKernel(hlo, tiled_param_ids, kernel_generator, &kernel_info);
|
||||
}
|
||||
|
||||
@ -3679,13 +3626,21 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
std::tie(mapping_scheme, is_row_reduction) =
|
||||
ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce);
|
||||
ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
|
||||
EmitElementFunction emit_reduction_tile =
|
||||
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
EmitTileElementForReduction(unnested_hlo, index, &reduction_info, y_loc,
|
||||
x_loc, x_iter_num);
|
||||
};
|
||||
|
||||
KernelCodeGenerator kernel_generator(
|
||||
/*tile_element_generator=*/
|
||||
[&](HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
EmitTileElementForReduction(hlo, index, kernel_info, y_loc, x_loc,
|
||||
x_iter_num);
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
EmitTiledElementalCodeWithBoundsCheck(&mapping_scheme, index, loop_name,
|
||||
ksl, &b_, y, x, tile_height,
|
||||
tile_width, emit_reduction_tile);
|
||||
},
|
||||
/*block_prologue_generator=*/
|
||||
[&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
|
||||
|
@ -55,8 +55,7 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
// to a global result to implement reduction.
|
||||
using TileGenerator =
|
||||
std::function<void(const llvm_ir::IrArray::Index& output_tile_origin,
|
||||
absl::Span<llvm::Value* const> output_tile_bounds,
|
||||
bool block_contains_multi_tiles)>;
|
||||
absl::Span<llvm::Value* const> output_tile_bounds)>;
|
||||
// KernelCodegenInfo records the common information to support the code
|
||||
// generation for a kernel to process tensor elements by blocks. A block of
|
||||
// tensor elements may contain one or multiple tiles. The code generators that
|
||||
@ -101,6 +100,7 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
// A function object to finalize the code generation for a tile block.
|
||||
using BlockEpilogueGenerator =
|
||||
std::function<void(HloInstruction* hlo, KernelCodegenInfo* kernel_info)>;
|
||||
|
||||
// A function object to generate code to process one element in a tile.
|
||||
//
|
||||
// hlo: the instruction for which the code is generated for.
|
||||
@ -110,11 +110,15 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
// kernel_info: Other information to support the kernel code generation.
|
||||
// x_iter_num: When a thread process N elements in the X dimension, x_iter_num
|
||||
// has a value of 0..N-1 to identify the element being process.
|
||||
using TileElementGenerator = std::function<void(
|
||||
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||
const KernelCodegenInfo* kernel_info, llvm::Value* y_loc,
|
||||
using EmitElementFunction = std::function<void(
|
||||
const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num)>;
|
||||
|
||||
using TileElementGenerator = std::function<void(
|
||||
llvm::Value* y, llvm::Value* x, const llvm_ir::IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
|
||||
|
||||
// KernelCodeGenerator records the code generator objects that generate code
|
||||
// for tile elements or tile block prologue/epilogue.
|
||||
class KernelCodeGenerator {
|
||||
@ -255,9 +259,10 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
absl::Span<const int64> param_ids,
|
||||
const KernelCodeGenerator& kernel_generator,
|
||||
KernelCodegenInfo* kernel_info);
|
||||
void EmitBlock(const TileGenerator& emit_one_tile,
|
||||
KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
|
||||
llvm::Type* index_ty);
|
||||
|
||||
void EmitBlock(KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
|
||||
llvm::Type* index_ty, TileGenerator emit_one_tile);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||
// that performs a 0-2-1 transpose.
|
||||
void EmitTileElementForCopy(HloInstruction* hlo,
|
||||
@ -296,24 +301,6 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
absl::Span<HloComputation* const> reducers,
|
||||
absl::Span<llvm::AllocaInst* const> partial_result_addresses);
|
||||
|
||||
// Generates the IrArray for each input of an hlo and returns a vector that
|
||||
// constains such IrArrays.
|
||||
std::vector<llvm_ir::IrArray> ConstructIrArrayForInputs(
|
||||
const HloInstruction& hlo);
|
||||
|
||||
// For each input of the `hlo` instruction, checks its value in
|
||||
// `param_buffers` to find out whether the input has a reduced shape. If the
|
||||
// input has a reduced shape, constructs the reduced shape for the input and
|
||||
// casts the original input IrArray in `param_arrays` to the reduced shape.
|
||||
// Return the total number of inputs.
|
||||
int ConstructInputReducedShapeAndCastInputIrArrayToShape(
|
||||
const HloInstruction& hlo,
|
||||
const std::vector<llvm_ir::IrArray>& param_arrays,
|
||||
const std::vector<llvm::Value*>& param_buffers,
|
||||
absl::Span<const int64> reduced_output_dims,
|
||||
std::vector<Shape>* param_reduced_shapes,
|
||||
std::vector<llvm_ir::IrArray>* param_in_reduced_shape_arrays);
|
||||
|
||||
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
|
||||
// caller needs to make sure `inst` outlives the lifetime of the returned
|
||||
// Thunk object. The kernel implementation will be unrolled if unroll_factor
|
||||
|
Loading…
Reference in New Issue
Block a user