Fix race in //tensorflow/c/eager:c_api_remote_test
There were a couple of race conditions in the code: - Update item->state outside lock before RunAsync since the callback could be executed right away. - Put node onto unfinished list before RunAsync since the callback could be executed right away - Since we're going to acquire the lock in NodeDone anyways, move the status check in the lock. Further cleaned up NodeDone to only do the node_queue_ pop if indicated by the caller. We also optimize NodeDone to return without acquiring the lock in sync mode. PiperOrigin-RevId: 282437248 Change-Id: Ia30b4846108c51eae630587534ec70daf2d08e5a
This commit is contained in:
parent
8cfec1b86a
commit
2a16758ee4
@ -131,7 +131,7 @@ Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
|
||||
if (!Async()) {
|
||||
status = this->status();
|
||||
if (status.ok()) {
|
||||
status = RunItem(std::move(item));
|
||||
status = RunItem(std::move(item), false);
|
||||
}
|
||||
return status;
|
||||
} else {
|
||||
@ -204,32 +204,41 @@ void EagerExecutor::ClearError() {
|
||||
}
|
||||
|
||||
void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
|
||||
const Status& status) {
|
||||
const Status& status, bool from_queue) {
|
||||
DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
|
||||
<< " with status: " << status.ToString();
|
||||
DCHECK(item->state != NodeState::kDONE);
|
||||
auto previous_state = item->state;
|
||||
item->state = NodeState::kDONE;
|
||||
if (!ok()) return;
|
||||
|
||||
bool async = item->node->AsAsync() != nullptr;
|
||||
// If executing synchronously we don't need to notify if status is OK since
|
||||
// the node was never added to the unfinished_nodes_ list and nobody should
|
||||
// ever be waiting for it.
|
||||
if (status.ok() && !from_queue && !async) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
|
||||
{
|
||||
mutex_lock l(node_queue_mutex_);
|
||||
bool need_notification = false;
|
||||
if (previous_state == NodeState::kPENDING) {
|
||||
if (Async()) {
|
||||
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
|
||||
need_notification = unfinished_nodes_.empty();
|
||||
node_queue_.pop();
|
||||
} else {
|
||||
need_notification = unfinished_nodes_.empty();
|
||||
}
|
||||
} else {
|
||||
if (!status_.ok()) return;
|
||||
|
||||
bool need_notification = from_queue;
|
||||
if (from_queue) {
|
||||
// Since this was from the async queue, pop it from the front of ht queue.
|
||||
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
|
||||
node_queue_.pop();
|
||||
} else if (async) {
|
||||
// If it is an Async node then we will find the node in the unfinished
|
||||
// nodes list. However we only notify if we are at the front of the list
|
||||
// since we don't want to notify any waiters of earlier nodes.
|
||||
need_notification = item->id == unfinished_nodes_.begin()->first;
|
||||
auto result = unfinished_nodes_.erase(item->id);
|
||||
DCHECK_GT(result, 0);
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
// Since we received an error, broadcast to any waiters.
|
||||
need_notification = true;
|
||||
status_ = status;
|
||||
ok_ = false;
|
||||
@ -254,6 +263,7 @@ void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
|
||||
NotifyWaiters(item->id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& item : items_to_destroy) {
|
||||
item->node->Abort(status);
|
||||
}
|
||||
@ -312,41 +322,56 @@ void EagerExecutor::Run() {
|
||||
curr_item.reset(node_queue_.front().get());
|
||||
curr_item->Ref();
|
||||
}
|
||||
Status status = RunItem(std::move(curr_item));
|
||||
Status status = RunItem(std::move(curr_item), true);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Failed to run item: " << status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item) {
|
||||
Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
|
||||
bool from_queue) {
|
||||
DVLOG(3) << "Running Node: [id " << item->id << "] "
|
||||
<< item->node->DebugString();
|
||||
AsyncEagerNode* async_node = item->node->AsAsync();
|
||||
if (async_node == nullptr) {
|
||||
tensorflow::Status status = item->node->Run();
|
||||
NodeDone(item, status);
|
||||
NodeDone(item, status, from_queue);
|
||||
return status;
|
||||
} else {
|
||||
auto* new_ref = item.get();
|
||||
new_ref->Ref();
|
||||
async_node->RunAsync([this, new_ref](const Status& status) {
|
||||
core::RefCountPtr<NodeItem> new_item(new_ref);
|
||||
NodeDone(new_item, status);
|
||||
});
|
||||
}
|
||||
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
if (item->state == NodeState::kPENDING) {
|
||||
item->state = NodeState::kSCHEDULED;
|
||||
if (!node_queue_.empty() && item.get() == node_queue_.front().get()) {
|
||||
node_queue_.pop();
|
||||
}
|
||||
DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
|
||||
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
|
||||
std::move(item));
|
||||
}
|
||||
item->state = NodeState::kSCHEDULED;
|
||||
auto async_ref = item.get();
|
||||
async_ref->Ref();
|
||||
|
||||
TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
|
||||
|
||||
async_node->RunAsync([this, async_ref](const Status& status) {
|
||||
core::RefCountPtr<NodeItem> async_item(async_ref);
|
||||
NodeDone(async_item, status, false);
|
||||
});
|
||||
|
||||
// Return the status of the executor in case we are in an error state.
|
||||
return status();
|
||||
}
|
||||
|
||||
Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
|
||||
bool from_queue) {
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
if (!status_.ok()) {
|
||||
return status_;
|
||||
}
|
||||
|
||||
if (from_queue) {
|
||||
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
|
||||
node_queue_.pop();
|
||||
}
|
||||
|
||||
DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
|
||||
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
|
||||
std::move(item));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -167,7 +167,8 @@ class EagerExecutor {
|
||||
|
||||
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||
|
||||
void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status);
|
||||
void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status,
|
||||
bool from_queue);
|
||||
void NotifyWaiters(uint64 id) EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||
|
||||
// Starts execution of pending EagerNodes. This function loops till
|
||||
@ -176,7 +177,8 @@ class EagerExecutor {
|
||||
// `status_` is not ok.
|
||||
void Run();
|
||||
|
||||
Status RunItem(core::RefCountPtr<NodeItem> item);
|
||||
Status RunItem(core::RefCountPtr<NodeItem> item, bool from_queue);
|
||||
Status MoveToUnfinished(core::RefCountPtr<NodeItem> item, bool from_queue);
|
||||
|
||||
// The impl of WaitForAllPendingNodes
|
||||
// `lock` is the lock that holds node_queue_mutex_.
|
||||
|
Loading…
Reference in New Issue
Block a user