Merge pull request #42717 from trentlo:parallel-reduce
PiperOrigin-RevId: 329867466 Change-Id: I899ad5926aa2379a302435cd894457f30efb7d15
This commit is contained in:
commit
b2f737e6ae
@ -609,7 +609,6 @@ cc_library(
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
":shape_inference_helpers",
|
||||
":union_find",
|
||||
":xla_activity_listener",
|
||||
":xla_cluster_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
@ -628,6 +627,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -705,11 +705,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "union_find",
|
||||
hdrs = ["union_find.h"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "deadness_analysis_test",
|
||||
size = "small",
|
||||
@ -891,7 +886,6 @@ cc_library(
|
||||
":device_util",
|
||||
":flags",
|
||||
":resource_operation_safety_analysis",
|
||||
":union_find",
|
||||
":xla_activity_listener",
|
||||
":xla_activity_proto_cc",
|
||||
":xla_cluster_util",
|
||||
@ -900,6 +894,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -36,7 +36,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity_listener.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
@ -44,6 +43,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
|
@ -26,11 +26,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
|
@ -32,12 +32,12 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
|
@ -819,9 +819,9 @@ cc_library(
|
||||
":frontend_attributes_util",
|
||||
":functionalize_control_flow_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:union_find",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -847,9 +847,9 @@ cc_library(
|
||||
":functionalize_control_flow_util",
|
||||
":functionalize_while",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:union_find",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -935,9 +935,9 @@ cc_library(
|
||||
":functionalize_cond",
|
||||
":functionalize_control_flow_util",
|
||||
":tf2xla_util",
|
||||
"//tensorflow/compiler/jit:union_find",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -25,9 +25,10 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
|
@ -23,12 +23,12 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_while.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/graph_optimizer.h"
|
||||
|
@ -24,11 +24,11 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/union_find.h"
|
||||
#include "tensorflow/compiler/tf2xla/frontend_attributes_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
|
@ -969,6 +969,11 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "union_find",
|
||||
hdrs = ["union_find.h"],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code.
|
||||
|
@ -265,6 +265,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:union_find",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
@ -1493,6 +1494,7 @@ cc_library(
|
||||
hdrs = ["stream_executor_util.h"],
|
||||
copts = tf_copts(),
|
||||
deps = [
|
||||
":launch_dimensions",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -613,10 +613,13 @@ static StatusOr<bool> DeviceCompare(se::Stream* stream,
|
||||
LaunchDimensions dim =
|
||||
CalculateLaunchDimensions(buffer_shape, gpu_device_info);
|
||||
|
||||
stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()),
|
||||
se::BlockDim(dim.block_count()), *comparison_kernel,
|
||||
lhs_typed, rhs_typed, static_cast<float>(kTolerance),
|
||||
buffer_size, out_param.cref());
|
||||
LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block();
|
||||
LaunchDimensions::Dim3D block_counts = dim.block_counts();
|
||||
stream->ThenLaunch(
|
||||
se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
|
||||
se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
|
||||
*comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
|
||||
buffer_size, out_param.cref());
|
||||
|
||||
uint64 result = -1;
|
||||
CHECK_EQ(out_param->size(), sizeof(result));
|
||||
|
@ -90,6 +90,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/union_find.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -141,7 +142,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
|
||||
llvm::LLVMContext& llvm_context = llvm_module->getContext();
|
||||
llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
|
||||
llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
|
||||
launch_dims.threads_per_block());
|
||||
launch_dims.thread_counts_per_block().x);
|
||||
// Our launch bounds are exact, so we can specify them as reqntidx rather than
|
||||
// maxntidx.
|
||||
nvvm_annotations_node->addOperand(llvm::MDNode::get(
|
||||
@ -2991,6 +2992,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId(
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Obtains the corresponding index of the out_instr in the outputs of the
|
||||
// `unnested_hlo`.
|
||||
ShapeIndex CreateShapeIndexForOutputInstruction(
|
||||
const HloInstruction& unnested_hlo, const HloInstruction& out_instr) {
|
||||
if (!unnested_hlo.IsMultiOutputFusion()) {
|
||||
return ShapeIndex({});
|
||||
}
|
||||
const auto& all_outputs = unnested_hlo.fused_expression_root()->operands();
|
||||
for (size_t i = 0; i < all_outputs.size(); ++i) {
|
||||
if (all_outputs[i] == &out_instr) {
|
||||
return ShapeIndex({static_cast<int64>(i)});
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << " Fusion root does not contain output instruction; "
|
||||
<< " fusion: " << unnested_hlo.ToString()
|
||||
<< ", output instruction: " << out_instr.ToString();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
|
||||
absl::Span<HloInstruction* const> output_instructions,
|
||||
@ -2998,7 +3021,6 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
const ReductionCodegenInfo& reduction_info,
|
||||
absl::Span<HloComputation* const> reducers, int64 x_iter_num) {
|
||||
VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
|
||||
bool returns_tuple = output_instructions.size() > 1;
|
||||
int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
|
||||
|
||||
InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
|
||||
@ -3015,7 +3037,8 @@ void IrEmitterUnnested::EmitTileElementForReduction(
|
||||
|
||||
for (int i = 0, e = output_instructions.size(); i != e; ++i) {
|
||||
const HloInstruction* inst = output_instructions[i];
|
||||
ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
|
||||
ShapeIndex idx =
|
||||
CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst);
|
||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||
input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
|
||||
} else {
|
||||
@ -3748,71 +3771,41 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
|
||||
reduction_dimensions.is_row_reduction);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
void IrEmitterUnnested::EmitIRForReduction(
|
||||
HloInstruction* unnested_hlo,
|
||||
absl::Span<HloInstruction* const> output_instructions) {
|
||||
bool returns_tuple = output_instructions.size() > 1;
|
||||
VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
|
||||
|
||||
absl::Span<HloInstruction* const> output_instructions,
|
||||
ReductionCodegenInfo* reduction_info, const Shape& input_shape) {
|
||||
std::vector<HloInstruction*> reduce_instructions;
|
||||
InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
|
||||
InlinedVector<HloComputation*, 1> reducers;
|
||||
|
||||
// Build an initializer thunk to initialize each reduction output.
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
for (int i = 0; i < output_instructions.size(); ++i) {
|
||||
for (size_t i = 0; i < output_instructions.size(); ++i) {
|
||||
if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
HloInstruction* output_instruction = output_instructions[i];
|
||||
reduce_instructions.push_back(output_instruction);
|
||||
ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
|
||||
reduction_output_shape_indices.push_back(idx);
|
||||
reduction_output_shape_indices.push_back(
|
||||
CreateShapeIndexForOutputInstruction(*unnested_hlo,
|
||||
*output_instruction));
|
||||
reducers.push_back(output_instruction->to_apply());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
|
||||
BuildInitializerThunk(unnested_hlo, idx));
|
||||
thunks.push_back(std::move(initializer_thunk));
|
||||
}
|
||||
CHECK(reduce_instructions.size() != 0)
|
||||
<< " expect at least one reduce instructions.";
|
||||
|
||||
const HloInstruction* first_reduce = reduce_instructions.at(0);
|
||||
if (output_instructions.size() > 1) {
|
||||
if (!AreFusedReductionOutputsConsistent(output_instructions,
|
||||
first_reduce)) {
|
||||
return InternalError("Inconsistent reduction fusion outputs");
|
||||
}
|
||||
}
|
||||
|
||||
// Build a kernel thunk to compute all the outputs.
|
||||
std::unique_ptr<KernelThunk> kernel_thunk =
|
||||
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
|
||||
|
||||
const Shape& input_shape = first_reduce->operand(0)->shape();
|
||||
// The layout of a reduction input is either set by LayoutAssignment for
|
||||
// unnested kReduce or by InstructionFusion for fused kReduce.
|
||||
CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
|
||||
"doesn't set the input layout of "
|
||||
<< first_reduce->ToString();
|
||||
|
||||
ReductionCodegenInfo reduction_info =
|
||||
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
|
||||
const KernelMappingScheme& mapping_scheme =
|
||||
reduction_info.GetKernelMappingScheme();
|
||||
reduction_info->GetKernelMappingScheme();
|
||||
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
|
||||
mapping_scheme.GetThreadsPerBlock());
|
||||
VLOG(3) << "Launch dimensions of " << unnested_hlo->name()
|
||||
<< ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
|
||||
<< " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
|
||||
llvm::Type* index_ty = GetIndexTypeForKernel(
|
||||
unnested_hlo, launch_dimensions.launch_bound(), &b_);
|
||||
EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions,
|
||||
EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions,
|
||||
index_ty);
|
||||
EmitElementFunction emit_reduction_tile =
|
||||
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||
llvm::Value* x_loc, int64 x_iter_num) {
|
||||
EmitTileElementForReduction(unnested_hlo, input_shape,
|
||||
output_instructions, index, reduction_info,
|
||||
output_instructions, index, *reduction_info,
|
||||
reducers, x_iter_num);
|
||||
};
|
||||
|
||||
@ -3821,19 +3814,180 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
[&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl) {
|
||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||
thread_id_info, tile_height, tile_width, emit_reduction_tile);
|
||||
EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
|
||||
ksl, thread_id_info, tile_height, tile_width,
|
||||
emit_reduction_tile);
|
||||
});
|
||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||
EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info,
|
||||
reduce_instructions, reduction_output_shape_indices,
|
||||
reducers, tiling_kernel_info);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns whether the `instr` is either a constant, a scalar, or a
|
||||
// broadcasted constant/scalar.
|
||||
bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
|
||||
return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
|
||||
(HloOpcode::kBroadcast == instr.opcode() &&
|
||||
(instr.operand(0)->IsConstant() ||
|
||||
ShapeUtil::IsScalar(instr.operand(0)->shape())));
|
||||
}
|
||||
|
||||
// Divides output_instructions into groups. Different groups will be executed
|
||||
// in parallel. Generally speaking, we'd like to run the reduce instructions
|
||||
// in parallel without incurring too much recomputation overhead. The current
|
||||
// heuristic is to place reduce instructions who share nothing or only
|
||||
// (broadcasted) scalars/constants into different groups; otherwise, they are
|
||||
// placed in the same group. Non-reduce instructions always go with the reduce
|
||||
// instructions into the same group so long as they share any predecessors.
|
||||
std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups(
|
||||
HloInstruction* unnested_hlo,
|
||||
absl::Span<HloInstruction* const> output_instructions) {
|
||||
CHECK(!output_instructions.empty());
|
||||
if (output_instructions.size() == 1) {
|
||||
return {{output_instructions[0]}};
|
||||
}
|
||||
|
||||
std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
|
||||
output_instructions.size());
|
||||
for (size_t i = 0; i < output_instructions.size(); ++i) {
|
||||
disjoint_sets[i].Get() = output_instructions[i];
|
||||
}
|
||||
|
||||
std::unique_ptr<HloReachabilityMap> reachability_map =
|
||||
HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation());
|
||||
for (auto* instr : unnested_hlo->fused_instructions()) {
|
||||
std::vector<int64> reached_output_ids;
|
||||
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
|
||||
if (HloOpcode::kReduce == output_instructions[oid]->opcode() &&
|
||||
(IsBroadcastedConstantOrScalar(*instr))) {
|
||||
// Do not group output reduce instructions through broadcasted
|
||||
// constants or scalars, as the recomputation should be acceptable.
|
||||
VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
|
||||
continue;
|
||||
}
|
||||
// Now group output instructions if they have common predecessors.
|
||||
if (reachability_map->IsReachable(instr, output_instructions[oid])) {
|
||||
VLOG(3) << "Reaching " << output_instructions[oid]->ToString()
|
||||
<< " from " << instr->ToString();
|
||||
reached_output_ids.push_back(oid);
|
||||
}
|
||||
}
|
||||
for (size_t j = 1; j < reached_output_ids.size(); ++j) {
|
||||
disjoint_sets[reached_output_ids[0]].Merge(
|
||||
&disjoint_sets[reached_output_ids[j]]);
|
||||
}
|
||||
}
|
||||
// Place output instructions in the same set into the same group.
|
||||
absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> groups;
|
||||
for (size_t oid = 0; oid < output_instructions.size(); ++oid) {
|
||||
groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid));
|
||||
}
|
||||
|
||||
std::vector<std::vector<HloInstruction*>> ret;
|
||||
absl::c_for_each(
|
||||
groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
HloInstruction* unnested_hlo,
|
||||
absl::Span<HloInstruction* const> output_instructions) {
|
||||
bool returns_tuple = output_instructions.size() > 1;
|
||||
VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
|
||||
|
||||
// Build an initializer thunk to initialize each reduction output.
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
for (int i = 0; i < output_instructions.size(); ++i) {
|
||||
if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
|
||||
BuildInitializerThunk(unnested_hlo, idx));
|
||||
thunks.push_back(std::move(initializer_thunk));
|
||||
}
|
||||
|
||||
// Build a kernel thunk to compute all the outputs.
|
||||
const HloInstruction* first_reduce = nullptr;
|
||||
for (int i = 0; i < output_instructions.size(); ++i) {
|
||||
if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
|
||||
first_reduce = output_instructions[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(first_reduce);
|
||||
if (output_instructions.size() > 1) {
|
||||
if (!AreFusedReductionOutputsConsistent(output_instructions,
|
||||
first_reduce)) {
|
||||
return InternalError("Inconsistent reduction fusion outputs");
|
||||
}
|
||||
}
|
||||
const Shape& input_shape = first_reduce->operand(0)->shape();
|
||||
// The layout of a reduction input is either set by LayoutAssignment for
|
||||
// unnested kReduce or by InstructionFusion for fused kReduce.
|
||||
CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
|
||||
"doesn't set the input layout of "
|
||||
<< first_reduce->ToString();
|
||||
|
||||
// Group output instructions. Each group will be executed in parallel.
|
||||
std::vector<std::vector<HloInstruction*>> instr_groups =
|
||||
DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions);
|
||||
VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ",
|
||||
unnested_hlo->ToString());
|
||||
std::unique_ptr<KernelThunk> kernel_thunk =
|
||||
BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
|
||||
KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
|
||||
for (size_t i = 0; i < instr_groups.size(); ++i) {
|
||||
// Create a new ReductionCodegenInfo instance as it contains states for
|
||||
// code generation per reduction group. For now, let's always use the very
|
||||
// first reduce as representative to construct ReductionCodegenInfo, since
|
||||
// all the reductions are required to have the same shape and layout as
|
||||
// verified by `AreFusedReductionOutputsConsistent()`. We can loosen the
|
||||
// constraint later when the needs arise.
|
||||
ReductionCodegenInfo reduction_info =
|
||||
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
|
||||
auto emit_reduction_func = [&] {
|
||||
EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info,
|
||||
input_shape);
|
||||
};
|
||||
// Use raw block_id_y to select the i-th parallel reduction to run. Using
|
||||
// block_id_y instead of block_id_x simplifies the index calculation
|
||||
// for reduction code generation as the block_id_y is orthogonal to
|
||||
// the indices used within the reductions.
|
||||
llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
|
||||
gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
|
||||
llvm_ir::AddRangeMetadata(0, instr_groups.size(),
|
||||
llvm::cast<llvm::Instruction>(raw_block_id_y));
|
||||
llvm::Value* guarding_cond =
|
||||
b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i));
|
||||
ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func);
|
||||
}
|
||||
ReductionCodegenInfo reduction_info =
|
||||
ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
|
||||
const KernelMappingScheme& mapping_scheme =
|
||||
reduction_info.GetKernelMappingScheme();
|
||||
// block_y_count is set to instr_groups.size(), so that each reduction group
|
||||
// can be run in parallel by a different BlockIdy.
|
||||
LaunchDimensions launch_dimensions(
|
||||
{/*x=*/mapping_scheme.GetNumberOfBlocks(),
|
||||
/*y=*/static_cast<int64>(instr_groups.size()),
|
||||
/*z=*/1},
|
||||
{/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
|
||||
VLOG(3) << "Launch dimensions of " << unnested_hlo->name()
|
||||
<< ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
|
||||
<< " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
thunks.push_back(std::move(kernel_thunk));
|
||||
auto sequential_thunk = absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(unnested_hlo), std::move(thunks));
|
||||
std::unique_ptr<SequentialThunk> sequential_thunk =
|
||||
absl::make_unique<SequentialThunk>(GetThunkInfo(unnested_hlo),
|
||||
std::move(thunks));
|
||||
AddThunkToThunkSequence(std::move(sequential_thunk));
|
||||
|
||||
return Status::OK();
|
||||
|
@ -372,6 +372,16 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// Moreover, a heuristic is implemented to divide the reduce instructions
|
||||
// into groups for parallelization (see `DivideOutputInstructionsIntoGroups`
|
||||
// for details about the heuristic.) Reduce instructions in the same group
|
||||
// will run sequentially while different groups will run in parallel.
|
||||
//
|
||||
// we use raw block_id_y to select the reduce groups for execution without
|
||||
// complicating the index calculation in the code generation of the reduce
|
||||
// instructions. In other words, a block_id_y is assigned to a group and so
|
||||
// different groups can be run in parallel.
|
||||
//
|
||||
// output_instructions: Output instructions in the computation: instruction
|
||||
// itself if it's not a fusion, fusion root if fusion is not multi-output, and
|
||||
// elements of the fusion multi-output tuple otherwise.
|
||||
@ -518,6 +528,12 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
absl::Span<HloComputation* const> reducers,
|
||||
const TilingKernelInfo& tiling_kernel_info);
|
||||
|
||||
// Emits code for reductions in the output_instructions.
|
||||
void EmitIRForReduction(HloInstruction* unnested_hlo,
|
||||
absl::Span<HloInstruction* const> output_instructions,
|
||||
ReductionCodegenInfo* reduction_info,
|
||||
const Shape& input_shape);
|
||||
|
||||
// For each reducer, emits the shuffle-down loop to accumulate the partial
|
||||
// result to the global result.
|
||||
void EmitFullWarpShuffleDownLoopForAllReduces(
|
||||
|
@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
return ExecuteKernelOnStream(*kernel, buffer_args,
|
||||
launch_dimensions.threads_per_block(),
|
||||
launch_dimensions.block_count(), params.stream);
|
||||
return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions,
|
||||
params.stream);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -26,8 +26,11 @@ namespace gpu {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const LaunchDimensions& launch_dims) {
|
||||
out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(),
|
||||
launch_dims.threads_per_block());
|
||||
LaunchDimensions::Dim3D block_counts = launch_dims.block_counts();
|
||||
LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block();
|
||||
out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]",
|
||||
block_counts.x, block_counts.y, block_counts.z,
|
||||
thread_counts.x, thread_counts.y, thread_counts.z);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -29,24 +29,37 @@ namespace gpu {
|
||||
// number of threads per block.
|
||||
class LaunchDimensions {
|
||||
public:
|
||||
struct Dim3D {
|
||||
int64 x, y, z;
|
||||
};
|
||||
|
||||
// The default constructor creates a launch dimension that indicate
|
||||
// single-threaded execution.
|
||||
LaunchDimensions() : block_count_(1), threads_per_block_(1) {}
|
||||
LaunchDimensions()
|
||||
: block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {}
|
||||
|
||||
LaunchDimensions(int64 block_count, int64 threads_per_block)
|
||||
: block_count_(block_count), threads_per_block_(threads_per_block) {}
|
||||
LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block)
|
||||
: block_counts_({block_x_count, 1, 1}),
|
||||
thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {}
|
||||
|
||||
bool IsSinglethreaded() const {
|
||||
return block_count_ == 1 && threads_per_block_ == 1;
|
||||
LaunchDimensions(const Dim3D& block_counts,
|
||||
const Dim3D& thread_counts_per_block)
|
||||
: block_counts_(block_counts),
|
||||
thread_counts_per_block_(thread_counts_per_block) {}
|
||||
|
||||
Dim3D block_counts() const { return block_counts_; }
|
||||
|
||||
Dim3D thread_counts_per_block() const { return thread_counts_per_block_; }
|
||||
|
||||
int64 launch_bound() const {
|
||||
return block_counts_.x * thread_counts_per_block_.x * block_counts_.y *
|
||||
thread_counts_per_block_.y * block_counts_.z *
|
||||
thread_counts_per_block_.z;
|
||||
}
|
||||
|
||||
int64 block_count() const { return block_count_; }
|
||||
int64 threads_per_block() const { return threads_per_block_; }
|
||||
int64 launch_bound() const { return block_count() * threads_per_block(); }
|
||||
|
||||
private:
|
||||
int64 block_count_;
|
||||
int64 threads_per_block_;
|
||||
Dim3D block_counts_;
|
||||
Dim3D thread_counts_per_block_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
|
@ -75,7 +75,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
std::vector<llvm_ir::IrArray::Index> array_indices;
|
||||
llvm::Value* block_id =
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_);
|
||||
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
|
||||
llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x,
|
||||
static_cast<llvm::Instruction*>(block_id));
|
||||
block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id");
|
||||
|
||||
@ -85,16 +85,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
// %ntid.x is currently specified as 1024.
|
||||
llvm::Value* thread_id =
|
||||
EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_);
|
||||
llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(),
|
||||
llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x,
|
||||
static_cast<llvm::Instruction*>(thread_id));
|
||||
thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id");
|
||||
|
||||
llvm::Value* linear_index_base = b_->CreateAdd(
|
||||
b_->CreateMul(block_id,
|
||||
llvm::ConstantInt::get(
|
||||
index_type, launch_dimensions_.threads_per_block()),
|
||||
"",
|
||||
/*HasNUW=*/true, /*HasNSW=*/true),
|
||||
b_->CreateMul(
|
||||
block_id,
|
||||
llvm::ConstantInt::get(
|
||||
index_type, launch_dimensions_.thread_counts_per_block().x),
|
||||
"",
|
||||
/*HasNUW=*/true, /*HasNSW=*/true),
|
||||
thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true);
|
||||
|
||||
// Add an @llvm.assume(linear_index < threads_per_block * num_blocks).
|
||||
@ -109,9 +110,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
llvm::Intrinsic::assume,
|
||||
{b_->CreateICmpULT(
|
||||
linear_index_base,
|
||||
llvm::ConstantInt::get(index_type,
|
||||
launch_dimensions_.threads_per_block() *
|
||||
launch_dimensions_.block_count()),
|
||||
llvm::ConstantInt::get(
|
||||
index_type, launch_dimensions_.thread_counts_per_block().x *
|
||||
launch_dimensions_.block_counts().x),
|
||||
"linear_index_in_range")},
|
||||
{}, b_);
|
||||
|
||||
|
@ -209,16 +209,18 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
|
||||
|
||||
Status ExecuteKernelOnStream(const se::KernelBase& kernel,
|
||||
absl::Span<const se::DeviceMemoryBase> args,
|
||||
int64 threads_per_block, int64 block_count,
|
||||
se::Stream* stream) {
|
||||
const LaunchDimensions& dims, se::Stream* stream) {
|
||||
static constexpr int kKernelArgsLimit = 1024;
|
||||
auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>();
|
||||
for (const se::DeviceMemoryBase& buf : args) {
|
||||
kernel_args->add_device_memory_argument(buf);
|
||||
}
|
||||
return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block),
|
||||
se::BlockDim(block_count), kernel,
|
||||
*kernel_args);
|
||||
LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block();
|
||||
LaunchDimensions::Dim3D block_counts = dims.block_counts();
|
||||
return stream->parent()->Launch(
|
||||
stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
|
||||
se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel,
|
||||
*kernel_args);
|
||||
}
|
||||
|
||||
se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -71,8 +72,7 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
|
||||
// Runs loaded kernel on the stream with the provided arguments.
|
||||
Status ExecuteKernelOnStream(const se::KernelBase& kernel,
|
||||
absl::Span<const se::DeviceMemoryBase> args,
|
||||
int64 threads_per_block, int64 block_count,
|
||||
se::Stream* stream);
|
||||
const LaunchDimensions& dims, se::Stream* stream);
|
||||
|
||||
// Create GpuAsmOpts out of HloModuleConfig.
|
||||
se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config);
|
||||
|
@ -219,6 +219,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "parallel_reduction_test",
|
||||
srcs = [
|
||||
"parallel_reduction_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + ["no_rocm"],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_copy_test",
|
||||
srcs = ["gpu_copy_test.cc"],
|
||||
|
@ -0,0 +1,190 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
class ParallelReductionTest : public GpuCodegenTest {
|
||||
DebugOptions GetDebugOptionsForTest() override {
|
||||
DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest();
|
||||
// The test contains a MOF fusion and the XLA optimizer passes
|
||||
// don't like this.
|
||||
debug_options.set_xla_disable_all_hlo_passes(true);
|
||||
return debug_options;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ParallelReductionTest, TwoParallelReductions) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule TwoParallelReductions
|
||||
|
||||
%add_f32 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param0 = f32[1024] parameter(0)
|
||||
%param1 = f32[1024] parameter(1)
|
||||
%constant0 = f32[] constant(0)
|
||||
%reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
%reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2)
|
||||
}
|
||||
|
||||
ENTRY %cluster {
|
||||
%param0 = f32[1024] parameter(0)
|
||||
%param1 = f32[1024] parameter(1)
|
||||
ROOT %fusion = (f32[], f32[])
|
||||
fusion(%param0, %param1), kind=kInput, calls=%fused_computation
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
CHECK: reduce-group-0
|
||||
CHECK: reduce-group-1
|
||||
CHECK-NOT: reduce-group-2
|
||||
)",
|
||||
/*match_optimized_ir=*/false);
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ParallelReductionTest, ManyParallelReductions) {
|
||||
std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
|
||||
// Simply use a number not too large to avoid long compilation time
|
||||
// and not too small for meaningful test.
|
||||
const size_t num_reduces = 32;
|
||||
|
||||
HloComputation* reduce_computation;
|
||||
{
|
||||
auto embedded_builder = HloComputation::Builder("add");
|
||||
HloInstruction* lhs =
|
||||
embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
|
||||
HloInstruction* rhs =
|
||||
embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
|
||||
embedded_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
|
||||
reduce_computation =
|
||||
module->AddEmbeddedComputation(embedded_builder.Build());
|
||||
}
|
||||
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {1024});
|
||||
Shape output_shape = ShapeUtil::MakeShape(F32, {});
|
||||
HloComputation* fusion_computation;
|
||||
{
|
||||
auto fusion_builder = HloComputation::Builder("fusion_computation");
|
||||
std::vector<HloInstruction*> outputs;
|
||||
HloInstruction* constant = fusion_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
for (size_t i = 0; i < num_reduces; ++i) {
|
||||
HloInstruction* param = fusion_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(i, input_shape, "param"));
|
||||
HloInstruction* output =
|
||||
fusion_builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
output_shape, param, constant, {0}, reduce_computation));
|
||||
outputs.push_back(output);
|
||||
}
|
||||
fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs));
|
||||
fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build());
|
||||
}
|
||||
|
||||
HloComputation::Builder b(TestName());
|
||||
std::vector<HloInstruction*> entry_params;
|
||||
std::vector<Shape> output_shapes;
|
||||
for (size_t i = 0; i < num_reduces; ++i) {
|
||||
HloInstruction* param = b.AddInstruction(
|
||||
HloInstruction::CreateParameter(i, input_shape, "param"));
|
||||
entry_params.push_back(param);
|
||||
output_shapes.push_back(output_shape);
|
||||
}
|
||||
b.AddInstruction(HloInstruction::CreateFusion(
|
||||
ShapeUtil::MakeTupleShape(output_shapes),
|
||||
HloInstruction::FusionKind::kInput, entry_params, fusion_computation));
|
||||
module->AddEntryComputation(b.Build());
|
||||
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(ParallelReductionTest, ThreeReductionGroups) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule ThreeReductionGroups
|
||||
|
||||
%add_f32 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
|
||||
%fused_computation {
|
||||
%param0 = f32[1024,128] parameter(0)
|
||||
%param1 = f32[1024,128] parameter(1)
|
||||
%param2 = f32[1024,128] parameter(2)
|
||||
%constant0 = f32[] constant(0)
|
||||
// %mul0, %reduce0, and %reduce1 should go into a group.
|
||||
%broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={}
|
||||
%mul0 = f32[1024,128] multiply(param0, broadcast0)
|
||||
%reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
%reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
// %reduce2 and %reduce3 should go into another group.
|
||||
%reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
%reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
// %reduce4 and %mul2 should go into the other group, although broadcast0 is
|
||||
// reused.
|
||||
%mul1 = f32[1024,128] multiply(param2, broadcast0)
|
||||
%reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32
|
||||
%mul2 = f32[1024,128] multiply(param2, param2)
|
||||
ROOT %tuple =
|
||||
(f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128])
|
||||
tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0)
|
||||
}
|
||||
|
||||
ENTRY %cluster {
|
||||
%param0 = f32[1024,128] parameter(0)
|
||||
%param1 = f32[1024,128] parameter(1)
|
||||
%param2 = f32[1024,128] parameter(2)
|
||||
ROOT %fusion =
|
||||
(f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128])
|
||||
fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module,
|
||||
ParseAndReturnVerifiedModule(hlo_text));
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
CHECK: reduce-group-0
|
||||
CHECK: reduce-group-1
|
||||
CHECK: reduce-group-2
|
||||
CHECK-NOT: reduce-group-3
|
||||
)",
|
||||
/*match_optimized_ir=*/false);
|
||||
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user