[XLA:CPU] CollectivePermute support on CPU

Dummy implementation, main thread performs all the work.

PiperOrigin-RevId: 305184516
Change-Id: Ib4af0b7fda920fe08b551cb0782884ba92947ba7
This commit is contained in:
George Karpenkov 2020-04-06 22:04:55 -07:00 committed by TensorFlower Gardener
parent 19a41461e2
commit 134dcd1302
7 changed files with 247 additions and 51 deletions

View File

@ -189,6 +189,30 @@ struct AllReduceParticipantData {
}
};
struct CollectivePermuteParticipantData {
explicit CollectivePermuteParticipantData(const RendezvousKey& rendezvous_key)
: rendezvous_key(rendezvous_key) {}
RendezvousKey rendezvous_key;
int64 device_ordinal;
int replica_id;
se::DeviceMemoryBase source_data;
se::DeviceMemoryBase destination_data;
int64 byte_size;
se::Stream* stream;
std::vector<int> replica_ids_to_copy_to;
string ToString() const {
return absl::StrFormat(
"CollectivePermuteParticipantData{replica_id=%d, "
"source_data=%p, destination_data=%p, byte_size=%d, "
"replica_ids_to_copy_to=[%s]}",
replica_id, source_data.opaque(), destination_data.opaque(), byte_size,
absl::StrJoin(replica_ids_to_copy_to, ", "));
}
};
// The set of threads that want to do a collective op together all pick the same
// Rendezvous object out of the global cache and call SubmitParticipant.
//
@ -243,6 +267,17 @@ class Rendezvous {
virtual StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
const I& participant) = 0;
// Initialize the rendezvous by the first ("primary") thread which reaches the
// barrier. Returns whether this thread is primary.
bool InitializationBarrier() {
tensorflow::mutex_lock lock(mu_);
if (!initialized_) {
initialized_ = true;
return true;
}
return false;
}
virtual void CleanupImpl(O handle, bool is_primary) {}
tensorflow::mutex mu_;

View File

