[XLA:CPU] AllToAll support for XLA:CPU

A single master thread performs all the work

PiperOrigin-RevId: 320074537
Change-Id: Iaa4e4a78b0f058ffdb11334a12e8b78126399e89
This commit is contained in:
George Karpenkov 2020-07-07 15:34:41 -07:00 committed by TensorFlower Gardener
parent a1b927ce1d
commit 052263c130
6 changed files with 335 additions and 19 deletions

View File

@ -122,6 +122,7 @@ extern const char* const kTracingStartSymbolName =
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
extern const char* const kCollectivePermuteSymbolName =
"__xla_cpu_runtime_CollectivePermute";
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
@ -154,6 +155,34 @@ struct CollectivePermuteParticipantData : xla::ParticipantData {
}
};
struct AllToAllParticipantData : xla::ParticipantData {
AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
xla::int64 device_ordinal_p, se::Stream* stream_p)
: ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
std::vector<se::DeviceMemoryBase> source_buffers;
std::vector<se::DeviceMemoryBase> destination_buffers;
int replica_id;
// Replica ids participating in AllToAll, concatenation happens in the order
// of appearence.
std::vector<xla::int64> replica_ids_to_copy_to;
std::string ToString() const override {
auto addr_formatter = [](std::string* out,
const se::DeviceMemoryBase& mem) {
absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
};
return absl::StrFormat(
"AllToAllParticipantData{replica_id=%d, "
"replica_ids_to_copy_to=[%s], source_buffers=[%s], "
"destination_buffers=[%s]}",
replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
absl::StrJoin(source_buffers, ", ", addr_formatter),
absl::StrJoin(destination_buffers, ", ", addr_formatter));
}
};
// Inverses the encoding of a Shape protobuf into an LLVM global variable.
xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
const void* shape_ptr, xla::int32 size_bytes) {
@ -286,6 +315,70 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
namespace {
class CpuAllToAllRendezvous
: public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
public:
explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
: xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
protected:
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
const AllToAllParticipantData& /*participant*/) override {
bool is_primary = InitializationBarrier();
if (is_primary) {
tensorflow::mutex_lock lock(mu_);
CHECK(!participants_.empty());
CHECK(!participants_[0].source_buffers.empty());
int expected_buffer_size = participants_[0].source_buffers[0].size();
// Replica id -> position in participants_.
absl::flat_hash_map<int, int> replica_id_map;
for (int pos = 0; pos < participants_.size(); pos++) {
const AllToAllParticipantData& p = participants_[pos];
CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
CHECK_EQ(p.source_buffers.size(), participants_.size());
for (int i = 0; i < p.source_buffers.size(); i++) {
CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
}
replica_id_map[p.replica_id] = pos;
}
for (AllToAllParticipantData& p : participants_) {
VLOG(3) << "Processing AllToAll participant data: " << p.ToString();
for (int j = 0; j < p.source_buffers.size(); j++) {
for (int i = 0; i < p.replica_ids_to_copy_to.size(); i++) {
int replica_id = p.replica_ids_to_copy_to[i];
int participant_num = xla::FindOrDie(replica_id_map, replica_id);
AllToAllParticipantData& other = participants_[participant_num];
// Sort by replica ordering.
std::vector<se::DeviceMemoryBase> destination_buffers =
other.destination_buffers;
absl::flat_hash_map<const void*, int> buffers_index;
for (int idx = 0; idx < destination_buffers.size(); idx++) {
buffers_index[destination_buffers[idx].opaque()] = idx;
}
absl::c_sort(
destination_buffers, [&](const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b) {
return p.replica_ids_to_copy_to[buffers_index[a.opaque()]] <
p.replica_ids_to_copy_to[buffers_index[b.opaque()]];
});
std::memcpy(destination_buffers[j].opaque(),
p.source_buffers[j].opaque(), expected_buffer_size);
}
}
}
}
return ParticipantImplOutput{is_primary, nullptr};
}
};
class CpuCollectivePermuteRendezvous
: public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
public:
@ -486,6 +579,13 @@ GlobalCollectivePermuteRendezvousMap() {
return m;
}
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
GlobalAllToAllRendezvousMap() {
static auto& m =
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
return m;
}
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
if (run_options->stream()) {
return run_options->stream()->parent()->device_ordinal();
@ -524,6 +624,48 @@ xla::RendezvousKey GetRendezvousKey(
} // namespace
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
xla::int64 op_id, const void* replica_groups_str,
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
int device_ordinal = GetDeviceOrdinal(run_options);
xla::int32 replica_id = run_options->device_assignment()
->ReplicaIdForDeviceOrdinal(device_ordinal)
.ValueOrDie();
absl::string_view replica_groups_serialized(
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
std::vector<xla::ReplicaGroup> group =
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
xla::RendezvousKey rendezvous_key =
GetRendezvousKey(run_options, group, channel_id_present, op_id);
AllToAllParticipantData participant(rendezvous_key, device_ordinal,
run_options->stream());
participant.replica_id = replica_id;
participant.replica_ids_to_copy_to =
xla::GetParticipatingReplicas(
xla::GlobalDeviceId(device_ordinal), group,
run_options->device_assignment()->replica_count(),
*run_options->device_assignment())
.ValueOrDie();
for (int i = 0; i < num_buffers; i++) {
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
participant.destination_buffers.emplace_back(destination_buffers[i],
buffer_size);
}
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
return absl::make_unique<CpuAllToAllRendezvous>(k);
};
TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
[&] {
return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
rendezvous_key, make_cpu_rendezvous);
},
participant)
.status());
}
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
const xla::ExecutableRunOptions* run_options,
const void* replica_groups_str, xla::int32 replica_groups_str_size,

