[XLA] NFC: Remove unnecessary ParticipantImplOutput type from Rendezvous.

PiperOrigin-RevId: 346637995
Change-Id: I995b3d4fbc8dfaf89c029b56a22ba4092003c7ee
This commit is contained in:
Chris Jones 2020-12-09 14:27:29 -08:00 committed by TensorFlower Gardener
parent 5eede885e0
commit aad6823650
3 changed files with 11 additions and 17 deletions

View File

@ -208,11 +208,6 @@ template <typename I, typename O,
std::enable_if_t<std::is_base_of<ParticipantData, I>::value>>
class Rendezvous {
public:
struct ParticipantImplOutput {
bool is_primary;
O custom_output;
};
virtual ~Rendezvous() {}
explicit Rendezvous(const RendezvousKey& k) : key_(k) {}
@ -246,8 +241,7 @@ class Rendezvous {
protected:
// Returns domain-specific output O and whether this replica is primary.
virtual StatusOr<ParticipantImplOutput> RunCollectiveOp(
const I& participant) = 0;
virtual StatusOr<O> RunCollectiveOp(const I& participant) = 0;
// Initialize the rendezvous by the first ("primary") thread which reaches the
// barrier. Returns whether this thread is primary.
@ -300,8 +294,8 @@ class Rendezvous {
participant.device_ordinal, participant.stream, key_.ToString());
});
TF_ASSIGN_OR_RETURN(ParticipantImplOutput p, RunCollectiveOp(participant));
return std::make_pair(p.custom_output, returned_blocking_counter_);
TF_ASSIGN_OR_RETURN(O output, RunCollectiveOp(participant));
return std::make_pair(std::move(output), returned_blocking_counter_);
}
const RendezvousKey key_;

View File

@ -323,7 +323,7 @@ class CpuAllToAllRendezvous
: xla::Rendezvous<AllToAllParticipantData, std::nullptr_t>(k) {}
protected:
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
xla::StatusOr<std::nullptr_t> RunCollectiveOp(
const AllToAllParticipantData& /*participant*/) override {
bool is_primary = InitializationBarrier();
@ -373,7 +373,7 @@ class CpuAllToAllRendezvous
}
}
}
return ParticipantImplOutput{is_primary, nullptr};
return nullptr;
}
};
@ -384,7 +384,7 @@ class CpuCollectivePermuteRendezvous
: xla::Rendezvous<CollectivePermuteParticipantData, std::nullptr_t>(k) {}
protected:
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
xla::StatusOr<std::nullptr_t> RunCollectiveOp(
const CollectivePermuteParticipantData& /*participant*/) override {
bool primary = InitializationBarrier();
@ -415,7 +415,7 @@ class CpuCollectivePermuteRendezvous
std::memset(p.destination_data.opaque(), 0, p.byte_size);
}
}
return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
return nullptr;
}
};
@ -426,7 +426,7 @@ class CpuAllReduceRendezvous
: xla::Rendezvous<xla::AllReduceParticipantData, std::nullptr_t>(k) {}
protected:
xla::StatusOr<ParticipantImplOutput> RunCollectiveOp(
xla::StatusOr<std::nullptr_t> RunCollectiveOp(
const xla::AllReduceParticipantData& participant) override {
xla::PrimitiveType datatype = participant.buffers.front().primitive_type;
bool primary = InitializationBarrier();
@ -465,7 +465,7 @@ class CpuAllReduceRendezvous
LOG(FATAL) << "Unexpected datatype;";
}
}
return ParticipantImplOutput{primary, /*custom_output=*/nullptr};
return nullptr;
}
private:

View File

@ -223,7 +223,7 @@ class NcclCliqueRendezvous
local_participants_(local_participants),
callback_(callback) {}
StatusOr<ParticipantImplOutput> RunCollectiveOp(
StatusOr<LockedNcclClique> RunCollectiveOp(
const NcclCliqueParticipantData&) override {
tensorflow::mutex_lock lock(mu_);
bool primary = !initialized_;
@ -238,7 +238,7 @@ class NcclCliqueRendezvous
if (primary) {
lock_ = std::make_shared<absl::MutexLock>(clique->mu());
}
return ParticipantImplOutput{primary, LockedNcclClique{clique, lock_}};
return LockedNcclClique{clique, lock_};
}
private: