Add check for events_ not containing the event we are waiting on because it has already completed

PiperOrigin-RevId: 337185599
Change-Id: I9c73388c0a99c2abbc52aef4e7bf2c61656e8199
This commit is contained in:
Frank Chen 2020-10-14 15:28:50 -07:00 committed by TensorFlower Gardener
parent 465aeca042
commit 82f4f50f4f

View File

@ -736,24 +736,38 @@ class PodTpuDriver : public TpuDriver {
auto done = [this, event_id]() { auto done = [this, event_id]() {
mu_.AssertHeld(); mu_.AssertHeld();
if (events_.count(event_id) == 0) { // The event was either completed and erased from the map or we have
LOG(ERROR) << "Cannot find event id " << event_id // an underlying event available to us.
<< " in WaitForEvent."; return events_.count(event_id) == 0 ||
} (events_[event_id]->underlying_event != nullptr &&
return events_[event_id]->underlying_event != nullptr && events_[event_id]->underlying_event.use_count() != 0);
events_[event_id]->underlying_event.use_count() != 0;
}; };
auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration);
if (!status) { if (!status) {
return absl::nullopt; return absl::nullopt;
} }
underlying_event = events_[event_id]->underlying_event;
if (events_.count(event_id) > 0) {
underlying_event = events_[event_id]->underlying_event;
} else {
underlying_event = nullptr;
}
} }
// Wait for the underlying event without holding on to the event_lock_, or // Wait for the underlying event without holding on to the event_lock_, or
// else incoming events will not be processed. // else incoming events will not be processed.
return underlying_event->AwaitWithTimeout(duration); if (underlying_event != nullptr) {
return underlying_event->AwaitWithTimeout(duration);
} else {
absl::MutexLock l(&mu_);
auto event_status = abnormal_event_status_.find(event_id);
if (event_status == abnormal_event_status_.end()) {
return Status::OK();
} else {
return event_status->second;
}
}
} }
void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn) void AddCallbackForEvent(int64_t event_id, std::function<void(Status)> fn)