[XLA GPU] [NFC] Remove TileGenerator abstraction, TileElementGenerator is sufficient
PiperOrigin-RevId: 265970092
This commit is contained in:
parent
3af471cd27
commit
ccfe164602
@ -2456,9 +2456,10 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
|
||||
// Emits a kernel for the hlo instruction using the given tiling scheme.
|
||||
void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
|
||||
KernelSupportLibrary* ksl,
|
||||
llvm::Type* index_ty,
|
||||
TileGenerator emit_one_tile) {
|
||||
KernelSupportLibrary* ksl, llvm::Value* y,
|
||||
llvm::Value* x,
|
||||
TileElementGenerator tile_generator) {
|
||||
llvm::Type* index_ty = kernel_info->GetIndexType();
|
||||
KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
|
||||
absl::Span<const int64> dims_in_tile = mapping_scheme->GetDimensionsInTiles();
|
||||
absl::Span<const int64> dims_in_block =
|
||||
@ -2505,7 +2506,7 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
|
||||
mapping_scheme->GetDimensionsInElements();
|
||||
|
||||
// Emit the tile with a given tile_index, by calculating the tight bounds for
|
||||
// each dimension of the tile and then calling emit_one_tile.
|
||||
// each dimension of the tile and then calling tile_generator.
|
||||
auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) {
|
||||
std::vector<llvm::Value*> output_tile_bounds(3);
|
||||
for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot;
|
||||
@ -2523,7 +2524,8 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
|
||||
|
||||
IrArray::Index tile_origin =
|
||||
mapping_scheme->GetElementIndexForTileOrigin(tile_index);
|
||||
emit_one_tile(tile_origin, output_tile_bounds);
|
||||
tile_generator(y, x, tile_origin, "output", output_tile_bounds[1],
|
||||
output_tile_bounds[2], ksl);
|
||||
};
|
||||
|
||||
const IrArray::Index starting_block =
|
||||
@ -2574,6 +2576,7 @@ void IrEmitterUnnested::EmitKernel(
|
||||
? b_.getInt64Ty()
|
||||
: GetIndexTypeForKernel(unnested_hlo,
|
||||
launch_dimensions.launch_bound(), &b_);
|
||||
kernel_info->SetIndexType(index_ty);
|
||||
|
||||
// Calculate the starting element coordinate within a tile for the current
|
||||
// thread, (y, x) from thread_id.
|
||||
@ -2584,17 +2587,10 @@ void IrEmitterUnnested::EmitKernel(
|
||||
kernel_info->SetLaneId(
|
||||
mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
|
||||
: nullptr);
|
||||
kernel_info->SetIndexType(index_ty);
|
||||
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
|
||||
|
||||
block_prologue_generator(unnested_hlo, kernel_info);
|
||||
EmitBlock(kernel_info, &ksl, index_ty,
|
||||
[&](const IrArray::Index& output_tile_origin,
|
||||
absl::Span<llvm::Value* const> output_tile_bounds) {
|
||||
tile_element_generator(y, x, output_tile_origin, "output",
|
||||
output_tile_bounds[1],
|
||||
output_tile_bounds[2], &ksl);
|
||||
});
|
||||
EmitBlock(kernel_info, &ksl, y, x, tile_element_generator);
|
||||
block_epilogue_generator(unnested_hlo, kernel_info);
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
@ -52,13 +52,6 @@ namespace gpu {
|
||||
class IrEmitterUnnested : public IrEmitter,
|
||||
private ThunkEmitter::EmissionContext {
|
||||
public:
|
||||
// Parameter block_contains_multi_tiles indicates whether a tile block
|
||||
// consists of multiple tiles or not. If the tile block contains only one
|
||||
// tile, there is no need to use atomic operation to accumulate a local result
|
||||
// to a global result to implement reduction.
|
||||
using TileGenerator =
|
||||
std::function<void(const llvm_ir::IrArray::Index& output_tile_origin,
|
||||
absl::Span<llvm::Value* const> output_tile_bounds)>;
|
||||
// KernelCodegenInfo records the common information to support the code
|
||||
// generation for a kernel to process tensor elements by blocks. A block of
|
||||
// tensor elements may contain one or multiple tiles. The code generators that
|
||||
@ -251,7 +244,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
BlockEpilogueGenerator block_epilogue_generator);
|
||||
|
||||
void EmitBlock(KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
|
||||
llvm::Type* index_ty, TileGenerator emit_one_tile);
|
||||
llvm::Value* y, llvm::Value* x,
|
||||
TileElementGenerator tile_generator);
|
||||
|
||||
// Emits code to process a tensor element in a tile for the given kCopy HLO
|
||||
// that performs a 0-2-1 transpose.
|
||||
|
Loading…
Reference in New Issue
Block a user