@ -114,6 +114,8 @@ extern const char* const kTracingStartSymbolName =
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
extern const char* const kCollectivePermuteSymbolName =
"__xla_cpu_runtime_CollectivePermute";
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
} // namespace runtime
@ -254,6 +256,50 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
namespace {
class CpuCollectivePermuteRendezvous
: public xla::Rendezvous<xla::CollectivePermuteParticipantData,
std::nullptr_t> {
public:
explicit CpuCollectivePermuteRendezvous(const xla::RendezvousKey& k)
: xla::Rendezvous<xla::CollectivePermuteParticipantData, std::nullptr_t>(
k) {}
protected:
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
const xla::CollectivePermuteParticipantData& participant) override {
bool primary = InitializationBarrier();
// Perform all copies from the primary thread.
if (primary) {
tensorflow::mutex_lock lock(mu_);
std::map<int, int> replica_idx_to_participant_idx;
for (int p_idx = 0; p_idx < participants_.size(); p_idx++) {
replica_idx_to_participant_idx[participants_[p_idx].replica_id] = p_idx;
}
for (auto& p : participants_) {
for (int dest_replica : p.replica_ids_to_copy_to) {
auto& dest_p = participants_[xla::FindOrDie(
replica_idx_to_participant_idx, dest_replica)];
std::memcpy(dest_p.destination_data.opaque(), p.source_data.opaque(),
p.byte_size);
// Each replica may be copied into only once.
replica_idx_to_participant_idx.erase(dest_replica);
}
}
// Zero out untouched participants.
for (auto& replica_p : replica_idx_to_participant_idx) {
auto& p = participants_[replica_p.second];
std::memset(p.destination_data.opaque(), 0, p.byte_size);
}
}
return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
}
};
class CpuAllReduceRendezvous
: public xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t> {
public:
@ -264,14 +310,7 @@ class CpuAllReduceRendezvous
xla::StatusOr<ParticipantImplOutput> SubmitParticipantImpl(
const xla::AllReduceParticipantData& participant) override {
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
bool primary = [&] {
tensorflow::mutex_lock lock(mu_);
if (!initialized_) {
initialized_ = true;
return true;
}
return false;
}();
bool primary = InitializationBarrier();
if (primary) {
switch (datatype) {
@ -406,12 +445,55 @@ class CpuAllReduceRendezvous
};
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>&
GlobalRendezvousMap() {
GlobalAllReduceRendezvousMap() {
static auto& m =
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllReduceRendezvous>;
return m;
}
xla::RefcountingHashMap<xla::RendezvousKey, CpuCollectivePermuteRendezvous>&
GlobalCollectivePermuteRendezvousMap() {
static auto& m = *new xla::RefcountingHashMap<xla::RendezvousKey,
CpuCollectivePermuteRendezvous>;
return m;
}
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
if (run_options->stream()) {
return run_options->stream()->parent()->device_ordinal();
} else {
return run_options->device_ordinal();
}
}
xla::RendezvousKey GetRendezvousKey(
const xla::ExecutableRunOptions* run_options,
std::vector<xla::ReplicaGroup> group, xla::int32 channel_id_present,
xla::int64 op_id) {
const xla::DeviceAssignment& device_assignment =
*run_options->device_assignment();
xla::int32 replica_count = device_assignment.replica_count();
int device_ordinal = GetDeviceOrdinal(run_options);
CHECK_EQ(device_assignment.computation_count(), 1);
std::vector<xla::int64> participating_replicas =
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
replica_count,
*run_options->device_assignment())
.ValueOrDie();
xla::RendezvousKey::CollectiveOpKind op_kind =
channel_id_present ? xla::RendezvousKey::kCrossModule
: xla::RendezvousKey::kCrossReplica;
std::vector<xla::GlobalDeviceId> participating_devices;
participating_devices.reserve(participating_replicas.size());
for (xla::int64 replica : participating_replicas) {
participating_devices.push_back(
xla::GlobalDeviceId(device_assignment(replica, 0)));
}
return xla::RendezvousKey{
run_options->run_id(), std::move(participating_devices),
static_cast<int>(participating_replicas.size()), op_kind, op_id};
}
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
@ -420,42 +502,13 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
xla::int32 channel_id_present, xla::int64 op_id, xla::int32 reduction_kind,
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
void** input_buffers, void** output_buffers) {
int device_ordinal = GetDeviceOrdinal(run_options);
absl::string_view replica_groups_serialized(
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
// FIXME(cheshire): avoid repetition w/__xla_cpu_runtime_ReplicaId.
int device_ordinal = [&] {
if (run_options->stream()) {
return run_options->stream()->parent()->device_ordinal();
} else {
return run_options->device_ordinal();
}
}();
std::vector<xla::ReplicaGroup> group =
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
const xla::DeviceAssignment& device_assignment =
*run_options->device_assignment();
xla::int32 replica_count = device_assignment.replica_count();
CHECK_EQ(device_assignment.computation_count(), 1);
std::vector<xla::int64> participating_replicas =
xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group,
replica_count,
*run_options->device_assignment())
.ValueOrDie();
xla::RendezvousKey::CollectiveOpKind op_kind =
channel_id_present ? xla::RendezvousKey::kCrossModule
: xla::RendezvousKey::kCrossReplica;
std::vector<xla::GlobalDeviceId> participating_devices;
participating_devices.reserve(participating_replicas.size());
for (xla::int64 replica : participating_replicas) {
participating_devices.push_back(
xla::GlobalDeviceId(device_assignment(replica, 0)));
}
xla::RendezvousKey rendezvous_key(
run_options->run_id(), std::move(participating_devices),
participating_replicas.size(), op_kind, op_id);
xla::RendezvousKey rendezvous_key =
GetRendezvousKey(run_options, group, channel_id_present, op_id);
auto shape_str = ShapeString(shape_ptr, shape_length);
VLOG(2) << "All-reduce input/output shape : " << shape_str;
@ -487,7 +540,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant(
[&] {
return GlobalRendezvousMap().GetOrCreateIfAbsent(
return GlobalAllReduceRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, make_cpu_rendezvous);
},
participant)
@ -496,16 +549,56 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer) {
int device_ordinal = [&]() {
if (run_options->stream()) {
return run_options->stream()->parent()->device_ordinal();
} else {
return run_options->device_ordinal();
}
}();
int device_ordinal = GetDeviceOrdinal(run_options);
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
std::memcpy(output_buffer, &replica_id, 4);
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_CollectivePermute(
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
void* output_buffer, const void* source_target_pairs,
xla::int32 source_target_pairs_size) {
int device_ordinal = GetDeviceOrdinal(run_options);
absl::string_view source_target_pairs_serialized(
static_cast<const char*>(source_target_pairs), source_target_pairs_size);
auto pairs = absl::StrSplit(source_target_pairs_serialized, ',');
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
std::vector<int> copy_to;
for (auto& p : pairs) {
std::vector<std::string> mapping = absl::StrSplit(p, '=');
CHECK_EQ(mapping.size(), 2);
int from = std::stoi(mapping[0]);
int to = std::stoi(mapping[1]);
if (from == replica_id) {
copy_to.push_back(to);
}
}
xla::RendezvousKey rendezvous_key =
GetRendezvousKey(run_options, {}, channel_id_present, op_id);
xla::CollectivePermuteParticipantData participant(rendezvous_key);
participant.replica_id = replica_id;
participant.device_ordinal = device_ordinal;
participant.stream = run_options->stream();
participant.source_data = se::DeviceMemoryBase(input_buffer, byte_size);
participant.destination_data = se::DeviceMemoryBase(output_buffer, byte_size);
participant.replica_ids_to_copy_to = copy_to;
participant.byte_size = byte_size;
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
return absl::make_unique<CpuCollectivePermuteRendezvous>(k);
};
TF_CHECK_OK(
CpuCollectivePermuteRendezvous::SubmitParticipant(
[&] {
return GlobalCollectivePermuteRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, make_cpu_rendezvous);
},
participant)
.status());
}

