Improve error message after collective ops is aborted

1. Protect BaseCollectiveExecutor::StartAbort() with a mutex so each sub component is aborted with the same status.
2. Clarify in the error message that the error could be from a previous operation.
3. Finish the callback with the aborted status if applicable. It's more likely the original error instead of an artifact of the abortion.
4. Mark the status as derived, so that if the abortion is triggered by some other op's error, the original one gets surfaced.

PiperOrigin-RevId: 338187200
Change-Id: I5c49de3c18579e4cf0da6371b48bdc87941580f1
This commit is contained in:
Ran Chen 2020-10-20 20:07:34 -07:00 committed by TensorFlower Gardener
parent b8a452090a
commit a1a8c27579
2 changed files with 46 additions and 9 deletions

View File

@ -216,11 +216,43 @@ BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
void BaseCollectiveExecutor::StartAbort(const Status& s) {
VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
cem_->GetParamResolver()->StartAbort(s);
remote_access_->StartAbort(s);
if (cem_->GetNcclCommunicator() != nullptr) {
cem_->GetNcclCommunicator()->StartAbort(s);
Status status;
{
mutex_lock l(status_mu_);
if (!status_.ok()) {
LOG(WARNING)
<< "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
<< s;
return;
}
status_ = StatusGroup::MakeDerived(Status(
s.code(),
absl::StrCat(
"Collective ops is aborted by: ", s.error_message(),
"\nThe error could be from a previous operation. Restart your "
"program to reset.")));
status = status_;
}
cem_->GetParamResolver()->StartAbort(status);
remote_access_->StartAbort(status);
if (cem_->GetNcclCommunicator() != nullptr) {
cem_->GetNcclCommunicator()->StartAbort(status);
}
}
Status BaseCollectiveExecutor::GetStatus(const Status& s) {
if (s.ok()) return s;
mutex_lock l(status_mu_);
// If the collective executor is already aborted, use the aborted status
// which is more likely the actual error instead of an artifact of an
// abortion.
if (!status_.ok()) {
VLOG(2) << "Overriding status with collective ops executor status. "
"Original status: "
<< s;
return status_;
}
return s;
}
void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
@ -229,10 +261,10 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
StatusCallback done) {
// See CompleteParamsAsync() how done() and the timeout callback interacts.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [done, is_callback_called](const Status& s) {
auto done_safe = [this, done, is_callback_called](const Status& s) {
bool called = is_callback_called->exchange(true);
if (!called) {
done(s);
done(GetStatus(s));
}
};
auto timeout_microseconds = static_cast<int64>(
@ -309,10 +341,10 @@ void BaseCollectiveExecutor::CompleteParamsAsync(
// timeout callback executes, done_safe will become a no-op and the timeout
// callback is responsible for invoking done() at the end.
const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
auto done_safe = [done, is_callback_called](const Status& s) {
auto done_safe = [this, is_callback_called, done](const Status& s) {
bool called = is_callback_called->exchange(true);
if (!called) {
done(s);
done(GetStatus(s));
}
};
auto timeout_microseconds =

View File

@ -108,7 +108,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
~BaseCollectiveExecutor() override;
void StartAbort(const Status& s) override;
void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_);
void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
const string& exec_key, StatusCallback done) override;
@ -148,6 +148,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
// collective instance key -> number of local devices for which NCCL ops have
// been launched.
std::unordered_map<int32, int32> launched_ TF_GUARDED_BY(launch_mu_);
mutex status_mu_;
Status status_ TF_GUARDED_BY(status_mu_);
private:
Status CreateCollective(const CollectiveParams& col_params,
@ -155,6 +157,9 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
// Check if all ops on which this collective depends on have launched.
bool CheckDependencies(const CollectiveParams& col_params)
TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_);
// Tries to return the status that is the original error. It returns the
// aborted status if the collective executor is aborted.
Status GetStatus(const Status& s) TF_LOCKS_EXCLUDED(status_mu_);
};
} // namespace tensorflow