XLA Parallel reduce.

Extend the XLA codegen to generate parallel reductions when there are multiple
reduce instructions in a fusion computation.

We see ~3% e2e gain for NVIDIA JoC BERT.

For `ManyParallelReductions` with 128 reduce instructions in the unittest, the
execution time is reduced from 325us to 3.9us (83X), reported by nvprof as below.

Before:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name

                   32.50%  325.54us         1  325.54us  325.54us  325.54us  fusion

After:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name

                    0.59%  3.9030us         1  3.9030us  3.9030us  3.9030us  fusion
This commit is contained in:
Trent Lo 2020-08-25 12:17:10 -07:00
parent d4617757bb
commit 8daab75490
22 changed files with 484 additions and 97 deletions

View File

@ -604,7 +604,6 @@ cc_library(
":flags",
":resource_operation_safety_analysis",
":shape_inference_helpers",
":union_find",
":xla_activity_listener",
":xla_cluster_util",
"//tensorflow/cc:cc_ops",
@ -623,6 +622,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/cc:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:all_kernels",
"//tensorflow/core:core_cpu",
@ -701,11 +701,6 @@ tf_cc_test(
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
tf_cc_test(
name = "deadness_analysis_test",
size = "small",
@ -886,7 +881,6 @@ cc_library(
":device_util",
":flags",
":resource_operation_safety_analysis",
":union_find",
":xla_activity_listener",
":xla_activity_proto_cc",
":xla_cluster_util",
@ -895,6 +889,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -36,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
@ -44,6 +43,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"

View File

@ -26,11 +26,11 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"

View File

@ -32,12 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"

View File

@ -818,9 +818,9 @@ cc_library(
":frontend_attributes_util",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -846,9 +846,9 @@ cc_library(
":functionalize_control_flow_util",
":functionalize_while",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
@ -934,9 +934,9 @@ cc_library(
":functionalize_cond",
":functionalize_control_flow_util",
":tf2xla_util",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",

View File

@ -25,9 +25,10 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"

View File

@ -23,12 +23,12 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_while.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"

View File

@ -24,11 +24,11 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"

View File

@ -969,6 +969,11 @@ tf_cc_test(
],
)
cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
# -----------------------------------------------------------------------------
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.

View File

@ -265,6 +265,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:union_find",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
@ -1493,6 +1494,7 @@ cc_library(
hdrs = ["stream_executor_util.h"],
copts = tf_copts(),
deps = [
":launch_dimensions",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",

View File

@ -613,10 +613,13 @@ static StatusOr<bool> DeviceCompare(se::Stream* stream,
LaunchDimensions dim =
CalculateLaunchDimensions(buffer_shape, gpu_device_info);
stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()),
se::BlockDim(dim.block_count()), *comparison_kernel,
lhs_typed, rhs_typed, static_cast<float>(kTolerance),
buffer_size, out_param.cref());
auto thread_counts = dim.thread_counts_per_block();
auto block_counts = dim.block_counts();
stream->ThenLaunch(
se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
*comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
buffer_size, out_param.cref());
uint64 result = -1;
CHECK_EQ(out_param->size(), sizeof(result));

View File

@ -90,6 +90,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/union_find.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -141,7 +142,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
llvm::LLVMContext& llvm_context = llvm_module->getContext();
llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
launch_dims.threads_per_block());
launch_dims.thread_counts_per_block().x);
// Our launch bounds are exact, so we can specify them as reqntidx rather than
// maxntidx.
nvvm_annotations_node->addOperand(llvm::MDNode::get(
@ -2991,6 +2992,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId(
});
}
namespace {
// Obtains the corresponding index of the out_instr in the outputs of the
// `unnested_hlo`.
ShapeIndex CreateShapeIndexForOutputInstruction(
const HloInstruction& unnested_hlo, const HloInstruction& out_instr) {
if (!unnested_hlo.IsMultiOutputFusion()) {
return ShapeIndex({});
}
const auto& all_outputs = unnested_hlo.fused_expression_root()->operands();
for (size_t i = 0; i < all_outputs.size(); ++i) {
if (all_outputs[i] == &out_instr) {
return ShapeIndex({i});
}
}
CHECK(false) << " Fusion root does not contain output instruction; "
<< " fusion: " << unnested_hlo.ToString()
<< ", output instruction: " << out_instr.ToString();
}
} // namespace
void IrEmitterUnnested::EmitTileElementForReduction(
HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
absl::Span<HloInstruction* const> output_instructions,
@ -2998,7 +3021,6 @@ void IrEmitterUnnested::EmitTileElementForReduction(
const ReductionCodegenInfo& reduction_info,
absl::Span<HloComputation* const> reducers, int64 x_iter_num) {
VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
bool returns_tuple = output_instructions.size() > 1;
int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
@ -3015,7 +3037,8 @@ void IrEmitterUnnested::EmitTileElementForReduction(
for (int i = 0, e = output_instructions.size(); i != e; ++i) {
const HloInstruction* inst = output_instructions[i];
ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
ShapeIndex idx =
CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
if (IsReductionFromOrToContiguousDimensions(*inst)) {
input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
} else {
@ -3748,16 +3771,131 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
reduction_dimensions.is_row_reduction);
}
Status IrEmitterUnnested::EmitIRForReduction(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions,
ReductionCodegenInfo* reduction_info, const Shape& input_shape) {
std::vector<HloInstruction*> reduce_instructions;
InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
InlinedVector<HloComputation*, 1> reducers;
for (size_t i = 0; i < output_instructions.size(); ++i) {
if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
continue;
}
HloInstruction* output_instruction = output_instructions[i];
reduce_instructions.push_back(output_instruction);
reduction_output_shape_indices.push_back(
CreateShapeIndexForOutputInstruction(*unnested_hlo,
*output_instruction));
reducers.push_back(output_instruction->to_apply());
}
CHECK(reduce_instructions.size() != 0)
<< " expect at least one reduce instructions.";
const KernelMappingScheme& mapping_scheme =
reduction_info->GetKernelMappingScheme();
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock());
llvm::Type* index_ty = GetIndexTypeForKernel(
unnested_hlo, launch_dimensions.launch_bound(), &b_);
EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions,
index_ty);
EmitElementFunction emit_reduction_tile = [&](
const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) {
EmitTileElementForReduction(unnested_hlo, input_shape, output_instructions,
index, *reduction_info, reducers, x_iter_num);
};
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
mapping_scheme, index_ty,
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
ksl, thread_id_info, tile_height, tile_width,
emit_reduction_tile);
});
EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info,
reduce_instructions, reduction_output_shape_indices,
reducers, tiling_kernel_info);
return Status::OK();
}
namespace {
// Returns whether the `instr` is either a constant, a scalar, or a
// broadcasted constant/scalar.
bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
HloOpcode::kBroadcast == instr.opcode() &&
(instr.operand(0)->IsConstant() ||
ShapeUtil::IsScalar(instr.operand(0)->shape()));
}
// Divides output_instructions into groups. Generally, we'd like to group output
// instructions sharing same predecessors to avoid recomputation. Different
// groups will be executed in parallel.
std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions) {
CHECK(!output_instructions.empty());
if (output_instructions.size() == 1) {
return {{output_instructions[0]}};
}
std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
output_instructions.size());
for (size_t i = 0; i < output_instructions.size(); ++i) {
disjoint_sets[i].Get() = output_instructions[i];
}
std::unique_ptr<HloReachabilityMap> reachability_map =
HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation());
for (auto* instr : unnested_hlo->fused_instructions()) {
std::vector<int64> reached_output_ids;
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
if (HloOpcode::kReduce == output_instructions[oid]->opcode() &&
(IsBroadcastedConstantOrScalar(*instr))) {
// Do not group output reduce instructions through broadcasted
// constants or scalars, as the recomputation should be acceptable.
VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
continue;
}
// Now group output instructions if they have common predecessors.
if (reachability_map->IsReachable(instr, output_instructions[oid])) {
VLOG(3) << "Reaching " << output_instructions[oid]->ToString()
<< " from " << instr->ToString();
reached_output_ids.push_back(oid);
}
}
for (size_t j = 1; j < reached_output_ids.size(); ++j) {
disjoint_sets[reached_output_ids[0]].Merge(
&disjoint_sets[reached_output_ids[j]]);
}
}
// Place output instructions in the same set into the same group.
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> groups;
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid));
}
std::vector<std::vector<HloInstruction*>> ret;
absl::c_for_each(
groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
return ret;
}
} // namespace
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions) {
bool returns_tuple = output_instructions.size() > 1;
VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
std::vector<HloInstruction*> reduce_instructions;
InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
InlinedVector<HloComputation*, 1> reducers;
// Build an initializer thunk to initialize each reduction output.
std::vector<std::unique_ptr<Thunk>> thunks;
for (int i = 0; i < output_instructions.size(); ++i) {
@ -3765,29 +3903,27 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
continue;
}
HloInstruction* output_instruction = output_instructions[i];
reduce_instructions.push_back(output_instruction);
ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
reduction_output_shape_indices.push_back(idx);
reducers.push_back(output_instruction->to_apply());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
BuildInitializerThunk(unnested_hlo, idx));
thunks.push_back(std::move(initializer_thunk));
}
const HloInstruction* first_reduce = reduce_instructions.at(0);
// Build a kernel thunk to compute all the outputs.
const HloInstruction* first_reduce = nullptr;
for (int i = 0; i < output_instructions.size(); ++i) {
if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
first_reduce = output_instructions[i];
break;
}
}
CHECK(first_reduce);
if (output_instructions.size() > 1) {
if (!AreFusedReductionOutputsConsistent(output_instructions,
first_reduce)) {
return InternalError("Inconsistent reduction fusion outputs");
}
}
// Build a kernel thunk to compute all the outputs.
std::unique_ptr<KernelThunk> kernel_thunk =
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
const Shape& input_shape = first_reduce->operand(0)->shape();
// The layout of a reduction input is either set by LayoutAssignment for
// unnested kReduce or by InstructionFusion for fused kReduce.
@ -3795,39 +3931,51 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
"doesn't set the input layout of "
<< first_reduce->ToString();
// Group output instructions. Each group will be executed in parallel.
auto instr_groups =
DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions);
VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ",
unnested_hlo->ToString());
auto kernel_thunk =
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
for (size_t i = 0; i < instr_groups.size(); ++i) {
// Create a new ReductionCodegenInfo instance as it contains states for
// code generation per reduction group. For now, let's always use the very
// first reduce as representative to construct ReductionCodegenInfo, since
// all the reductions are required to have the same shape and layout as
// verified by AreFusedReductionOutputsConsistent(). We can loosen the
// constraint later when the needs arise.
ReductionCodegenInfo reduction_info =
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
auto emit_reduction_func = [&] {
EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info,
input_shape);
};
// Use raw block_id_y to select the i-th parallel reduction to run. Using
// block_id_y instead of block_id_x simplifies the index calculation
// for reduction code generation as the block_id_y is orthogonal to
// the indices used within the reductions.
llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
llvm_ir::AddRangeMetadata(0, instr_groups.size(),
llvm::cast<llvm::Instruction>(raw_block_id_y));
auto guarding_cond = b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i));
ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func);
}
ReductionCodegenInfo reduction_info =
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
const KernelMappingScheme& mapping_scheme =
reduction_info.GetKernelMappingScheme();
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock());
// block_y_count is set to instr_groups.size(), so that each reduction group
// can be run in parallel by a different BlockIdy.
LaunchDimensions launch_dimensions(
{/*x=*/mapping_scheme.GetNumberOfBlocks(), /*y=*/instr_groups.size(),
/*z=*/1},
{/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
VLOG(3) << "Launch dimensions of " << unnested_hlo->name()
<< ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
<< " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
llvm::Type* index_ty = GetIndexTypeForKernel(
unnested_hlo, launch_dimensions.launch_bound(), &b_);
EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions,
index_ty);
EmitElementFunction emit_reduction_tile =
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) {
EmitTileElementForReduction(unnested_hlo, input_shape,
output_instructions, index, reduction_info,
reducers, x_iter_num);
};
TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
mapping_scheme, index_ty,
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
thread_id_info, tile_height, tile_width, emit_reduction_tile);
});
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
reduce_instructions, reduction_output_shape_indices,
reducers, tiling_kernel_info);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());

