[XLA GPU] [NFC] Remove TileGenerator abstraction, TileElementGenerator is sufficient

PiperOrigin-RevId: 265970092
This commit is contained in:
George Karpenkov 2019-08-28 12:41:02 -07:00 committed by TensorFlower Gardener
parent 3af471cd27
commit ccfe164602
2 changed files with 11 additions and 21 deletions

View File

@ -2456,9 +2456,10 @@ void IrEmitterUnnested::EmitTileElementForReduction(
// Emits a kernel for the hlo instruction using the given tiling scheme.
void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
KernelSupportLibrary* ksl,
llvm::Type* index_ty,
TileGenerator emit_one_tile) {
KernelSupportLibrary* ksl, llvm::Value* y,
llvm::Value* x,
TileElementGenerator tile_generator) {
llvm::Type* index_ty = kernel_info->GetIndexType();
KernelMappingScheme* mapping_scheme = kernel_info->GetKernelMappingScheme();
absl::Span<const int64> dims_in_tile = mapping_scheme->GetDimensionsInTiles();
absl::Span<const int64> dims_in_block =
@ -2505,7 +2506,7 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
mapping_scheme->GetDimensionsInElements();
// Emit the tile with a given tile_index, by calculating the tight bounds for
// each dimension of the tile and then calling emit_one_tile.
// each dimension of the tile and then calling tile_generator.
auto emit_one_tile_for_tile_index = [&](const IrArray::Index& tile_index) {
std::vector<llvm::Value*> output_tile_bounds(3);
for (int i = KernelMappingScheme::DimY; i < KernelMappingScheme::DimTot;
@ -2523,7 +2524,8 @@ void IrEmitterUnnested::EmitBlock(KernelCodegenInfo* kernel_info,
IrArray::Index tile_origin =
mapping_scheme->GetElementIndexForTileOrigin(tile_index);
emit_one_tile(tile_origin, output_tile_bounds);
tile_generator(y, x, tile_origin, "output", output_tile_bounds[1],
output_tile_bounds[2], ksl);
};
const IrArray::Index starting_block =
@ -2574,6 +2576,7 @@ void IrEmitterUnnested::EmitKernel(
? b_.getInt64Ty()
: GetIndexTypeForKernel(unnested_hlo,
launch_dimensions.launch_bound(), &b_);
kernel_info->SetIndexType(index_ty);
// Calculate the starting element coordinate within a tile for the current
// thread, (y, x) from thread_id.
@ -2584,17 +2587,10 @@ void IrEmitterUnnested::EmitKernel(
kernel_info->SetLaneId(
mapping_scheme->GetNumberOfThreadsForDimensionX() == kWarpSize ? x
: nullptr);
kernel_info->SetIndexType(index_ty);
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
block_prologue_generator(unnested_hlo, kernel_info);
EmitBlock(kernel_info, &ksl, index_ty,
[&](const IrArray::Index& output_tile_origin,
absl::Span<llvm::Value* const> output_tile_bounds) {
tile_element_generator(y, x, output_tile_origin, "output",
output_tile_bounds[1],
output_tile_bounds[2], &ksl);
});
EmitBlock(kernel_info, &ksl, y, x, tile_element_generator);
block_epilogue_generator(unnested_hlo, kernel_info);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
ir_emitter_context_->llvm_module());

View File

@ -52,13 +52,6 @@ namespace gpu {
class IrEmitterUnnested : public IrEmitter,
private ThunkEmitter::EmissionContext {
public:
// Parameter block_contains_multi_tiles indicates whether a tile block
// consists of multiple tiles or not. If the tile block contains only one
// tile, there is no need to use atomic operation to accumulate a local result
// to a global result to implement reduction.
using TileGenerator =
std::function<void(const llvm_ir::IrArray::Index& output_tile_origin,
absl::Span<llvm::Value* const> output_tile_bounds)>;
// KernelCodegenInfo records the common information to support the code
// generation for a kernel to process tensor elements by blocks. A block of
// tensor elements may contain one or multiple tiles. The code generators that
@ -251,7 +244,8 @@ class IrEmitterUnnested : public IrEmitter,
BlockEpilogueGenerator block_epilogue_generator);
void EmitBlock(KernelCodegenInfo* kernel_info, KernelSupportLibrary* ksl,
llvm::Type* index_ty, TileGenerator emit_one_tile);
llvm::Value* y, llvm::Value* x,
TileElementGenerator tile_generator);
// Emits code to process a tensor element in a tile for the given kCopy HLO
// that performs a 0-2-1 transpose.