View File

@ -77,6 +77,7 @@ extern const char* const kCollectivePermuteSymbolName;
extern const char* const kReplicaIdSymbolName;
extern const char* const kTracingStartSymbolName;
extern const char* const kTracingEndSymbolName;
extern const char* const kAllToAllSymbolName;
// All symbol names for XLA CPU runtime functions need to start with this
// prefix.
@ -181,6 +182,12 @@ extern void __xla_cpu_runtime_CollectivePermute(
void* output_buffer, const void* source_target_pairs,
xla::int32 source_target_pairs_size);
extern void __xla_cpu_runtime_AllToAll(
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
xla::int64 op_id, const void* replica_groups_str,
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
xla::int64 buffer_size, void** source_buffers, void** destination_buffers);
// Write the replica ID into the output buffer.
extern void __xla_cpu_runtime_ReplicaId(
const xla::ExecutableRunOptions* run_options, void* output_buffer);

View File

@ -359,7 +359,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
// to the output buffer of its corresponding operand. A GetTupleElement
// instruction forwards a pointer to the tuple element buffer at the given
// index.
auto operand = get_tuple_element->operand(0);
const HloInstruction* operand = get_tuple_element->operand(0);
const Shape& shape = get_tuple_element->shape();
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
@ -1432,6 +1432,83 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
return HandleAllReduceMultipleReplica(crs);
}
Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
auto* instr = Cast<HloAllToAllInstruction>(instruction);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
CHECK(!instr->split_dimension() && instr->shape().IsTuple())
<< "Only tuple AllToAll is supported";
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::Type* int32_type = b_.getInt32Ty();
llvm::Type* int64_type = b_.getInt64Ty();
// TODO(cheshire): 3 statements below should be a single line.
llvm::FunctionType* all_to_all_func_ty =
llvm::FunctionType::get(b_.getVoidTy(),
{/*run_options=*/i8_ptr_type,
/*channel_id_present=*/int32_type,
/*op_id=*/int64_type,
/*replica_groups=*/i8_ptr_type,
/*replica_groups_size=*/int32_type,
/*num_buffers=*/int32_type,
/*buffer_size=*/int64_type,
/*input_buffer=*/i8_ptr_type,
/*output_buffer=*/i8_ptr_type},
/*isVarArg=*/false);
auto all_to_all_func = llvm::dyn_cast<llvm::Function>(
module_
->getOrInsertFunction(runtime::kAllToAllSymbolName,
all_to_all_func_ty)
.getCallee());
all_to_all_func->setCallingConv(llvm::CallingConv::C);
std::string replica_groups =
ReplicaGroupsToString(instruction->replica_groups());
int32 replica_groups_size = replica_groups.size();
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
int64 buffer_size = -1;
std::vector<llvm::Value*> input_buffer_ptrs;
std::vector<llvm::Value*> output_buffer_ptrs;
for (int64 i = 0; i < instruction->operand_count(); i++) {
const HloInstruction* op = instruction->operand(i);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
assignment_.GetUniqueSlice(instruction, {i}));
const Shape& operand_shape = instruction->operand(i)->shape();
CHECK(operand_shape.IsArray())
<< "Operands to all-to-all must be arrays: " << instruction->ToString();
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
CHECK(buffer_size == -1 || buffer_size == out_slice.size());
buffer_size = out_slice.size();
}
llvm::Value* input_buffers =
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
llvm::Value* output_buffers =
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
b_.CreateCall(
all_to_all_func,
{/*run_options=*/GetExecutableRunOptionsArgument(),
/*channel_id_present=*/
b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
/*op_id=*/
b_.getInt64(instruction->channel_id().has_value()
? *instruction->channel_id()
: instruction->GetModule()->unique_id()),
/*replica_groups=*/replica_groups_v,
/*replica_groups_size=*/b_.getInt32(replica_groups_size),
/*num_buffers=*/b_.getInt32(instruction->operand_count()),
/*buffer_size=*/b_.getInt64(buffer_size),
/*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
/*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
return Status::OK();
}
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
std::string source_target_pairs = absl::StrJoin(
@ -2017,10 +2094,6 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
return DefaultAction(reduce);
}
Status IrEmitter::HandleAllToAll(HloInstruction*) {
return Unimplemented("AllToAll is not implemented on CPU.");
}
Status IrEmitter::HandleSend(HloInstruction* send) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send is not implemented on CPU.");
@ -2749,10 +2822,10 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
element_alignment);
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
} else {
auto* memcpy_instruction =
MemCpy(target, /*DstAlign=*/llvm::Align(element_alignment), source,
/*SrcAlign=*/llvm::Align(element_alignment),
element_count * primitive_type_size);
auto* memcpy_instruction = b_.CreateMemCpy(
target, /*DstAlign=*/llvm::Align(element_alignment), source,
/*SrcAlign=*/llvm::Align(element_alignment),
element_count * primitive_type_size);
// The memcpy does the load and the store internally. The aliasing related
// metadata has to reflect that.

View File

@ -45,6 +45,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/statusor.h"

View File

@ -241,6 +241,7 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
REGISTER_CPU_RUNTIME_SYMBOL(AllToAll);
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);

