[XLA GPU] [NFC] Simplify and document getters of KernelMappingScheme
PiperOrigin-RevId: 265975853
This commit is contained in:
parent
35e5ec9e9e
commit
a1e5191049
tensorflow/compiler/xla/service
@ -2627,9 +2627,9 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
||||
constexpr int kNumRows = 4;
|
||||
KernelMappingScheme mapping_scheme(
|
||||
reduced_output_dims, /*tile_size_y=*/kWarpSize,
|
||||
/*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1},
|
||||
/*tile_size_x=*/kWarpSize, /*block_size_z=*/1,
|
||||
/*num_threads_y=*/kNumRows,
|
||||
/*num_threads_x=*/kWarpSize, &b_);
|
||||
/*num_threads_x=*/kWarpSize, /*is_dilated_x=*/false, &b_);
|
||||
KernelCodegenInfo kernel_info(&mapping_scheme);
|
||||
|
||||
std::vector<IrArray> param_arrays;
|
||||
@ -3062,7 +3062,7 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<KernelMappingScheme, bool>
|
||||
std::pair<KernelMappingScheme, bool>
|
||||
IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
|
||||
const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
|
||||
const Shape& input_shape = first_reduce->operand(0)->shape();
|
||||
@ -3121,12 +3121,10 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
|
||||
tile_size_y = kNumElementsPerPartialSum;
|
||||
}
|
||||
|
||||
DimensionVector req_block_sizes{block_size_z, 1, 1};
|
||||
llvm_ir::KernelMappingScheme mapping_scheme(
|
||||
dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y,
|
||||
num_threads_x, &b_);
|
||||
mapping_scheme.SetDilatedX(dilated_x);
|
||||
return std::make_tuple(mapping_scheme, is_row_reduction);
|
||||
dims_in_elem, tile_size_y, tile_size_x, block_size_z, num_threads_y,
|
||||
num_threads_x, dilated_x, &b_);
|
||||
return std::make_pair(mapping_scheme, is_row_reduction);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
@ -3197,11 +3195,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
"doesn't set the input layout of "
|
||||
<< first_reduce->ToString();
|
||||
|
||||
bool is_row_reduction;
|
||||
llvm_ir::KernelMappingScheme mapping_scheme;
|
||||
std::tie(mapping_scheme, is_row_reduction) =
|
||||
auto mapping_scheme_pair =
|
||||
ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce);
|
||||
ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
|
||||
bool is_row_reduction = mapping_scheme_pair.second;
|
||||
ReductionCodegenInfo reduction_info(&mapping_scheme_pair.first,
|
||||
is_row_reduction);
|
||||
EmitElementFunction emit_reduction_tile =
|
||||
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
@ -3216,9 +3214,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
EmitTiledElementalCodeWithBoundsCheck(&mapping_scheme, index, loop_name,
|
||||
ksl, &b_, y, x, tile_height,
|
||||
tile_width, emit_reduction_tile);
|
||||
EmitTiledElementalCodeWithBoundsCheck(
|
||||
&mapping_scheme_pair.first, index, loop_name, ksl, &b_, y, x,
|
||||
tile_height, tile_width, emit_reduction_tile);
|
||||
},
|
||||
/*block_prologue_generator=*/
|
||||
[&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {
|
||||
|
@ -212,7 +212,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// and first_reduce are the same instruction. For a kInput fusion,
|
||||
// unnested_hlo is the fusion instruction while first_reduce is the first
|
||||
// reduce op.
|
||||
std::tuple<llvm_ir::KernelMappingScheme, bool>
|
||||
std::pair<llvm_ir::KernelMappingScheme, bool>
|
||||
ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo,
|
||||
const HloInstruction* first_reduce);
|
||||
|
||||
|
@ -103,29 +103,36 @@ absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
KernelMappingScheme::KernelMappingScheme(
|
||||
absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x,
|
||||
absl::Span<const int64> req_block_sizes, int64 num_threads_y,
|
||||
int64 num_threads_x, llvm::IRBuilder<>* b)
|
||||
KernelMappingScheme::KernelMappingScheme(absl::Span<const int64> dims_in_elems,
|
||||
int64 tile_size_y, int64 tile_size_x,
|
||||
int64 block_size_z,
|
||||
int64 num_threads_y,
|
||||
int64 num_threads_x, bool is_dilated_x,
|
||||
llvm::IRBuilder<>* b)
|
||||
: b_(b),
|
||||
dims_in_elems_{dims_in_elems.at(0), dims_in_elems.at(1),
|
||||
dims_in_elems.at(2)},
|
||||
dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]},
|
||||
tile_sizes_{1, tile_size_y, tile_size_x},
|
||||
dims_in_tiles_(ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_)),
|
||||
block_sizes_{std::min(req_block_sizes.at(0), dims_in_tiles_.at(0)),
|
||||
std::min(req_block_sizes.at(1), dims_in_tiles_.at(1)),
|
||||
std::min(req_block_sizes.at(2), dims_in_tiles_.at(2))},
|
||||
dims_in_blocks_(ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_)),
|
||||
dims_in_tiles_{dims_in_elems[0],
|
||||
CeilOfRatio<int64>(dims_in_elems[1], tile_size_y),
|
||||
CeilOfRatio<int64>(dims_in_elems[2], tile_size_x)},
|
||||
block_sizes_{block_size_z, 1, 1},
|
||||
dims_in_blocks_{CeilOfRatio<int64>(dims_in_elems[0], block_sizes_[0]),
|
||||
dims_in_tiles_[1], dims_in_tiles_[2]},
|
||||
num_threads_x_(num_threads_x),
|
||||
num_threads_y_(num_threads_y),
|
||||
dilated_x_(true) {
|
||||
DCHECK_EQ(req_block_sizes.size(), 3);
|
||||
dilated_x_(is_dilated_x) {
|
||||
DCHECK_EQ(tile_size_y % num_threads_y_, 0);
|
||||
DCHECK_EQ(tile_size_x % num_threads_x_, 0);
|
||||
CHECK_EQ((dims_in_elems[0] % block_size_z), 0);
|
||||
VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]";
|
||||
VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]";
|
||||
VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",")
|
||||
<< "]";
|
||||
if (!dilated_x_) {
|
||||
// dilated_x_=false is for the purpose of vectorization, which requires
|
||||
// GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
|
||||
CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
IrArray::Index KernelMappingScheme::GetUnnormalizedIndex(
|
||||
|
@ -90,23 +90,24 @@ class KernelMappingScheme {
|
||||
enum { DimZ = 0, DimY, DimX, DimTot };
|
||||
|
||||
public:
|
||||
KernelMappingScheme() {}
|
||||
// dims_in_elems: the normalized tensor dimensions.
|
||||
// req_block_sizes: the requested block size in number of tiles for each
|
||||
// dimension. The actual block size is set to min(req_block_size,
|
||||
// dims_in_number_of_blocks).
|
||||
KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y,
|
||||
int64 tile_size_x,
|
||||
absl::Span<const int64> req_block_sizes,
|
||||
int64 tile_size_x, int64 block_size_z,
|
||||
int64 num_threads_y, int64 num_threads_x,
|
||||
llvm::IRBuilder<>* b);
|
||||
bool is_dilated_x, llvm::IRBuilder<>* b);
|
||||
|
||||
// Number of elements in each dimension (Z/Y/X respectively).
|
||||
absl::Span<const int64> GetDimensionsInElements() const {
|
||||
return dims_in_elems_;
|
||||
}
|
||||
|
||||
// Ratio of elements in each dimension over tile sizes for Z/Y/X
|
||||
// respectively.
|
||||
absl::Span<const int64> GetDimensionsInTiles() const {
|
||||
return dims_in_tiles_;
|
||||
}
|
||||
|
||||
// Ratio of dimensions per tile over block sizes.
|
||||
absl::Span<const int64> GetDimensionsInBlocks() const {
|
||||
return dims_in_blocks_;
|
||||
}
|
||||
@ -147,14 +148,6 @@ class KernelMappingScheme {
|
||||
}
|
||||
|
||||
bool DilatedX() const { return dilated_x_; }
|
||||
void SetDilatedX(bool v) {
|
||||
dilated_x_ = v;
|
||||
if (!dilated_x_) {
|
||||
// dilated_x_=false is for the purpose of vectorization, which requires
|
||||
// GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
|
||||
CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
|
||||
}
|
||||
}
|
||||
|
||||
IrArray::Index EmitBlockIndex(llvm::Type* index_ty);
|
||||
// Returns the index for the first tile in the block with the given block
|
||||
|
Loading…
Reference in New Issue
Block a user