diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index 2e88c76d9fe..ff61986c29d 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -35,7 +35,7 @@ EagerExecutor::~EagerExecutor() { Status EagerExecutor::ShutDown() { { - std::vector> items_to_destroy; + std::vector> items_to_destroy; bool has_thread; Status status; { @@ -47,7 +47,9 @@ Status EagerExecutor::ShutDown() { // thread_exited_notification_.WaitForNotification() below. state_ = ExecutorState::kShuttingDown; } - WaitForOrDestroyAllPendingNodes(&l, &items_to_destroy); + // It is OK to ignore the returned status here because it will be saved + // as the final status_. + WaitForAllPendingNodesLocked(&l).IgnoreError(); state_ = ExecutorState::kShutDown; has_thread = thread_ != nullptr; status = status_; @@ -68,36 +70,6 @@ Status EagerExecutor::ShutDown() { return status_; } -void EagerExecutor::WaitForOrDestroyAllPendingNodes( - mutex_lock* lock, - std::vector>* nodes_to_destroy) { - if (state_ == ExecutorState::kShutDown) { - return; - } - if (thread_ == nullptr) { - Status status = status_; - if (status.ok()) { - status = errors::FailedPrecondition( - "Aborting eager nodes because EagerExecutor is being shut down " - "before it got a thread to run the nodes"); - status_ = status; - } - while (!node_queue_.empty()) { - nodes_to_destroy->push_back(std::move(node_queue_.front())); - node_queue_.pop(); - } - for (auto& it : unfinished_nodes_) { - nodes_to_destroy->push_back(absl::WrapUnique(it.second)); - } - unfinished_nodes_.clear(); - return; - } - - // It is OK to ignore the returned status here because it will be saved - // as the final status_. - WaitForAllPendingNodesLocked(lock).IgnoreError(); -} - bool EagerExecutor::Async() const { return thread_ != nullptr; } @@ -113,15 +85,19 @@ const char* EagerExecutor::StateStringLocked() { } } -Status EagerExecutor::Add(std::unique_ptr node) { +Status EagerExecutor::AddOrExecute(std::unique_ptr node) { Status status; + core::RefCountPtr item(new NodeItem); + item->id = next_node_id_++; + item->node = std::move(node); + item->state = NodeState::kPENDING; // If we are unable to add the node to the queue, we must call Abort. However, // we want to do that outside of the scope of the lock since the Abort may // try to call EagerExecutor::Add() { tensorflow::mutex_lock l(node_queue_mutex_); - VLOG(3) << "Add node [id " << next_node_id_ << "]" << node->DebugString() + VLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString() << " with status: " << status_.ToString(); if (state_ != ExecutorState::kActive) { status = errors::FailedPrecondition( @@ -129,16 +105,11 @@ Status EagerExecutor::Add(std::unique_ptr node) { "Current state is '", StateStringLocked(), "'"); } else { - DCHECK(thread_) << "EnableAsync should have been called before Add"; status = status_; - if (status.ok()) { - auto item = absl::make_unique(); - item->id = next_node_id_++; - item->node = std::move(node); + if (status.ok() && Async()) { node_queue_.push(std::move(item)); - - // If there were no previous nodes pending, wake the run thread to start - // processing requests again. + // If there were no previous nodes pending, wake the run thread to + // start processing requests again. if (node_queue_.size() == 1) { nodes_pending_.notify_all(); } @@ -148,9 +119,17 @@ Status EagerExecutor::Add(std::unique_ptr node) { } } - // Node needs to be aborted since it was not added to the queue - node->Abort(status); - return status; + if (status.ok()) { + // Inline execution in sync mode. + DCHECK(!Async()); + RunItem(std::move(item)); + status = this->status(); + return status; + } else { + // Node needs to be aborted since it was not added to the queue + item->node->Abort(status); + return status; + } } tensorflow::Status EagerExecutor::WaitForAllPendingNodes() { @@ -165,6 +144,8 @@ tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked( if (!status_.ok()) return status_; if (node_queue_.empty() && unfinished_nodes_.empty()) return tensorflow::Status::OK(); + // node_queue_ must be empty in sync mode. + DCHECK(Async() || node_queue_.empty()); auto last_id = next_node_id_ - 1; VLOG(3) << "Wait for Node: [id " << last_id << "] "; node_done_notifications_.insert(std::make_pair(last_id, &cond)); @@ -191,25 +172,30 @@ tensorflow::Status EagerExecutor::status() const { return status_; } -void EagerExecutor::NodeDone(NodeItem* item, const Status& status) { +void EagerExecutor::NodeDone(core::RefCountPtr item, + const Status& status) { VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString() << " with status: " << status.ToString(); - std::unique_ptr current_item; - std::vector> items_to_destroy; + DCHECK(item->state != NodeState::kDONE); + std::vector> items_to_destroy; { mutex_lock l(node_queue_mutex_); + auto previous_state = item->state; + item->state = NodeState::kDONE; if (!status_.ok()) return; bool need_notification = false; - if (!node_queue_.empty() && item == node_queue_.front().get()) { - need_notification = unfinished_nodes_.empty(); - current_item = std::move(node_queue_.front()); - node_queue_.pop(); + if (previous_state == NodeState::kPENDING) { + if (Async()) { + DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get()); + need_notification = unfinished_nodes_.empty(); + node_queue_.pop(); + } else { + need_notification = unfinished_nodes_.empty(); + } } else { - DCHECK(!unfinished_nodes_.empty()); need_notification = item->id == unfinished_nodes_.begin()->first; - auto erase_result = unfinished_nodes_.erase(item->id); - DCHECK_GT(erase_result, 0); - current_item = absl::WrapUnique(item); + auto result = unfinished_nodes_.erase(item->id); + DCHECK_GT(result, 0); } if (!status.ok()) { need_notification = true; @@ -217,7 +203,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) { // We remove any pending ops so that we don't try to execute them if // ClearError is called. errors::AppendToMessage(&status_, - ". Encountered when executing an operation using " + "Encountered when executing an operation using " "EagerExecutor. This error cancels all future " "operations and poisons their output tensors."); while (!node_queue_.empty()) { @@ -225,7 +211,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) { node_queue_.pop(); } for (auto& it : unfinished_nodes_) { - items_to_destroy.push_back(absl::WrapUnique(it.second)); + items_to_destroy.push_back(std::move(it.second)); } unfinished_nodes_.clear(); } @@ -238,6 +224,8 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) { } else { upperbound_id = next_node_id_ - 1; } + VLOG(3) << "Notify node done: [id " << item->id << " to " << upperbound_id + << "] "; // Note that we notify all waiting threads in case an error has // occurred. These calling threads are responsible for checking status_ // before proceeding. @@ -266,7 +254,7 @@ void EagerExecutor::Run() { auto thread_exited_notifier = gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); }); while (true) { - NodeItem* curr_item_raw; + core::RefCountPtr curr_item; { tensorflow::mutex_lock l(node_queue_mutex_); while (node_queue_.empty() || !status_.ok()) { @@ -280,30 +268,39 @@ void EagerExecutor::Run() { // will then contain a nullptr. This can be a problem in // WaitForAllPendingNodes where we get the top EagerNode pointer // and register a notification for its completion. - curr_item_raw = node_queue_.front().get(); + curr_item.reset(node_queue_.front().get()); + curr_item->Ref(); } - VLOG(3) << "Running Node: [id " << curr_item_raw->id << "] " - << curr_item_raw->node->DebugString(); - AsyncEagerNode* async_node_raw = curr_item_raw->node->AsAsync(); - if (async_node_raw == nullptr) { - tensorflow::Status status = curr_item_raw->node->Run(); - NodeDone(curr_item_raw, status); - } else { - async_node_raw->RunAsync([this, curr_item_raw](const Status& status) { - NodeDone(curr_item_raw, status); - }); - { - tensorflow::mutex_lock l(node_queue_mutex_); - // If false, NodeDone has been called. - if (!node_queue_.empty() && - curr_item_raw == node_queue_.front().get()) { - node_queue_.front().release(); - node_queue_.pop(); - unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), - curr_item_raw->id, curr_item_raw); - } - } + RunItem(std::move(curr_item)); + } +} + +void EagerExecutor::RunItem(core::RefCountPtr item) { + VLOG(3) << "Running Node: [id " << item->id << "] " + << item->node->DebugString(); + AsyncEagerNode* async_node = item->node->AsAsync(); + if (async_node == nullptr) { + core::RefCountPtr new_ref(item.get()); + new_ref->Ref(); + tensorflow::Status status = item->node->Run(); + NodeDone(std::move(new_ref), status); + } else { + auto* new_ref = item.get(); + new_ref->Ref(); + async_node->RunAsync([this, new_ref](const Status& status) { + core::RefCountPtr new_item(new_ref); + NodeDone(std::move(new_item), status); + }); + } + tensorflow::mutex_lock l(node_queue_mutex_); + if (item->state == NodeState::kPENDING) { + item->state = NodeState::kSCHEDULED; + if (!node_queue_.empty() && item.get() == node_queue_.front().get()) { + node_queue_.pop(); } + VLOG(3) << "Add Node: [id " << item->id << "] to unfinished map."; + unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id, + std::move(item)); } } diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h index 3280a3d2193..534b826f4bc 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.h +++ b/tensorflow/core/common_runtime/eager/eager_executor.h @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" @@ -73,12 +75,8 @@ class AsyncEagerNode : public EagerNode { AsyncEagerNode* AsAsync() final { return this; } - // This is non-blocking. It returns the scheduling status. - // TODO(fishx): avoid calling this AsyncEagerNode::Run. Status Run() final { - std::shared_ptr status(new Status); - RunAsync([status](const Status& s) { status->Update(s); }); - return *status; + return errors::Unimplemented("Don't call AsyncEagerNode::Run()."); } }; @@ -105,10 +103,11 @@ class EagerExecutor { bool Async() const; - // Schedules `node` for execution. If an error occurs (e.g. EagerExecutor - // has already been shut down), the `node` is not added to this executor - // and its Abort() method is called. - Status Add(std::unique_ptr node); + // - Async Mode: schedules `node` for execution. + // - Sync Mode: inline execute the 'node' directly. + // If an error occurs (e.g. EagerExecutor has already been shut down), the + // `node` is not added to this executor and its Abort() method is called. + Status AddOrExecute(std::unique_ptr node); // Blocks till all currently pending ops are done. // In particular, if EnableAsync() has not beed called, it will not return @@ -139,16 +138,23 @@ class EagerExecutor { kShutDown, }; - struct NodeItem { + enum class NodeState { + kPENDING, + kSCHEDULED, + kDONE, + }; + + struct NodeItem : core::RefCounted { // Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it // means item1.node is added before item2.node. uint64 id; std::unique_ptr node; + NodeState state; }; const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); - void NodeDone(NodeItem* item, const Status& status); + void NodeDone(core::RefCountPtr item, const Status& status); // Starts execution of pending EagerNodes. This function loops till // thread_done_ is set to true. If any errors are encontered, these are set @@ -156,20 +162,13 @@ class EagerExecutor { // `status_` is not ok. void Run(); + void RunItem(core::RefCountPtr item); + // The impl of WaitForAllPendingNodes // `lock` is the lock that holds node_queue_mutex_. Status WaitForAllPendingNodesLocked(mutex_lock* lock) EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); - // If async has been enabled on this executor, just calls - // WaitForAllPendingNodes. Else sets the status_ to an error if it does not - // already contain one `lock` is the lock that holds node_queue_mutex_. - // Precondition: state_ != kActive. - void WaitForOrDestroyAllPendingNodes( - mutex_lock* lock, - std::vector>* nodes_to_destroy) - EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); - Status WaitImpl(bool wait_all, uint64 node_id); std::atomic next_node_id_; @@ -180,12 +179,12 @@ class EagerExecutor { condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_); // Queue of pending NodeItems. Ordered by NodeItem::id. - std::queue> node_queue_ + std::queue> node_queue_ GUARDED_BY(node_queue_mutex_); - // Owned the NodeItem in it. Ordered by NodeItem::id. - std::map> unfinished_nodes_ - GUARDED_BY(node_queue_mutex_); + // Ordered by NodeItem::id. + std::map, std::less> + unfinished_nodes_ GUARDED_BY(node_queue_mutex_); // `status_` is set based on any errors raised during execution of a // EagerNode. It remains set until ClearError is called. diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 54acda6bd76..ef4671de56b 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -362,6 +362,7 @@ Status GetDeviceForInput(const EagerContext* ctx, TensorHandle* tensor_handle, // Use the resource's actual device because it is the device that will // influence partitioning the multi-device function. const Tensor* tensor; + // TODO(fishx): Avoid blocking here. TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor)); const ResourceHandle& handle = tensor->flat()(0); device_name = handle.device(); @@ -639,7 +640,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, // input handles are ready before executing them. // TODO(b/137118203): Consider executing "cheap" kernels inline for // performance. - Status s = executor.Async() ? executor.Add(std::move(node)) : node->Run(); + Status s = executor.AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. if (!s.ok()) { @@ -765,14 +766,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } auto& executor = op->Executor(); - bool is_async = executor.Async(); VLOG(4) << "Execute remote eager op: " << op->Name() - << " (is async?: " << is_async << ")."; + << " (is async?: " << executor.Async() << ")."; std::unique_ptr node( new eager::RemoteExecuteNode(std::move(request), op_device, eager_client, op->Inputs(), {retvals, num_outputs})); - Status s = is_async ? executor.Add(std::move(node)) : node->Run(); + Status s = executor.AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. if (!s.ok()) { @@ -904,6 +904,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName()); + if (!op->Executor().Async()) { + // In sync mode, always clear error to maintain the same behavior as before. + // TODO(b/141004939): Remove this. + op->Executor().ClearError(); + } + std::unique_ptr out_op; TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite( EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op)); @@ -1044,7 +1050,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. std::unique_ptr node(new CopyToDeviceNode(h, *result, dstd, ctx)); - Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run(); + Status s = executor->AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. if (!s.ok()) { @@ -1065,6 +1071,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, bool recver_is_local = device->IsLocal(); + if (!executor->Async()) { + // In sync mode, always clear error to maintain the same behavior as before. + // TODO(b/141004939): Remove this. + executor->ClearError(); + } + if (sender_is_local && recver_is_local) { return LocalEagerCopyToDevice(h, ctx, executor, device, result); } else { @@ -1105,7 +1117,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } auto node = absl::make_unique( ctx, executor, h, result[0], device, recv_op_id); - Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run(); + Status s = executor->AddOrExecute(std::move(node)); if (!s.ok()) { result[0]->Unref(); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 26e33634404..0b3c8b5d449 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -296,8 +296,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request, item.handle_to_decref()); auto node = absl::make_unique( context, std::move(handle_to_decref)); - s = executor.Async() ? context->Context()->Executor().Add(std::move(node)) - : node->Run(); + s = context->Context()->Executor().AddOrExecute(std::move(node)); } else { s = SendTensor(item.send_tensor(), context->Context()); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 1814f7b28a1..163db36ce34 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -47,7 +47,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx, eager_client, ready)); auto& executor = ctx->Executor(); if (executor.Async()) { - Status status = executor.Add(std::move(node)); + Status status = executor.AddOrExecute(std::move(node)); if (!status.ok()) { LOG(ERROR) << "Unable to destroy remote tensor handles: " << status.error_message(); @@ -56,13 +56,13 @@ void DestoryRemoteTensorHandle(EagerContext* ctx, // This thread may still hold tensorflow::StreamingRPCState::mu_. We need // to send out the destroy request in a new thread to avoid deadlock. auto* released_node = node.release(); - (*ctx->runner())([released_node] { - Status status = released_node->Run(); + (*ctx->runner())([ctx, released_node] { + Status status = + ctx->Executor().AddOrExecute(absl::WrapUnique(released_node)); if (!status.ok()) { LOG(ERROR) << "Unable to destroy remote tensor handles: " << status.error_message(); } - delete released_node; }); } } diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py index e523f36639d..4b349ebd811 100644 --- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py @@ -523,7 +523,7 @@ class CsvDatasetTest(test_base.DatasetTestBase): if context.executing_eagerly(): err_spec = errors.InvalidArgumentError, ( 'Each record default should be at ' - 'most rank 1.') + 'most rank 1') else: err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2' diff --git a/tensorflow/python/kernel_tests/decode_csv_op_test.py b/tensorflow/python/kernel_tests/decode_csv_op_test.py index 6c7a9de6e05..3ad3e93df8e 100644 --- a/tensorflow/python/kernel_tests/decode_csv_op_test.py +++ b/tensorflow/python/kernel_tests/decode_csv_op_test.py @@ -73,7 +73,7 @@ class DecodeCSVOpTest(test.TestCase): if context.executing_eagerly(): err_spec = errors.InvalidArgumentError, ( "Each record default should be at " - "most rank 1.") + "most rank 1") else: err_spec = ValueError, "Shape must be at most rank 1 but is rank 2" with self.assertRaisesWithPredicateMatch(*err_spec): diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index e763da72eb9..b36b252bd81 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -107,9 +107,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") resource_variable_ops.assign_variable_op(handle, 1) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Trying to read variable with wrong dtype. " - "Expected float got int32."): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Trying to read variable with wrong dtype. " + "Expected float got int32"): _ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32) def testEagerInitializedValue(self): @@ -195,9 +196,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, dtype=dtypes.int32, shape=[1], name="foo") resource_variable_ops.assign_variable_op( handle, constant_op.constant([1])) - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Trying to assign variable with wrong " - "dtype. Expected int32 got float."): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, "Trying to assign variable with wrong " + "dtype. Expected int32 got float"): resource_variable_ops.assign_variable_op( handle, constant_op.constant([1.], dtype=dtypes.float32))