View File

@ -108,7 +108,7 @@ class CollectiveOpsTest : public HloTestBase {
}
template <typename LiteralType>
void TestAllOps() {
void TestAllOpsForReduce() {
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
auto to_literal = [&](absl::Span<const LiteralType> values) {
return LiteralUtil::CreateR1<LiteralType>(values);
@ -183,39 +183,39 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
TestAllOps<int8>();
TestAllOpsForReduce<int8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
TestAllOps<uint8>();
TestAllOpsForReduce<uint8>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
TestAllOps<uint32>();
TestAllOpsForReduce<uint32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
TestAllOps<int32>();
TestAllOpsForReduce<int32>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
TestAllOps<int64>();
TestAllOpsForReduce<int64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
TestAllOps<uint64>();
TestAllOpsForReduce<uint64>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
TestAllOps<float>();
TestAllOpsForReduce<float>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
TestAllOps<double>();
TestAllOpsForReduce<double>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
TestAllOps<Eigen::half>();
TestAllOpsForReduce<Eigen::half>();
}
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
@ -593,6 +593,98 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
results[3]));
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_EmptyReplicaGroups)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a = f32[2] constant({10, 10})
b = f32[2] constant({20, 20})
c = f32[2] constant({30, 30})
d = f32[2] constant({40, 40})
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={}
a_prime = f32[2] get-tuple-element(all2all), index=0
b_prime = f32[2] get-tuple-element(all2all), index=1
c_prime = f32[2] get-tuple-element(all2all), index=2
d_prime = f32[2] get-tuple-element(all2all), index=3
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (int i = 0; i < kNumReplicas; i++) {
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
LiteralUtil::CreateR1<float>({10, 10, 20, 20, 30, 30, 40, 40}),
results[i], ErrorSpec{1e-5, 1e-5}));
}
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_OrderedReplicaGroups)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a = f32[2] constant({10, 10})
b = f32[2] constant({20, 20})
c = f32[2] constant({30, 30})
d = f32[2] constant({40, 40})
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={{3,2,1,0}}
a_prime = f32[2] get-tuple-element(all2all), index=0
b_prime = f32[2] get-tuple-element(all2all), index=1
c_prime = f32[2] get-tuple-element(all2all), index=2
d_prime = f32[2] get-tuple-element(all2all), index=3
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (int i = 0; i < kNumReplicas; i++) {
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
LiteralUtil::CreateR1<float>({40, 40, 30, 30, 20, 20, 10, 10}),
results[i], ErrorSpec{1e-5, 1e-5}));
}
}
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_TwoReplicaGroups)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a = f32[2] constant({10, 10})
b = f32[2] constant({20, 20})
all2all = (f32[2], f32[2]) all-to-all(a, b), replica_groups={{2,1},{3,0}}
a_prime = f32[2] get-tuple-element(all2all), index=0
b_prime = f32[2] get-tuple-element(all2all), index=1
ROOT out = f32[4] concatenate(a_prime, b_prime), dimensions={0}
}
)";
const int64 kNumReplicas = 4;
auto config = GetModuleConfigForTest(kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
ExecuteReplicated(std::move(module), {}, kNumReplicas,
/*use_threads=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (int i = 0; i < kNumReplicas; i++) {
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
LiteralUtil::CreateR1<float>({20, 20, 10, 10}), results[i],
ErrorSpec{1e-5, 1e-5}));
}
}
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
std::string hlo_string = R"(
HloModule test