Use executor to execute ops in sync mode as well. A few benefits for doing this:
1. Unify sync/async code path a little bit more. 2. Allow executor to capture error in AsyncEagerNode. We may need to find a way to propagate this error to python layer in the future. 3. Allow EagerContext wait for all pending AsyncEagerNode in shutdown. PiperOrigin-RevId: 268956697
This commit is contained in:
parent
c3e32b03e1
commit
bfb4c9bcb2
@ -35,7 +35,7 @@ EagerExecutor::~EagerExecutor() {
|
|||||||
|
|
||||||
Status EagerExecutor::ShutDown() {
|
Status EagerExecutor::ShutDown() {
|
||||||
{
|
{
|
||||||
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
|
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
|
||||||
bool has_thread;
|
bool has_thread;
|
||||||
Status status;
|
Status status;
|
||||||
{
|
{
|
||||||
@ -47,7 +47,9 @@ Status EagerExecutor::ShutDown() {
|
|||||||
// thread_exited_notification_.WaitForNotification() below.
|
// thread_exited_notification_.WaitForNotification() below.
|
||||||
state_ = ExecutorState::kShuttingDown;
|
state_ = ExecutorState::kShuttingDown;
|
||||||
}
|
}
|
||||||
WaitForOrDestroyAllPendingNodes(&l, &items_to_destroy);
|
// It is OK to ignore the returned status here because it will be saved
|
||||||
|
// as the final status_.
|
||||||
|
WaitForAllPendingNodesLocked(&l).IgnoreError();
|
||||||
state_ = ExecutorState::kShutDown;
|
state_ = ExecutorState::kShutDown;
|
||||||
has_thread = thread_ != nullptr;
|
has_thread = thread_ != nullptr;
|
||||||
status = status_;
|
status = status_;
|
||||||
@ -68,36 +70,6 @@ Status EagerExecutor::ShutDown() {
|
|||||||
return status_;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerExecutor::WaitForOrDestroyAllPendingNodes(
|
|
||||||
mutex_lock* lock,
|
|
||||||
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy) {
|
|
||||||
if (state_ == ExecutorState::kShutDown) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (thread_ == nullptr) {
|
|
||||||
Status status = status_;
|
|
||||||
if (status.ok()) {
|
|
||||||
status = errors::FailedPrecondition(
|
|
||||||
"Aborting eager nodes because EagerExecutor is being shut down "
|
|
||||||
"before it got a thread to run the nodes");
|
|
||||||
status_ = status;
|
|
||||||
}
|
|
||||||
while (!node_queue_.empty()) {
|
|
||||||
nodes_to_destroy->push_back(std::move(node_queue_.front()));
|
|
||||||
node_queue_.pop();
|
|
||||||
}
|
|
||||||
for (auto& it : unfinished_nodes_) {
|
|
||||||
nodes_to_destroy->push_back(absl::WrapUnique(it.second));
|
|
||||||
}
|
|
||||||
unfinished_nodes_.clear();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// It is OK to ignore the returned status here because it will be saved
|
|
||||||
// as the final status_.
|
|
||||||
WaitForAllPendingNodesLocked(lock).IgnoreError();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool EagerExecutor::Async() const {
|
bool EagerExecutor::Async() const {
|
||||||
return thread_ != nullptr;
|
return thread_ != nullptr;
|
||||||
}
|
}
|
||||||
@ -113,15 +85,19 @@ const char* EagerExecutor::StateStringLocked() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
|
||||||
Status status;
|
Status status;
|
||||||
|
core::RefCountPtr<NodeItem> item(new NodeItem);
|
||||||
|
item->id = next_node_id_++;
|
||||||
|
item->node = std::move(node);
|
||||||
|
item->state = NodeState::kPENDING;
|
||||||
|
|
||||||
// If we are unable to add the node to the queue, we must call Abort. However,
|
// If we are unable to add the node to the queue, we must call Abort. However,
|
||||||
// we want to do that outside of the scope of the lock since the Abort may
|
// we want to do that outside of the scope of the lock since the Abort may
|
||||||
// try to call EagerExecutor::Add()
|
// try to call EagerExecutor::Add()
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||||
VLOG(3) << "Add node [id " << next_node_id_ << "]" << node->DebugString()
|
VLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
|
||||||
<< " with status: " << status_.ToString();
|
<< " with status: " << status_.ToString();
|
||||||
if (state_ != ExecutorState::kActive) {
|
if (state_ != ExecutorState::kActive) {
|
||||||
status = errors::FailedPrecondition(
|
status = errors::FailedPrecondition(
|
||||||
@ -129,16 +105,11 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
|||||||
"Current state is '",
|
"Current state is '",
|
||||||
StateStringLocked(), "'");
|
StateStringLocked(), "'");
|
||||||
} else {
|
} else {
|
||||||
DCHECK(thread_) << "EnableAsync should have been called before Add";
|
|
||||||
status = status_;
|
status = status_;
|
||||||
if (status.ok()) {
|
if (status.ok() && Async()) {
|
||||||
auto item = absl::make_unique<NodeItem>();
|
|
||||||
item->id = next_node_id_++;
|
|
||||||
item->node = std::move(node);
|
|
||||||
node_queue_.push(std::move(item));
|
node_queue_.push(std::move(item));
|
||||||
|
// If there were no previous nodes pending, wake the run thread to
|
||||||
// If there were no previous nodes pending, wake the run thread to start
|
// start processing requests again.
|
||||||
// processing requests again.
|
|
||||||
if (node_queue_.size() == 1) {
|
if (node_queue_.size() == 1) {
|
||||||
nodes_pending_.notify_all();
|
nodes_pending_.notify_all();
|
||||||
}
|
}
|
||||||
@ -148,9 +119,17 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Node needs to be aborted since it was not added to the queue
|
if (status.ok()) {
|
||||||
node->Abort(status);
|
// Inline execution in sync mode.
|
||||||
return status;
|
DCHECK(!Async());
|
||||||
|
RunItem(std::move(item));
|
||||||
|
status = this->status();
|
||||||
|
return status;
|
||||||
|
} else {
|
||||||
|
// Node needs to be aborted since it was not added to the queue
|
||||||
|
item->node->Abort(status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
|
tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
|
||||||
@ -165,6 +144,8 @@ tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
|
|||||||
if (!status_.ok()) return status_;
|
if (!status_.ok()) return status_;
|
||||||
if (node_queue_.empty() && unfinished_nodes_.empty())
|
if (node_queue_.empty() && unfinished_nodes_.empty())
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
|
// node_queue_ must be empty in sync mode.
|
||||||
|
DCHECK(Async() || node_queue_.empty());
|
||||||
auto last_id = next_node_id_ - 1;
|
auto last_id = next_node_id_ - 1;
|
||||||
VLOG(3) << "Wait for Node: [id " << last_id << "] ";
|
VLOG(3) << "Wait for Node: [id " << last_id << "] ";
|
||||||
node_done_notifications_.insert(std::make_pair(last_id, &cond));
|
node_done_notifications_.insert(std::make_pair(last_id, &cond));
|
||||||
@ -191,25 +172,30 @@ tensorflow::Status EagerExecutor::status() const {
|
|||||||
return status_;
|
return status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
void EagerExecutor::NodeDone(core::RefCountPtr<NodeItem> item,
|
||||||
|
const Status& status) {
|
||||||
VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
|
VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
|
||||||
<< " with status: " << status.ToString();
|
<< " with status: " << status.ToString();
|
||||||
std::unique_ptr<NodeItem> current_item;
|
DCHECK(item->state != NodeState::kDONE);
|
||||||
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
|
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
|
||||||
{
|
{
|
||||||
mutex_lock l(node_queue_mutex_);
|
mutex_lock l(node_queue_mutex_);
|
||||||
|
auto previous_state = item->state;
|
||||||
|
item->state = NodeState::kDONE;
|
||||||
if (!status_.ok()) return;
|
if (!status_.ok()) return;
|
||||||
bool need_notification = false;
|
bool need_notification = false;
|
||||||
if (!node_queue_.empty() && item == node_queue_.front().get()) {
|
if (previous_state == NodeState::kPENDING) {
|
||||||
need_notification = unfinished_nodes_.empty();
|
if (Async()) {
|
||||||
current_item = std::move(node_queue_.front());
|
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
|
||||||
node_queue_.pop();
|
need_notification = unfinished_nodes_.empty();
|
||||||
|
node_queue_.pop();
|
||||||
|
} else {
|
||||||
|
need_notification = unfinished_nodes_.empty();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
DCHECK(!unfinished_nodes_.empty());
|
|
||||||
need_notification = item->id == unfinished_nodes_.begin()->first;
|
need_notification = item->id == unfinished_nodes_.begin()->first;
|
||||||
auto erase_result = unfinished_nodes_.erase(item->id);
|
auto result = unfinished_nodes_.erase(item->id);
|
||||||
DCHECK_GT(erase_result, 0);
|
DCHECK_GT(result, 0);
|
||||||
current_item = absl::WrapUnique(item);
|
|
||||||
}
|
}
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
need_notification = true;
|
need_notification = true;
|
||||||
@ -217,7 +203,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
|||||||
// We remove any pending ops so that we don't try to execute them if
|
// We remove any pending ops so that we don't try to execute them if
|
||||||
// ClearError is called.
|
// ClearError is called.
|
||||||
errors::AppendToMessage(&status_,
|
errors::AppendToMessage(&status_,
|
||||||
". Encountered when executing an operation using "
|
"Encountered when executing an operation using "
|
||||||
"EagerExecutor. This error cancels all future "
|
"EagerExecutor. This error cancels all future "
|
||||||
"operations and poisons their output tensors.");
|
"operations and poisons their output tensors.");
|
||||||
while (!node_queue_.empty()) {
|
while (!node_queue_.empty()) {
|
||||||
@ -225,7 +211,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
|||||||
node_queue_.pop();
|
node_queue_.pop();
|
||||||
}
|
}
|
||||||
for (auto& it : unfinished_nodes_) {
|
for (auto& it : unfinished_nodes_) {
|
||||||
items_to_destroy.push_back(absl::WrapUnique(it.second));
|
items_to_destroy.push_back(std::move(it.second));
|
||||||
}
|
}
|
||||||
unfinished_nodes_.clear();
|
unfinished_nodes_.clear();
|
||||||
}
|
}
|
||||||
@ -238,6 +224,8 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
|||||||
} else {
|
} else {
|
||||||
upperbound_id = next_node_id_ - 1;
|
upperbound_id = next_node_id_ - 1;
|
||||||
}
|
}
|
||||||
|
VLOG(3) << "Notify node done: [id " << item->id << " to " << upperbound_id
|
||||||
|
<< "] ";
|
||||||
// Note that we notify all waiting threads in case an error has
|
// Note that we notify all waiting threads in case an error has
|
||||||
// occurred. These calling threads are responsible for checking status_
|
// occurred. These calling threads are responsible for checking status_
|
||||||
// before proceeding.
|
// before proceeding.
|
||||||
@ -266,7 +254,7 @@ void EagerExecutor::Run() {
|
|||||||
auto thread_exited_notifier =
|
auto thread_exited_notifier =
|
||||||
gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
|
gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
|
||||||
while (true) {
|
while (true) {
|
||||||
NodeItem* curr_item_raw;
|
core::RefCountPtr<NodeItem> curr_item;
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||||
while (node_queue_.empty() || !status_.ok()) {
|
while (node_queue_.empty() || !status_.ok()) {
|
||||||
@ -280,30 +268,39 @@ void EagerExecutor::Run() {
|
|||||||
// will then contain a nullptr. This can be a problem in
|
// will then contain a nullptr. This can be a problem in
|
||||||
// WaitForAllPendingNodes where we get the top EagerNode pointer
|
// WaitForAllPendingNodes where we get the top EagerNode pointer
|
||||||
// and register a notification for its completion.
|
// and register a notification for its completion.
|
||||||
curr_item_raw = node_queue_.front().get();
|
curr_item.reset(node_queue_.front().get());
|
||||||
|
curr_item->Ref();
|
||||||
}
|
}
|
||||||
VLOG(3) << "Running Node: [id " << curr_item_raw->id << "] "
|
RunItem(std::move(curr_item));
|
||||||
<< curr_item_raw->node->DebugString();
|
}
|
||||||
AsyncEagerNode* async_node_raw = curr_item_raw->node->AsAsync();
|
}
|
||||||
if (async_node_raw == nullptr) {
|
|
||||||
tensorflow::Status status = curr_item_raw->node->Run();
|
void EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item) {
|
||||||
NodeDone(curr_item_raw, status);
|
VLOG(3) << "Running Node: [id " << item->id << "] "
|
||||||
} else {
|
<< item->node->DebugString();
|
||||||
async_node_raw->RunAsync([this, curr_item_raw](const Status& status) {
|
AsyncEagerNode* async_node = item->node->AsAsync();
|
||||||
NodeDone(curr_item_raw, status);
|
if (async_node == nullptr) {
|
||||||
});
|
core::RefCountPtr<NodeItem> new_ref(item.get());
|
||||||
{
|
new_ref->Ref();
|
||||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
tensorflow::Status status = item->node->Run();
|
||||||
// If false, NodeDone has been called.
|
NodeDone(std::move(new_ref), status);
|
||||||
if (!node_queue_.empty() &&
|
} else {
|
||||||
curr_item_raw == node_queue_.front().get()) {
|
auto* new_ref = item.get();
|
||||||
node_queue_.front().release();
|
new_ref->Ref();
|
||||||
node_queue_.pop();
|
async_node->RunAsync([this, new_ref](const Status& status) {
|
||||||
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(),
|
core::RefCountPtr<NodeItem> new_item(new_ref);
|
||||||
curr_item_raw->id, curr_item_raw);
|
NodeDone(std::move(new_item), status);
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||||
|
if (item->state == NodeState::kPENDING) {
|
||||||
|
item->state = NodeState::kSCHEDULED;
|
||||||
|
if (!node_queue_.empty() && item.get() == node_queue_.front().get()) {
|
||||||
|
node_queue_.pop();
|
||||||
}
|
}
|
||||||
|
VLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
|
||||||
|
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
|
||||||
|
std::move(item));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/refcount.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||||
@ -73,12 +75,8 @@ class AsyncEagerNode : public EagerNode {
|
|||||||
|
|
||||||
AsyncEagerNode* AsAsync() final { return this; }
|
AsyncEagerNode* AsAsync() final { return this; }
|
||||||
|
|
||||||
// This is non-blocking. It returns the scheduling status.
|
|
||||||
// TODO(fishx): avoid calling this AsyncEagerNode::Run.
|
|
||||||
Status Run() final {
|
Status Run() final {
|
||||||
std::shared_ptr<Status> status(new Status);
|
return errors::Unimplemented("Don't call AsyncEagerNode::Run().");
|
||||||
RunAsync([status](const Status& s) { status->Update(s); });
|
|
||||||
return *status;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -105,10 +103,11 @@ class EagerExecutor {
|
|||||||
|
|
||||||
bool Async() const;
|
bool Async() const;
|
||||||
|
|
||||||
// Schedules `node` for execution. If an error occurs (e.g. EagerExecutor
|
// - Async Mode: schedules `node` for execution.
|
||||||
// has already been shut down), the `node` is not added to this executor
|
// - Sync Mode: inline execute the 'node' directly.
|
||||||
// and its Abort() method is called.
|
// If an error occurs (e.g. EagerExecutor has already been shut down), the
|
||||||
Status Add(std::unique_ptr<EagerNode> node);
|
// `node` is not added to this executor and its Abort() method is called.
|
||||||
|
Status AddOrExecute(std::unique_ptr<EagerNode> node);
|
||||||
|
|
||||||
// Blocks till all currently pending ops are done.
|
// Blocks till all currently pending ops are done.
|
||||||
// In particular, if EnableAsync() has not beed called, it will not return
|
// In particular, if EnableAsync() has not beed called, it will not return
|
||||||
@ -139,16 +138,23 @@ class EagerExecutor {
|
|||||||
kShutDown,
|
kShutDown,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NodeItem {
|
enum class NodeState {
|
||||||
|
kPENDING,
|
||||||
|
kSCHEDULED,
|
||||||
|
kDONE,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NodeItem : core::RefCounted {
|
||||||
// Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
|
// Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
|
||||||
// means item1.node is added before item2.node.
|
// means item1.node is added before item2.node.
|
||||||
uint64 id;
|
uint64 id;
|
||||||
std::unique_ptr<EagerNode> node;
|
std::unique_ptr<EagerNode> node;
|
||||||
|
NodeState state;
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||||
|
|
||||||
void NodeDone(NodeItem* item, const Status& status);
|
void NodeDone(core::RefCountPtr<NodeItem> item, const Status& status);
|
||||||
|
|
||||||
// Starts execution of pending EagerNodes. This function loops till
|
// Starts execution of pending EagerNodes. This function loops till
|
||||||
// thread_done_ is set to true. If any errors are encontered, these are set
|
// thread_done_ is set to true. If any errors are encontered, these are set
|
||||||
@ -156,20 +162,13 @@ class EagerExecutor {
|
|||||||
// `status_` is not ok.
|
// `status_` is not ok.
|
||||||
void Run();
|
void Run();
|
||||||
|
|
||||||
|
void RunItem(core::RefCountPtr<NodeItem> item);
|
||||||
|
|
||||||
// The impl of WaitForAllPendingNodes
|
// The impl of WaitForAllPendingNodes
|
||||||
// `lock` is the lock that holds node_queue_mutex_.
|
// `lock` is the lock that holds node_queue_mutex_.
|
||||||
Status WaitForAllPendingNodesLocked(mutex_lock* lock)
|
Status WaitForAllPendingNodesLocked(mutex_lock* lock)
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||||
|
|
||||||
// If async has been enabled on this executor, just calls
|
|
||||||
// WaitForAllPendingNodes. Else sets the status_ to an error if it does not
|
|
||||||
// already contain one `lock` is the lock that holds node_queue_mutex_.
|
|
||||||
// Precondition: state_ != kActive.
|
|
||||||
void WaitForOrDestroyAllPendingNodes(
|
|
||||||
mutex_lock* lock,
|
|
||||||
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy)
|
|
||||||
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
|
||||||
|
|
||||||
Status WaitImpl(bool wait_all, uint64 node_id);
|
Status WaitImpl(bool wait_all, uint64 node_id);
|
||||||
|
|
||||||
std::atomic<uint64> next_node_id_;
|
std::atomic<uint64> next_node_id_;
|
||||||
@ -180,12 +179,12 @@ class EagerExecutor {
|
|||||||
condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
|
condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
|
||||||
|
|
||||||
// Queue of pending NodeItems. Ordered by NodeItem::id.
|
// Queue of pending NodeItems. Ordered by NodeItem::id.
|
||||||
std::queue<std::unique_ptr<NodeItem>> node_queue_
|
std::queue<core::RefCountPtr<NodeItem>> node_queue_
|
||||||
GUARDED_BY(node_queue_mutex_);
|
GUARDED_BY(node_queue_mutex_);
|
||||||
|
|
||||||
// Owned the NodeItem in it. Ordered by NodeItem::id.
|
// Ordered by NodeItem::id.
|
||||||
std::map<uint64, NodeItem*, std::less<uint64>> unfinished_nodes_
|
std::map<uint64, core::RefCountPtr<NodeItem>, std::less<uint64>>
|
||||||
GUARDED_BY(node_queue_mutex_);
|
unfinished_nodes_ GUARDED_BY(node_queue_mutex_);
|
||||||
|
|
||||||
// `status_` is set based on any errors raised during execution of a
|
// `status_` is set based on any errors raised during execution of a
|
||||||
// EagerNode. It remains set until ClearError is called.
|
// EagerNode. It remains set until ClearError is called.
|
||||||
|
@ -362,6 +362,7 @@ Status GetDeviceForInput(const EagerContext* ctx, TensorHandle* tensor_handle,
|
|||||||
// Use the resource's actual device because it is the device that will
|
// Use the resource's actual device because it is the device that will
|
||||||
// influence partitioning the multi-device function.
|
// influence partitioning the multi-device function.
|
||||||
const Tensor* tensor;
|
const Tensor* tensor;
|
||||||
|
// TODO(fishx): Avoid blocking here.
|
||||||
TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
|
TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
|
||||||
const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
|
const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
|
||||||
device_name = handle.device();
|
device_name = handle.device();
|
||||||
@ -639,7 +640,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// input handles are ready before executing them.
|
// input handles are ready before executing them.
|
||||||
// TODO(b/137118203): Consider executing "cheap" kernels inline for
|
// TODO(b/137118203): Consider executing "cheap" kernels inline for
|
||||||
// performance.
|
// performance.
|
||||||
Status s = executor.Async() ? executor.Add(std::move(node)) : node->Run();
|
Status s = executor.AddOrExecute(std::move(node));
|
||||||
// Since the operation failed, we need to Unref any outputs that were
|
// Since the operation failed, we need to Unref any outputs that were
|
||||||
// allocated.
|
// allocated.
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -765,14 +766,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& executor = op->Executor();
|
auto& executor = op->Executor();
|
||||||
bool is_async = executor.Async();
|
|
||||||
VLOG(4) << "Execute remote eager op: " << op->Name()
|
VLOG(4) << "Execute remote eager op: " << op->Name()
|
||||||
<< " (is async?: " << is_async << ").";
|
<< " (is async?: " << executor.Async() << ").";
|
||||||
|
|
||||||
std::unique_ptr<EagerNode> node(
|
std::unique_ptr<EagerNode> node(
|
||||||
new eager::RemoteExecuteNode(std::move(request), op_device, eager_client,
|
new eager::RemoteExecuteNode(std::move(request), op_device, eager_client,
|
||||||
op->Inputs(), {retvals, num_outputs}));
|
op->Inputs(), {retvals, num_outputs}));
|
||||||
Status s = is_async ? executor.Add(std::move(node)) : node->Run();
|
Status s = executor.AddOrExecute(std::move(node));
|
||||||
// Since the operation failed, we need to Unref any outputs that were
|
// Since the operation failed, we need to Unref any outputs that were
|
||||||
// allocated.
|
// allocated.
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -904,6 +904,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
|
|
||||||
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
|
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
|
||||||
|
|
||||||
|
if (!op->Executor().Async()) {
|
||||||
|
// In sync mode, always clear error to maintain the same behavior as before.
|
||||||
|
// TODO(b/141004939): Remove this.
|
||||||
|
op->Executor().ClearError();
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||||
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
|
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||||
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
|
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
|
||||||
@ -1044,7 +1050,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
// Note that `h` may not be currently ready. However execution order will
|
// Note that `h` may not be currently ready. However execution order will
|
||||||
// make sure that `h` is ready before the copy is actually done.
|
// make sure that `h` is ready before the copy is actually done.
|
||||||
std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, ctx));
|
std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, ctx));
|
||||||
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
|
Status s = executor->AddOrExecute(std::move(node));
|
||||||
// Since the operation failed, we need to Unref any outputs that were
|
// Since the operation failed, we need to Unref any outputs that were
|
||||||
// allocated.
|
// allocated.
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -1065,6 +1071,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
|
|
||||||
bool recver_is_local = device->IsLocal();
|
bool recver_is_local = device->IsLocal();
|
||||||
|
|
||||||
|
if (!executor->Async()) {
|
||||||
|
// In sync mode, always clear error to maintain the same behavior as before.
|
||||||
|
// TODO(b/141004939): Remove this.
|
||||||
|
executor->ClearError();
|
||||||
|
}
|
||||||
|
|
||||||
if (sender_is_local && recver_is_local) {
|
if (sender_is_local && recver_is_local) {
|
||||||
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
|
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
|
||||||
} else {
|
} else {
|
||||||
@ -1105,7 +1117,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
}
|
}
|
||||||
auto node = absl::make_unique<eager::RemoteCopyNode>(
|
auto node = absl::make_unique<eager::RemoteCopyNode>(
|
||||||
ctx, executor, h, result[0], device, recv_op_id);
|
ctx, executor, h, result[0], device, recv_op_id);
|
||||||
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
|
Status s = executor->AddOrExecute(std::move(node));
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
result[0]->Unref();
|
result[0]->Unref();
|
||||||
}
|
}
|
||||||
|
@ -296,8 +296,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
|||||||
item.handle_to_decref());
|
item.handle_to_decref());
|
||||||
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
|
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
|
||||||
context, std::move(handle_to_decref));
|
context, std::move(handle_to_decref));
|
||||||
s = executor.Async() ? context->Context()->Executor().Add(std::move(node))
|
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
||||||
: node->Run();
|
|
||||||
} else {
|
} else {
|
||||||
s = SendTensor(item.send_tensor(), context->Context());
|
s = SendTensor(item.send_tensor(), context->Context());
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
|||||||
eager_client, ready));
|
eager_client, ready));
|
||||||
auto& executor = ctx->Executor();
|
auto& executor = ctx->Executor();
|
||||||
if (executor.Async()) {
|
if (executor.Async()) {
|
||||||
Status status = executor.Add(std::move(node));
|
Status status = executor.AddOrExecute(std::move(node));
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
@ -56,13 +56,13 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
|||||||
// This thread may still hold tensorflow::StreamingRPCState::mu_. We need
|
// This thread may still hold tensorflow::StreamingRPCState::mu_. We need
|
||||||
// to send out the destroy request in a new thread to avoid deadlock.
|
// to send out the destroy request in a new thread to avoid deadlock.
|
||||||
auto* released_node = node.release();
|
auto* released_node = node.release();
|
||||||
(*ctx->runner())([released_node] {
|
(*ctx->runner())([ctx, released_node] {
|
||||||
Status status = released_node->Run();
|
Status status =
|
||||||
|
ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
||||||
<< status.error_message();
|
<< status.error_message();
|
||||||
}
|
}
|
||||||
delete released_node;
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -523,7 +523,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
err_spec = errors.InvalidArgumentError, (
|
err_spec = errors.InvalidArgumentError, (
|
||||||
'Each record default should be at '
|
'Each record default should be at '
|
||||||
'most rank 1.')
|
'most rank 1')
|
||||||
else:
|
else:
|
||||||
err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
|
err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ class DecodeCSVOpTest(test.TestCase):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
err_spec = errors.InvalidArgumentError, (
|
err_spec = errors.InvalidArgumentError, (
|
||||||
"Each record default should be at "
|
"Each record default should be at "
|
||||||
"most rank 1.")
|
"most rank 1")
|
||||||
else:
|
else:
|
||||||
err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
|
err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
|
||||||
with self.assertRaisesWithPredicateMatch(*err_spec):
|
with self.assertRaisesWithPredicateMatch(*err_spec):
|
||||||
|
@ -107,9 +107,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
handle = resource_variable_ops.var_handle_op(
|
handle = resource_variable_ops.var_handle_op(
|
||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
resource_variable_ops.assign_variable_op(handle, 1)
|
resource_variable_ops.assign_variable_op(handle, 1)
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(
|
||||||
"Trying to read variable with wrong dtype. "
|
errors.InvalidArgumentError,
|
||||||
"Expected float got int32."):
|
"Trying to read variable with wrong dtype. "
|
||||||
|
"Expected float got int32"):
|
||||||
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
|
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
|
||||||
|
|
||||||
def testEagerInitializedValue(self):
|
def testEagerInitializedValue(self):
|
||||||
@ -195,9 +196,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
|||||||
dtype=dtypes.int32, shape=[1], name="foo")
|
dtype=dtypes.int32, shape=[1], name="foo")
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([1]))
|
handle, constant_op.constant([1]))
|
||||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
with self.assertRaisesRegexp(
|
||||||
"Trying to assign variable with wrong "
|
errors.InvalidArgumentError, "Trying to assign variable with wrong "
|
||||||
"dtype. Expected int32 got float."):
|
"dtype. Expected int32 got float"):
|
||||||
resource_variable_ops.assign_variable_op(
|
resource_variable_ops.assign_variable_op(
|
||||||
handle, constant_op.constant([1.], dtype=dtypes.float32))
|
handle, constant_op.constant([1.], dtype=dtypes.float32))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user