View File

@ -69,6 +69,7 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName;
extern const char* const kParallelForkJoinSymbolName;
extern const char* const kKeyValueSortSymbolName;
extern const char* const kAllReduceSymbolName;
extern const char* const kCollectivePermuteSymbolName;
extern const char* const kReplicaIdSymbolName;
extern const char* const kTracingStartSymbolName;
extern const char* const kTracingEndSymbolName;
@ -170,6 +171,12 @@ extern void __xla_cpu_runtime_AllReduce(
const void* shape_ptr, xla::int32 shape_length, xla::int32 num_buffers,
void** input_buffers, void** output_buffers);
extern void __xla_cpu_runtime_CollectivePermute(
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
xla::int64 op_id, xla::int32 byte_size, void* input_buffer,
void* output_buffer, const void* source_target_pairs,
xla::int32 source_target_pairs_size);
// Write the replica ID into the output buffer.
extern void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer);

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
@ -1540,6 +1541,64 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
std::string source_target_pairs = absl::StrJoin(
instr->source_target_pairs(), ",", absl::PairFormatter("="));
llvm::Value* source_target_pairs_v =
b_.CreateGlobalStringPtr(source_target_pairs);
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* int64_type = b_.getInt64Ty();
llvm::FunctionType* collective_permute_func_ty =
llvm::FunctionType::get(b_.getVoidTy(),
{
/*run_options=*/i8_ptr_type,
/*channel_id_present=*/int32_type,
/*op_id=*/int64_type,
/*byte_size=*/int32_type,
/*input_buffer=*/i8_ptr_type,
/*output_buffer=*/i8_ptr_type,
/*source_target_pairs=*/i8_ptr_type,
/*source_target_pairs_size=*/int32_type,
},
/*isVarArg=*/false);
auto collective_permute_func = llvm::dyn_cast<llvm::Function>(
module_
->getOrInsertFunction(runtime::kCollectivePermuteSymbolName,
collective_permute_func_ty)
.getCallee());
collective_permute_func->setCallingConv(llvm::CallingConv::C);
Shape shape = crs->operand(0)->shape();
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_slice,
assignment_.GetUniqueSlice(crs->operand(0), {}));
llvm::Value* input_buffer = EmitBufferPointer(input_slice, shape);
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
assignment_.GetUniqueSlice(crs, {}));
llvm::Value* output_buffer = EmitBufferPointer(output_slice, shape);
Call(collective_permute_func,
{/*run_options=*/GetExecutableRunOptionsArgument(),
/*channel_id_present=*/
b_.getInt32(static_cast<int32>(crs->channel_id().has_value())),
/*op_id=*/
b_.getInt64(crs->channel_id().has_value()
? *crs->channel_id()
: crs->GetModule()->unique_id()),
/*byte_size=*/b_.getInt32(ShapeUtil::ByteSizeOf(shape)),
/*input_buffer=*/b_.CreateBitCast(input_buffer, i8_ptr_type),
/*output_buffer=*/b_.CreateBitCast(output_buffer, i8_ptr_type),
/*source_target_pairs=*/source_target_pairs_v,
/*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())});
return Status::OK();
}
Status IrEmitter::HandleReplicaId(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());

View File

@ -155,6 +155,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleCollectivePermute(HloInstruction* crs) override;
Status HandleInfeed(HloInstruction* infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleSort(HloInstruction* sort) override;

View File

@ -237,6 +237,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);

View File

@ -560,7 +560,7 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) {
}
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_Simple)) {
XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {