From 8daab754903fc4226b6aeaf4f395b9d72b101e3f Mon Sep 17 00:00:00 2001 From: Trent Lo <trentl@nvidia.com> Date: Tue, 25 Aug 2020 12:17:10 -0700 Subject: [PATCH] XLA Parallel reduce. Extend the XLA codegen to generate parallel reductions when there are multiple reduce instructions in a fusion computation. We see ~3% e2e gain for NVIDIA JoC BERT. For `ManyParallelReductions` with 128 reduce instructions in the unittest, the execution time is reduced from 325us to 3.9us (83X), reported by nvprof as below. Before: Type Time(%) Time Calls Avg Min Max Name 32.50% 325.54us 1 325.54us 325.54us 325.54us fusion After: Type Time(%) Time Calls Avg Min Max Name 0.59% 3.9030us 1 3.9030us 3.9030us 3.9030us fusion --- tensorflow/compiler/jit/BUILD | 9 +- .../compiler/jit/compilability_check_util.cc | 2 +- .../compiler/jit/compilability_check_util.h | 2 +- .../compiler/jit/mark_for_compilation_pass.cc | 2 +- tensorflow/compiler/tf2xla/BUILD | 6 +- .../compiler/tf2xla/functionalize_cond.cc | 3 +- .../tf2xla/functionalize_control_flow.cc | 2 +- .../compiler/tf2xla/functionalize_while.cc | 2 +- tensorflow/compiler/xla/BUILD | 5 + tensorflow/compiler/xla/service/gpu/BUILD | 2 + .../xla/service/gpu/buffer_comparator.cc | 11 +- .../xla/service/gpu/ir_emitter_unnested.cc | 236 ++++++++++++++---- .../xla/service/gpu/ir_emitter_unnested.h | 6 + .../compiler/xla/service/gpu/kernel_thunk.cc | 5 +- .../xla/service/gpu/launch_dimensions.cc | 7 +- .../xla/service/gpu/launch_dimensions.h | 35 ++- .../xla/service/gpu/parallel_loop_emitter.cc | 21 +- .../xla/service/gpu/stream_executor_util.cc | 12 +- .../xla/service/gpu/stream_executor_util.h | 4 +- .../compiler/xla/service/gpu/tests/BUILD | 22 ++ .../gpu/tests/parallel_reduction_test.cc | 187 ++++++++++++++ tensorflow/compiler/{jit => xla}/union_find.h | 0 22 files changed, 484 insertions(+), 97 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc rename tensorflow/compiler/{jit => xla}/union_find.h (100%) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 35c6a8b0357..17e342c8be1 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -604,7 +604,6 @@ cc_library( ":flags", ":resource_operation_safety_analysis", ":shape_inference_helpers", - ":union_find", ":xla_activity_listener", ":xla_cluster_util", "//tensorflow/cc:cc_ops", @@ -623,6 +622,7 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:all_kernels", "//tensorflow/core:core_cpu", @@ -701,11 +701,6 @@ tf_cc_test( ], ) -cc_library( - name = "union_find", - hdrs = ["union_find.h"], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -886,7 +881,6 @@ cc_library( ":device_util", ":flags", ":resource_operation_safety_analysis", - ":union_find", ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", @@ -895,6 +889,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 6d4bc51f1b2..cab5fb9e54a 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index f1ef67bfb3d..7ae964ac229 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 03ac7b0a59a..42d142b66f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e9bcbcc6d83..3af17eb7cca 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -818,9 +818,9 @@ cc_library( ":frontend_attributes_util", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -846,9 +846,9 @@ cc_library( ":functionalize_control_flow_util", ":functionalize_while", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -934,9 +934,9 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 54abccb4cfc..452b102fade 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,9 +25,10 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/graph_to_functiondef.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 596fa8e8e38..2a3e35e0ffd 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index dce5efe5557..79412c4abc8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 35fa6a617f0..598112e00df 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -969,6 +969,11 @@ tf_cc_test( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index ce761d8e0ae..ef44c7b21d6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -265,6 +265,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1493,6 +1494,7 @@ cc_library( hdrs = ["stream_executor_util.h"], copts = tf_copts(), deps = [ + ":launch_dimensions", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 9b192aaa8e1..fa764a42a2f 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -613,10 +613,13 @@ static StatusOr<bool> DeviceCompare(se::Stream* stream, LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); - stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), - se::BlockDim(dim.block_count()), *comparison_kernel, - lhs_typed, rhs_typed, static_cast<float>(kTolerance), - buffer_size, out_param.cref()); + auto thread_counts = dim.thread_counts_per_block(); + auto block_counts = dim.block_counts(); + stream->ThenLaunch( + se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), + *comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance), + buffer_size, out_param.cref()); uint64 result = -1; CHECK_EQ(out_param->size(), sizeof(result)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f88c70b1a33..5233a8deddc 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -141,7 +142,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::LLVMContext& llvm_context = llvm_module->getContext(); llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), - launch_dims.threads_per_block()); + launch_dims.thread_counts_per_block().x); // Our launch bounds are exact, so we can specify them as reqntidx rather than // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( @@ -2991,6 +2992,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId( }); } +namespace { + +// Obtains the corresponding index of the out_instr in the outputs of the +// `unnested_hlo`. +ShapeIndex CreateShapeIndexForOutputInstruction( + const HloInstruction& unnested_hlo, const HloInstruction& out_instr) { + if (!unnested_hlo.IsMultiOutputFusion()) { + return ShapeIndex({}); + } + const auto& all_outputs = unnested_hlo.fused_expression_root()->operands(); + for (size_t i = 0; i < all_outputs.size(); ++i) { + if (all_outputs[i] == &out_instr) { + return ShapeIndex({i}); + } + } + CHECK(false) << " Fusion root does not contain output instruction; " + << " fusion: " << unnested_hlo.ToString() + << ", output instruction: " << out_instr.ToString(); +} + +} // namespace + void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, absl::Span<HloInstruction* const> output_instructions, @@ -2998,7 +3021,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( const ReductionCodegenInfo& reduction_info, absl::Span<HloComputation* const> reducers, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); - bool returns_tuple = output_instructions.size() > 1; int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num; InlinedVector<llvm_ir::ElementGenerator, 1> input_gens; @@ -3015,7 +3037,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0, e = output_instructions.size(); i != e; ++i) { const HloInstruction* inst = output_instructions[i]; - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + ShapeIndex idx = + CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); } else { @@ -3748,16 +3771,131 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( reduction_dimensions.is_row_reduction); } +Status IrEmitterUnnested::EmitIRForReduction( + HloInstruction* unnested_hlo, + absl::Span<HloInstruction* const> output_instructions, + ReductionCodegenInfo* reduction_info, const Shape& input_shape) { + std::vector<HloInstruction*> reduce_instructions; + InlinedVector<ShapeIndex, 1> reduction_output_shape_indices; + InlinedVector<HloComputation*, 1> reducers; + for (size_t i = 0; i < output_instructions.size(); ++i) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + continue; + } + + HloInstruction* output_instruction = output_instructions[i]; + reduce_instructions.push_back(output_instruction); + reduction_output_shape_indices.push_back( + CreateShapeIndexForOutputInstruction(*unnested_hlo, + *output_instruction)); + reducers.push_back(output_instruction->to_apply()); + } + CHECK(reduce_instructions.size() != 0) + << " expect at least one reduce instructions."; + + const KernelMappingScheme& mapping_scheme = + reduction_info->GetKernelMappingScheme(); + LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), + mapping_scheme.GetThreadsPerBlock()); + llvm::Type* index_ty = GetIndexTypeForKernel( + unnested_hlo, launch_dimensions.launch_bound(), &b_); + EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions, + index_ty); + EmitElementFunction emit_reduction_tile = [&]( + const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, + llvm::Value* x_loc, int64 x_iter_num) { + EmitTileElementForReduction(unnested_hlo, input_shape, output_instructions, + index, *reduction_info, reducers, x_iter_num); + }; + + TilingKernelInfo tiling_kernel_info = EmitTilingKernel( + mapping_scheme, index_ty, + [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, + const string& loop_name, llvm::Value* tile_height, + llvm::Value* tile_width, KernelSupportLibrary* ksl) { + EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name, + ksl, thread_id_info, tile_height, tile_width, + emit_reduction_tile); + }); + EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info, + reduce_instructions, reduction_output_shape_indices, + reducers, tiling_kernel_info); + + return Status::OK(); +} + +namespace { + +// Returns whether the `instr` is either a constant, a scalar, or a +// broadcasted constant/scalar. +bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { + return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) || + HloOpcode::kBroadcast == instr.opcode() && + (instr.operand(0)->IsConstant() || + ShapeUtil::IsScalar(instr.operand(0)->shape())); +} + +// Divides output_instructions into groups. Generally, we'd like to group output +// instructions sharing same predecessors to avoid recomputation. Different +// groups will be executed in parallel. +std::vector<std::vector<HloInstruction*>> DivideOutputInstructionsIntoGroups( + HloInstruction* unnested_hlo, + absl::Span<HloInstruction* const> output_instructions) { + CHECK(!output_instructions.empty()); + if (output_instructions.size() == 1) { + return {{output_instructions[0]}}; + } + + std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets( + output_instructions.size()); + for (size_t i = 0; i < output_instructions.size(); ++i) { + disjoint_sets[i].Get() = output_instructions[i]; + } + + std::unique_ptr<HloReachabilityMap> reachability_map = + HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation()); + for (auto* instr : unnested_hlo->fused_instructions()) { + std::vector<int64> reached_output_ids; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + if (HloOpcode::kReduce == output_instructions[oid]->opcode() && + (IsBroadcastedConstantOrScalar(*instr))) { + // Do not group output reduce instructions through broadcasted + // constants or scalars, as the recomputation should be acceptable. + VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString(); + continue; + } + // Now group output instructions if they have common predecessors. + if (reachability_map->IsReachable(instr, output_instructions[oid])) { + VLOG(3) << "Reaching " << output_instructions[oid]->ToString() + << " from " << instr->ToString(); + reached_output_ids.push_back(oid); + } + } + for (size_t j = 1; j < reached_output_ids.size(); ++j) { + disjoint_sets[reached_output_ids[0]].Merge( + &disjoint_sets[reached_output_ids[j]]); + } + } + // Place output instructions in the same set into the same group. + absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> groups; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid)); + } + + std::vector<std::vector<HloInstruction*>> ret; + absl::c_for_each( + groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); + return ret; +} + +} // namespace + Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( HloInstruction* unnested_hlo, absl::Span<HloInstruction* const> output_instructions) { bool returns_tuple = output_instructions.size() > 1; VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); - std::vector<HloInstruction*> reduce_instructions; - InlinedVector<ShapeIndex, 1> reduction_output_shape_indices; - InlinedVector<HloComputation*, 1> reducers; - // Build an initializer thunk to initialize each reduction output. std::vector<std::unique_ptr<Thunk>> thunks; for (int i = 0; i < output_instructions.size(); ++i) { @@ -3765,29 +3903,27 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( continue; } - HloInstruction* output_instruction = output_instructions[i]; - reduce_instructions.push_back(output_instruction); ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); - reduction_output_shape_indices.push_back(idx); - reducers.push_back(output_instruction->to_apply()); - TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk, BuildInitializerThunk(unnested_hlo, idx)); thunks.push_back(std::move(initializer_thunk)); } - const HloInstruction* first_reduce = reduce_instructions.at(0); + // Build a kernel thunk to compute all the outputs. + const HloInstruction* first_reduce = nullptr; + for (int i = 0; i < output_instructions.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + first_reduce = output_instructions[i]; + break; + } + } + CHECK(first_reduce); if (output_instructions.size() > 1) { if (!AreFusedReductionOutputsConsistent(output_instructions, first_reduce)) { return InternalError("Inconsistent reduction fusion outputs"); } } - - // Build a kernel thunk to compute all the outputs. - std::unique_ptr<KernelThunk> kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); - const Shape& input_shape = first_reduce->operand(0)->shape(); // The layout of a reduction input is either set by LayoutAssignment for // unnested kReduce or by InstructionFusion for fused kReduce. @@ -3795,39 +3931,51 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( "doesn't set the input layout of " << first_reduce->ToString(); + // Group output instructions. Each group will be executed in parallel. + auto instr_groups = + DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions); + VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ", + unnested_hlo->ToString()); + auto kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (size_t i = 0; i < instr_groups.size(); ++i) { + // Create a new ReductionCodegenInfo instance as it contains states for + // code generation per reduction group. For now, let's always use the very + // first reduce as representative to construct ReductionCodegenInfo, since + // all the reductions are required to have the same shape and layout as + // verified by AreFusedReductionOutputsConsistent(). We can loosen the + // constraint later when the needs arise. + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + auto emit_reduction_func = [&] { + EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info, + input_shape); + }; + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_groups.size(), + llvm::cast<llvm::Instruction>(raw_block_id_y)); + auto guarding_cond = b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)); + ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func); + } ReductionCodegenInfo reduction_info = ComputeReductionCodegenInfo(unnested_hlo, first_reduce); const KernelMappingScheme& mapping_scheme = reduction_info.GetKernelMappingScheme(); - LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), - mapping_scheme.GetThreadsPerBlock()); + // block_y_count is set to instr_groups.size(), so that each reduction group + // can be run in parallel by a different BlockIdy. + LaunchDimensions launch_dimensions( + {/*x=*/mapping_scheme.GetNumberOfBlocks(), /*y=*/instr_groups.size(), + /*z=*/1}, + {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1}); VLOG(3) << "Launch dimensions of " << unnested_hlo->name() << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); - llvm::Type* index_ty = GetIndexTypeForKernel( - unnested_hlo, launch_dimensions.launch_bound(), &b_); - EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions, - index_ty); - EmitElementFunction emit_reduction_tile = - [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, - llvm::Value* x_loc, int64 x_iter_num) { - EmitTileElementForReduction(unnested_hlo, input_shape, - output_instructions, index, reduction_info, - reducers, x_iter_num); - }; - - TilingKernelInfo tiling_kernel_info = EmitTilingKernel( - mapping_scheme, index_ty, - [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, - const string& loop_name, llvm::Value* tile_height, - llvm::Value* tile_width, KernelSupportLibrary* ksl) { - EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, - thread_id_info, tile_height, tile_width, emit_reduction_tile); - }); - EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, - reduce_instructions, reduction_output_shape_indices, - reducers, tiling_kernel_info); - UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b9146dd8fae..cb7c6db9445 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -519,6 +519,12 @@ class IrEmitterUnnested : public IrEmitter, absl::Span<HloComputation* const> reducers, const TilingKernelInfo& tiling_kernel_info); + // Emits code for reductions in the output_instructions. + Status EmitIRForReduction( + HloInstruction* unnested_hlo, + absl::Span<HloInstruction* const> output_instructions, + ReductionCodegenInfo* reduction_info, const Shape& input_shape); + // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 19fef37db7e..6c138258aa0 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - return ExecuteKernelOnStream(*kernel, buffer_args, - launch_dimensions.threads_per_block(), - launch_dimensions.block_count(), params.stream); + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3668a521ec7..3779372f2c5 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -26,8 +26,11 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), - launch_dims.threads_per_block()); + auto block_counts = launch_dims.block_counts(); + auto thread_counts = launch_dims.thread_counts_per_block(); + out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", + block_counts.x, block_counts.y, block_counts.z, + thread_counts.x, thread_counts.y, thread_counts.z); return out; } diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 1a5a9d618e4..dbe5a037e43 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -29,24 +29,37 @@ namespace gpu { // number of threads per block. class LaunchDimensions { public: + struct Dim3D { + int64 x, y, z; + }; + // The default constructor creates a launch dimension that indicate // single-threaded execution. - LaunchDimensions() : block_count_(1), threads_per_block_(1) {} + LaunchDimensions() + : block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {} - LaunchDimensions(int64 block_count, int64 threads_per_block) - : block_count_(block_count), threads_per_block_(threads_per_block) {} + LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block) + : block_counts_({block_x_count, 1, 1}), + thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {} - bool IsSinglethreaded() const { - return block_count_ == 1 && threads_per_block_ == 1; + LaunchDimensions(const Dim3D& block_counts, + const Dim3D& thread_counts_per_block) + : block_counts_(block_counts), + thread_counts_per_block_(thread_counts_per_block) {} + + Dim3D block_counts() const { return block_counts_; } + + Dim3D thread_counts_per_block() const { return thread_counts_per_block_; } + + int64 launch_bound() const { + return block_counts_.x * thread_counts_per_block_.x * block_counts_.y * + thread_counts_per_block_.y * block_counts_.z * + thread_counts_per_block_.z; } - int64 block_count() const { return block_count_; } - int64 threads_per_block() const { return threads_per_block_; } - int64 launch_bound() const { return block_count() * threads_per_block(); } - private: - int64 block_count_; - int64 threads_per_block_; + Dim3D block_counts_; + Dim3D thread_counts_per_block_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index f9937ba77de..6b7b31e8288 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -75,7 +75,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::vector<llvm_ir::IrArray::Index> array_indices; llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, static_cast<llvm::Instruction*>(block_id)); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); @@ -85,16 +85,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // %ntid.x is currently specified as 1024. llvm::Value* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, static_cast<llvm::Instruction*>(thread_id)); thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = b_->CreateAdd( - b_->CreateMul(block_id, - llvm::ConstantInt::get( - index_type, launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + b_->CreateMul( + block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,9 +110,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Intrinsic::assume, {b_->CreateICmpULT( linear_index_base, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x * + launch_dimensions_.block_counts().x), "linear_index_in_range")}, {}, b_); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index d7468a31377..718cc8a3697 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -209,16 +209,18 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel( Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span<const se::DeviceMemoryBase> args, - int64 threads_per_block, int64 block_count, - se::Stream* stream) { + const LaunchDimensions& dims, se::Stream* stream) { static constexpr int kKernelArgsLimit = 1024; auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>(); for (const se::DeviceMemoryBase& buf : args) { kernel_args->add_device_memory_argument(buf); } - return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), - se::BlockDim(block_count), kernel, - *kernel_args); + auto thread_counts = dims.thread_counts_per_block(); + auto block_counts = dims.block_counts(); + return stream->parent()->Launch( + stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, + *kernel_args); } se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 0a5e0e93a51..6696d1957b3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -71,8 +72,7 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel( // Runs loaded kernel on the stream with the provided arguments. Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span<const se::DeviceMemoryBase> args, - int64 threads_per_block, int64 block_count, - se::Stream* stream); + const LaunchDimensions& dims, se::Stream* stream); // Create GpuAsmOpts out of HloModuleConfig. se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 809b277317f..848cb77059d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -219,6 +219,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "parallel_reduction_test", + srcs = [ + "parallel_reduction_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc new file mode 100644 index 00000000000..0b4c84c17a1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc @@ -0,0 +1,187 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { + +namespace { + +class ParallelReductionTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test contains a MOF fusion and the XLA optimizer passes + // don't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(ParallelReductionTest, TwoParallelReductions) { + const char* hlo_text = R"( +HloModule TwoParallelReductions + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +ENTRY %cluster { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + ROOT %fusion = (f32[], f32[]) + fusion(%param0, %param1), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK-NOT: reduce-group-2 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ManyParallelReductions) { + auto module = CreateNewVerifiedModule(); + // Simply use a number not too large to avoid long compilation time + // and not too small for meaningful test. + const size_t num_reduces = 32; + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + Shape input_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloComputation* fusion_computation; + { + auto fusion_builder = HloComputation::Builder("fusion_computation"); + std::vector<HloInstruction*> outputs; + HloInstruction* constant = fusion_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))); + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + auto output = fusion_builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, param, constant, {0}, reduce_computation)); + outputs.push_back(output); + } + fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs)); + fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build()); + } + + HloComputation::Builder b(TestName()); + std::vector<HloInstruction*> entry_params; + std::vector<Shape> output_shapes; + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = b.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + entry_params.push_back(param); + output_shapes.push_back(output_shape); + } + b.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape(output_shapes), + HloInstruction::FusionKind::kInput, entry_params, fusion_computation)); + module->AddEntryComputation(b.Build()); + + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ThreeReductionGroups) { + const char* hlo_text = R"( +HloModule ThreeReductionGroups + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + %constant0 = f32[] constant(0) + // %mul0, %reduce0, and %reduce1 should go into a group. + %broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={} + %mul0 = f32[1024,128] multiply(param0, broadcast0) + %reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce2 and %reduce3 should go into another group. + %reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce4 and %mul2 should go into the other group, although broadcast0 is + // reused. + %mul1 = f32[1024,128] multiply(param2, broadcast0) + %reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32 + %mul2 = f32[1024,128] multiply(param2, param2) + ROOT %tuple = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0) +} + +ENTRY %cluster { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + ROOT %fusion = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK: reduce-group-2 +CHECK-NOT: reduce-group-3 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/xla/union_find.h similarity index 100% rename from tensorflow/compiler/jit/union_find.h rename to tensorflow/compiler/xla/union_find.h