diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 65248c285f5..e9656d34bd3 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -609,7 +609,6 @@ cc_library( ":flags", ":resource_operation_safety_analysis", ":shape_inference_helpers", - ":union_find", ":xla_activity_listener", ":xla_cluster_util", "//tensorflow/cc:cc_ops", @@ -628,6 +627,7 @@ cc_library( "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -705,11 +705,6 @@ tf_cc_test( ], ) -cc_library( - name = "union_find", - hdrs = ["union_find.h"], -) - tf_cc_test( name = "deadness_analysis_test", size = "small", @@ -891,7 +886,6 @@ cc_library( ":device_util", ":flags", ":resource_operation_safety_analysis", - ":union_find", ":xla_activity_listener", ":xla_activity_proto_cc", ":xla_cluster_util", @@ -900,6 +894,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index ca29bafb8eb..51c4d4ad2d8 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 8beb47543fd..a522f119243 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 46d2354b779..212fd1b4f94 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -32,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 50ebc035404..1a91f54afc9 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -819,9 +819,9 @@ cc_library( ":frontend_attributes_util", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -847,9 +847,9 @@ cc_library( ":functionalize_control_flow_util", ":functionalize_while", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -935,9 +935,9 @@ cc_library( ":functionalize_cond", ":functionalize_control_flow_util", ":tf2xla_util", - "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:union_find", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 54abccb4cfc..452b102fade 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -25,9 +25,10 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/graph_to_functiondef.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 596fa8e8e38..2a3e35e0ffd 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h" #include "tensorflow/compiler/tf2xla/functionalize_while.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index dce5efe5557..79412c4abc8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/types/optional.h" -#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/frontend_attributes_util.h" #include "tensorflow/compiler/tf2xla/functionalize_cond.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 35fa6a617f0..598112e00df 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -969,6 +969,11 @@ tf_cc_test( ], ) +cc_library( + name = "union_find", + hdrs = ["union_find.h"], +) + # ----------------------------------------------------------------------------- # This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index fa93cea05d3..21ec9924ea6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -265,6 +265,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:union_find", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1493,6 +1494,7 @@ cc_library( hdrs = ["stream_executor_util.h"], copts = tf_copts(), deps = [ + ":launch_dimensions", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 9b192aaa8e1..10a565308de 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -613,10 +613,13 @@ static StatusOr DeviceCompare(se::Stream* stream, LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); - stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()), - se::BlockDim(dim.block_count()), *comparison_kernel, - lhs_typed, rhs_typed, static_cast(kTolerance), - buffer_size, out_param.cref()); + LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dim.block_counts(); + stream->ThenLaunch( + se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), + *comparison_kernel, lhs_typed, rhs_typed, static_cast(kTolerance), + buffer_size, out_param.cref()); uint64 result = -1; CHECK_EQ(out_param->size(), sizeof(result)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 52ff167e2cc..1ce25ae8aca 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -90,6 +90,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/union_find.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -141,7 +142,7 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::LLVMContext& llvm_context = llvm_module->getContext(); llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), - launch_dims.threads_per_block()); + launch_dims.thread_counts_per_block().x); // Our launch bounds are exact, so we can specify them as reqntidx rather than // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( @@ -2991,6 +2992,28 @@ void IrEmitterUnnested::EmitPrintfWithThreadId( }); } +namespace { + +// Obtains the corresponding index of the out_instr in the outputs of the +// `unnested_hlo`. +ShapeIndex CreateShapeIndexForOutputInstruction( + const HloInstruction& unnested_hlo, const HloInstruction& out_instr) { + if (!unnested_hlo.IsMultiOutputFusion()) { + return ShapeIndex({}); + } + const auto& all_outputs = unnested_hlo.fused_expression_root()->operands(); + for (size_t i = 0; i < all_outputs.size(); ++i) { + if (all_outputs[i] == &out_instr) { + return ShapeIndex({static_cast(i)}); + } + } + LOG(FATAL) << " Fusion root does not contain output instruction; " + << " fusion: " << unnested_hlo.ToString() + << ", output instruction: " << out_instr.ToString(); +} + +} // namespace + void IrEmitterUnnested::EmitTileElementForReduction( HloInstruction* unnested_hlo, const Shape& reduction_operand_shape, absl::Span output_instructions, @@ -2998,7 +3021,6 @@ void IrEmitterUnnested::EmitTileElementForReduction( const ReductionCodegenInfo& reduction_info, absl::Span reducers, int64 x_iter_num) { VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString(); - bool returns_tuple = output_instructions.size() > 1; int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num; InlinedVector input_gens; @@ -3015,7 +3037,8 @@ void IrEmitterUnnested::EmitTileElementForReduction( for (int i = 0, e = output_instructions.size(); i != e; ++i) { const HloInstruction* inst = output_instructions[i]; - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + ShapeIndex idx = + CreateShapeIndexForOutputInstruction(*unnested_hlo, *inst); if (IsReductionFromOrToContiguousDimensions(*inst)) { input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); } else { @@ -3748,71 +3771,41 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo( reduction_dimensions.is_row_reduction); } -Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( +void IrEmitterUnnested::EmitIRForReduction( HloInstruction* unnested_hlo, - absl::Span output_instructions) { - bool returns_tuple = output_instructions.size() > 1; - VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); - + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, const Shape& input_shape) { std::vector reduce_instructions; InlinedVector reduction_output_shape_indices; InlinedVector reducers; - - // Build an initializer thunk to initialize each reduction output. - std::vector> thunks; - for (int i = 0; i < output_instructions.size(); ++i) { + for (size_t i = 0; i < output_instructions.size(); ++i) { if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { continue; } HloInstruction* output_instruction = output_instructions[i]; reduce_instructions.push_back(output_instruction); - ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); - reduction_output_shape_indices.push_back(idx); + reduction_output_shape_indices.push_back( + CreateShapeIndexForOutputInstruction(*unnested_hlo, + *output_instruction)); reducers.push_back(output_instruction->to_apply()); - - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(unnested_hlo, idx)); - thunks.push_back(std::move(initializer_thunk)); } + CHECK(reduce_instructions.size() != 0) + << " expect at least one reduce instructions."; - const HloInstruction* first_reduce = reduce_instructions.at(0); - if (output_instructions.size() > 1) { - if (!AreFusedReductionOutputsConsistent(output_instructions, - first_reduce)) { - return InternalError("Inconsistent reduction fusion outputs"); - } - } - - // Build a kernel thunk to compute all the outputs. - std::unique_ptr kernel_thunk = - BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); - - const Shape& input_shape = first_reduce->operand(0)->shape(); - // The layout of a reduction input is either set by LayoutAssignment for - // unnested kReduce or by InstructionFusion for fused kReduce. - CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " - "doesn't set the input layout of " - << first_reduce->ToString(); - - ReductionCodegenInfo reduction_info = - ComputeReductionCodegenInfo(unnested_hlo, first_reduce); const KernelMappingScheme& mapping_scheme = - reduction_info.GetKernelMappingScheme(); + reduction_info->GetKernelMappingScheme(); LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), mapping_scheme.GetThreadsPerBlock()); - VLOG(3) << "Launch dimensions of " << unnested_hlo->name() - << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() - << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); llvm::Type* index_ty = GetIndexTypeForKernel( unnested_hlo, launch_dimensions.launch_bound(), &b_); - EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions, + EmitPrologueForReduction(unnested_hlo, reduction_info, reduce_instructions, index_ty); EmitElementFunction emit_reduction_tile = [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, llvm::Value* x_loc, int64 x_iter_num) { EmitTileElementForReduction(unnested_hlo, input_shape, - output_instructions, index, reduction_info, + output_instructions, index, *reduction_info, reducers, x_iter_num); }; @@ -3821,19 +3814,180 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl) { - EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl, - thread_id_info, tile_height, tile_width, emit_reduction_tile); + EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name, + ksl, thread_id_info, tile_height, tile_width, + emit_reduction_tile); }); - EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info, + EmitEpilogueForReduction(index_ty, unnested_hlo, *reduction_info, reduce_instructions, reduction_output_shape_indices, reducers, tiling_kernel_info); +} +namespace { + +// Returns whether the `instr` is either a constant, a scalar, or a +// broadcasted constant/scalar. +bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) { + return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) || + (HloOpcode::kBroadcast == instr.opcode() && + (instr.operand(0)->IsConstant() || + ShapeUtil::IsScalar(instr.operand(0)->shape()))); +} + +// Divides output_instructions into groups. Different groups will be executed +// in parallel. Generally speaking, we'd like to run the reduce instructions +// in parallel without incurring too much recomputation overhead. The current +// heuristic is to place reduce instructions who share nothing or only +// (broadcasted) scalars/constants into different groups; otherwise, they are +// placed in the same group. Non-reduce instructions always go with the reduce +// instructions into the same group so long as they share any predecessors. +std::vector> DivideOutputInstructionsIntoGroups( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + CHECK(!output_instructions.empty()); + if (output_instructions.size() == 1) { + return {{output_instructions[0]}}; + } + + std::vector> disjoint_sets( + output_instructions.size()); + for (size_t i = 0; i < output_instructions.size(); ++i) { + disjoint_sets[i].Get() = output_instructions[i]; + } + + std::unique_ptr reachability_map = + HloReachabilityMap::Build(unnested_hlo->fused_instructions_computation()); + for (auto* instr : unnested_hlo->fused_instructions()) { + std::vector reached_output_ids; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + if (HloOpcode::kReduce == output_instructions[oid]->opcode() && + (IsBroadcastedConstantOrScalar(*instr))) { + // Do not group output reduce instructions through broadcasted + // constants or scalars, as the recomputation should be acceptable. + VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString(); + continue; + } + // Now group output instructions if they have common predecessors. + if (reachability_map->IsReachable(instr, output_instructions[oid])) { + VLOG(3) << "Reaching " << output_instructions[oid]->ToString() + << " from " << instr->ToString(); + reached_output_ids.push_back(oid); + } + } + for (size_t j = 1; j < reached_output_ids.size(); ++j) { + disjoint_sets[reached_output_ids[0]].Merge( + &disjoint_sets[reached_output_ids[j]]); + } + } + // Place output instructions in the same set into the same group. + absl::flat_hash_map> groups; + for (size_t oid = 0; oid < output_instructions.size(); ++oid) { + groups[disjoint_sets[oid].Get()].push_back(output_instructions.at(oid)); + } + + std::vector> ret; + absl::c_for_each( + groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); }); + return ret; +} + +} // namespace + +Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions( + HloInstruction* unnested_hlo, + absl::Span output_instructions) { + bool returns_tuple = output_instructions.size() > 1; + VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString(); + + // Build an initializer thunk to initialize each reduction output. + std::vector> thunks; + for (int i = 0; i < output_instructions.size(); ++i) { + if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + continue; + } + + ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({}); + TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, + BuildInitializerThunk(unnested_hlo, idx)); + thunks.push_back(std::move(initializer_thunk)); + } + + // Build a kernel thunk to compute all the outputs. + const HloInstruction* first_reduce = nullptr; + for (int i = 0; i < output_instructions.size(); ++i) { + if (IsReductionFromOrToContiguousDimensions(*output_instructions[i])) { + first_reduce = output_instructions[i]; + break; + } + } + CHECK(first_reduce); + if (output_instructions.size() > 1) { + if (!AreFusedReductionOutputsConsistent(output_instructions, + first_reduce)) { + return InternalError("Inconsistent reduction fusion outputs"); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + // The layout of a reduction input is either set by LayoutAssignment for + // unnested kReduce or by InstructionFusion for fused kReduce. + CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion " + "doesn't set the input layout of " + << first_reduce->ToString(); + + // Group output instructions. Each group will be executed in parallel. + std::vector> instr_groups = + DivideOutputInstructionsIntoGroups(unnested_hlo, output_instructions); + VLOG(2) << StrCat("Generate in ", instr_groups.size(), " groups for ", + unnested_hlo->ToString()); + std::unique_ptr kernel_thunk = + BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false); + KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); + for (size_t i = 0; i < instr_groups.size(); ++i) { + // Create a new ReductionCodegenInfo instance as it contains states for + // code generation per reduction group. For now, let's always use the very + // first reduce as representative to construct ReductionCodegenInfo, since + // all the reductions are required to have the same shape and layout as + // verified by `AreFusedReductionOutputsConsistent()`. We can loosen the + // constraint later when the needs arise. + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + auto emit_reduction_func = [&] { + EmitIRForReduction(unnested_hlo, instr_groups[i], &reduction_info, + input_shape); + }; + // Use raw block_id_y to select the i-th parallel reduction to run. Using + // block_id_y instead of block_id_x simplifies the index calculation + // for reduction code generation as the block_id_y is orthogonal to + // the indices used within the reductions. + llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic( + gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_); + llvm_ir::AddRangeMetadata(0, instr_groups.size(), + llvm::cast(raw_block_id_y)); + llvm::Value* guarding_cond = + b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)); + ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func); + } + ReductionCodegenInfo reduction_info = + ComputeReductionCodegenInfo(unnested_hlo, first_reduce); + const KernelMappingScheme& mapping_scheme = + reduction_info.GetKernelMappingScheme(); + // block_y_count is set to instr_groups.size(), so that each reduction group + // can be run in parallel by a different BlockIdy. + LaunchDimensions launch_dimensions( + {/*x=*/mapping_scheme.GetNumberOfBlocks(), + /*y=*/static_cast(instr_groups.size()), + /*z=*/1}, + {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1}); + VLOG(3) << "Launch dimensions of " << unnested_hlo->name() + << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks() + << " - threads per block: " << mapping_scheme.GetThreadsPerBlock(); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); thunks.push_back(std::move(kernel_thunk)); - auto sequential_thunk = absl::make_unique( - GetThunkInfo(unnested_hlo), std::move(thunks)); + std::unique_ptr sequential_thunk = + absl::make_unique(GetThunkInfo(unnested_hlo), + std::move(thunks)); AddThunkToThunkSequence(std::move(sequential_thunk)); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index c2955689f98..a637c865525 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -372,6 +372,16 @@ class IrEmitterUnnested : public IrEmitter, // } // ``` // + // Moreover, a heuristic is implemented to divide the reduce instructions + // into groups for parallelization (see `DivideOutputInstructionsIntoGroups` + // for details about the heuristic.) Reduce instructions in the same group + // will run sequentially while different groups will run in parallel. + // + // we use raw block_id_y to select the reduce groups for execution without + // complicating the index calculation in the code generation of the reduce + // instructions. In other words, a block_id_y is assigned to a group and so + // different groups can be run in parallel. + // // output_instructions: Output instructions in the computation: instruction // itself if it's not a fusion, fusion root if fusion is not multi-output, and // elements of the fusion multi-output tuple otherwise. @@ -518,6 +528,12 @@ class IrEmitterUnnested : public IrEmitter, absl::Span reducers, const TilingKernelInfo& tiling_kernel_info); + // Emits code for reductions in the output_instructions. + void EmitIRForReduction(HloInstruction* unnested_hlo, + absl::Span output_instructions, + ReductionCodegenInfo* reduction_info, + const Shape& input_shape); + // For each reducer, emits the shuffle-down loop to accumulate the partial // result to the global result. void EmitFullWarpShuffleDownLoopForAllReduces( diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 19fef37db7e..6c138258aa0 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -115,9 +115,8 @@ Status KernelThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - return ExecuteKernelOnStream(*kernel, buffer_args, - launch_dimensions.threads_per_block(), - launch_dimensions.block_count(), params.stream); + return ExecuteKernelOnStream(*kernel, buffer_args, launch_dimensions, + params.stream); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3668a521ec7..c23e8112cb0 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -26,8 +26,11 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), - launch_dims.threads_per_block()); + LaunchDimensions::Dim3D block_counts = launch_dims.block_counts(); + LaunchDimensions::Dim3D thread_counts = launch_dims.thread_counts_per_block(); + out << absl::StrFormat("[block: {%d, %d, %d}, thread: {%d, %d, %d}]", + block_counts.x, block_counts.y, block_counts.z, + thread_counts.x, thread_counts.y, thread_counts.z); return out; } diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 1a5a9d618e4..dbe5a037e43 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -29,24 +29,37 @@ namespace gpu { // number of threads per block. class LaunchDimensions { public: + struct Dim3D { + int64 x, y, z; + }; + // The default constructor creates a launch dimension that indicate // single-threaded execution. - LaunchDimensions() : block_count_(1), threads_per_block_(1) {} + LaunchDimensions() + : block_counts_({1, 1, 1}), thread_counts_per_block_({1, 1, 1}) {} - LaunchDimensions(int64 block_count, int64 threads_per_block) - : block_count_(block_count), threads_per_block_(threads_per_block) {} + LaunchDimensions(int64 block_x_count, int64 thread_x_count_per_block) + : block_counts_({block_x_count, 1, 1}), + thread_counts_per_block_({thread_x_count_per_block, 1, 1}) {} - bool IsSinglethreaded() const { - return block_count_ == 1 && threads_per_block_ == 1; + LaunchDimensions(const Dim3D& block_counts, + const Dim3D& thread_counts_per_block) + : block_counts_(block_counts), + thread_counts_per_block_(thread_counts_per_block) {} + + Dim3D block_counts() const { return block_counts_; } + + Dim3D thread_counts_per_block() const { return thread_counts_per_block_; } + + int64 launch_bound() const { + return block_counts_.x * thread_counts_per_block_.x * block_counts_.y * + thread_counts_per_block_.y * block_counts_.z * + thread_counts_per_block_.z; } - int64 block_count() const { return block_count_; } - int64 threads_per_block() const { return threads_per_block_; } - int64 launch_bound() const { return block_count() * threads_per_block(); } - private: - int64 block_count_; - int64 threads_per_block_; + Dim3D block_counts_; + Dim3D thread_counts_per_block_; }; std::ostream& operator<<(std::ostream& out, diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index f9937ba77de..6b7b31e8288 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -75,7 +75,7 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, std::vector array_indices; llvm::Value* block_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_counts().x, static_cast(block_id)); block_id = b_->CreateZExtOrTrunc(block_id, index_type, "block_id"); @@ -85,16 +85,17 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // %ntid.x is currently specified as 1024. llvm::Value* thread_id = EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b_); - llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(), + llvm_ir::AddRangeMetadata(0, launch_dimensions_.thread_counts_per_block().x, static_cast(thread_id)); thread_id = b_->CreateZExtOrTrunc(thread_id, index_type, "thread_id"); llvm::Value* linear_index_base = b_->CreateAdd( - b_->CreateMul(block_id, - llvm::ConstantInt::get( - index_type, launch_dimensions_.threads_per_block()), - "", - /*HasNUW=*/true, /*HasNSW=*/true), + b_->CreateMul( + block_id, + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x), + "", + /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). @@ -109,9 +110,9 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Intrinsic::assume, {b_->CreateICmpULT( linear_index_base, - llvm::ConstantInt::get(index_type, - launch_dimensions_.threads_per_block() * - launch_dimensions_.block_count()), + llvm::ConstantInt::get( + index_type, launch_dimensions_.thread_counts_per_block().x * + launch_dimensions_.block_counts().x), "linear_index_in_range")}, {}, b_); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index d7468a31377..8ea7c57c978 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -209,16 +209,18 @@ StatusOr> CreateKernel( Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream) { + const LaunchDimensions& dims, se::Stream* stream) { static constexpr int kKernelArgsLimit = 1024; auto kernel_args = absl::make_unique>(); for (const se::DeviceMemoryBase& buf : args) { kernel_args->add_device_memory_argument(buf); } - return stream->parent()->Launch(stream, se::ThreadDim(threads_per_block), - se::BlockDim(block_count), kernel, - *kernel_args); + LaunchDimensions::Dim3D thread_counts = dims.thread_counts_per_block(); + LaunchDimensions::Dim3D block_counts = dims.block_counts(); + return stream->parent()->Launch( + stream, se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z), + se::BlockDim(block_counts.x, block_counts.y, block_counts.z), kernel, + *kernel_args); } se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config) { diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h index 0a5e0e93a51..6696d1957b3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -71,8 +72,7 @@ StatusOr> CreateKernel( // Runs loaded kernel on the stream with the provided arguments. Status ExecuteKernelOnStream(const se::KernelBase& kernel, absl::Span args, - int64 threads_per_block, int64 block_count, - se::Stream* stream); + const LaunchDimensions& dims, se::Stream* stream); // Create GpuAsmOpts out of HloModuleConfig. se::GpuAsmOpts PtxOptsFromConfig(const HloModuleConfig& hlo_module_config); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 02b4b807323..f6e3e965166 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -219,6 +219,28 @@ tf_cc_test( ], ) +tf_cc_test( + name = "parallel_reduction_test", + srcs = [ + "parallel_reduction_test.cc", + ], + tags = tf_cuda_tests_tags() + ["no_rocm"], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc new file mode 100644 index 00000000000..06e547dfe34 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/parallel_reduction_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { + +namespace { + +class ParallelReductionTest : public GpuCodegenTest { + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + // The test contains a MOF fusion and the XLA optimizer passes + // don't like this. + debug_options.set_xla_disable_all_hlo_passes(true); + return debug_options; + } +}; + +TEST_F(ParallelReductionTest, TwoParallelReductions) { + const char* hlo_text = R"( +HloModule TwoParallelReductions + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +ENTRY %cluster { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + ROOT %fusion = (f32[], f32[]) + fusion(%param0, %param1), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK-NOT: reduce-group-2 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ManyParallelReductions) { + std::unique_ptr module = CreateNewVerifiedModule(); + // Simply use a number not too large to avoid long compilation time + // and not too small for meaningful test. + const size_t num_reduces = 32; + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + HloInstruction* lhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + HloInstruction* rhs = + embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + Shape input_shape = ShapeUtil::MakeShape(F32, {1024}); + Shape output_shape = ShapeUtil::MakeShape(F32, {}); + HloComputation* fusion_computation; + { + auto fusion_builder = HloComputation::Builder("fusion_computation"); + std::vector outputs; + HloInstruction* constant = fusion_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = fusion_builder.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + HloInstruction* output = + fusion_builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, param, constant, {0}, reduce_computation)); + outputs.push_back(output); + } + fusion_builder.AddInstruction(HloInstruction::CreateTuple(outputs)); + fusion_computation = module->AddEmbeddedComputation(fusion_builder.Build()); + } + + HloComputation::Builder b(TestName()); + std::vector entry_params; + std::vector output_shapes; + for (size_t i = 0; i < num_reduces; ++i) { + HloInstruction* param = b.AddInstruction( + HloInstruction::CreateParameter(i, input_shape, "param")); + entry_params.push_back(param); + output_shapes.push_back(output_shape); + } + b.AddInstruction(HloInstruction::CreateFusion( + ShapeUtil::MakeTupleShape(output_shapes), + HloInstruction::FusionKind::kInput, entry_params, fusion_computation)); + module->AddEntryComputation(b.Build()); + + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +TEST_F(ParallelReductionTest, ThreeReductionGroups) { + const char* hlo_text = R"( +HloModule ThreeReductionGroups + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + %constant0 = f32[] constant(0) + // %mul0, %reduce0, and %reduce1 should go into a group. + %broadcast0 = f32[1024,128] broadcast(%constant0), dimensions={} + %mul0 = f32[1024,128] multiply(param0, broadcast0) + %reduce0 = f32[128] reduce(%mul0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce1 = f32[128] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce2 and %reduce3 should go into another group. + %reduce2 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce3 = f32[128] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + // %reduce4 and %mul2 should go into the other group, although broadcast0 is + // reused. + %mul1 = f32[1024,128] multiply(param2, broadcast0) + %reduce4 = f32[128] reduce(%mul1, %constant0), dimensions={0}, to_apply=%add_f32 + %mul2 = f32[1024,128] multiply(param2, param2) + ROOT %tuple = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + tuple(%mul2, %reduce0, %reduce4, %reduce3, %reduce2, %reduce1, %mul0) +} + +ENTRY %cluster { + %param0 = f32[1024,128] parameter(0) + %param1 = f32[1024,128] parameter(1) + %param2 = f32[1024,128] parameter(2) + ROOT %fusion = + (f32[1024, 128], f32[128], f32[128], f32[128], f32[128], f32[128], f32[1024, 128]) + fusion(%param0, %param1, %param2), kind=kInput, calls=%fused_computation +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(hlo_text)); + CompileAndVerifyIr(std::move(hlo_module), + R"( +CHECK: reduce-group-0 +CHECK: reduce-group-1 +CHECK: reduce-group-2 +CHECK-NOT: reduce-group-3 +)", + /*match_optimized_ir=*/false); + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/jit/union_find.h b/tensorflow/compiler/xla/union_find.h similarity index 100% rename from tensorflow/compiler/jit/union_find.h rename to tensorflow/compiler/xla/union_find.h