Fix race in //tensorflow/c/eager:c_api_remote_test

There were a couple of race conditions in the code:
- Update item->state outside lock before RunAsync since the callback
  could be executed right away.
- Put node onto unfinished list before RunAsync since the callback could
  be executed right away
- Since we're going to acquire the lock in NodeDone anyways, move the
  status check in the lock.

Further cleaned up NodeDone to only do the node_queue_ pop if indicated
by the caller.

We also optimize NodeDone to return without acquiring the lock in sync
mode.

PiperOrigin-RevId: 282437248
Change-Id: Ia30b4846108c51eae630587534ec70daf2d08e5a
This commit is contained in:
Gaurav Jain 2019-11-25 14:45:01 -08:00 committed by TensorFlower Gardener
parent 8cfec1b86a
commit 2a16758ee4
2 changed files with 63 additions and 36 deletions

View File

@ -131,7 +131,7 @@ Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
if (!Async()) {
status = this->status();
if (status.ok()) {
status = RunItem(std::move(item));
status = RunItem(std::move(item), false);
}
return status;
} else {
@ -204,32 +204,41 @@ void EagerExecutor::ClearError() {
}
void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
const Status& status) {
const Status& status, bool from_queue) {
DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
<< " with status: " << status.ToString();
DCHECK(item->state != NodeState::kDONE);
auto previous_state = item->state;
item->state = NodeState::kDONE;
if (!ok()) return;
bool async = item->node->AsAsync() != nullptr;
// If executing synchronously we don't need to notify if status is OK since
// the node was never added to the unfinished_nodes_ list and nobody should
// ever be waiting for it.
if (status.ok() && !from_queue && !async) {
return;
}
std::forward_list<core::RefCountPtr<NodeItem>> items_to_destroy;
{
mutex_lock l(node_queue_mutex_);
bool need_notification = false;
if (previous_state == NodeState::kPENDING) {
if (Async()) {
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
need_notification = unfinished_nodes_.empty();
node_queue_.pop();
} else {
need_notification = unfinished_nodes_.empty();
}
} else {
if (!status_.ok()) return;
bool need_notification = from_queue;
if (from_queue) {
// Since this was from the async queue, pop it from the front of ht queue.
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
node_queue_.pop();
} else if (async) {
// If it is an Async node then we will find the node in the unfinished
// nodes list. However we only notify if we are at the front of the list
// since we don't want to notify any waiters of earlier nodes.
need_notification = item->id == unfinished_nodes_.begin()->first;
auto result = unfinished_nodes_.erase(item->id);
DCHECK_GT(result, 0);
}
if (!status.ok()) {
// Since we received an error, broadcast to any waiters.
need_notification = true;
status_ = status;
ok_ = false;
@ -254,6 +263,7 @@ void EagerExecutor::NodeDone(const core::RefCountPtr<NodeItem>& item,
NotifyWaiters(item->id);
}
}
for (auto& item : items_to_destroy) {
item->node->Abort(status);
}
@ -312,41 +322,56 @@ void EagerExecutor::Run() {
curr_item.reset(node_queue_.front().get());
curr_item->Ref();
}
Status status = RunItem(std::move(curr_item));
Status status = RunItem(std::move(curr_item), true);
if (!status.ok()) {
VLOG(1) << "Failed to run item: " << status;
}
}
}
Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item) {
Status EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item,
bool from_queue) {
DVLOG(3) << "Running Node: [id " << item->id << "] "
<< item->node->DebugString();
AsyncEagerNode* async_node = item->node->AsAsync();
if (async_node == nullptr) {
tensorflow::Status status = item->node->Run();
NodeDone(item, status);
NodeDone(item, status, from_queue);
return status;
} else {
auto* new_ref = item.get();
new_ref->Ref();
async_node->RunAsync([this, new_ref](const Status& status) {
core::RefCountPtr<NodeItem> new_item(new_ref);
NodeDone(new_item, status);
});
}
tensorflow::mutex_lock l(node_queue_mutex_);
if (item->state == NodeState::kPENDING) {
item->state = NodeState::kSCHEDULED;
if (!node_queue_.empty() && item.get() == node_queue_.front().get()) {
node_queue_.pop();
}
DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
std::move(item));
}
item->state = NodeState::kSCHEDULED;
auto async_ref = item.get();
async_ref->Ref();
TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue));
async_node->RunAsync([this, async_ref](const Status& status) {
core::RefCountPtr<NodeItem> async_item(async_ref);
NodeDone(async_item, status, false);
});
// Return the status of the executor in case we are in an error state.
return status();
}
Status EagerExecutor::MoveToUnfinished(core::RefCountPtr<NodeItem> item,
bool from_queue) {
tensorflow::mutex_lock l(node_queue_mutex_);
if (!status_.ok()) {
return status_;
}
if (from_queue) {
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
node_queue_.pop();
}
DVLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
std::move(item));
return Status::OK();
}
} // namespace tensorflow

View File

@ -167,7 +167,8 @@ class EagerExecutor {
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status);
void NodeDone(const core::RefCountPtr<NodeItem>& item, const Status& status,
bool from_queue);
void NotifyWaiters(uint64 id) EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
// Starts execution of pending EagerNodes. This function loops till
@ -176,7 +177,8 @@ class EagerExecutor {
// `status_` is not ok.
void Run();
Status RunItem(core::RefCountPtr<NodeItem> item);
Status RunItem(core::RefCountPtr<NodeItem> item, bool from_queue);
Status MoveToUnfinished(core::RefCountPtr<NodeItem> item, bool from_queue);
// The impl of WaitForAllPendingNodes
// `lock` is the lock that holds node_queue_mutex_.