[XLA:CPU] CollectivePermute support on CPU
Dummy implementation, main thread performs all the work. PiperOrigin-RevId: 305184516 Change-Id: Ib4af0b7fda920fe08b551cb0782884ba92947ba7
This commit is contained in:
parent
19a41461e2
commit
134dcd1302
@ -189,6 +189,30 @@ struct AllReduceParticipantData {
|
||||
}
|
||||
};
|
||||
|
||||
struct CollectivePermuteParticipantData {
|
||||
explicit CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key)
|
||||
: rendezvous_key(rendezvous_key) {}
|
||||
|
||||
RendezvousKey rendezvous_key;
|
||||
|
||||
int64 device_ordinal;
|
||||
int replica_id;
|
||||
se::DeviceMemoryBase source_data;
|
||||
se::DeviceMemoryBase destination_data;
|
||||
int64 byte_size;
|
||||
se::Stream* stream;
|
||||
std::vector<int> replica_ids_to_copy_to;
|
||||
|
||||
string ToString() const {
|
||||
return absl::StrFormat(
|
||||
"CollectivePermuteParticipantData{replica_id=%d, "
|
||||
"source_data=%p, destination_data=%p, byte_size=%d, "
|
||||
"replica_ids_to_copy_to=[%s]}",
|
||||
replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
|
||||
absl::StrJoin(replica_ids_to_copy_to, ", "));
|
||||
}
|
||||
};
|
||||
|
||||
// The set of threads that want to do a collective op together all pick the same
|
||||
// Rendezvous object out of the global cache and call SubmitParticipant.
|
||||
//
|
||||
@ -243,6 +267,17 @@ class Rendezvous {
|
||||
virtual StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
||||
const I& participant) = 0;
|
||||
|
||||
// Initialize the rendezvous by the first ("primary") thread which reaches the
|
||||
// barrier. Returns whether this thread is primary.
|
||||
bool InitializationBarrier() {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
if (!initialized_) {
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void CleanupImpl(O handle, bool is_primary) {}
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
|
||||
@ -114,6 +114,8 @@ extern const char* const kTracingStartSymbolName =
|
||||
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
||||
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
|
||||
extern const char* const kCollectivePermuteSymbolName =
|
||||
"__xla_cpu_runtime_CollectivePermute";
|
||||
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
|
||||
|
||||
} // namespace runtime
|
||||
@ -254,6 +256,50 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
|
||||
|
||||
namespace {
|
||||
|
||||
class CpuCollectivePermuteRendezvous
|
||||
: public xla::Rendezvous<xla::CollectivePermuteParticipantData,
|
||||
std::nullptr_t> {
|
||||
public:
|
||||
explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
|
||||
: xla::Rendezvous<xla::CollectivePermuteParticipantData, std::nullptr_t>(
|
||||
k) {}
|
||||
|
||||
protected:
|
||||
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
||||
const xla::CollectivePermuteParticipantData& participant) override {
|
||||
bool primary = InitializationBarrier();
|
||||
|
||||
// Perform all copies from the primary thread.
|
||||
if (primary) {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
|
||||
std::map<int, int> replica_idx_to_participant_idx;
|
||||
for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
|
||||
replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
|
||||
}
|
||||
|
||||
for (auto& p : participants_) {
|
||||
for (int dest_replica : p.replica_ids_to_copy_to) {
|
||||
auto& dest_p = participants_[xla::FindOrDie(
|
||||
replica_idx_to_participant_idx, dest_replica)];
|
||||
std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
|
||||
p.byte_size);
|
||||
|
||||
// Each replica may be copied into only once.
|
||||
replica_idx_to_participant_idx.erase(dest_replica);
|
||||
}
|
||||
}
|
||||
|
||||
// Zero out untouched participants.
|
||||
for (auto& replica_p : replica_idx_to_participant_idx) {
|
||||
auto& p = participants_[replica_p.second];
|
||||
std::memset(p.destination_data.opaque(), 0, p.byte_size);
|
||||
}
|
||||
}
|
||||
return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
|
||||
}
|
||||
};
|
||||
|
||||
class CpuAllReduceRendezvous
|
||||
: public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
|
||||
public:
|
||||
@ -264,14 +310,7 @@ class CpuAllReduceRendezvous
|
||||
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
|
||||
const xla::AllReduceParticipantData& participant) override {
|
||||
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
|
||||
bool primary = [&] {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
if (!initialized_) {
|
||||
initialized_ = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
bool primary = InitializationBarrier();
|
||||
|
||||
if (primary) {
|
||||
switch (datatype) {
|
||||
@ -406,12 +445,55 @@ class CpuAllReduceRendezvous
|
||||
};
|
||||
|
||||
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
|
||||
GlobalRendezvousMap() {
|
||||
GlobalAllReduceRendezvousMap() {
|
||||
static auto& m =
|
||||
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
|
||||
return m;
|
||||
}
|
||||
|
||||
xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
|
||||
GlobalCollectivePermuteRendezvousMap() {
|
||||
static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
|
||||
CpuCollectivePermuteRendezvous>;
|
||||
return m;
|
||||
}
|
||||
|
||||
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
|
||||
if (run_options->stream()) {
|
||||
return run_options->stream()->parent()->device_ordinal();
|
||||
} else {
|
||||
return run_options->device_ordinal();
|
||||
}
|
||||
}
|
||||
|
||||
xla::RendezvousKey GetRendezvousKey(
|
||||
const xla::ExecutableRunOptions* run_options,
|
||||
std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
|
||||
xla::int64 op_id) {
|
||||
const xla::DeviceAssignment& device_assignment =
|
||||
*run_options->device_assignment();
|
||||
xla::int32 replica_count = device_assignment.replica_count();
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
CHECK_EQ(device_assignment.computation_count(), 1);
|
||||
std::vector<xla::int64> participating_replicas =
|
||||
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
|
||||
replica_count,
|
||||
*run_options->device_assignment())
|
||||
.ValueOrDie();
|
||||
xla::RendezvousKey::CollectiveOpKind op_kind =
|
||||
channel_id_present ? xla::RendezvousKey::kCrossModule
|
||||
: xla::RendezvousKey::kCrossReplica;
|
||||
std::vector<xla::GlobalDeviceId> participating_devices;
|
||||
participating_devices.reserve(participating_replicas.size());
|
||||
for (xla::int64 replica : participating_replicas) {
|
||||
participating_devices.push_back(
|
||||
xla::GlobalDeviceId(device_assignment(replica, 0)));
|
||||
}
|
||||
return xla::RendezvousKey{
|
||||
run_options->run_id(), std::move(participating_devices),
|
||||
static_cast<int>(participating_replicas.size()), op_kind, op_id};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
@ -420,42 +502,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
|
||||
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||
void** input_buffers, void** output_buffers) {
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
absl::string_view replica_groups_serialized(
|
||||
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||
|
||||
// FIXME(cheshire): avoid repetition w/__xla_cpu_runtime_ReplicaId.
|
||||
int device_ordinal = [&] {
|
||||
if (run_options->stream()) {
|
||||
return run_options->stream()->parent()->device_ordinal();
|
||||
} else {
|
||||
return run_options->device_ordinal();
|
||||
}
|
||||
}();
|
||||
|
||||
std::vector<xla::ReplicaGroup> group =
|
||||
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
|
||||
const xla::DeviceAssignment& device_assignment =
|
||||
*run_options->device_assignment();
|
||||
xla::int32 replica_count = device_assignment.replica_count();
|
||||
CHECK_EQ(device_assignment.computation_count(), 1);
|
||||
std::vector<xla::int64> participating_replicas =
|
||||
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
|
||||
replica_count,
|
||||
*run_options->device_assignment())
|
||||
.ValueOrDie();
|
||||
|
||||
xla::RendezvousKey::CollectiveOpKind op_kind =
|
||||
channel_id_present ? xla::RendezvousKey::kCrossModule
|
||||
: xla::RendezvousKey::kCrossReplica;
|
||||
std::vector<xla::GlobalDeviceId> participating_devices;
|
||||
participating_devices.reserve(participating_replicas.size());
|
||||
for (xla::int64 replica : participating_replicas) {
|
||||
participating_devices.push_back(
|
||||
xla::GlobalDeviceId(device_assignment(replica, 0)));
|
||||
}
|
||||
xla::RendezvousKey rendezvous_key(
|
||||
run_options->run_id(), std::move(participating_devices),
|
||||
participating_replicas.size(), op_kind, op_id);
|
||||
xla::RendezvousKey rendezvous_key =
|
||||
GetRendezvousKey(run_options, group, channel_id_present, op_id);
|
||||
auto shape_str = ShapeString(shape_ptr, shape_length);
|
||||
VLOG(2) << "All-reduce input/output shape : " << shape_str;
|
||||
|
||||
@ -487,7 +540,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
|
||||
TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalRendezvousMap().GetOrCreateIfAbsent(
|
||||
return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, make_cpu_rendezvous);
|
||||
},
|
||||
participant)
|
||||
@ -496,16 +549,56 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
|
||||
const xla::ExecutableRunOptions* run_options, void* output_buffer) {
|
||||
int device_ordinal = [&]() {
|
||||
if (run_options->stream()) {
|
||||
return run_options->stream()->parent()->device_ordinal();
|
||||
} else {
|
||||
return run_options->device_ordinal();
|
||||
}
|
||||
}();
|
||||
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
std::memcpy(output_buffer, &replica_id, 4);
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
|
||||
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||
xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
|
||||
void* output_buffer, const void* source_target_pairs,
|
||||
xla::int32 source_target_pairs_size) {
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
absl::string_view source_target_pairs_serialized(
|
||||
static_cast<const char*>(source_target_pairs), source_target_pairs_size);
|
||||
auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
std::vector<int> copy_to;
|
||||
for (auto& p : pairs) {
|
||||
std::vector<std::string> mapping = absl::StrSplit(p, '=');
|
||||
CHECK_EQ(mapping.size(), 2);
|
||||
int from = std::stoi(mapping[0]);
|
||||
int to = std::stoi(mapping[1]);
|
||||
if (from == replica_id) {
|
||||
copy_to.push_back(to);
|
||||
}
|
||||
}
|
||||
xla::RendezvousKey rendezvous_key =
|
||||
GetRendezvousKey(run_options, {}, channel_id_present, op_id);
|
||||
|
||||
xla::CollectivePermuteParticipantData participant(rendezvous_key);
|
||||
participant.replica_id = replica_id;
|
||||
participant.device_ordinal = device_ordinal;
|
||||
participant.stream = run_options->stream();
|
||||
participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
|
||||
participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
|
||||
participant.replica_ids_to_copy_to = copy_to;
|
||||
participant.byte_size = byte_size;
|
||||
|
||||
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||
return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
|
||||
};
|
||||
TF_CHECK_OK(
|
||||
CpuCollectivePermuteRendezvous::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, make_cpu_rendezvous);
|
||||
},
|
||||
participant)
|
||||
.status());
|
||||
}
|
||||
|
||||
@ -69,6 +69,7 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
|
||||
extern const char* const kParallelForkJoinSymbolName;
|
||||
extern const char* const kKeyValueSortSymbolName;
|
||||
extern const char* const kAllReduceSymbolName;
|
||||
extern const char* const kCollectivePermuteSymbolName;
|
||||
extern const char* const kReplicaIdSymbolName;
|
||||
extern const char* const kTracingStartSymbolName;
|
||||
extern const char* const kTracingEndSymbolName;
|
||||
@ -170,6 +171,12 @@ extern void __xla_cpu_runtime_AllReduce(
|
||||
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
|
||||
void** input_buffers, void** output_buffers);
|
||||
|
||||
extern void __xla_cpu_runtime_CollectivePermute(
|
||||
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||
xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
|
||||
void* output_buffer, const void* source_target_pairs,
|
||||
xla::int32 source_target_pairs_size);
|
||||
|
||||
// Write the replica ID into the output buffer.
|
||||
extern void __xla_cpu_runtime_ReplicaId(
|
||||
const xla::ExecutableRunOptions* run_options, void* output_buffer);
|
||||
|
||||
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/CodeGen/TargetRegisterInfo.h"
|
||||
#include "llvm/CodeGen/TargetSubtargetInfo.h"
|
||||
@ -1540,6 +1541,64 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
|
||||
return HandleAllReduceMultipleReplica(crs);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
|
||||
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
|
||||
std::string source_target_pairs = absl::StrJoin(
|
||||
instr->source_target_pairs(), ",", absl::PairFormatter("="));
|
||||
llvm::Value* source_target_pairs_v =
|
||||
b_.CreateGlobalStringPtr(source_target_pairs);
|
||||
|
||||
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
|
||||
llvm::Type* int32_type = b_.getInt32Ty();
|
||||
llvm::Type* int64_type = b_.getInt64Ty();
|
||||
llvm::FunctionType* collective_permute_func_ty =
|
||||
llvm::FunctionType::get(b_.getVoidTy(),
|
||||
{
|
||||
/*run_options=*/i8_ptr_type,
|
||||
/*channel_id_present=*/int32_type,
|
||||
/*op_id=*/int64_type,
|
||||
/*byte_size=*/int32_type,
|
||||
/*input_buffer=*/i8_ptr_type,
|
||||
/*output_buffer=*/i8_ptr_type,
|
||||
/*source_target_pairs=*/i8_ptr_type,
|
||||
/*source_target_pairs_size=*/int32_type,
|
||||
},
|
||||
/*isVarArg=*/false);
|
||||
|
||||
auto collective_permute_func = llvm::dyn_cast<llvm::Function>(
|
||||
module_
|
||||
->getOrInsertFunction(runtime::kCollectivePermuteSymbolName,
|
||||
collective_permute_func_ty)
|
||||
.getCallee());
|
||||
collective_permute_func->setCallingConv(llvm::CallingConv::C);
|
||||
|
||||
Shape shape = crs->operand(0)->shape();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
|
||||
assignment_.GetUniqueSlice(crs->operand(0), {}));
|
||||
llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
||||
assignment_.GetUniqueSlice(crs, {}));
|
||||
llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
|
||||
|
||||
Call(collective_permute_func,
|
||||
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
||||
/*channel_id_present=*/
|
||||
b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
|
||||
/*op_id=*/
|
||||
b_.getInt64(crs->channel_id().has_value()
|
||||
? *crs->channel_id()
|
||||
: crs->GetModule()->unique_id()),
|
||||
/*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
|
||||
/*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
|
||||
/*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
|
||||
/*source_target_pairs=*/source_target_pairs_v,
|
||||
/*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())});
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
|
||||
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
|
||||
|
||||
@ -155,6 +155,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
Status HandleCollectivePermute(HloInstruction* crs) override;
|
||||
Status HandleInfeed(HloInstruction* infeed) override;
|
||||
Status HandleOutfeed(HloInstruction* outfeed) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
|
||||
@ -237,6 +237,7 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
|
||||
|
||||
@ -560,7 +560,7 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
|
||||
XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY test_computation {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user