From 052263c1305c476d1faf470dc0bfdfc24f63d501 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Tue, 7 Jul 2020 15:34:41 -0700 Subject: [PATCH] [XLA:CPU] AllToAll support for XLA:CPU A single master thread performs all the work PiperOrigin-RevId: 320074537 Change-Id: Iaa4e4a78b0f058ffdb11334a12e8b78126399e89 --- .../compiler/xla/service/cpu/cpu_runtime.cc | 142 ++++++++++++++++++ .../compiler/xla/service/cpu/cpu_runtime.h | 7 + .../compiler/xla/service/cpu/ir_emitter.cc | 91 +++++++++-- .../compiler/xla/service/cpu/ir_emitter.h | 1 + .../xla/service/cpu/simple_orc_jit.cc | 1 + .../compiler/xla/tests/collective_ops_test.cc | 112 ++++++++++++-- 6 files changed, 335 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 3df9ef35bab..2231ecfa1e8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -122,6 +122,7 @@ extern const char* const kTracingStartSymbolName = extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; extern const char* const kAllReduceSymbolName = "__xla_cpu_runtime_AllReduce"; +extern const char* const kAllToAllSymbolName = "__xla_cpu_runtime_AllToAll"; extern const char* const kCollectivePermuteSymbolName = "__xla_cpu_runtime_CollectivePermute"; extern const char* const kReplicaIdSymbolName = "__xla_cpu_runtime_ReplicaId"; @@ -154,6 +155,34 @@ struct CollectivePermuteParticipantData : xla::ParticipantData { } }; +struct AllToAllParticipantData : xla::ParticipantData { + AllToAllParticipantData(const xla::RendezvousKey& rendezvous_key_p, + xla::int64 device_ordinal_p, se::Stream* stream_p) + : ParticipantData(rendezvous_key_p, device_ordinal_p, stream_p) {} + + std::vector source_buffers; + std::vector destination_buffers; + int replica_id; + + // Replica ids participating in AllToAll, concatenation happens in the order + // of appearence. + std::vector replica_ids_to_copy_to; + + std::string ToString() const override { + auto addr_formatter = [](std::string* out, + const se::DeviceMemoryBase& mem) { + absl::StrAppend(out, absl::StrFormat("%p", mem.opaque())); + }; + return absl::StrFormat( + "AllToAllParticipantData{replica_id=%d, " + "replica_ids_to_copy_to=[%s], source_buffers=[%s], " + "destination_buffers=[%s]}", + replica_id, absl::StrJoin(replica_ids_to_copy_to, ", "), + absl::StrJoin(source_buffers, ", ", addr_formatter), + absl::StrJoin(destination_buffers, ", ", addr_formatter)); + } +}; + // Inverses the encoding of a Shape protobuf into an LLVM global variable. xla::StatusOr DecodeSelfDescribingShapeConstant( const void* shape_ptr, xla::int32 size_bytes) { @@ -286,6 +315,70 @@ __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( namespace { +class CpuAllToAllRendezvous + : public xla::Rendezvous { + public: + explicit CpuAllToAllRendezvous(const xla::RendezvousKey& k) + : xla::Rendezvous(k) {} + + protected: + xla::StatusOr RunCollectiveOp( + const AllToAllParticipantData& /*participant*/) override { + bool is_primary = InitializationBarrier(); + + if (is_primary) { + tensorflow::mutex_lock lock(mu_); + + CHECK(!participants_.empty()); + CHECK(!participants_[0].source_buffers.empty()); + int expected_buffer_size = participants_[0].source_buffers[0].size(); + + // Replica id -> position in participants_. + absl::flat_hash_map replica_id_map; + + for (int pos = 0; pos < participants_.size(); pos++) { + const AllToAllParticipantData& p = participants_[pos]; + CHECK_EQ(p.source_buffers.size(), p.destination_buffers.size()); + CHECK_EQ(p.source_buffers.size(), participants_.size()); + for (int i = 0; i < p.source_buffers.size(); i++) { + CHECK_EQ(p.destination_buffers[i].size(), expected_buffer_size); + CHECK_EQ(p.source_buffers[i].size(), expected_buffer_size); + } + replica_id_map[p.replica_id] = pos; + } + + for (AllToAllParticipantData& p : participants_) { + VLOG(3) << "Processing AllToAll participant data: " << p.ToString(); + for (int j = 0; j < p.source_buffers.size(); j++) { + for (int i = 0; i < p.replica_ids_to_copy_to.size(); i++) { + int replica_id = p.replica_ids_to_copy_to[i]; + int participant_num = xla::FindOrDie(replica_id_map, replica_id); + AllToAllParticipantData& other = participants_[participant_num]; + + // Sort by replica ordering. + std::vector destination_buffers = + other.destination_buffers; + absl::flat_hash_map buffers_index; + for (int idx = 0; idx < destination_buffers.size(); idx++) { + buffers_index[destination_buffers[idx].opaque()] = idx; + } + absl::c_sort( + destination_buffers, [&](const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b) { + return p.replica_ids_to_copy_to[buffers_index[a.opaque()]] < + p.replica_ids_to_copy_to[buffers_index[b.opaque()]]; + }); + + std::memcpy(destination_buffers[j].opaque(), + p.source_buffers[j].opaque(), expected_buffer_size); + } + } + } + } + return ParticipantImplOutput{is_primary, nullptr}; + } +}; + class CpuCollectivePermuteRendezvous : public xla::Rendezvous { public: @@ -486,6 +579,13 @@ GlobalCollectivePermuteRendezvousMap() { return m; } +xla::RefcountingHashMap& +GlobalAllToAllRendezvousMap() { + static auto& m = + *new xla::RefcountingHashMap; + return m; +} + int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) { if (run_options->stream()) { return run_options->stream()->parent()->device_ordinal(); @@ -524,6 +624,48 @@ xla::RendezvousKey GetRendezvousKey( } // namespace +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllToAll( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, const void* replica_groups_str, + xla::int32 replica_groups_str_size, xla::int32 num_buffers, + xla::int64 buffer_size, void** source_buffers, void** destination_buffers) { + int device_ordinal = GetDeviceOrdinal(run_options); + xla::int32 replica_id = run_options->device_assignment() + ->ReplicaIdForDeviceOrdinal(device_ordinal) + .ValueOrDie(); + absl::string_view replica_groups_serialized( + static_cast(replica_groups_str), replica_groups_str_size); + std::vector group = + xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); + xla::RendezvousKey rendezvous_key = + GetRendezvousKey(run_options, group, channel_id_present, op_id); + + AllToAllParticipantData participant(rendezvous_key, device_ordinal, + run_options->stream()); + participant.replica_id = replica_id; + participant.replica_ids_to_copy_to = + xla::GetParticipatingReplicas( + xla::GlobalDeviceId(device_ordinal), group, + run_options->device_assignment()->replica_count(), + *run_options->device_assignment()) + .ValueOrDie(); + for (int i = 0; i < num_buffers; i++) { + participant.source_buffers.emplace_back(source_buffers[i], buffer_size); + participant.destination_buffers.emplace_back(destination_buffers[i], + buffer_size); + } + auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { + return absl::make_unique(k); + }; + TF_CHECK_OK(CpuAllToAllRendezvous::SubmitParticipant( + [&] { + return GlobalAllToAllRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, make_cpu_rendezvous); + }, + participant) + .status()); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( const xla::ExecutableRunOptions* run_options, const void* replica_groups_str, xla::int32 replica_groups_str_size, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 492ce3f68b2..ee75b97e4dc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -77,6 +77,7 @@ extern const char* const kCollectivePermuteSymbolName; extern const char* const kReplicaIdSymbolName; extern const char* const kTracingStartSymbolName; extern const char* const kTracingEndSymbolName; +extern const char* const kAllToAllSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. @@ -181,6 +182,12 @@ extern void __xla_cpu_runtime_CollectivePermute( void* output_buffer, const void* source_target_pairs, xla::int32 source_target_pairs_size); +extern void __xla_cpu_runtime_AllToAll( + const xla::ExecutableRunOptions* run_options, xla::int32 channel_id_present, + xla::int64 op_id, const void* replica_groups_str, + xla::int32 replica_groups_str_size, xla::int32 num_buffers, + xla::int64 buffer_size, void** source_buffers, void** destination_buffers); + // Write the replica ID into the output buffer. extern void __xla_cpu_runtime_ReplicaId( const xla::ExecutableRunOptions* run_options, void* output_buffer); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 0cfab50d0a3..ae5e6433f17 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -359,7 +359,7 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { // to the output buffer of its corresponding operand. A GetTupleElement // instruction forwards a pointer to the tuple element buffer at the given // index. - auto operand = get_tuple_element->operand(0); + const HloInstruction* operand = get_tuple_element->operand(0); const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), @@ -1432,6 +1432,83 @@ Status IrEmitter::HandleAllReduce(HloInstruction* crs) { return HandleAllReduceMultipleReplica(crs); } +Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { + auto* instr = Cast(instruction); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instruction)); + CHECK(!instr->split_dimension() && instr->shape().IsTuple()) + << "Only tuple AllToAll is supported"; + + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* int32_type = b_.getInt32Ty(); + llvm::Type* int64_type = b_.getInt64Ty(); + + // TODO(cheshire): 3 statements below should be a single line. + llvm::FunctionType* all_to_all_func_ty = + llvm::FunctionType::get(b_.getVoidTy(), + {/*run_options=*/i8_ptr_type, + /*channel_id_present=*/int32_type, + /*op_id=*/int64_type, + /*replica_groups=*/i8_ptr_type, + /*replica_groups_size=*/int32_type, + /*num_buffers=*/int32_type, + /*buffer_size=*/int64_type, + /*input_buffer=*/i8_ptr_type, + /*output_buffer=*/i8_ptr_type}, + /*isVarArg=*/false); + auto all_to_all_func = llvm::dyn_cast( + module_ + ->getOrInsertFunction(runtime::kAllToAllSymbolName, + all_to_all_func_ty) + .getCallee()); + all_to_all_func->setCallingConv(llvm::CallingConv::C); + + std::string replica_groups = + ReplicaGroupsToString(instruction->replica_groups()); + int32 replica_groups_size = replica_groups.size(); + llvm::Value* replica_groups_v = b_.CreateGlobalStringPtr(replica_groups); + + int64 buffer_size = -1; + std::vector input_buffer_ptrs; + std::vector output_buffer_ptrs; + + for (int64 i = 0; i < instruction->operand_count(); i++) { + const HloInstruction* op = instruction->operand(i); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_slice, + assignment_.GetUniqueSlice(instruction, {i})); + const Shape& operand_shape = instruction->operand(i)->shape(); + CHECK(operand_shape.IsArray()) + << "Operands to all-to-all must be arrays: " << instruction->ToString(); + output_buffer_ptrs.push_back(EmitBufferPointer(out_slice, operand_shape)); + input_buffer_ptrs.push_back(GetEmittedValueFor(op)); + CHECK(buffer_size == -1 || buffer_size == out_slice.size()); + buffer_size = out_slice.size(); + } + + llvm::Value* input_buffers = + EncodeArrayFunctionArguments(input_buffer_ptrs, "input_buffers", &b_); + llvm::Value* output_buffers = + EncodeArrayFunctionArguments(output_buffer_ptrs, "output_buffers", &b_); + + b_.CreateCall( + all_to_all_func, + {/*run_options=*/GetExecutableRunOptionsArgument(), + /*channel_id_present=*/ + b_.getInt32(static_cast(instruction->channel_id().has_value())), + /*op_id=*/ + b_.getInt64(instruction->channel_id().has_value() + ? *instruction->channel_id() + : instruction->GetModule()->unique_id()), + /*replica_groups=*/replica_groups_v, + /*replica_groups_size=*/b_.getInt32(replica_groups_size), + /*num_buffers=*/b_.getInt32(instruction->operand_count()), + /*buffer_size=*/b_.getInt64(buffer_size), + /*source_buffers=*/b_.CreateBitCast(input_buffers, i8_ptr_type), + /*destination_buffers=*/b_.CreateBitCast(output_buffers, i8_ptr_type)}); + + llvm_ir::EmitTuple(GetIrArrayFor(instruction), output_buffer_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); std::string source_target_pairs = absl::StrJoin( @@ -2017,10 +2094,6 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { return DefaultAction(reduce); } -Status IrEmitter::HandleAllToAll(HloInstruction*) { - return Unimplemented("AllToAll is not implemented on CPU."); -} - Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. return Unimplemented("Send is not implemented on CPU."); @@ -2749,10 +2822,10 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = - MemCpy(target, /*DstAlign=*/llvm::Align(element_alignment), source, - /*SrcAlign=*/llvm::Align(element_alignment), - element_count * primitive_type_size); + auto* memcpy_instruction = b_.CreateMemCpy( + target, /*DstAlign=*/llvm::Align(element_alignment), source, + /*SrcAlign=*/llvm::Align(element_alignment), + element_count * primitive_type_size); // The memcpy does the load and the store internally. The aliasing related // metadata has to reflect that. diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index cef9b817503..0b19b0d67d5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 4bdb601a3d1..631c6985b03 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -241,6 +241,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); REGISTER_CPU_RUNTIME_SYMBOL(CollectivePermute); + REGISTER_CPU_RUNTIME_SYMBOL(AllToAll); REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index f5466c632ac..7459b3d3f1f 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -108,7 +108,7 @@ class CollectiveOpsTest : public HloTestBase { } template - void TestAllOps() { + void TestAllOpsForReduce() { auto cast = [&](int value) { return static_cast(value); }; auto to_literal = [&](absl::Span values) { return LiteralUtil::CreateR1(values); @@ -183,39 +183,39 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceSingleOutput_float32) { } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int8) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint8) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_int64) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_uint64) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_float32) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_double) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceTwoReplicasOneOperand_half) { - TestAllOps(); + TestAllOpsForReduce(); } XLA_TEST_F(CollectiveOpsTest, AllReduceAnd_Pred) { @@ -593,6 +593,98 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { results[3])); } +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_EmptyReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + c = f32[2] constant({30, 30}) + d = f32[2] constant({40, 40}) + all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + c_prime = f32[2] get-tuple-element(all2all), index=2 + d_prime = f32[2] get-tuple-element(all2all), index=3 + ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({10, 10, 20, 20, 30, 30, 40, 40}), + results[i], ErrorSpec{1e-5, 1e-5})); + } +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_OrderedReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + c = f32[2] constant({30, 30}) + d = f32[2] constant({40, 40}) + all2all = (f32[2], f32[2], f32[2], f32[2]) all-to-all(a, b, c, d), replica_groups={{3,2,1,0}} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + c_prime = f32[2] get-tuple-element(all2all), index=2 + d_prime = f32[2] get-tuple-element(all2all), index=3 + ROOT out = f32[8] concatenate(a_prime, b_prime, c_prime, d_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({40, 40, 30, 30, 20, 20, 10, 10}), + results[i], ErrorSpec{1e-5, 1e-5})); + } +} + +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_GPU(AllToAll_TwoReplicaGroups)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a = f32[2] constant({10, 10}) + b = f32[2] constant({20, 20}) + all2all = (f32[2], f32[2]) all-to-all(a, b), replica_groups={{2,1},{3,0}} + a_prime = f32[2] get-tuple-element(all2all), index=0 + b_prime = f32[2] get-tuple-element(all2all), index=1 + ROOT out = f32[4] concatenate(a_prime, b_prime), dimensions={0} + } + )"; + const int64 kNumReplicas = 4; + auto config = GetModuleConfigForTest(kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(std::move(module), {}, kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (int i = 0; i < kNumReplicas; i++) { + EXPECT_TRUE(LiteralTestUtil::NearOrEqual( + LiteralUtil::CreateR1({20, 20, 10, 10}), results[i], + ErrorSpec{1e-5, 1e-5})); + } +} + XLA_TEST_F(CollectiveOpsTest, AllReduce_TupleAllReduce) { std::string hlo_string = R"( HloModule test