From d9d029f510dbbc92329bafcd6bf2fbd0d273a675 Mon Sep 17 00:00:00 2001 From: Thomas Joerg Date: Tue, 17 Jul 2018 23:50:05 -0700 Subject: [PATCH] [XLA:GPU] Generalize the column reduction algorithm to handle tile widths greater than 1. Tiles of width 1 result in poor memory bandwidth for 16b inputs. PiperOrigin-RevId: 205033124 --- .../xla/service/gpu/ir_emitter_unnested.cc | 218 ++++++++++++------ 1 file changed, 147 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 75bbbbe8efc..f2597da4b9d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -984,8 +984,8 @@ Status IrEmitterUnnested::EmitColumnReduction( tensorflow::gtl::ArraySlice< std::pair> extra_output_gens) { - // Divide the input matrix into tiles of size Kx1. For example, when the - // input matrix is 4x4 and K=2, the tiled matrix looks like + // Divide the input matrix into tiles of size KxL. For example, when the + // input matrix is 4x4, K=2, and L=1 the tiled matrix looks like // // 0123 // 0123 @@ -997,14 +997,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // // We choose 128 as the tile size based on empirical evidence. It's big enough // to reduce the amount of atomic adds in the end, maximizing the memory - // bandwidth. - constexpr int64 kTileSize = 128; + // bandwidth. A tile width of 2 allows for high memory bandwidth utilization + // on 16b input data. + constexpr int64 kTileHeight = 128; + constexpr int64 kTileWidth = 2; - // If the height is not a multiple of the tile size, we pad the bottom of the + // If the height is not a multiple of kTileHeight, we pad the bottom of the // input matrix. - const int64 height_in_tiles = CeilOfRatio(height, kTileSize); - Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {height_in_tiles, width}, {1, 0}); + const int64 height_in_tiles = CeilOfRatio(height, kTileHeight); + // If width is not a multiple of kTileWidth the rightmost thread will process + // fewer input elements. + const int64 width_in_tiles = CeilOfRatio(width, kTileWidth); + Shape tiled_input_shape = + ShapeUtil::MakeShapeWithLayout(reduce->shape().element_type(), + {height_in_tiles, width_in_tiles}, {1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); @@ -1016,27 +1022,39 @@ Status IrEmitterUnnested::EmitColumnReduction( }; // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; - // linear_index < height_in_tiles * width; + // linear_index < height_in_tiles * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { - // y_in_tiles = linear_index / width; - // x = linear_index % width; + // y_in_tiles = linear_index / width_in_tiles; + // x_in_tiles = linear_index % width_in_tiles; // - // partial_result = init_value; - // if (height % kTileSize == 0 || - // y_in_tiles * kTileSize + kTileSize <= height) { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // partial_result = Reducer(partial_result, input[y][x]); + // partial_results[kTileWidth] = init_values; + // tile_in_y_bounds = height % kTileHeight == 0 || + // y_in_tiles * kTileHeight + kTileHeight <= height; + // tile_in_x_bounds = width % kTileWidth == 0 || + // x_in_tiles * kTileWidth + kTileWidth <= width; + // // The implementation handles y and x bound checks separately. + // if (tile_in_y_bounds && tile_in_x_bounds) { + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (x_offset : range(kTileWidth)) { + // x = x_in_tiles * kTileWidth + x_offset; + // partial_result = Reducer(partial_result[x_offset], input[y][x]); + // } // } // } else { - // for (element_id_in_tile : range(kTileSize)) { - // y = y_in_tiles * kTileSize + element_id_in_tile; - // if (y < height) { - // partial_result = Reducer(partial_result, input[y][x]); + // for (y_offset : range(kTileHeight)) { + // y = y_in_tiles * kTileHeight + y_offset; + // for (y_offset : range(kTileHeight)) { + // x = x_in_tiles * kTileWidth + x_offset; + // if (y < height && x < width) { + // partial_result = Reducer(partial_result, input[y][x]); + // } // } // } // } - // AtomicReducer(&output[x], partial_result); + // for (x_offset : range(kTileWidth)) { + // AtomicReducer(&output[x + x_offset], partial_result[x_offset]); + // } // } auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); @@ -1045,41 +1063,48 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); std::vector partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); - TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, - init_value_gens[i](IrArray::Index(index_ty))); - ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); - partial_reduction_result_addresses.push_back( - partial_reduction_result_address); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* partial_reduction_result_address = + ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](IrArray::Index(index_ty))); + ir_builder_.CreateStore(init_ir_value, + partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); + } } // Emit an inner for-loop that partially reduces the elements in the given // tile. llvm::Value* y_in_tiles = tile_index[0]; - llvm::Value* x = tile_index[1]; + llvm::Value* x_in_tiles = tile_index[1]; y_in_tiles = ir_builder_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x = ir_builder_.CreateZExtOrTrunc(x, index_ty); + x_in_tiles = ir_builder_.CreateZExtOrTrunc(x_in_tiles, index_ty); - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + auto emit_tile_element_loop = [=](bool tile_in_y_bounds, + bool tile_in_x_bounds) -> Status { std::unique_ptr tile_element_loop = llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", index_typed_constant(0), - index_typed_constant(kTileSize), + index_typed_constant(kTileHeight), index_typed_constant(1), &ir_builder_); // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &ir_builder_); llvm::Value* y = ir_builder_.CreateNSWAdd( - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize)), + ir_builder_.CreateNSWMul(y_in_tiles, + index_typed_constant(kTileHeight)), tile_element_loop->GetIndVarValue()); - // Unless we know the tile is entirely in bounds, we have to emit a - // y-in-bounds check before reading from the input. - if (!tile_in_bounds) { + // Unless we know that y is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( ir_builder_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", &ir_builder_); @@ -1088,8 +1113,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); } - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = ir_builder_.CreateNSWAdd( + ir_builder_.CreateNSWMul(x_in_tiles, + index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + // Unless we know that x is in bounds, we have to emit a check before + // reading from the input. + if (!tile_in_x_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, index_typed_constant(width)), + "x_in_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + } + llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1120,51 +1157,90 @@ Status IrEmitterUnnested::EmitColumnReduction( ir_builder_.CreateStore(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], - {partial_reduction_result_addresses[i], input_address}, - partial_reduction_result_addresses[i])); + {partial_reduction_result_addresses[i * kTileWidth + x_offset], + input_address}, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + TF_RETURN_IF_ERROR(EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens)); } - return EmitExtraOutputsForReduce(reduce, input_index, - extra_output_gens); } + return Status::OK(); }; - // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's - // immediately beyond the tile. + // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location + // that's immediately beyond the tile. llvm::Value* y_end = ir_builder_.CreateNSWAdd( - index_typed_constant(kTileSize), - ir_builder_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileSize))); - llvm::Value* tile_in_bounds = ir_builder_.CreateOr( + index_typed_constant(kTileHeight), + ir_builder_.CreateNSWMul(y_in_tiles, + index_typed_constant(kTileHeight))); + // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location + // that's immediately beyond the tile. + llvm::Value* x_end = ir_builder_.CreateNSWAdd( + index_typed_constant(kTileWidth), + ir_builder_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* tile_in_y_bounds = ir_builder_.CreateOr( ir_builder_.CreateICmpULE(y_end, index_typed_constant(height)), - ir_builder_.getInt1(height % kTileSize == 0)); - // The tile is entirely in bound if "height" is a multiple of kTileSize or + ir_builder_.getInt1(height % kTileHeight == 0)); + llvm::Value* tile_in_x_bounds = ir_builder_.CreateOr( + ir_builder_.CreateICmpULE(x_end, index_typed_constant(width)), + ir_builder_.getInt1(width % kTileWidth == 0)); + // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, + llvm_ir::LlvmIfData if_tile_in_y_bounds_data = llvm_ir::EmitIfThenElse( + tile_in_y_bounds, "tile_in_y_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.true_block, &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, + // The tile is in x bounds if "width" is a multiple of kTileWidth or + // x_end <= width. + llvm_ir::LlvmIfData if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse( + tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/true, + /*tile_in_x_bounds=*/false)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.false_block, + &ir_builder_); + if_tile_in_x_bounds_data = llvm_ir::EmitIfThenElse( + tile_in_x_bounds, "tile_in_x_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.true_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_x_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_y_bounds=*/false, + /*tile_in_x_bounds=*/false)); - // After the if-then-else statement on tile_in_bounds, emit atomic - // operations to accumulate the partial reduction result to the output - // element. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, + // After the nested if-then-else statement on tile_in_y_bounds and + // tile_in_x_bounds, emit atomic operations to accumulate the partial + // reduction result to the output element. + llvm_ir::SetToFirstInsertPoint(if_tile_in_y_bounds_data.after_block, &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { - llvm::Value* output_address = - GetIrArray(*output, *output, reduce_output_shapes[i]) - .EmitArrayElementAddress( - IrArray::Index(x, - ShapeUtil::GetSubshape( - output->shape(), reduce_output_shapes[i]), - &ir_builder_), - &ir_builder_, "output_element_address"); - TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( - *reducers[i], output_address, partial_reduction_result_addresses[i])); + for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { + llvm::Value* x = ir_builder_.CreateNSWAdd( + ir_builder_.CreateNSWMul(x_in_tiles, + index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i * kTileWidth + x_offset])); + } } return Status::OK(); };