Atomically update status_ in permuter.
PiperOrigin-RevId: 327527922 Change-Id: I588fec50b41354d02a85373bdcc55cf2fe3a96f4
This commit is contained in:
parent
311c8d233b
commit
9ddea770cd
@ -39,17 +39,14 @@ namespace tensorflow {
|
||||
Permuter::Permuter()
|
||||
: col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {}
|
||||
|
||||
bool Permuter::CheckCounter() {
|
||||
mutex_lock lock(mu_counter_);
|
||||
++counter_;
|
||||
if (counter_ == 2) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
StatusCallback Permuter::HalfDone() {
|
||||
StatusCallback Permuter::CheckCounterAndCallDone() {
|
||||
return [this](const Status& s) {
|
||||
mu_.lock();
|
||||
status_.Update(s);
|
||||
if (CheckCounter()) done_(status_);
|
||||
int counter = ++counter_;
|
||||
Status status = status_;
|
||||
mu_.unlock();
|
||||
if (counter == 2) done_(status);
|
||||
};
|
||||
}
|
||||
|
||||
@ -71,11 +68,11 @@ void Permuter::Run(StatusCallback done) {
|
||||
done_ = std::move(done);
|
||||
DispatchSend(col_params_->default_rank,
|
||||
col_params_->instance.permutation[col_params_->default_rank],
|
||||
col_ctx_->input, HalfDone());
|
||||
col_ctx_->input, CheckCounterAndCallDone());
|
||||
for (int i = 0; i < col_params_->instance.permutation.size(); ++i) {
|
||||
if (col_params_->default_rank == col_params_->instance.permutation[i]) {
|
||||
DispatchRecv(i, col_params_->instance.permutation[i], col_ctx_->output,
|
||||
HalfDone());
|
||||
CheckCounterAndCallDone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -67,9 +67,9 @@ class Permuter : public CollectiveImplementationInterface {
|
||||
std::shared_ptr<CollectiveContext> col_ctx_;
|
||||
const CollectiveParams* col_params_; // Not owned
|
||||
StatusCallback done_;
|
||||
Status status_;
|
||||
mutex mu_counter_;
|
||||
int counter_ TF_GUARDED_BY(mu_counter_);
|
||||
mutex mu_;
|
||||
Status status_ TF_GUARDED_BY(mu_);
|
||||
int counter_ TF_GUARDED_BY(mu_);
|
||||
|
||||
void DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
|
||||
const StatusCallback& done);
|
||||
@ -77,12 +77,10 @@ class Permuter : public CollectiveImplementationInterface {
|
||||
void DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
|
||||
const StatusCallback& done);
|
||||
|
||||
// Checks if counter_ reaches 2.
|
||||
// Atomically increments counter_ by one for sending, one for receiving.
|
||||
// The purpose of this check is to ensure that done_ is called only once.
|
||||
bool CheckCounter();
|
||||
|
||||
StatusCallback HalfDone();
|
||||
// Invokes done when counter_ reaches 2.
|
||||
// The purpose of checking counter_ is to ensure that done_ is called once.
|
||||
StatusCallback CheckCounterAndCallDone();
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user