[XLA GPU] [NFC] Simplify and document getters of KernelMappingScheme

PiperOrigin-RevId: 265975853
This commit is contained in:
George Karpenkov 2019-08-28 13:07:46 -07:00 committed by TensorFlower Gardener
parent 35e5ec9e9e
commit a1e5191049
4 changed files with 42 additions and 44 deletions

View File

@ -2627,9 +2627,9 @@ void IrEmitterUnnested::EmitHlo021Tile(
constexpr int kNumRows = 4;
KernelMappingScheme mapping_scheme(
reduced_output_dims, /*tile_size_y=*/kWarpSize,
/*tile_size_x=*/kWarpSize, /*req_block_sizes=*/{1, 1, 1},
/*tile_size_x=*/kWarpSize, /*block_size_z=*/1,
/*num_threads_y=*/kNumRows,
/*num_threads_x=*/kWarpSize, &b_);
/*num_threads_x=*/kWarpSize, /*is_dilated_x=*/false, &b_);
KernelCodegenInfo kernel_info(&mapping_scheme);
std::vector<IrArray> param_arrays;
@ -3062,7 +3062,7 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
} // namespace
std::tuple<KernelMappingScheme, bool>
std::pair<KernelMappingScheme, bool>
IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
const Shape& input_shape = first_reduce->operand(0)->shape();
@ -3121,12 +3121,10 @@ IrEmitterUnnested::ComputeMappingSchemeAndReductionKind(
tile_size_y = kNumElementsPerPartialSum;
}
DimensionVector req_block_sizes{block_size_z, 1, 1};
llvm_ir::KernelMappingScheme mapping_scheme(
dims_in_elem, tile_size_y, tile_size_x, req_block_sizes, num_threads_y,
num_threads_x, &b_);
mapping_scheme.SetDilatedX(dilated_x);
return std::make_tuple(mapping_scheme, is_row_reduction);
dims_in_elem, tile_size_y, tile_size_x, block_size_z, num_threads_y,
num_threads_x, dilated_x, &b_);
return std::make_pair(mapping_scheme, is_row_reduction);
}
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
@ -3197,11 +3195,11 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
"doesn't set the input layout of "
<< first_reduce->ToString();
bool is_row_reduction;
llvm_ir::KernelMappingScheme mapping_scheme;
std::tie(mapping_scheme, is_row_reduction) =
auto mapping_scheme_pair =
ComputeMappingSchemeAndReductionKind(unnested_hlo, first_reduce);
ReductionCodegenInfo reduction_info(&mapping_scheme, is_row_reduction);
bool is_row_reduction = mapping_scheme_pair.second;
ReductionCodegenInfo reduction_info(&mapping_scheme_pair.first,
is_row_reduction);
EmitElementFunction emit_reduction_tile =
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) {
@ -3216,9 +3214,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
[&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
EmitTiledElementalCodeWithBoundsCheck(&mapping_scheme, index, loop_name,
ksl, &b_, y, x, tile_height,
tile_width, emit_reduction_tile);
EmitTiledElementalCodeWithBoundsCheck(
&mapping_scheme_pair.first, index, loop_name, ksl, &b_, y, x,
tile_height, tile_width, emit_reduction_tile);
},
/*block_prologue_generator=*/
[&](HloInstruction* hlo, KernelCodegenInfo* kernel_info) {

View File

@ -212,7 +212,7 @@ class IrEmitterUnnested : public IrEmitter,
// and first_reduce are the same instruction. For a kInput fusion,
// unnested_hlo is the fusion instruction while first_reduce is the first
// reduce op.
std::tuple<llvm_ir::KernelMappingScheme, bool>
std::pair<llvm_ir::KernelMappingScheme, bool>
ComputeMappingSchemeAndReductionKind(const HloInstruction* unnested_hlo,
const HloInstruction* first_reduce);

View File

@ -103,29 +103,36 @@ absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
return absl::nullopt;
}
KernelMappingScheme::KernelMappingScheme(
absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x,
absl::Span<const int64> req_block_sizes, int64 num_threads_y,
int64 num_threads_x, llvm::IRBuilder<>* b)
KernelMappingScheme::KernelMappingScheme(absl::Span<const int64> dims_in_elems,
int64 tile_size_y, int64 tile_size_x,
int64 block_size_z,
int64 num_threads_y,
int64 num_threads_x, bool is_dilated_x,
llvm::IRBuilder<>* b)
: b_(b),
dims_in_elems_{dims_in_elems.at(0), dims_in_elems.at(1),
dims_in_elems.at(2)},
dims_in_elems_{dims_in_elems[0], dims_in_elems[1], dims_in_elems[2]},
tile_sizes_{1, tile_size_y, tile_size_x},
dims_in_tiles_(ElementWiseCeilOfRatio(dims_in_elems_, tile_sizes_)),
block_sizes_{std::min(req_block_sizes.at(0), dims_in_tiles_.at(0)),
std::min(req_block_sizes.at(1), dims_in_tiles_.at(1)),
std::min(req_block_sizes.at(2), dims_in_tiles_.at(2))},
dims_in_blocks_(ElementWiseCeilOfRatio(dims_in_tiles_, block_sizes_)),
dims_in_tiles_{dims_in_elems[0],
CeilOfRatio<int64>(dims_in_elems[1], tile_size_y),
CeilOfRatio<int64>(dims_in_elems[2], tile_size_x)},
block_sizes_{block_size_z, 1, 1},
dims_in_blocks_{CeilOfRatio<int64>(dims_in_elems[0], block_sizes_[0]),
dims_in_tiles_[1], dims_in_tiles_[2]},
num_threads_x_(num_threads_x),
num_threads_y_(num_threads_y),
dilated_x_(true) {
DCHECK_EQ(req_block_sizes.size(), 3);
dilated_x_(is_dilated_x) {
DCHECK_EQ(tile_size_y % num_threads_y_, 0);
DCHECK_EQ(tile_size_x % num_threads_x_, 0);
CHECK_EQ((dims_in_elems[0] % block_size_z), 0);
VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]";
VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]";
VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",")
<< "]";
if (!dilated_x_) {
// dilated_x_=false is for the purpose of vectorization, which requires
// GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
}
}
IrArray::Index KernelMappingScheme::GetUnnormalizedIndex(

View File

@ -90,23 +90,24 @@ class KernelMappingScheme {
enum { DimZ = 0, DimY, DimX, DimTot };
public:
KernelMappingScheme() {}
// dims_in_elems: the normalized tensor dimensions.
// req_block_sizes: the requested block size in number of tiles for each
// dimension. The actual block size is set to min(req_block_size,
// dims_in_number_of_blocks).
KernelMappingScheme(absl::Span<const int64> dims_in_elems, int64 tile_size_y,
int64 tile_size_x,
absl::Span<const int64> req_block_sizes,
int64 tile_size_x, int64 block_size_z,
int64 num_threads_y, int64 num_threads_x,
llvm::IRBuilder<>* b);
bool is_dilated_x, llvm::IRBuilder<>* b);
// Number of elements in each dimension (Z/Y/X respectively).
absl::Span<const int64> GetDimensionsInElements() const {
return dims_in_elems_;
}
// Ratio of elements in each dimension over tile sizes for Z/Y/X
// respectively.
absl::Span<const int64> GetDimensionsInTiles() const {
return dims_in_tiles_;
}
// Ratio of dimensions per tile over block sizes.
absl::Span<const int64> GetDimensionsInBlocks() const {
return dims_in_blocks_;
}
@ -147,14 +148,6 @@ class KernelMappingScheme {
}
bool DilatedX() const { return dilated_x_; }
void SetDilatedX(bool v) {
dilated_x_ = v;
if (!dilated_x_) {
// dilated_x_=false is for the purpose of vectorization, which requires
// GetTileSizeForDimension(DimX) to be a multiplier of num_threads_x_.
CHECK_EQ(GetTileSizeForDimension(DimX) % num_threads_x_, 0);
}
}
IrArray::Index EmitBlockIndex(llvm::Type* index_ty);
// Returns the index for the first tile in the block with the given block