From 82f4f50f4fbfd74aceb741ff097d6c42688b5023 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 14 Oct 2020 15:28:50 -0700 Subject: [PATCH] Add check for events_ not containing the event we are waiting on because it has already completed PiperOrigin-RevId: 337185599 Change-Id: I9c73388c0a99c2abbc52aef4e7bf2c61656e8199 --- .../xla/python/tpu_driver/pod_tpu_driver.cc | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc index c462b78942e..a5a6cbabb82 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/pod_tpu_driver.cc @@ -736,24 +736,38 @@ class PodTpuDriver : public TpuDriver { auto done = [this, event_id]() { mu_.AssertHeld(); - if (events_.count(event_id) == 0) { - LOG(ERROR) << "Cannot find event id " << event_id - << " in WaitForEvent."; - } - return events_[event_id]->underlying_event != nullptr && - events_[event_id]->underlying_event.use_count() != 0; + // The event was either completed and erased from the map or we have + // an underlying event available to us. + return events_.count(event_id) == 0 || + (events_[event_id]->underlying_event != nullptr && + events_[event_id]->underlying_event.use_count() != 0); }; auto status = mu_.AwaitWithTimeout(absl::Condition(&done), duration); if (!status) { return absl::nullopt; } - underlying_event = events_[event_id]->underlying_event; + + if (events_.count(event_id) > 0) { + underlying_event = events_[event_id]->underlying_event; + } else { + underlying_event = nullptr; + } } // Wait for the underlying event without holding on to the event_lock_, or // else incoming events will not be processed. - return underlying_event->AwaitWithTimeout(duration); + if (underlying_event != nullptr) { + return underlying_event->AwaitWithTimeout(duration); + } else { + absl::MutexLock l(&mu_); + auto event_status = abnormal_event_status_.find(event_id); + if (event_status == abnormal_event_status_.end()) { + return Status::OK(); + } else { + return event_status->second; + } + } } void AddCallbackForEvent(int64_t event_id, std::function fn)