[XLA:GPU] Generalize the column reduction algorithm to handle tile widths greater than 1.
Tiles of width 1 result in poor memory bandwidth for 16b inputs. PiperOrigin-RevId: 205033124
This commit is contained in:
parent
f1de0ddd55
commit
d9d029f510
@ -984,8 +984,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
tensorflow::gtl::ArraySlice<
|
||||
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
|
||||
extra_output_gens) {
|
||||
// Divide the input matrix into tiles of size Kx1. For example, when the
|
||||
// input matrix is 4x4 and K=2, the tiled matrix looks like
|
||||
// Divide the input matrix into tiles of size KxL. For example, when the
|
||||
// input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
|
||||
//
|
||||
// 0123
|
||||
// 0123
|
||||
@ -997,14 +997,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
//
|
||||
// We choose 128 as the tile size based on empirical evidence. It's big enough
|
||||
// to reduce the amount of atomic adds in the end, maximizing the memory
|
||||
// bandwidth.
|
||||
constexpr int64 kTileSize = 128;
|
||||
// bandwidth. A tile width of 2 allows for high memory bandwidth utilization
|
||||
// on 16b input data.
|
||||
constexpr int64 kTileHeight = 128;
|
||||
constexpr int64 kTileWidth = 2;
|
||||
|
||||
// If the height is not a multiple of the tile size, we pad the bottom of the
|
||||
// If the height is not a multiple of kTileHeight, we pad the bottom of the
|
||||
// input matrix.
|
||||
const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
|
||||
Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
|
||||
reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
|
||||
const int64 height_in_tiles = CeilOfRatio(height, kTileHeight);
|
||||
// If width is not a multiple of kTileWidth the rightmost thread will process
|
||||
// fewer input elements.
|
||||
const int64 width_in_tiles = CeilOfRatio(width, kTileWidth);
|
||||
Shape tiled_input_shape =
|
||||
ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(),
|
||||
{height_in_tiles, width_in_tiles}, {1, 0});
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
tiled_input_shape, ir_emitter_context_->device_description());
|
||||
|
||||
@ -1016,27 +1022,39 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
};
|
||||
|
||||
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
// linear_index < height_in_tiles * width;
|
||||
// linear_index < height_in_tiles * width_in_tiles;
|
||||
// linear_index += blockDim.x * gridDim.x) {
|
||||
// y_in_tiles = linear_index / width;
|
||||
// x = linear_index % width;
|
||||
// y_in_tiles = linear_index / width_in_tiles;
|
||||
// x_in_tiles = linear_index % width_in_tiles;
|
||||
//
|
||||
// partial_result = init_value;
|
||||
// if (height % kTileSize == 0 ||
|
||||
// y_in_tiles * kTileSize + kTileSize <= height) {
|
||||
// for (element_id_in_tile : range(kTileSize)) {
|
||||
// y = y_in_tiles * kTileSize + element_id_in_tile;
|
||||
// partial_result = Reducer(partial_result, input[y][x]);
|
||||
// partial_results[kTileWidth] = init_values;
|
||||
// tile_in_y_bounds = height % kTileHeight == 0 ||
|
||||
// y_in_tiles * kTileHeight + kTileHeight <= height;
|
||||
// tile_in_x_bounds = width % kTileWidth == 0 ||
|
||||
// x_in_tiles * kTileWidth + kTileWidth <= width;
|
||||
// // The implementation handles y and x bound checks separately.
|
||||
// if (tile_in_y_bounds && tile_in_x_bounds) {
|
||||
// for (y_offset : range(kTileHeight)) {
|
||||
// y = y_in_tiles * kTileHeight + y_offset;
|
||||
// for (x_offset : range(kTileWidth)) {
|
||||
// x = x_in_tiles * kTileWidth + x_offset;
|
||||
// partial_result = Reducer(partial_result[x_offset], input[y][x]);
|
||||
// }
|
||||
// }
|
||||
// } else {
|
||||
// for (element_id_in_tile : range(kTileSize)) {
|
||||
// y = y_in_tiles * kTileSize + element_id_in_tile;
|
||||
// if (y < height) {
|
||||
// partial_result = Reducer(partial_result, input[y][x]);
|
||||
// for (y_offset : range(kTileHeight)) {
|
||||
// y = y_in_tiles * kTileHeight + y_offset;
|
||||
// for (y_offset : range(kTileHeight)) {
|
||||
// x = x_in_tiles * kTileWidth + x_offset;
|
||||
// if (y < height && x < width) {
|
||||
// partial_result = Reducer(partial_result, input[y][x]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// AtomicReducer(&output[x], partial_result);
|
||||
// for (x_offset : range(kTileWidth)) {
|
||||
// AtomicReducer(&output[x + x_offset], partial_result[x_offset]);
|
||||
// }
|
||||
// }
|
||||
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
|
||||
const int num_reduces = reducers.size();
|
||||
@ -1045,41 +1063,48 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
|
||||
std::vector<llvm::Value*> partial_reduction_result_addresses;
|
||||
for (int i = 0; i != num_reduces; ++i) {
|
||||
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
|
||||
element_ir_type, /*ArraySize=*/nullptr,
|
||||
"partial_reduction_result." + llvm::Twine(i));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
|
||||
init_value_gens[i](IrArray::Index(index_ty)));
|
||||
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
|
||||
partial_reduction_result_addresses.push_back(
|
||||
partial_reduction_result_address);
|
||||
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
|
||||
llvm::Value* partial_reduction_result_address =
|
||||
ir_builder_.CreateAlloca(
|
||||
element_ir_type, /*ArraySize=*/nullptr,
|
||||
"partial_reduction_result." +
|
||||
llvm::Twine(i * kTileWidth + x_offset));
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
|
||||
init_value_gens[i](IrArray::Index(index_ty)));
|
||||
ir_builder_.CreateStore(init_ir_value,
|
||||
partial_reduction_result_address);
|
||||
partial_reduction_result_addresses.push_back(
|
||||
partial_reduction_result_address);
|
||||
}
|
||||
}
|
||||
|
||||
// Emit an inner for-loop that partially reduces the elements in the given
|
||||
// tile.
|
||||
llvm::Value* y_in_tiles = tile_index[0];
|
||||
llvm::Value* x = tile_index[1];
|
||||
llvm::Value* x_in_tiles = tile_index[1];
|
||||
|
||||
y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
|
||||
x = ir_builder_.CreateZExtOrTrunc(x, index_ty);
|
||||
x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
|
||||
|
||||
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
|
||||
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
|
||||
bool tile_in_x_bounds) -> Status {
|
||||
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
|
||||
llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
|
||||
index_typed_constant(0),
|
||||
index_typed_constant(kTileSize),
|
||||
index_typed_constant(kTileHeight),
|
||||
index_typed_constant(1), &ir_builder_);
|
||||
|
||||
// Emit the body of the partial reduction loop.
|
||||
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
|
||||
&ir_builder_);
|
||||
llvm::Value* y = ir_builder_.CreateNSWAdd(
|
||||
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)),
|
||||
ir_builder_.CreateNSWMul(y_in_tiles,
|
||||
index_typed_constant(kTileHeight)),
|
||||
tile_element_loop->GetIndVarValue());
|
||||
|
||||
// Unless we know the tile is entirely in bounds, we have to emit a
|
||||
// y-in-bounds check before reading from the input.
|
||||
if (!tile_in_bounds) {
|
||||
// Unless we know that y is in bounds, we have to emit a check before
|
||||
// reading from the input.
|
||||
if (!tile_in_y_bounds) {
|
||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
||||
ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
|
||||
"y_in_bounds", &ir_builder_);
|
||||
@ -1088,8 +1113,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
// the partial reduction result.
|
||||
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
|
||||
}
|
||||
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
|
||||
{
|
||||
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
|
||||
llvm::Value* x = ir_builder_.CreateNSWAdd(
|
||||
ir_builder_.CreateNSWMul(x_in_tiles,
|
||||
index_typed_constant(kTileWidth)),
|
||||
index_typed_constant(x_offset));
|
||||
// Unless we know that x is in bounds, we have to emit a check before
|
||||
// reading from the input.
|
||||
if (!tile_in_x_bounds) {
|
||||
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
|
||||
ir_builder_.CreateICmpULT(x, index_typed_constant(width)),
|
||||
"x_in_bounds", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
|
||||
}
|
||||
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
|
||||
// {y,x} is an index to input_matrix_shape [height,width]. We need to
|
||||
// convert that to an index to input_shape (the shape of the operand of
|
||||
// "reduce"). This conversion is composed of a transposition from
|
||||
@ -1120,51 +1157,90 @@ Status IrEmitterUnnested::EmitColumnReduction(
|
||||
ir_builder_.CreateStore(input_ir_value, input_address);
|
||||
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
|
||||
*reducers[i],
|
||||
{partial_reduction_result_addresses[i], input_address},
|
||||
partial_reduction_result_addresses[i]));
|
||||
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
|
||||
input_address},
|
||||
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
|
||||
TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index,
|
||||
extra_output_gens));
|
||||
}
|
||||
return EmitExtraOutputsForReduce(reduce, input_index,
|
||||
extra_output_gens);
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
|
||||
// immediately beyond the tile.
|
||||
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
|
||||
// that's immediately beyond the tile.
|
||||
llvm::Value* y_end = ir_builder_.CreateNSWAdd(
|
||||
index_typed_constant(kTileSize),
|
||||
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)));
|
||||
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
|
||||
index_typed_constant(kTileHeight),
|
||||
ir_builder_.CreateNSWMul(y_in_tiles,
|
||||
index_typed_constant(kTileHeight)));
|
||||
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
|
||||
// that's immediately beyond the tile.
|
||||
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
|
||||
index_typed_constant(kTileWidth),
|
||||
ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
|
||||
llvm::Value* tile_in_y_bounds = ir_builder_.CreateOr(
|
||||
ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
|
||||
ir_builder_.getInt1(height % kTileSize == 0));
|
||||
// The tile is entirely in bound if "height" is a multiple of kTileSize or
|
||||
ir_builder_.getInt1(height % kTileHeight == 0));
|
||||
llvm::Value* tile_in_x_bounds = ir_builder_.CreateOr(
|
||||
ir_builder_.CreateICmpULE(x_end, index_typed_constant(width)),
|
||||
ir_builder_.getInt1(width % kTileWidth == 0));
|
||||
// The tile is in y bounds if "height" is a multiple of kTileHeight or
|
||||
// y_end <= height.
|
||||
llvm_ir::LlvmIfData if_tile_in_bounds_data =
|
||||
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
|
||||
llvm_ir::LlvmIfData if_tile_in_y_bounds_data = llvm_ir::EmitIfThenElse(
|
||||
tile_in_y_bounds, "tile_in_y_bounds", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block,
|
||||
&ir_builder_);
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
|
||||
// The tile is in x bounds if "width" is a multiple of kTileWidth or
|
||||
// x_end <= width.
|
||||
llvm_ir::LlvmIfData if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
|
||||
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
|
||||
&ir_builder_);
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
|
||||
/*tile_in_x_bounds=*/true));
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
|
||||
&ir_builder_);
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
|
||||
/*tile_in_x_bounds=*/false));
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block,
|
||||
&ir_builder_);
|
||||
if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
|
||||
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
|
||||
&ir_builder_);
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
|
||||
/*tile_in_x_bounds=*/true));
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
|
||||
&ir_builder_);
|
||||
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
|
||||
/*tile_in_x_bounds=*/false));
|
||||
|
||||
// After the if-then-else statement on tile_in_bounds, emit atomic
|
||||
// operations to accumulate the partial reduction result to the output
|
||||
// element.
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
|
||||
// After the nested if-then-else statement on tile_in_y_bounds and
|
||||
// tile_in_x_bounds, emit atomic operations to accumulate the partial
|
||||
// reduction result to the output element.
|
||||
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block,
|
||||
&ir_builder_);
|
||||
const HloInstruction* output =
|
||||
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
|
||||
for (int i = 0; i != num_reduces; ++i) {
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*output, *output, reduce_output_shapes[i])
|
||||
.EmitArrayElementAddress(
|
||||
IrArray::Index(x,
|
||||
ShapeUtil::GetSubshape(
|
||||
output->shape(), reduce_output_shapes[i]),
|
||||
&ir_builder_),
|
||||
&ir_builder_, "output_element_address");
|
||||
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
|
||||
*reducers[i], output_address, partial_reduction_result_addresses[i]));
|
||||
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
|
||||
llvm::Value* x = ir_builder_.CreateNSWAdd(
|
||||
ir_builder_.CreateNSWMul(x_in_tiles,
|
||||
index_typed_constant(kTileWidth)),
|
||||
index_typed_constant(x_offset));
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*output, *output, reduce_output_shapes[i])
|
||||
.EmitArrayElementAddress(
|
||||
IrArray::Index(
|
||||
x,
|
||||
ShapeUtil::GetSubshape(output->shape(),
|
||||
reduce_output_shapes[i]),
|
||||
&ir_builder_),
|
||||
&ir_builder_, "output_element_address");
|
||||
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
|
||||
*reducers[i], output_address,
|
||||
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user