[XLA:GPU] Generalize the column reduction algorithm to handle tile widths greater than 1.

Tiles of width 1 result in poor memory bandwidth for 16b inputs.

PiperOrigin-RevId: 205033124
This commit is contained in:
Thomas Joerg 2018-07-17 23:50:05 -07:00 committed by TensorFlower Gardener
parent f1de0ddd55
commit d9d029f510

View File

@ -984,8 +984,8 @@ Status IrEmitterUnnested::EmitColumnReduction(
tensorflow::gtl::ArraySlice<
std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
extra_output_gens) {
// Divide the input matrix into tiles of size Kx1. For example, when the
// input matrix is 4x4 and K=2, the tiled matrix looks like
// Divide the input matrix into tiles of size KxL. For example, when the
// input matrix is 4x4, K=2, and L=1 the tiled matrix looks like
//
// 0123
// 0123
@ -997,14 +997,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
//
// We choose 128 as the tile size based on empirical evidence. It's big enough
// to reduce the amount of atomic adds in the end, maximizing the memory
// bandwidth.
constexpr int64 kTileSize = 128;
// bandwidth. A tile width of 2 allows for high memory bandwidth utilization
// on 16b input data.
constexpr int64 kTileHeight = 128;
constexpr int64 kTileWidth = 2;
// If the height is not a multiple of the tile size, we pad the bottom of the
// If the height is not a multiple of kTileHeight, we pad the bottom of the
// input matrix.
const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
const int64 height_in_tiles = CeilOfRatio(height, kTileHeight);
// If width is not a multiple of kTileWidth the rightmost thread will process
// fewer input elements.
const int64 width_in_tiles = CeilOfRatio(width, kTileWidth);
Shape tiled_input_shape =
ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(),
{height_in_tiles, width_in_tiles}, {1, 0});
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
tiled_input_shape, ir_emitter_context_->device_description());
@ -1016,27 +1022,39 @@ Status IrEmitterUnnested::EmitColumnReduction(
};
// for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
// linear_index < height_in_tiles * width;
// linear_index < height_in_tiles * width_in_tiles;
// linear_index += blockDim.x * gridDim.x) {
// y_in_tiles = linear_index / width;
// x = linear_index % width;
// y_in_tiles = linear_index / width_in_tiles;
// x_in_tiles = linear_index % width_in_tiles;
//
// partial_result = init_value;
// if (height % kTileSize == 0 ||
// y_in_tiles * kTileSize + kTileSize <= height) {
// for (element_id_in_tile : range(kTileSize)) {
// y = y_in_tiles * kTileSize + element_id_in_tile;
// partial_result = Reducer(partial_result, input[y][x]);
// partial_results[kTileWidth] = init_values;
// tile_in_y_bounds = height % kTileHeight == 0 ||
// y_in_tiles * kTileHeight + kTileHeight <= height;
// tile_in_x_bounds = width % kTileWidth == 0 ||
// x_in_tiles * kTileWidth + kTileWidth <= width;
// // The implementation handles y and x bound checks separately.
// if (tile_in_y_bounds && tile_in_x_bounds) {
// for (y_offset : range(kTileHeight)) {
// y = y_in_tiles * kTileHeight + y_offset;
// for (x_offset : range(kTileWidth)) {
// x = x_in_tiles * kTileWidth + x_offset;
// partial_result = Reducer(partial_result[x_offset], input[y][x]);
// }
// }
// } else {
// for (element_id_in_tile : range(kTileSize)) {
// y = y_in_tiles * kTileSize + element_id_in_tile;
// if (y < height) {
// partial_result = Reducer(partial_result, input[y][x]);
// for (y_offset : range(kTileHeight)) {
// y = y_in_tiles * kTileHeight + y_offset;
// for (y_offset : range(kTileHeight)) {
// x = x_in_tiles * kTileWidth + x_offset;
// if (y < height && x < width) {
// partial_result = Reducer(partial_result, input[y][x]);
// }
// }
// }
// }
// AtomicReducer(&output[x], partial_result);
// for (x_offset : range(kTileWidth)) {
// AtomicReducer(&output[x + x_offset], partial_result[x_offset]);
// }
// }
auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status {
const int num_reduces = reducers.size();
@ -1045,41 +1063,48 @@ Status IrEmitterUnnested::EmitColumnReduction(
llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_);
std::vector<llvm::Value*> partial_reduction_result_addresses;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." + llvm::Twine(i));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* partial_reduction_result_address =
ir_builder_.CreateAlloca(
element_ir_type, /*ArraySize=*/nullptr,
"partial_reduction_result." +
llvm::Twine(i * kTileWidth + x_offset));
TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value,
init_value_gens[i](IrArray::Index(index_ty)));
ir_builder_.CreateStore(init_ir_value,
partial_reduction_result_address);
partial_reduction_result_addresses.push_back(
partial_reduction_result_address);
}
}
// Emit an inner for-loop that partially reduces the elements in the given
// tile.
llvm::Value* y_in_tiles = tile_index[0];
llvm::Value* x = tile_index[1];
llvm::Value* x_in_tiles = tile_index[1];
y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty);
x = ir_builder_.CreateZExtOrTrunc(x, index_ty);
x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty);
auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
auto emit_tile_element_loop = [=](bool tile_in_y_bounds,
bool tile_in_x_bounds) -> Status {
std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
index_typed_constant(0),
index_typed_constant(kTileSize),
index_typed_constant(kTileHeight),
index_typed_constant(1), &ir_builder_);
// Emit the body of the partial reduction loop.
llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
&ir_builder_);
llvm::Value* y = ir_builder_.CreateNSWAdd(
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)),
ir_builder_.CreateNSWMul(y_in_tiles,
index_typed_constant(kTileHeight)),
tile_element_loop->GetIndVarValue());
// Unless we know the tile is entirely in bounds, we have to emit a
// y-in-bounds check before reading from the input.
if (!tile_in_bounds) {
// Unless we know that y is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_y_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
ir_builder_.CreateICmpULT(y, index_typed_constant(height)),
"y_in_bounds", &ir_builder_);
@ -1088,8 +1113,20 @@ Status IrEmitterUnnested::EmitColumnReduction(
// the partial reduction result.
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
{
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* x = ir_builder_.CreateNSWAdd(
ir_builder_.CreateNSWMul(x_in_tiles,
index_typed_constant(kTileWidth)),
index_typed_constant(x_offset));
// Unless we know that x is in bounds, we have to emit a check before
// reading from the input.
if (!tile_in_x_bounds) {
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
ir_builder_.CreateICmpULT(x, index_typed_constant(width)),
"x_in_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
}
llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
// {y,x} is an index to input_matrix_shape [height,width]. We need to
// convert that to an index to input_shape (the shape of the operand of
// "reduce"). This conversion is composed of a transposition from
@ -1120,51 +1157,90 @@ Status IrEmitterUnnested::EmitColumnReduction(
ir_builder_.CreateStore(input_ir_value, input_address);
TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
*reducers[i],
{partial_reduction_result_addresses[i], input_address},
partial_reduction_result_addresses[i]));
{partial_reduction_result_addresses[i * kTileWidth + x_offset],
input_address},
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index,
extra_output_gens));
}
return EmitExtraOutputsForReduce(reduce, input_index,
extra_output_gens);
}
return Status::OK();
};
// y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
// immediately beyond the tile.
// y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location
// that's immediately beyond the tile.
llvm::Value* y_end = ir_builder_.CreateNSWAdd(
index_typed_constant(kTileSize),
ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)));
llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
index_typed_constant(kTileHeight),
ir_builder_.CreateNSWMul(y_in_tiles,
index_typed_constant(kTileHeight)));
// x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location
// that's immediately beyond the tile.
llvm::Value* x_end = ir_builder_.CreateNSWAdd(
index_typed_constant(kTileWidth),
ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)));
llvm::Value* tile_in_y_bounds = ir_builder_.CreateOr(
ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)),
ir_builder_.getInt1(height % kTileSize == 0));
// The tile is entirely in bound if "height" is a multiple of kTileSize or
ir_builder_.getInt1(height % kTileHeight == 0));
llvm::Value* tile_in_x_bounds = ir_builder_.CreateOr(
ir_builder_.CreateICmpULE(x_end, index_typed_constant(width)),
ir_builder_.getInt1(width % kTileWidth == 0));
// The tile is in y bounds if "height" is a multiple of kTileHeight or
// y_end <= height.
llvm_ir::LlvmIfData if_tile_in_bounds_data =
llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
llvm_ir::LlvmIfData if_tile_in_y_bounds_data = llvm_ir::EmitIfThenElse(
tile_in_y_bounds, "tile_in_y_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
// The tile is in x bounds if "width" is a multiple of kTileWidth or
// x_end <= width.
llvm_ir::LlvmIfData if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/true));
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true,
/*tile_in_x_bounds=*/false));
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block,
&ir_builder_);
if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse(
tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_);
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/true));
llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block,
&ir_builder_);
TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false,
/*tile_in_x_bounds=*/false));
// After the if-then-else statement on tile_in_bounds, emit atomic
// operations to accumulate the partial reduction result to the output
// element.
llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
// After the nested if-then-else statement on tile_in_y_bounds and
// tile_in_x_bounds, emit atomic operations to accumulate the partial
// reduction result to the output element.
llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block,
&ir_builder_);
const HloInstruction* output =
reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce;
for (int i = 0; i != num_reduces; ++i) {
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
IrArray::Index(x,
ShapeUtil::GetSubshape(
output->shape(), reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address, partial_reduction_result_addresses[i]));
for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) {
llvm::Value* x = ir_builder_.CreateNSWAdd(
ir_builder_.CreateNSWMul(x_in_tiles,
index_typed_constant(kTileWidth)),
index_typed_constant(x_offset));
llvm::Value* output_address =
GetIrArray(*output, *output, reduce_output_shapes[i])
.EmitArrayElementAddress(
IrArray::Index(
x,
ShapeUtil::GetSubshape(output->shape(),
reduce_output_shapes[i]),
&ir_builder_),
&ir_builder_, "output_element_address");
TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation(
*reducers[i], output_address,
partial_reduction_result_addresses[i * kTileWidth + x_offset]));
}
}
return Status::OK();
};