Atomically update status_ in permuter.

PiperOrigin-RevId: 327527922
Change-Id: I588fec50b41354d02a85373bdcc55cf2fe3a96f4
This commit is contained in:
A. Unique TensorFlower 2020-08-19 16:17:55 -07:00 committed by TensorFlower Gardener
parent 311c8d233b
commit 9ddea770cd
2 changed files with 14 additions and 19 deletions

View File

@ -39,17 +39,14 @@ namespace tensorflow {
Permuter::Permuter()
: col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {}
bool Permuter::CheckCounter() {
mutex_lock lock(mu_counter_);
++counter_;
if (counter_ == 2) return true;
return false;
}
StatusCallback Permuter::HalfDone() {
StatusCallback Permuter::CheckCounterAndCallDone() {
return [this](const Status& s) {
mu_.lock();
status_.Update(s);
if (CheckCounter()) done_(status_);
int counter = ++counter_;
Status status = status_;
mu_.unlock();
if (counter == 2) done_(status);
};
}
@ -71,11 +68,11 @@ void Permuter::Run(StatusCallback done) {
done_ = std::move(done);
DispatchSend(col_params_->default_rank,
col_params_->instance.permutation[col_params_->default_rank],
col_ctx_->input, HalfDone());
col_ctx_->input, CheckCounterAndCallDone());
for (int i = 0; i < col_params_->instance.permutation.size(); ++i) {
if (col_params_->default_rank == col_params_->instance.permutation[i]) {
DispatchRecv(i, col_params_->instance.permutation[i], col_ctx_->output,
HalfDone());
CheckCounterAndCallDone());
}
}
}

View File

@ -67,9 +67,9 @@ class Permuter : public CollectiveImplementationInterface {
std::shared_ptr<CollectiveContext> col_ctx_;
const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
Status status_;
mutex mu_counter_;
int counter_ TF_GUARDED_BY(mu_counter_);
mutex mu_;
Status status_ TF_GUARDED_BY(mu_);
int counter_ TF_GUARDED_BY(mu_);
void DispatchSend(int src_rank, int target_rank, const Tensor* tensor,
const StatusCallback& done);
@ -77,12 +77,10 @@ class Permuter : public CollectiveImplementationInterface {
void DispatchRecv(int src_rank, int target_rank, Tensor* tensor,
const StatusCallback& done);
// Checks if counter_ reaches 2.
// Atomically increments counter_ by one for sending, one for receiving.
// The purpose of this check is to ensure that done_ is called only once.
bool CheckCounter();
StatusCallback HalfDone();
// Invokes done when counter_ reaches 2.
// The purpose of checking counter_ is to ensure that done_ is called once.
StatusCallback CheckCounterAndCallDone();
};
} // namespace tensorflow