View File

@ -519,6 +519,12 @@ class IrEmitterUnnested : public IrEmitter,
absl::Span<HloComputation* const> reducers,
const TilingKernelInfo& tiling_kernel_info);
// Emits code for reductions in the output_instructions.
Status EmitIRForReduction(
HloInstruction* unnested_hlo,
absl::Span<HloInstruction* const> output_instructions,
ReductionCodegenInfo* reduction_info, const Shape& input_shape);
// For each reducer, emits the shuffle-down loop to accumulate the partial
// result to the global result.
void EmitFullWarpShuffleDownLoopForAllReduces(

View File

@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
return ExecuteKernelOnStream(*kernel, buffer_args,
launch_dimensions.threads_per_block(),
launch_dimensions.block_count(), params.stream);
return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions,
params.stream);
}
} // namespace gpu

View File

@ -26,8 +26,11 @@ namespace gpu {
std::ostream& operator<<(std::ostream& out,
const LaunchDimensions& launch_dims) {
out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(),
launch_dims.threads_per_block());
auto block_counts = launch_dims.block_counts();
auto thread_counts = launch_dims.thread_counts_per_block();
out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]",
block_counts.x, block_counts.y, block_counts.z,
thread_counts.x, thread_counts.y, thread_counts.z);
return out;
}

View File

@ -29,24 +29,37 @@ namespace gpu {
// number of threads per block.
class LaunchDimensions {
public:
struct Dim3D {
int64 x, y, z;
};
// The default constructor creates a launch dimension that indicate
// single-threaded execution.
LaunchDimensions() : block_count_(1), threads_per_block_(1) {}
LaunchDimensions()
: block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {}
LaunchDimensions(int64 block_count, int64 threads_per_block)
: block_count_(block_count), threads_per_block_(threads_per_block) {}
LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block)
: block_counts_({block_x_count, 1, 1}),
thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {}
bool IsSinglethreaded() const {
return block_count_ == 1 && threads_per_block_ == 1;
LaunchDimensions(const Dim3D& block_counts,
const Dim3D& thread_counts_per_block)
: block_counts_(block_counts),
thread_counts_per_block_(thread_counts_per_block) {}
Dim3D block_counts() const { return block_counts_; }
Dim3D thread_counts_per_block() const { return thread_counts_per_block_; }
int64 launch_bound() const {
return block_counts_.x * thread_counts_per_block_.x * block_counts_.y *
thread_counts_per_block_.y * block_counts_.z *
thread_counts_per_block_.z;
}
int64 block_count() const { return block_count_; }
int64 threads_per_block() const { return threads_per_block_; }
int64 launch_bound() const { return block_count() * threads_per_block(); }
private:
int64 block_count_;
int64 threads_per_block_;
Dim3D block_counts_;
Dim3D thread_counts_per_block_;
};
std::ostream& operator<<(std::ostream& out,

View File

@ -75,7 +75,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
std::vector<llvm_ir::IrArray::Index> array_indices;
llvm::Value* block_id =
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x,
static_cast<llvm::Instruction*>(block_id));
block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id");
@ -85,16 +85,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
// %ntid.x is currently specified as 1024.
llvm::Value* thread_id =
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_);
llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(),
llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x,
static_cast<llvm::Instruction*>(thread_id));
thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id");
llvm::Value* linear_index_base = b_->CreateAdd(
b_->CreateMul(block_id,
llvm::ConstantInt::get(
index_type, launch_dimensions_.threads_per_block()),
"",
/*HasNUW=*/true, /*HasNSW=*/true),
b_->CreateMul(
block_id,
llvm::ConstantInt::get(
index_type, launch_dimensions_.thread_counts_per_block().x),
"",
/*HasNUW=*/true, /*HasNSW=*/true),
thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true);
// Add an @llvm.assume(linear_index < threads_per_block * num_blocks).
@ -109,9 +110,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
llvm::Intrinsic::assume,
{b_->CreateICmpULT(
linear_index_base,
llvm::ConstantInt::get(index_type,
launch_dimensions_.threads_per_block() *
launch_dimensions_.block_count()),
llvm::ConstantInt::get(
index_type, launch_dimensions_.thread_counts_per_block().x *
launch_dimensions_.block_counts().x),
"linear_index_in_range")},
{}, b_);

View File

@ -209,16 +209,18 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
Status ExecuteKernelOnStream(const se::KernelBase& kernel,
absl::Span<const se::DeviceMemoryBase> args,
int64 threads_per_block, int64 block_count,
se::Stream* stream) {
const LaunchDimensions& dims, se::Stream* stream) {
static constexpr int kKernelArgsLimit = 1024;
auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
for (const se::DeviceMemoryBase& buf : args) {
kernel_args->add_device_memory_argument(buf);
}
return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block),
se::BlockDim(block_count), kernel,
*kernel_args);
auto thread_counts = dims.thread_counts_per_block();
auto block_counts = dims.block_counts();
return stream->parent()->Launch(
stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
*kernel_args);
}
se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@ -71,8 +72,7 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
// Runs loaded kernel on the stream with the provided arguments.
Status ExecuteKernelOnStream(const se::KernelBase& kernel,
absl::Span<const se::DeviceMemoryBase> args,
int64 threads_per_block, int64 block_count,
se::Stream* stream);
const LaunchDimensions& dims, se::Stream* stream);
// Create GpuAsmOpts out of HloModuleConfig.
se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config);

View File

@ -219,6 +219,28 @@ tf_cc_test(
],
)
tf_cc_test(
name = "parallel_reduction_test",
srcs = [
"parallel_reduction_test.cc",
],
tags = tf_cuda_tests_tags() + ["no_rocm"],
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "gpu_copy_test",
srcs = ["gpu_copy_test.cc"],

View File

@ -0,0 +1,187 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
namespace xla {
namespace gpu {
namespace {
class ParallelReductionTest : public GpuCodegenTest {
DebugOptions GetDebugOptionsForTest() override {
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
// The test contains a MOF fusion and the XLA optimizer passes
// don't like this.
debug_options.set_xla_disable_all_hlo_passes(true);
return debug_options;
}
};
TEST_F(ParallelReductionTest, TwoParallelReductions) {
const char* hlo_text = R"(
HloModule TwoParallelReductions
%add_f32 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
%fused_computation {
%param0 = f32[1024] parameter(0)
%param1 = f32[1024] parameter(1)
%constant0 = f32[] constant(0)
%reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
%reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
}
ENTRY %cluster {
%param0 = f32[1024] parameter(0)
%param1 = f32[1024] parameter(1)
ROOT %fusion = (f32[], f32[])
fusion(%param0, %param1), kind=kInput, calls=%fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_text));
CompileAndVerifyIr(std::move(hlo_module),
R"(
CHECK: reduce-group-0
CHECK: reduce-group-1
CHECK-NOT: reduce-group-2
)",
/*match_optimized_ir=*/false);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
TEST_F(ParallelReductionTest, ManyParallelReductions) {
auto module = CreateNewVerifiedModule();
// Simply use a number not too large to avoid long compilation time
// and not too small for meaningful test.
const size_t num_reduces = 32;
HloComputation* reduce_computation;
{
auto embedded_builder = HloComputation::Builder("add");
auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
reduce_computation =
module->AddEmbeddedComputation(embedded_builder.Build());
}
Shape input_shape = ShapeUtil::MakeShape(F32, {1024});
Shape output_shape = ShapeUtil::MakeShape(F32, {});
HloComputation* fusion_computation;
{
auto fusion_builder = HloComputation::Builder("fusion_computation");
std::vector<HloInstruction*> outputs;
HloInstruction* constant = fusion_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
for (size_t i = 0; i < num_reduces; ++i) {
HloInstruction* param = fusion_builder.AddInstruction(
HloInstruction::CreateParameter(i, input_shape, "param"));
auto output = fusion_builder.AddInstruction(HloInstruction::CreateReduce(
output_shape, param, constant, {0}, reduce_computation));
outputs.push_back(output);
}
fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs));
fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build());
}
HloComputation::Builder b(TestName());
std::vector<HloInstruction*> entry_params;
std::vector<Shape> output_shapes;
for (size_t i = 0; i < num_reduces; ++i) {
HloInstruction* param = b.AddInstruction(
HloInstruction::CreateParameter(i, input_shape, "param"));
entry_params.push_back(param);
output_shapes.push_back(output_shape);
}
b.AddInstruction(HloInstruction::CreateFusion(
ShapeUtil::MakeTupleShape(output_shapes),
HloInstruction::FusionKind::kInput, entry_params, fusion_computation));
module->AddEntryComputation(b.Build());
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
}
TEST_F(ParallelReductionTest, ThreeReductionGroups) {
const char* hlo_text = R"(
HloModule ThreeReductionGroups
%add_f32 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
%fused_computation {
%param0 = f32[1024,128] parameter(0)
%param1 = f32[1024,128] parameter(1)
%param2 = f32[1024,128] parameter(2)
%constant0 = f32[] constant(0)
// %mul0, %reduce0, and %reduce1 should go into a group.
%broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={}
%mul0 = f32[1024,128] multiply(param0, broadcast0)
%reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32
%reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
// %reduce2 and %reduce3 should go into another group.
%reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
%reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
// %reduce4 and %mul2 should go into the other group, although broadcast0 is
// reused.
%mul1 = f32[1024,128] multiply(param2, broadcast0)
%reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32
%mul2 = f32[1024,128] multiply(param2, param2)
ROOT %tuple =
(f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128])
tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0)
}
ENTRY %cluster {
%param0 = f32[1024,128] parameter(0)
%param1 = f32[1024,128] parameter(1)
%param2 = f32[1024,128] parameter(2)
ROOT %fusion =
(f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128])
fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_text));
CompileAndVerifyIr(std::move(hlo_module),
R"(
CHECK: reduce-group-0
CHECK: reduce-group-1
CHECK: reduce-group-2
CHECK-NOT: reduce-group-3
)",
/*match_optimized_ir=*/false);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
} // namespace
} // namespace gpu
} // namespace xla