diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index d7f63cb2663..1552397d4bc 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -208,11 +208,6 @@ template ::value>> class Rendezvous { public: - struct ParticipantImplOutput { - bool is_primary; - O custom_output; - }; - virtual ~Rendezvous() {} explicit Rendezvous(const RendezvousKey& k) : key_(k) {} @@ -246,8 +241,7 @@ class Rendezvous { protected: // Returns domain-specific output O and whether this replica is primary. - virtual StatusOr RunCollectiveOp( - const I& participant) = 0; + virtual StatusOr RunCollectiveOp(const I& participant) = 0; // Initialize the rendezvous by the first ("primary") thread which reaches the // barrier. Returns whether this thread is primary. @@ -300,8 +294,8 @@ class Rendezvous { participant.device_ordinal, participant.stream, key_.ToString()); }); - TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant)); - return std::make_pair(p.custom_output, returned_blocking_counter_); + TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant)); + return std::make_pair(std::move(output), returned_blocking_counter_); } const RendezvousKey key_; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index d4d78f5ac12..437dbd5cbb3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -323,7 +323,7 @@ class CpuAllToAllRendezvous : xla::Rendezvous(k) {} protected: - xla::StatusOr RunCollectiveOp( + xla::StatusOr RunCollectiveOp( const AllToAllParticipantData& /*participant*/) override { bool is_primary = InitializationBarrier(); @@ -373,7 +373,7 @@ class CpuAllToAllRendezvous } } } - return ParticipantImplOutput{is_primary, nullptr}; + return nullptr; } }; @@ -384,7 +384,7 @@ class CpuCollectivePermuteRendezvous : xla::Rendezvous(k) {} protected: - xla::StatusOr RunCollectiveOp( + xla::StatusOr RunCollectiveOp( const CollectivePermuteParticipantData& /*participant*/) override { bool primary = InitializationBarrier(); @@ -415,7 +415,7 @@ class CpuCollectivePermuteRendezvous std::memset(p.destination_data.opaque(), 0, p.byte_size); } } - return ParticipantImplOutput{primary, /*custom_output=*/nullptr}; + return nullptr; } }; @@ -426,7 +426,7 @@ class CpuAllReduceRendezvous : xla::Rendezvous(k) {} protected: - xla::StatusOr RunCollectiveOp( + xla::StatusOr RunCollectiveOp( const xla::AllReduceParticipantData& participant) override { xla::PrimitiveType datatype = participant.buffers.front().primitive_type; bool primary = InitializationBarrier(); @@ -465,7 +465,7 @@ class CpuAllReduceRendezvous LOG(FATAL) << "Unexpected datatype;"; } } - return ParticipantImplOutput{primary, /*custom_output=*/nullptr}; + return nullptr; } private: diff --git a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc index 81f54ba08b6..6b6a9351384 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_utils.cc @@ -223,7 +223,7 @@ class NcclCliqueRendezvous local_participants_(local_participants), callback_(callback) {} - StatusOr RunCollectiveOp( + StatusOr RunCollectiveOp( const NcclCliqueParticipantData&) override { tensorflow::mutex_lock lock(mu_); bool primary = !initialized_; @@ -238,7 +238,7 @@ class NcclCliqueRendezvous if (primary) { lock_ = std::make_shared(clique->mu()); } - return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}}; + return LockedNcclClique{clique, lock_}; } private: