Use executor to execute ops in sync mode as well. A few benefits for doing this:

1. Unify sync/async code path a little bit more.
2. Allow executor to capture error in AsyncEagerNode. We may need to find a way to propagate this error to python layer in the future.
3. Allow EagerContext wait for all pending AsyncEagerNode in shutdown.

PiperOrigin-RevId: 268956697
This commit is contained in:
Xiao Yu 2019-09-13 12:22:46 -07:00 committed by TensorFlower Gardener
parent c3e32b03e1
commit bfb4c9bcb2
8 changed files with 133 additions and 125 deletions

View File

@ -35,7 +35,7 @@ EagerExecutor::~EagerExecutor() {
Status EagerExecutor::ShutDown() {
{
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
bool has_thread;
Status status;
{
@ -47,7 +47,9 @@ Status EagerExecutor::ShutDown() {
// thread_exited_notification_.WaitForNotification() below.
state_ = ExecutorState::kShuttingDown;
}
WaitForOrDestroyAllPendingNodes(&l, &items_to_destroy);
// It is OK to ignore the returned status here because it will be saved
// as the final status_.
WaitForAllPendingNodesLocked(&l).IgnoreError();
state_ = ExecutorState::kShutDown;
has_thread = thread_ != nullptr;
status = status_;
@ -68,36 +70,6 @@ Status EagerExecutor::ShutDown() {
return status_;
}
void EagerExecutor::WaitForOrDestroyAllPendingNodes(
mutex_lock* lock,
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy) {
if (state_ == ExecutorState::kShutDown) {
return;
}
if (thread_ == nullptr) {
Status status = status_;
if (status.ok()) {
status = errors::FailedPrecondition(
"Aborting eager nodes because EagerExecutor is being shut down "
"before it got a thread to run the nodes");
status_ = status;
}
while (!node_queue_.empty()) {
nodes_to_destroy->push_back(std::move(node_queue_.front()));
node_queue_.pop();
}
for (auto& it : unfinished_nodes_) {
nodes_to_destroy->push_back(absl::WrapUnique(it.second));
}
unfinished_nodes_.clear();
return;
}
// It is OK to ignore the returned status here because it will be saved
// as the final status_.
WaitForAllPendingNodesLocked(lock).IgnoreError();
}
bool EagerExecutor::Async() const {
return thread_ != nullptr;
}
@ -113,15 +85,19 @@ const char* EagerExecutor::StateStringLocked() {
}
}
Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
Status status;
core::RefCountPtr<NodeItem> item(new NodeItem);
item->id = next_node_id_++;
item->node = std::move(node);
item->state = NodeState::kPENDING;
// If we are unable to add the node to the queue, we must call Abort. However,
// we want to do that outside of the scope of the lock since the Abort may
// try to call EagerExecutor::Add()
{
tensorflow::mutex_lock l(node_queue_mutex_);
VLOG(3) << "Add node [id " << next_node_id_ << "]" << node->DebugString()
VLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
<< " with status: " << status_.ToString();
if (state_ != ExecutorState::kActive) {
status = errors::FailedPrecondition(
@ -129,16 +105,11 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
"Current state is '",
StateStringLocked(), "'");
} else {
DCHECK(thread_) << "EnableAsync should have been called before Add";
status = status_;
if (status.ok()) {
auto item = absl::make_unique<NodeItem>();
item->id = next_node_id_++;
item->node = std::move(node);
if (status.ok() && Async()) {
node_queue_.push(std::move(item));
// If there were no previous nodes pending, wake the run thread to start
// processing requests again.
// If there were no previous nodes pending, wake the run thread to
// start processing requests again.
if (node_queue_.size() == 1) {
nodes_pending_.notify_all();
}
@ -148,9 +119,17 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
}
}
// Node needs to be aborted since it was not added to the queue
node->Abort(status);
return status;
if (status.ok()) {
// Inline execution in sync mode.
DCHECK(!Async());
RunItem(std::move(item));
status = this->status();
return status;
} else {
// Node needs to be aborted since it was not added to the queue
item->node->Abort(status);
return status;
}
}
tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
@ -165,6 +144,8 @@ tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
if (!status_.ok()) return status_;
if (node_queue_.empty() && unfinished_nodes_.empty())
return tensorflow::Status::OK();
// node_queue_ must be empty in sync mode.
DCHECK(Async() || node_queue_.empty());
auto last_id = next_node_id_ - 1;
VLOG(3) << "Wait for Node: [id " << last_id << "] ";
node_done_notifications_.insert(std::make_pair(last_id, &cond));
@ -191,25 +172,30 @@ tensorflow::Status EagerExecutor::status() const {
return status_;
}
void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
void EagerExecutor::NodeDone(core::RefCountPtr<NodeItem> item,
const Status& status) {
VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
<< " with status: " << status.ToString();
std::unique_ptr<NodeItem> current_item;
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
DCHECK(item->state != NodeState::kDONE);
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
{
mutex_lock l(node_queue_mutex_);
auto previous_state = item->state;
item->state = NodeState::kDONE;
if (!status_.ok()) return;
bool need_notification = false;
if (!node_queue_.empty() && item == node_queue_.front().get()) {
need_notification = unfinished_nodes_.empty();
current_item = std::move(node_queue_.front());
node_queue_.pop();
if (previous_state == NodeState::kPENDING) {
if (Async()) {
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
need_notification = unfinished_nodes_.empty();
node_queue_.pop();
} else {
need_notification = unfinished_nodes_.empty();
}
} else {
DCHECK(!unfinished_nodes_.empty());
need_notification = item->id == unfinished_nodes_.begin()->first;
auto erase_result = unfinished_nodes_.erase(item->id);
DCHECK_GT(erase_result, 0);
current_item = absl::WrapUnique(item);
auto result = unfinished_nodes_.erase(item->id);
DCHECK_GT(result, 0);
}
if (!status.ok()) {
need_notification = true;
@ -217,7 +203,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
// We remove any pending ops so that we don't try to execute them if
// ClearError is called.
errors::AppendToMessage(&status_,
". Encountered when executing an operation using "
"Encountered when executing an operation using "
"EagerExecutor. This error cancels all future "
"operations and poisons their output tensors.");
while (!node_queue_.empty()) {
@ -225,7 +211,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
node_queue_.pop();
}
for (auto& it : unfinished_nodes_) {
items_to_destroy.push_back(absl::WrapUnique(it.second));
items_to_destroy.push_back(std::move(it.second));
}
unfinished_nodes_.clear();
}
@ -238,6 +224,8 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
} else {
upperbound_id = next_node_id_ - 1;
}
VLOG(3) << "Notify node done: [id " << item->id << " to " << upperbound_id
<< "] ";
// Note that we notify all waiting threads in case an error has
// occurred. These calling threads are responsible for checking status_
// before proceeding.
@ -266,7 +254,7 @@ void EagerExecutor::Run() {
auto thread_exited_notifier =
gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
while (true) {
NodeItem* curr_item_raw;
core::RefCountPtr<NodeItem> curr_item;
{
tensorflow::mutex_lock l(node_queue_mutex_);
while (node_queue_.empty() || !status_.ok()) {
@ -280,30 +268,39 @@ void EagerExecutor::Run() {
// will then contain a nullptr. This can be a problem in
// WaitForAllPendingNodes where we get the top EagerNode pointer
// and register a notification for its completion.
curr_item_raw = node_queue_.front().get();
curr_item.reset(node_queue_.front().get());
curr_item->Ref();
}
VLOG(3) << "Running Node: [id " << curr_item_raw->id << "] "
<< curr_item_raw->node->DebugString();
AsyncEagerNode* async_node_raw = curr_item_raw->node->AsAsync();
if (async_node_raw == nullptr) {
tensorflow::Status status = curr_item_raw->node->Run();
NodeDone(curr_item_raw, status);
} else {
async_node_raw->RunAsync([this, curr_item_raw](const Status& status) {
NodeDone(curr_item_raw, status);
});
{
tensorflow::mutex_lock l(node_queue_mutex_);
// If false, NodeDone has been called.
if (!node_queue_.empty() &&
curr_item_raw == node_queue_.front().get()) {
node_queue_.front().release();
node_queue_.pop();
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(),
curr_item_raw->id, curr_item_raw);
}
}
RunItem(std::move(curr_item));
}
}
void EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item) {
VLOG(3) << "Running Node: [id " << item->id << "] "
<< item->node->DebugString();
AsyncEagerNode* async_node = item->node->AsAsync();
if (async_node == nullptr) {
core::RefCountPtr<NodeItem> new_ref(item.get());
new_ref->Ref();
tensorflow::Status status = item->node->Run();
NodeDone(std::move(new_ref), status);
} else {
auto* new_ref = item.get();
new_ref->Ref();
async_node->RunAsync([this, new_ref](const Status& status) {
core::RefCountPtr<NodeItem> new_item(new_ref);
NodeDone(std::move(new_item), status);
});
}
tensorflow::mutex_lock l(node_queue_mutex_);
if (item->state == NodeState::kPENDING) {
item->state = NodeState::kSCHEDULED;
if (!node_queue_.empty() && item.get() == node_queue_.front().get()) {
node_queue_.pop();
}
VLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
std::move(item));
}
}

View File

@ -27,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
@ -73,12 +75,8 @@ class AsyncEagerNode : public EagerNode {
AsyncEagerNode* AsAsync() final { return this; }
// This is non-blocking. It returns the scheduling status.
// TODO(fishx): avoid calling this AsyncEagerNode::Run.
Status Run() final {
std::shared_ptr<Status> status(new Status);
RunAsync([status](const Status& s) { status->Update(s); });
return *status;
return errors::Unimplemented("Don't call AsyncEagerNode::Run().");
}
};
@ -105,10 +103,11 @@ class EagerExecutor {
bool Async() const;
// Schedules `node` for execution. If an error occurs (e.g. EagerExecutor
// has already been shut down), the `node` is not added to this executor
// and its Abort() method is called.
Status Add(std::unique_ptr<EagerNode> node);
// - Async Mode: schedules `node` for execution.
// - Sync Mode: inline execute the 'node' directly.
// If an error occurs (e.g. EagerExecutor has already been shut down), the
// `node` is not added to this executor and its Abort() method is called.
Status AddOrExecute(std::unique_ptr<EagerNode> node);
// Blocks till all currently pending ops are done.
// In particular, if EnableAsync() has not beed called, it will not return
@ -139,16 +138,23 @@ class EagerExecutor {
kShutDown,
};
struct NodeItem {
enum class NodeState {
kPENDING,
kSCHEDULED,
kDONE,
};
struct NodeItem : core::RefCounted {
// Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
// means item1.node is added before item2.node.
uint64 id;
std::unique_ptr<EagerNode> node;
NodeState state;
};
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
void NodeDone(NodeItem* item, const Status& status);
void NodeDone(core::RefCountPtr<NodeItem> item, const Status& status);
// Starts execution of pending EagerNodes. This function loops till
// thread_done_ is set to true. If any errors are encontered, these are set
@ -156,20 +162,13 @@ class EagerExecutor {
// `status_` is not ok.
void Run();
void RunItem(core::RefCountPtr<NodeItem> item);
// The impl of WaitForAllPendingNodes
// `lock` is the lock that holds node_queue_mutex_.
Status WaitForAllPendingNodesLocked(mutex_lock* lock)
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
// If async has been enabled on this executor, just calls
// WaitForAllPendingNodes. Else sets the status_ to an error if it does not
// already contain one `lock` is the lock that holds node_queue_mutex_.
// Precondition: state_ != kActive.
void WaitForOrDestroyAllPendingNodes(
mutex_lock* lock,
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy)
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
Status WaitImpl(bool wait_all, uint64 node_id);
std::atomic<uint64> next_node_id_;
@ -180,12 +179,12 @@ class EagerExecutor {
condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
// Queue of pending NodeItems. Ordered by NodeItem::id.
std::queue<std::unique_ptr<NodeItem>> node_queue_
std::queue<core::RefCountPtr<NodeItem>> node_queue_
GUARDED_BY(node_queue_mutex_);
// Owned the NodeItem in it. Ordered by NodeItem::id.
std::map<uint64, NodeItem*, std::less<uint64>> unfinished_nodes_
GUARDED_BY(node_queue_mutex_);
// Ordered by NodeItem::id.
std::map<uint64, core::RefCountPtr<NodeItem>, std::less<uint64>>
unfinished_nodes_ GUARDED_BY(node_queue_mutex_);
// `status_` is set based on any errors raised during execution of a
// EagerNode. It remains set until ClearError is called.

View File

@ -362,6 +362,7 @@ Status GetDeviceForInput(const EagerContext* ctx, TensorHandle* tensor_handle,
// Use the resource's actual device because it is the device that will
// influence partitioning the multi-device function.
const Tensor* tensor;
// TODO(fishx): Avoid blocking here.
TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
device_name = handle.device();
@ -639,7 +640,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// input handles are ready before executing them.
// TODO(b/137118203): Consider executing "cheap" kernels inline for
// performance.
Status s = executor.Async() ? executor.Add(std::move(node)) : node->Run();
Status s = executor.AddOrExecute(std::move(node));
// Since the operation failed, we need to Unref any outputs that were
// allocated.
if (!s.ok()) {
@ -765,14 +766,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
}
auto& executor = op->Executor();
bool is_async = executor.Async();
VLOG(4) << "Execute remote eager op: " << op->Name()
<< " (is async?: " << is_async << ").";
<< " (is async?: " << executor.Async() << ").";
std::unique_ptr<EagerNode> node(
new eager::RemoteExecuteNode(std::move(request), op_device, eager_client,
op->Inputs(), {retvals, num_outputs}));
Status s = is_async ? executor.Add(std::move(node)) : node->Run();
Status s = executor.AddOrExecute(std::move(node));
// Since the operation failed, we need to Unref any outputs that were
// allocated.
if (!s.ok()) {
@ -904,6 +904,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
if (!op->Executor().Async()) {
// In sync mode, always clear error to maintain the same behavior as before.
// TODO(b/141004939): Remove this.
op->Executor().ClearError();
}
std::unique_ptr<tensorflow::EagerOperation> out_op;
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
@ -1044,7 +1050,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
// Note that `h` may not be currently ready. However execution order will
// make sure that `h` is ready before the copy is actually done.
std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, ctx));
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
Status s = executor->AddOrExecute(std::move(node));
// Since the operation failed, we need to Unref any outputs that were
// allocated.
if (!s.ok()) {
@ -1065,6 +1071,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
bool recver_is_local = device->IsLocal();
if (!executor->Async()) {
// In sync mode, always clear error to maintain the same behavior as before.
// TODO(b/141004939): Remove this.
executor->ClearError();
}
if (sender_is_local && recver_is_local) {
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
} else {
@ -1105,7 +1117,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
}
auto node = absl::make_unique<eager::RemoteCopyNode>(
ctx, executor, h, result[0], device, recv_op_id);
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
Status s = executor->AddOrExecute(std::move(node));
if (!s.ok()) {
result[0]->Unref();
}

View File

@ -296,8 +296,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
item.handle_to_decref());
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
context, std::move(handle_to_decref));
s = executor.Async() ? context->Context()->Executor().Add(std::move(node))
: node->Run();
s = context->Context()->Executor().AddOrExecute(std::move(node));
} else {
s = SendTensor(item.send_tensor(), context->Context());
}

View File

@ -47,7 +47,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
eager_client, ready));
auto& executor = ctx->Executor();
if (executor.Async()) {
Status status = executor.Add(std::move(node));
Status status = executor.AddOrExecute(std::move(node));
if (!status.ok()) {
LOG(ERROR) << "Unable to destroy remote tensor handles: "
<< status.error_message();
@ -56,13 +56,13 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
// This thread may still hold tensorflow::StreamingRPCState::mu_. We need
// to send out the destroy request in a new thread to avoid deadlock.
auto* released_node = node.release();
(*ctx->runner())([released_node] {
Status status = released_node->Run();
(*ctx->runner())([ctx, released_node] {
Status status =
ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
if (!status.ok()) {
LOG(ERROR) << "Unable to destroy remote tensor handles: "
<< status.error_message();
}
delete released_node;
});
}
}

View File

@ -523,7 +523,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
if context.executing_eagerly():
err_spec = errors.InvalidArgumentError, (
'Each record default should be at '
'most rank 1.')
'most rank 1')
else:
err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'

View File

@ -73,7 +73,7 @@ class DecodeCSVOpTest(test.TestCase):
if context.executing_eagerly():
err_spec = errors.InvalidArgumentError, (
"Each record default should be at "
"most rank 1.")
"most rank 1")
else:
err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
with self.assertRaisesWithPredicateMatch(*err_spec):

View File

@ -107,9 +107,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[1], name="foo")
resource_variable_ops.assign_variable_op(handle, 1)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Trying to read variable with wrong dtype. "
"Expected float got int32."):
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Trying to read variable with wrong dtype. "
"Expected float got int32"):
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
def testEagerInitializedValue(self):
@ -195,9 +196,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
dtype=dtypes.int32, shape=[1], name="foo")
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([1]))
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"Trying to assign variable with wrong "
"dtype. Expected int32 got float."):
with self.assertRaisesRegexp(
errors.InvalidArgumentError, "Trying to assign variable with wrong "
"dtype. Expected int32 got float"):
resource_variable_ops.assign_variable_op(
handle, constant_op.constant([1.], dtype=dtypes.float32))