[XLA:GPU] Generalize the column reduction algorithm to handle tile widths greater than 1.

Tiles of width 1 result in poor memory bandwidth for 16b inputs.

PiperOrigin-RevId: 205033124
This commit is contained in:
Thomas Joerg 2018-07-17 23:50:05 -07:00 committed by TensorFlower Gardener
parent f1de0ddd55
commit d9d029f510

View File

@ -984,8 +984,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
tensorflow::gtl::ArraySlice< tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>> std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) { extra_output_gens) {
// Divide the input matrix into tiles of size Kx1. For example, when the // Divide the input matrix into tiles of size KxL. For example, when the
// input matrix is 4x4 and K=2, the tiled matrix looks like // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
// //
// 0123 // 0123
// 0123 // 0123
@ -997,14 +997,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// //
// We choose 128 as the tile size based on empirical evidence. It's big enough // We choose 128 as the tile size based on empirical evidence. It's big enough
// to reduce the amount of atomic adds in the end, maximizing the memory // to reduce the amount of atomic adds in the end, maximizing the memory
// bandwidth. // bandwidth. A tile width of 2 allows for high memory bandwidth utilization
constexpr int64 kTileSize = 128; // on 16b input data.
constexpr int64 kTileHeight = 128;
constexpr int64 kTileWidth = 2;
// If the height is not a multiple of the tile size, we pad the bottom of the // If the height is not a multiple of kTileHeight, we pad the bottom of the
// input matrix. // input matrix.
const int64 height_in_tiles = CeilOfRatio(height, kTileSize); const int64 height_in_tiles = CeilOfRatio(height, kTileHeight);
Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( // If width is not a multiple of kTileWidth the rightmost thread will process
reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); // fewer input elements.
const int64 width_in_tiles = CeilOfRatio(width, kTileWidth);
Shape tiled_input_shape =
ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(),
{height_in_tiles, width_in_tiles}, {1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions( LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description()); tiled_input_shape, ir_emitter_context_->device_description());
@ -1016,27 +1022,39 @@ Status IrEmitterUnnested::EmitColumnReduction(
}; };
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
// linear_index < height_in_tiles * width; // linear_index < height_in_tiles * width_in_tiles;
// linear_index += blockDim.x * gridDim.x) { // linear_index += blockDim.x * gridDim.x) {
// y_in_tiles = linear_index / width; // y_in_tiles = linear_index / width_in_tiles;
// x = linear_index % width; // x_in_tiles = linear_index % width_in_tiles;
// //
// partial_result = init_value; // partial_results[kTileWidth] = init_values;
// if (height % kTileSize == 0 || // tile_in_y_bounds = height % kTileHeight == 0 ||
// y_in_tiles * kTileSize + kTileSize <= height) { // y_in_tiles * kTileHeight + kTileHeight <= height;
// for (element_id_in_tile : range(kTileSize)) { // tile_in_x_bounds = width % kTileWidth == 0 ||
// y = y_in_tiles * kTileSize + element_id_in_tile; // x_in_tiles * kTileWidth + kTileWidth <= width;
// partial_result = Reducer(partial_result, input[y][x]); // // The implementation handles y and x bound checks separately.
// if (tile_in_y_bounds && tile_in_x_bounds) {
// for (y_offset : range(kTileHeight)) {
// y = y_in_tiles * kTileHeight + y_offset;
// for (x_offset : range(kTileWidth)) {
// x = x_in_tiles * kTileWidth + x_offset;
// partial_result = Reducer(partial_result[x_offset], input[y][x]);
// }
// } // }
// } else { // } else {
// for (element_id_in_tile : range(kTileSize)) { // for (y_offset : range(kTileHeight)) {
// y = y_in_tiles * kTileSize + element_id_in_tile; // y = y_in_tiles * kTileHeight + y_offset;
// if (y < height) { // for (y_offset : range(kTileHeight)) {
// partial_result = Reducer(partial_result, input[y][x]); // x = x_in_tiles * kTileWidth + x_offset;
// if (y < height && x < width) {
// partial_result = Reducer(partial_result, input[y][x]);
// }
// } // }
// } // }
// } // }
// AtomicReducer(&output[x], partial_result); // for (x_offset : range(kTileWidth)) {
// AtomicReducer(&output[x + x_offset], partial_result[x_offset]);
// }
// } // }
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size(); const int num_reduces = reducers.size();
@ -1045,41 +1063,48 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
std::vector<llvm::Value*> partial_reduction_result_addresses; std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) { for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
element_ir_type, /*ArraySize=*/nullptr, llvm::Value* partial_reduction_result_address =
"partial_reduction_result." + llvm::Twine(i)); ir_builder_.CreateAlloca(
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, element_ir_type, /*ArraySize=*/nullptr,
init_value_gens[i](IrArray::Index(index_ty))); "partial_reduction_result." +
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); llvm::Twine(i * kTileWidth + x_offset));
partial_reduction_result_addresses.push_back( TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
partial_reduction_result_address); init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value,
partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
} }
// Emit an inner for-loop that partially reduces the elements in the given // Emit an inner for-loop that partially reduces the elements in the given
// tile. // tile.
llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x = tile_index[1]; llvm::Value* x_in_tiles = tile_index[1];
y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
x = ir_builder_.CreateZExtOrTrunc(x, index_ty); x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop = std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
index_typed_constant(0), index_typed_constant(0),
index_typed_constant(kTileSize), index_typed_constant(kTileHeight),
index_typed_constant(1), &ir_builder_); index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop. // Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_); &ir_builder_);
llvm::Value* y = ir_builder_.CreateNSWAdd( llvm::Value* y = ir_builder_.CreateNSWAdd(
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)), ir_builder_.CreateNSWMul(y_in_tiles,
index_typed_constant(kTileHeight)),
tile_element_loop->GetIndVarValue()); tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a // Unless we know that y is in bounds, we have to emit a check before
// y-in-bounds check before reading from the input. // reading from the input.
if (!tile_in_bounds) { if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
ir_builder_.CreateICmpULT(y, index_typed_constant(height)), ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
"y_in_bounds", &ir_builder_); "y_in_bounds", &ir_builder_);
@ -1088,8 +1113,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// the partial reduction result. // the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
} }
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
{ llvm::Value* x = ir_builder_.CreateNSWAdd(
ir_builder_.CreateNSWMul(x_in_tiles,
index_typed_constant(kTileWidth)),
index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
ir_builder_.CreateICmpULT(x, index_typed_constant(width)),
"x_in_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to // {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of // convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from // "reduce"). This conversion is composed of a transposition from
@ -1120,51 +1157,90 @@ Status IrEmitterUnnested::EmitColumnReduction(
ir_builder_.CreateStore(input_ir_value, input_address); ir_builder_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation( TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i], *reducers[i],
{partial_reduction_result_addresses[i], input_address}, {partial_reduction_result_addresses[i * kTileWidth + x_offset],
partial_reduction_result_addresses[i])); input_address},
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index,
extra_output_gens));
} }
return EmitExtraOutputsForReduce(reduce, input_index,
extra_output_gens);
} }
return Status::OK();
}; };
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// immediately beyond the tile. // that's immediately beyond the tile.
llvm::Value* y_end = ir_builder_.CreateNSWAdd( llvm::Value* y_end = ir_builder_.CreateNSWAdd(
index_typed_constant(kTileSize), index_typed_constant(kTileHeight),
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize))); ir_builder_.CreateNSWMul(y_in_tiles,
llvm::Value* tile_in_bounds = ir_builder_.CreateOr( index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
index_typed_constant(kTileWidth),
ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds = ir_builder_.CreateOr(
ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)), ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
ir_builder_.getInt1(height % kTileSize == 0)); ir_builder_.getInt1(height % kTileHeight == 0));
// The tile is entirely in bound if "height" is a multiple of kTileSize or llvm::Value* tile_in_x_bounds = ir_builder_.CreateOr(
ir_builder_.CreateICmpULE(x_end, index_typed_constant(width)),
ir_builder_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height. // y_end <= height.
llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::LlvmIfData if_tile_in_y_bounds_data = llvm_ir::EmitIfThenElse(
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); tile_in_y_bounds, "tile_in_y_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block,
&ir_builder_); &ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); // The tile is in x bounds if "width" is a multiple of kTileWidth or
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, // x_end <= width.
llvm_ir::LlvmIfData if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
&ir_builder_); &ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/true));
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/false));
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block,
&ir_builder_);
if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/true));
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/false));
// After the if-then-else statement on tile_in_bounds, emit atomic // After the nested if-then-else statement on tile_in_y_bounds and
// operations to accumulate the partial reduction result to the output // tile_in_x_bounds, emit atomic operations to accumulate the partial
// element. // reduction result to the output element.
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block,
&ir_builder_); &ir_builder_);
const HloInstruction* output = const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) { for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address = for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
GetIrArray(*output, *output, reduce_output_shapes[i]) llvm::Value* x = ir_builder_.CreateNSWAdd(
.EmitArrayElementAddress( ir_builder_.CreateNSWMul(x_in_tiles,
IrArray::Index(x, index_typed_constant(kTileWidth)),
ShapeUtil::GetSubshape( index_typed_constant(x_offset));
output->shape(), reduce_output_shapes[i]), llvm::Value* output_address =
&ir_builder_), GetIrArray(*output, *output, reduce_output_shapes[i])
&ir_builder_, "output_element_address"); .EmitArrayElementAddress(
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( IrArray::Index(
*reducers[i], output_address, partial_reduction_result_addresses[i])); x,
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address,
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
}
} }
return Status::OK(); return Status::OK();
}; };