[XLA:CPU] AllToAll support for XLA:CPU
A single master thread performs all the work PiperOrigin-RevId: 320074537 Change-Id: Iaa4e4a78b0f058ffdb11334a12e8b78126399e89
This commit is contained in:
parent
a1b927ce1d
commit
052263c130
@ -122,6 +122,7 @@ extern const char* const kTracingStartSymbolName =
|
||||
extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd";
|
||||
extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_";
|
||||
extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce";
|
||||
extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll";
|
||||
extern const char* const kCollectivePermuteSymbolName =
|
||||
"__xla_cpu_runtime_CollectivePermute";
|
||||
extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId";
|
||||
@ -154,6 +155,34 @@ struct CollectivePermuteParticipantData : xla::ParticipantData {
|
||||
}
|
||||
};
|
||||
|
||||
struct AllToAllParticipantData : xla::ParticipantData {
|
||||
AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p,
|
||||
xla::int64 device_ordinal_p, se::Stream* stream_p)
|
||||
: ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {}
|
||||
|
||||
std::vector<se::DeviceMemoryBase> source_buffers;
|
||||
std::vector<se::DeviceMemoryBase> destination_buffers;
|
||||
int replica_id;
|
||||
|
||||
// Replica ids participating in AllToAll, concatenation happens in the order
|
||||
// of appearence.
|
||||
std::vector<xla::int64> replica_ids_to_copy_to;
|
||||
|
||||
std::string ToString() const override {
|
||||
auto addr_formatter = [](std::string* out,
|
||||
const se::DeviceMemoryBase& mem) {
|
||||
absl::StrAppend(out, absl::StrFormat("%p", mem.opaque()));
|
||||
};
|
||||
return absl::StrFormat(
|
||||
"AllToAllParticipantData{replica_id=%d, "
|
||||
"replica_ids_to_copy_to=[%s], source_buffers=[%s], "
|
||||
"destination_buffers=[%s]}",
|
||||
replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "),
|
||||
absl::StrJoin(source_buffers, ", ", addr_formatter),
|
||||
absl::StrJoin(destination_buffers, ", ", addr_formatter));
|
||||
}
|
||||
};
|
||||
|
||||
// Inverses the encoding of a Shape protobuf into an LLVM global variable.
|
||||
xla::StatusOr<xla::Shape> DecodeSelfDescribingShapeConstant(
|
||||
const void* shape_ptr, xla::int32 size_bytes) {
|
||||
@ -286,6 +315,70 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(
|
||||
|
||||
namespace {
|
||||
|
||||
class CpuAllToAllRendezvous
|
||||
: public xla::Rendezvous<AllToAllParticipantData, std::nullptr_t> {
|
||||
public:
|
||||
explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k)
|
||||
: xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
|
||||
|
||||
protected:
|
||||
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
|
||||
const AllToAllParticipantData& /*participant*/) override {
|
||||
bool is_primary = InitializationBarrier();
|
||||
|
||||
if (is_primary) {
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
|
||||
CHECK(!participants_.empty());
|
||||
CHECK(!participants_[0].source_buffers.empty());
|
||||
int expected_buffer_size = participants_[0].source_buffers[0].size();
|
||||
|
||||
// Replica id -> position in participants_.
|
||||
absl::flat_hash_map<int, int> replica_id_map;
|
||||
|
||||
for (int pos = 0; pos < participants_.size(); pos++) {
|
||||
const AllToAllParticipantData& p = participants_[pos];
|
||||
CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size());
|
||||
CHECK_EQ(p.source_buffers.size(), participants_.size());
|
||||
for (int i = 0; i < p.source_buffers.size(); i++) {
|
||||
CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size);
|
||||
CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size);
|
||||
}
|
||||
replica_id_map[p.replica_id] = pos;
|
||||
}
|
||||
|
||||
for (AllToAllParticipantData& p : participants_) {
|
||||
VLOG(3) << "Processing AllToAll participant data: " << p.ToString();
|
||||
for (int j = 0; j < p.source_buffers.size(); j++) {
|
||||
for (int i = 0; i < p.replica_ids_to_copy_to.size(); i++) {
|
||||
int replica_id = p.replica_ids_to_copy_to[i];
|
||||
int participant_num = xla::FindOrDie(replica_id_map, replica_id);
|
||||
AllToAllParticipantData& other = participants_[participant_num];
|
||||
|
||||
// Sort by replica ordering.
|
||||
std::vector<se::DeviceMemoryBase> destination_buffers =
|
||||
other.destination_buffers;
|
||||
absl::flat_hash_map<const void*, int> buffers_index;
|
||||
for (int idx = 0; idx < destination_buffers.size(); idx++) {
|
||||
buffers_index[destination_buffers[idx].opaque()] = idx;
|
||||
}
|
||||
absl::c_sort(
|
||||
destination_buffers, [&](const se::DeviceMemoryBase& a,
|
||||
const se::DeviceMemoryBase& b) {
|
||||
return p.replica_ids_to_copy_to[buffers_index[a.opaque()]] <
|
||||
p.replica_ids_to_copy_to[buffers_index[b.opaque()]];
|
||||
});
|
||||
|
||||
std::memcpy(destination_buffers[j].opaque(),
|
||||
p.source_buffers[j].opaque(), expected_buffer_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ParticipantImplOutput{is_primary, nullptr};
|
||||
}
|
||||
};
|
||||
|
||||
class CpuCollectivePermuteRendezvous
|
||||
: public xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t> {
|
||||
public:
|
||||
@ -486,6 +579,13 @@ GlobalCollectivePermuteRendezvousMap() {
|
||||
return m;
|
||||
}
|
||||
|
||||
xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>&
|
||||
GlobalAllToAllRendezvousMap() {
|
||||
static auto& m =
|
||||
*new xla::RefcountingHashMap<xla::RendezvousKey, CpuAllToAllRendezvous>;
|
||||
return m;
|
||||
}
|
||||
|
||||
int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) {
|
||||
if (run_options->stream()) {
|
||||
return run_options->stream()->parent()->device_ordinal();
|
||||
@ -524,6 +624,48 @@ xla::RendezvousKey GetRendezvousKey(
|
||||
|
||||
} // namespace
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll(
|
||||
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||
xla::int64 op_id, const void* replica_groups_str,
|
||||
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
|
||||
xla::int64 buffer_size, void** source_buffers, void** destination_buffers) {
|
||||
int device_ordinal = GetDeviceOrdinal(run_options);
|
||||
xla::int32 replica_id = run_options->device_assignment()
|
||||
->ReplicaIdForDeviceOrdinal(device_ordinal)
|
||||
.ValueOrDie();
|
||||
absl::string_view replica_groups_serialized(
|
||||
static_cast<const char*>(replica_groups_str), replica_groups_str_size);
|
||||
std::vector<xla::ReplicaGroup> group =
|
||||
xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie();
|
||||
xla::RendezvousKey rendezvous_key =
|
||||
GetRendezvousKey(run_options, group, channel_id_present, op_id);
|
||||
|
||||
AllToAllParticipantData participant(rendezvous_key, device_ordinal,
|
||||
run_options->stream());
|
||||
participant.replica_id = replica_id;
|
||||
participant.replica_ids_to_copy_to =
|
||||
xla::GetParticipatingReplicas(
|
||||
xla::GlobalDeviceId(device_ordinal), group,
|
||||
run_options->device_assignment()->replica_count(),
|
||||
*run_options->device_assignment())
|
||||
.ValueOrDie();
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
participant.source_buffers.emplace_back(source_buffers[i], buffer_size);
|
||||
participant.destination_buffers.emplace_back(destination_buffers[i],
|
||||
buffer_size);
|
||||
}
|
||||
auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) {
|
||||
return absl::make_unique<CpuAllToAllRendezvous>(k);
|
||||
};
|
||||
TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant(
|
||||
[&] {
|
||||
return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent(
|
||||
rendezvous_key, make_cpu_rendezvous);
|
||||
},
|
||||
participant)
|
||||
.status());
|
||||
}
|
||||
|
||||
TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce(
|
||||
const xla::ExecutableRunOptions* run_options,
|
||||
const void* replica_groups_str, xla::int32 replica_groups_str_size,
|
||||
|
@ -77,6 +77,7 @@ extern const char* const kCollectivePermuteSymbolName;
|
||||
extern const char* const kReplicaIdSymbolName;
|
||||
extern const char* const kTracingStartSymbolName;
|
||||
extern const char* const kTracingEndSymbolName;
|
||||
extern const char* const kAllToAllSymbolName;
|
||||
|
||||
// All symbol names for XLA CPU runtime functions need to start with this
|
||||
// prefix.
|
||||
@ -181,6 +182,12 @@ extern void __xla_cpu_runtime_CollectivePermute(
|
||||
void* output_buffer, const void* source_target_pairs,
|
||||
xla::int32 source_target_pairs_size);
|
||||
|
||||
extern void __xla_cpu_runtime_AllToAll(
|
||||
const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present,
|
||||
xla::int64 op_id, const void* replica_groups_str,
|
||||
xla::int32 replica_groups_str_size, xla::int32 num_buffers,
|
||||
xla::int64 buffer_size, void** source_buffers, void** destination_buffers);
|
||||
|
||||
// Write the replica ID into the output buffer.
|
||||
extern void __xla_cpu_runtime_ReplicaId(
|
||||
const xla::ExecutableRunOptions* run_options, void* output_buffer);
|
||||
|
@ -359,7 +359,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
// to the output buffer of its corresponding operand. A GetTupleElement
|
||||
// instruction forwards a pointer to the tuple element buffer at the given
|
||||
// index.
|
||||
auto operand = get_tuple_element->operand(0);
|
||||
const HloInstruction* operand = get_tuple_element->operand(0);
|
||||
const Shape& shape = get_tuple_element->shape();
|
||||
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
|
||||
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
|
||||
@ -1432,6 +1432,83 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) {
|
||||
return HandleAllReduceMultipleReplica(crs);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleAllToAll(HloInstruction* instruction) {
|
||||
auto* instr = Cast<HloAllToAllInstruction>(instruction);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction));
|
||||
CHECK(!instr->split_dimension() && instr->shape().IsTuple())
|
||||
<< "Only tuple AllToAll is supported";
|
||||
|
||||
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
|
||||
llvm::Type* int32_type = b_.getInt32Ty();
|
||||
llvm::Type* int64_type = b_.getInt64Ty();
|
||||
|
||||
// TODO(cheshire): 3 statements below should be a single line.
|
||||
llvm::FunctionType* all_to_all_func_ty =
|
||||
llvm::FunctionType::get(b_.getVoidTy(),
|
||||
{/*run_options=*/i8_ptr_type,
|
||||
/*channel_id_present=*/int32_type,
|
||||
/*op_id=*/int64_type,
|
||||
/*replica_groups=*/i8_ptr_type,
|
||||
/*replica_groups_size=*/int32_type,
|
||||
/*num_buffers=*/int32_type,
|
||||
/*buffer_size=*/int64_type,
|
||||
/*input_buffer=*/i8_ptr_type,
|
||||
/*output_buffer=*/i8_ptr_type},
|
||||
/*isVarArg=*/false);
|
||||
auto all_to_all_func = llvm::dyn_cast<llvm::Function>(
|
||||
module_
|
||||
->getOrInsertFunction(runtime::kAllToAllSymbolName,
|
||||
all_to_all_func_ty)
|
||||
.getCallee());
|
||||
all_to_all_func->setCallingConv(llvm::CallingConv::C);
|
||||
|
||||
std::string replica_groups =
|
||||
ReplicaGroupsToString(instruction->replica_groups());
|
||||
int32 replica_groups_size = replica_groups.size();
|
||||
llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups);
|
||||
|
||||
int64 buffer_size = -1;
|
||||
std::vector<llvm::Value*> input_buffer_ptrs;
|
||||
std::vector<llvm::Value*> output_buffer_ptrs;
|
||||
|
||||
for (int64 i = 0; i < instruction->operand_count(); i++) {
|
||||
const HloInstruction* op = instruction->operand(i);
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice,
|
||||
assignment_.GetUniqueSlice(instruction, {i}));
|
||||
const Shape& operand_shape = instruction->operand(i)->shape();
|
||||
CHECK(operand_shape.IsArray())
|
||||
<< "Operands to all-to-all must be arrays: " << instruction->ToString();
|
||||
output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape));
|
||||
input_buffer_ptrs.push_back(GetEmittedValueFor(op));
|
||||
CHECK(buffer_size == -1 || buffer_size == out_slice.size());
|
||||
buffer_size = out_slice.size();
|
||||
}
|
||||
|
||||
llvm::Value* input_buffers =
|
||||
EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_);
|
||||
llvm::Value* output_buffers =
|
||||
EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_);
|
||||
|
||||
b_.CreateCall(
|
||||
all_to_all_func,
|
||||
{/*run_options=*/GetExecutableRunOptionsArgument(),
|
||||
/*channel_id_present=*/
|
||||
b_.getInt32(static_cast<int32>(instruction->channel_id().has_value())),
|
||||
/*op_id=*/
|
||||
b_.getInt64(instruction->channel_id().has_value()
|
||||
? *instruction->channel_id()
|
||||
: instruction->GetModule()->unique_id()),
|
||||
/*replica_groups=*/replica_groups_v,
|
||||
/*replica_groups_size=*/b_.getInt32(replica_groups_size),
|
||||
/*num_buffers=*/b_.getInt32(instruction->operand_count()),
|
||||
/*buffer_size=*/b_.getInt64(buffer_size),
|
||||
/*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type),
|
||||
/*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)});
|
||||
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) {
|
||||
auto* instr = Cast<HloCollectivePermuteInstruction>(crs);
|
||||
std::string source_target_pairs = absl::StrJoin(
|
||||
@ -2017,10 +2094,6 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
return DefaultAction(reduce);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleAllToAll(HloInstruction*) {
|
||||
return Unimplemented("AllToAll is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSend(HloInstruction* send) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Send is not implemented on CPU.");
|
||||
@ -2749,10 +2822,10 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
|
||||
element_alignment);
|
||||
target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction);
|
||||
} else {
|
||||
auto* memcpy_instruction =
|
||||
MemCpy(target, /*DstAlign=*/llvm::Align(element_alignment), source,
|
||||
/*SrcAlign=*/llvm::Align(element_alignment),
|
||||
element_count * primitive_type_size);
|
||||
auto* memcpy_instruction = b_.CreateMemCpy(
|
||||
target, /*DstAlign=*/llvm::Align(element_alignment), source,
|
||||
/*SrcAlign=*/llvm::Align(element_alignment),
|
||||
element_count * primitive_type_size);
|
||||
|
||||
// The memcpy does the load and the store internally. The aliasing related
|
||||
// metadata has to reflect that.
|
||||
|
@ -45,6 +45,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
@ -241,6 +241,7 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AllReduce);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(AllToAll);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
|
||||
|
@ -108,7 +108,7 @@ class CollectiveOpsTest : public HloTestBase {
|
||||
}
|
||||
|
||||
template <typename LiteralType>
|
||||
void TestAllOps() {
|
||||
void TestAllOpsForReduce() {
|
||||
auto cast = [&](int value) { return static_cast<LiteralType>(value); };
|
||||
auto to_literal = [&](absl::Span<const LiteralType> values) {
|
||||
return LiteralUtil::CreateR1<LiteralType>(values);
|
||||
@ -183,39 +183,39 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) {
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) {
|
||||
TestAllOps<int8>();
|
||||
TestAllOpsForReduce<int8>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) {
|
||||
TestAllOps<uint8>();
|
||||
TestAllOpsForReduce<uint8>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) {
|
||||
TestAllOps<uint32>();
|
||||
TestAllOpsForReduce<uint32>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) {
|
||||
TestAllOps<int32>();
|
||||
TestAllOpsForReduce<int32>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) {
|
||||
TestAllOps<int64>();
|
||||
TestAllOpsForReduce<int64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) {
|
||||
TestAllOps<uint64>();
|
||||
TestAllOpsForReduce<uint64>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) {
|
||||
TestAllOps<float>();
|
||||
TestAllOpsForReduce<float>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) {
|
||||
TestAllOps<double>();
|
||||
TestAllOpsForReduce<double>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) {
|
||||
TestAllOps<Eigen::half>();
|
||||
TestAllOpsForReduce<Eigen::half>();
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) {
|
||||
@ -593,6 +593,98 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) {
|
||||
results[3]));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_EmptyReplicaGroups)) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY test_computation {
|
||||
a = f32[2] constant({10, 10})
|
||||
b = f32[2] constant({20, 20})
|
||||
c = f32[2] constant({30, 30})
|
||||
d = f32[2] constant({40, 40})
|
||||
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={}
|
||||
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||
c_prime = f32[2] get-tuple-element(all2all), index=2
|
||||
d_prime = f32[2] get-tuple-element(all2all), index=3
|
||||
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
|
||||
}
|
||||
)";
|
||||
const int64 kNumReplicas = 4;
|
||||
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||
/*use_threads=*/true));
|
||||
ASSERT_EQ(results.size(), kNumReplicas);
|
||||
for (int i = 0; i < kNumReplicas; i++) {
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||
LiteralUtil::CreateR1<float>({10, 10, 20, 20, 30, 30, 40, 40}),
|
||||
results[i], ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_OrderedReplicaGroups)) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY test_computation {
|
||||
a = f32[2] constant({10, 10})
|
||||
b = f32[2] constant({20, 20})
|
||||
c = f32[2] constant({30, 30})
|
||||
d = f32[2] constant({40, 40})
|
||||
all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={{3,2,1,0}}
|
||||
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||
c_prime = f32[2] get-tuple-element(all2all), index=2
|
||||
d_prime = f32[2] get-tuple-element(all2all), index=3
|
||||
ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0}
|
||||
}
|
||||
)";
|
||||
const int64 kNumReplicas = 4;
|
||||
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||
/*use_threads=*/true));
|
||||
ASSERT_EQ(results.size(), kNumReplicas);
|
||||
for (int i = 0; i < kNumReplicas; i++) {
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||
LiteralUtil::CreateR1<float>({40, 40, 30, 30, 20, 20, 10, 10}),
|
||||
results[i], ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_TwoReplicaGroups)) {
|
||||
const char* const kModuleStr = R"(
|
||||
HloModule test
|
||||
ENTRY test_computation {
|
||||
a = f32[2] constant({10, 10})
|
||||
b = f32[2] constant({20, 20})
|
||||
all2all = (f32[2], f32[2]) all-to-all(a, b), replica_groups={{2,1},{3,0}}
|
||||
a_prime = f32[2] get-tuple-element(all2all), index=0
|
||||
b_prime = f32[2] get-tuple-element(all2all), index=1
|
||||
ROOT out = f32[4] concatenate(a_prime, b_prime), dimensions={0}
|
||||
}
|
||||
)";
|
||||
const int64 kNumReplicas = 4;
|
||||
auto config = GetModuleConfigForTest(kNumReplicas);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(kModuleStr, config));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {}, kNumReplicas,
|
||||
/*use_threads=*/true));
|
||||
ASSERT_EQ(results.size(), kNumReplicas);
|
||||
for (int i = 0; i < kNumReplicas; i++) {
|
||||
EXPECT_TRUE(LiteralTestUtil::NearOrEqual(
|
||||
LiteralUtil::CreateR1<float>({20, 20, 10, 10}), results[i],
|
||||
ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) {
|
||||
std::string hlo_string = R"(
|
||||
HloModule test
|
||||
|
Loading…
x
Reference in New Issue
Block a user