From 134dcd13020ff079572ff33218a9bcd92132655c Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 6 Apr 2020 22:04:55 -0700 Subject: [PATCH] [XLA:CPU] CollectivePermute support on CPU Dummy implementation, main thread performs all the work. PiperOrigin-RevId: 305184516 Change-Id: Ib4af0b7fda920fe08b551cb0782884ba92947ba7 --- .../xla/service/collective_ops_utils.h | 35 ++++ .../compiler/xla/service/cpu/cpu_runtime.cc | 193 +++++++++++++----- .../compiler/xla/service/cpu/cpu_runtime.h | 7 + .../compiler/xla/service/cpu/ir_emitter.cc | 59 ++++++ .../compiler/xla/service/cpu/ir_emitter.h | 1 + .../xla/service/cpu/simple_orc_jit.cc | 1 + .../compiler/xla/tests/collective_ops_test.cc | 2 +- 7 files changed, 247 insertions(+), 51 deletions(-) diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index 3dd62bc1afa..8ca0e87bc88 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -189,6 +189,30 @@ struct AllReduceParticipantData { } }; +struct CollectivePermuteParticipantData { + explicit CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key) + : rendezvous_key(rendezvous_key) {} + + RendezvousKey rendezvous_key; + + int64 device_ordinal; + int replica_id; + se::DeviceMemoryBase source_data; + se::DeviceMemoryBase destination_data; + int64 byte_size; + se::Stream* stream; + std::vector replica_ids_to_copy_to; + + string ToString() const { + return absl::StrFormat( + "CollectivePermuteParticipantData{replica_id=%d, " + "source_data=%p, destination_data=%p, byte_size=%d, " + "replica_ids_to_copy_to=[%s]}", + replica_id, source_data.opaque(), destination_data.opaque(), byte_size, + absl::StrJoin(replica_ids_to_copy_to, ", ")); + } +}; + // The set of threads that want to do a collective op together all pick the same // Rendezvous object out of the global cache and call SubmitParticipant. // @@ -243,6 +267,17 @@ class Rendezvous { virtual StatusOr SubmitParticipantImpl( const I& participant) = 0; + // Initialize the rendezvous by the first ("primary") thread which reaches the + // barrier. Returns whether this thread is primary. + bool InitializationBarrier() { + tensorflow::mutex_lock lock(mu_); + if (!initialized_) { + initialized_ = true; + return true; + } + return false; + } + virtual void CleanupImpl(O handle, bool is_primary) {} tensorflow::mutex mu_; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index d6f64828c32..bd949aa24c7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -114,6 +114,8 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kCollectivePermuteSymbolName = + "__xla_cpu_runtime_CollectivePermute"; extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId"; } // namespace runtime @@ -254,6 +256,50 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( namespace { +class CpuCollectivePermuteRendezvous + : public xla::Rendezvous { + public: + explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k) + : xla::Rendezvous( + k) {} + + protected: + xla::StatusOr SubmitParticipantImpl( + const xla::CollectivePermuteParticipantData& participant) override { + bool primary = InitializationBarrier(); + + // Perform all copies from the primary thread. + if (primary) { + tensorflow::mutex_lock lock(mu_); + + std::map replica_idx_to_participant_idx; + for (int p_idx = 0; p_idx < participants_.size(); p_idx++) { + replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx; + } + + for (auto& p : participants_) { + for (int dest_replica : p.replica_ids_to_copy_to) { + auto& dest_p = participants_[xla::FindOrDie( + replica_idx_to_participant_idx, dest_replica)]; + std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(), + p.byte_size); + + // Each replica may be copied into only once. + replica_idx_to_participant_idx.erase(dest_replica); + } + } + + // Zero out untouched participants. + for (auto& replica_p : replica_idx_to_participant_idx) { + auto& p = participants_[replica_p.second]; + std::memset(p.destination_data.opaque(), 0, p.byte_size); + } + } + return ParticipantImplOutput{primary, /*custom_output=*/nullptr}; + } +}; + class CpuAllReduceRendezvous : public xla::Rendezvous { public: @@ -264,14 +310,7 @@ class CpuAllReduceRendezvous xla::StatusOr SubmitParticipantImpl( const xla::AllReduceParticipantData& participant) override { xla::PrimitiveType datatype = participant.buffers.front().primitive_type; - bool primary = [&] { - tensorflow::mutex_lock lock(mu_); - if (!initialized_) { - initialized_ = true; - return true; - } - return false; - }(); + bool primary = InitializationBarrier(); if (primary) { switch (datatype) { @@ -406,12 +445,55 @@ class CpuAllReduceRendezvous }; xla::RefcountingHashMap& -GlobalRendezvousMap() { +GlobalAllReduceRendezvousMap() { static auto& m = *new xla::RefcountingHashMap; return m; } +xla::RefcountingHashMap& +GlobalCollectivePermuteRendezvousMap() { + static auto& m = *new xla::RefcountingHashMap; + return m; +} + +int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) { + if (run_options->stream()) { + return run_options->stream()->parent()->device_ordinal(); + } else { + return run_options->device_ordinal(); + } +} + +xla::RendezvousKey GetRendezvousKey( + const xla::ExecutableRunOptions* run_options, + std::vector group, xla::int32 channel_id_present, + xla::int64 op_id) { + const xla::DeviceAssignment& device_assignment = + *run_options->device_assignment(); + xla::int32 replica_count = device_assignment.replica_count(); + int device_ordinal = GetDeviceOrdinal(run_options); + CHECK_EQ(device_assignment.computation_count(), 1); + std::vector participating_replicas = + xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group, + replica_count, + *run_options->device_assignment()) + .ValueOrDie(); + xla::RendezvousKey::CollectiveOpKind op_kind = + channel_id_present ? xla::RendezvousKey::kCrossModule + : xla::RendezvousKey::kCrossReplica; + std::vector participating_devices; + participating_devices.reserve(participating_replicas.size()); + for (xla::int64 replica : participating_replicas) { + participating_devices.push_back( + xla::GlobalDeviceId(device_assignment(replica, 0))); + } + return xla::RendezvousKey{ + run_options->run_id(), std::move(participating_devices), + static_cast(participating_replicas.size()), op_kind, op_id}; +} + } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( @@ -420,42 +502,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind, const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers, void** input_buffers, void** output_buffers) { + int device_ordinal = GetDeviceOrdinal(run_options); absl::string_view replica_groups_serialized( static_cast(replica_groups_str), replica_groups_str_size); - - // FIXME(cheshire): avoid repetition w/__xla_cpu_runtime_ReplicaId. - int device_ordinal = [&] { - if (run_options->stream()) { - return run_options->stream()->parent()->device_ordinal(); - } else { - return run_options->device_ordinal(); - } - }(); - std::vector group = xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); - const xla::DeviceAssignment& device_assignment = - *run_options->device_assignment(); - xla::int32 replica_count = device_assignment.replica_count(); - CHECK_EQ(device_assignment.computation_count(), 1); - std::vector participating_replicas = - xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group, - replica_count, - *run_options->device_assignment()) - .ValueOrDie(); - - xla::RendezvousKey::CollectiveOpKind op_kind = - channel_id_present ? xla::RendezvousKey::kCrossModule - : xla::RendezvousKey::kCrossReplica; - std::vector participating_devices; - participating_devices.reserve(participating_replicas.size()); - for (xla::int64 replica : participating_replicas) { - participating_devices.push_back( - xla::GlobalDeviceId(device_assignment(replica, 0))); - } - xla::RendezvousKey rendezvous_key( - run_options->run_id(), std::move(participating_devices), - participating_replicas.size(), op_kind, op_id); + xla::RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, group, channel_id_present, op_id); auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; @@ -487,7 +540,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant( [&] { - return GlobalRendezvousMap().GetOrCreateIfAbsent( + return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent( rendezvous_key, make_cpu_rendezvous); }, participant) @@ -496,16 +549,56 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId( const xla::ExecutableRunOptions* run_options, void* output_buffer) { - int device_ordinal = [&]() { - if (run_options->stream()) { - return run_options->stream()->parent()->device_ordinal(); - } else { - return run_options->device_ordinal(); - } - }(); - + int device_ordinal = GetDeviceOrdinal(run_options); xla::int32 replica_id = run_options->device_assignment() ->ReplicaIdForDeviceOrdinal(device_ordinal) .ValueOrDie(); std::memcpy(output_buffer, &replica_id, 4); } + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, xla::int32 byte_size, void* input_buffer, + void* output_buffer, const void* source_target_pairs, + xla::int32 source_target_pairs_size) { + int device_ordinal = GetDeviceOrdinal(run_options); + absl::string_view source_target_pairs_serialized( + static_cast(source_target_pairs), source_target_pairs_size); + auto pairs = absl::StrSplit(source_target_pairs_serialized, ','); + xla::int32 replica_id = run_options->device_assignment() + ->ReplicaIdForDeviceOrdinal(device_ordinal) + .ValueOrDie(); + std::vector copy_to; + for (auto& p : pairs) { + std::vector mapping = absl::StrSplit(p, '='); + CHECK_EQ(mapping.size(), 2); + int from = std::stoi(mapping[0]); + int to = std::stoi(mapping[1]); + if (from == replica_id) { + copy_to.push_back(to); + } + } + xla::RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, {}, channel_id_present, op_id); + + xla::CollectivePermuteParticipantData participant(rendezvous_key); + participant.replica_id = replica_id; + participant.device_ordinal = device_ordinal; + participant.stream = run_options->stream(); + participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size); + participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size); + participant.replica_ids_to_copy_to = copy_to; + participant.byte_size = byte_size; + + auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { + return absl::make_unique(k); + }; + TF_CHECK_OK( + CpuCollectivePermuteRendezvous::SubmitParticipant( + [&] { + return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, make_cpu_rendezvous); + }, + participant) + .status()); +} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 6af41dea484..14ea5448eef 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -69,6 +69,7 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; extern const char* const kKeyValueSortSymbolName; extern const char* const kAllReduceSymbolName; +extern const char* const kCollectivePermuteSymbolName; extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; @@ -170,6 +171,12 @@ extern void __xla_cpu_runtime_AllReduce( const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers, void** input_buffers, void** output_buffers); +extern void __xla_cpu_runtime_CollectivePermute( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, xla::int32 byte_size, void* input_buffer, + void* output_buffer, const void* source_target_pairs, + xla::int32 source_target_pairs_size); + // Write the replica ID into the output buffer. extern void __xla_cpu_runtime_ReplicaId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index cef45128ea0..f4549ac9f3b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" @@ -1540,6 +1541,64 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } +Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { + auto* instr = Cast(crs); + std::string source_target_pairs = absl::StrJoin( + instr->source_target_pairs(), ",", absl::PairFormatter("=")); + llvm::Value* source_target_pairs_v = + b_.CreateGlobalStringPtr(source_target_pairs); + + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* int32_type = b_.getInt32Ty(); + llvm::Type* int64_type = b_.getInt64Ty(); + llvm::FunctionType* collective_permute_func_ty = + llvm::FunctionType::get(b_.getVoidTy(), + { + /*run_options=*/i8_ptr_type, + /*channel_id_present=*/int32_type, + /*op_id=*/int64_type, + /*byte_size=*/int32_type, + /*input_buffer=*/i8_ptr_type, + /*output_buffer=*/i8_ptr_type, + /*source_target_pairs=*/i8_ptr_type, + /*source_target_pairs_size=*/int32_type, + }, + /*isVarArg=*/false); + + auto collective_permute_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction(runtime::kCollectivePermuteSymbolName, + collective_permute_func_ty) + .getCallee()); + collective_permute_func->setCallingConv(llvm::CallingConv::C); + + Shape shape = crs->operand(0)->shape(); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice, + assignment_.GetUniqueSlice(crs->operand(0), {})); + llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + assignment_.GetUniqueSlice(crs, {})); + llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape); + + Call(collective_permute_func, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32(static_cast(crs->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(crs->channel_id().has_value() + ? *crs->channel_id() + : crs->GetModule()->unique_id()), + /*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)), + /*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type), + /*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type), + /*source_target_pairs=*/source_target_pairs_v, + /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())}); + + return Status::OK(); +} + Status IrEmitter::HandleReplicaId(HloInstruction* hlo) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 95458ba05a4..cc5aa3f37fc 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -155,6 +155,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; Status HandleAllReduce(HloInstruction* crs) override; + Status HandleCollectivePermute(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index a43007aba73..153bd572eba 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -237,6 +237,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); + REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 380486357f7..f5466c632ac 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -560,7 +560,7 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) { +XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { const char* const kModuleStr = R"( HloModule test ENTRY test_computation {