Use executor to execute ops in sync mode as well. A few benefits for doing this:
1. Unify sync/async code path a little bit more. 2. Allow executor to capture error in AsyncEagerNode. We may need to find a way to propagate this error to python layer in the future. 3. Allow EagerContext wait for all pending AsyncEagerNode in shutdown. PiperOrigin-RevId: 268956697
This commit is contained in:
parent
c3e32b03e1
commit
bfb4c9bcb2
tensorflow
core
common_runtime/eager
distributed_runtime/eager
python
data/experimental/kernel_tests
kernel_tests
@ -35,7 +35,7 @@ EagerExecutor::~EagerExecutor() {
|
||||
|
||||
Status EagerExecutor::ShutDown() {
|
||||
{
|
||||
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
|
||||
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
|
||||
bool has_thread;
|
||||
Status status;
|
||||
{
|
||||
@ -47,7 +47,9 @@ Status EagerExecutor::ShutDown() {
|
||||
// thread_exited_notification_.WaitForNotification() below.
|
||||
state_ = ExecutorState::kShuttingDown;
|
||||
}
|
||||
WaitForOrDestroyAllPendingNodes(&l, &items_to_destroy);
|
||||
// It is OK to ignore the returned status here because it will be saved
|
||||
// as the final status_.
|
||||
WaitForAllPendingNodesLocked(&l).IgnoreError();
|
||||
state_ = ExecutorState::kShutDown;
|
||||
has_thread = thread_ != nullptr;
|
||||
status = status_;
|
||||
@ -68,36 +70,6 @@ Status EagerExecutor::ShutDown() {
|
||||
return status_;
|
||||
}
|
||||
|
||||
void EagerExecutor::WaitForOrDestroyAllPendingNodes(
|
||||
mutex_lock* lock,
|
||||
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy) {
|
||||
if (state_ == ExecutorState::kShutDown) {
|
||||
return;
|
||||
}
|
||||
if (thread_ == nullptr) {
|
||||
Status status = status_;
|
||||
if (status.ok()) {
|
||||
status = errors::FailedPrecondition(
|
||||
"Aborting eager nodes because EagerExecutor is being shut down "
|
||||
"before it got a thread to run the nodes");
|
||||
status_ = status;
|
||||
}
|
||||
while (!node_queue_.empty()) {
|
||||
nodes_to_destroy->push_back(std::move(node_queue_.front()));
|
||||
node_queue_.pop();
|
||||
}
|
||||
for (auto& it : unfinished_nodes_) {
|
||||
nodes_to_destroy->push_back(absl::WrapUnique(it.second));
|
||||
}
|
||||
unfinished_nodes_.clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// It is OK to ignore the returned status here because it will be saved
|
||||
// as the final status_.
|
||||
WaitForAllPendingNodesLocked(lock).IgnoreError();
|
||||
}
|
||||
|
||||
bool EagerExecutor::Async() const {
|
||||
return thread_ != nullptr;
|
||||
}
|
||||
@ -113,15 +85,19 @@ const char* EagerExecutor::StateStringLocked() {
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
||||
Status EagerExecutor::AddOrExecute(std::unique_ptr<EagerNode> node) {
|
||||
Status status;
|
||||
core::RefCountPtr<NodeItem> item(new NodeItem);
|
||||
item->id = next_node_id_++;
|
||||
item->node = std::move(node);
|
||||
item->state = NodeState::kPENDING;
|
||||
|
||||
// If we are unable to add the node to the queue, we must call Abort. However,
|
||||
// we want to do that outside of the scope of the lock since the Abort may
|
||||
// try to call EagerExecutor::Add()
|
||||
{
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
VLOG(3) << "Add node [id " << next_node_id_ << "]" << node->DebugString()
|
||||
VLOG(3) << "Add node [id " << item->id << "]" << item->node->DebugString()
|
||||
<< " with status: " << status_.ToString();
|
||||
if (state_ != ExecutorState::kActive) {
|
||||
status = errors::FailedPrecondition(
|
||||
@ -129,16 +105,11 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
||||
"Current state is '",
|
||||
StateStringLocked(), "'");
|
||||
} else {
|
||||
DCHECK(thread_) << "EnableAsync should have been called before Add";
|
||||
status = status_;
|
||||
if (status.ok()) {
|
||||
auto item = absl::make_unique<NodeItem>();
|
||||
item->id = next_node_id_++;
|
||||
item->node = std::move(node);
|
||||
if (status.ok() && Async()) {
|
||||
node_queue_.push(std::move(item));
|
||||
|
||||
// If there were no previous nodes pending, wake the run thread to start
|
||||
// processing requests again.
|
||||
// If there were no previous nodes pending, wake the run thread to
|
||||
// start processing requests again.
|
||||
if (node_queue_.size() == 1) {
|
||||
nodes_pending_.notify_all();
|
||||
}
|
||||
@ -148,9 +119,17 @@ Status EagerExecutor::Add(std::unique_ptr<EagerNode> node) {
|
||||
}
|
||||
}
|
||||
|
||||
// Node needs to be aborted since it was not added to the queue
|
||||
node->Abort(status);
|
||||
return status;
|
||||
if (status.ok()) {
|
||||
// Inline execution in sync mode.
|
||||
DCHECK(!Async());
|
||||
RunItem(std::move(item));
|
||||
status = this->status();
|
||||
return status;
|
||||
} else {
|
||||
// Node needs to be aborted since it was not added to the queue
|
||||
item->node->Abort(status);
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status EagerExecutor::WaitForAllPendingNodes() {
|
||||
@ -165,6 +144,8 @@ tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked(
|
||||
if (!status_.ok()) return status_;
|
||||
if (node_queue_.empty() && unfinished_nodes_.empty())
|
||||
return tensorflow::Status::OK();
|
||||
// node_queue_ must be empty in sync mode.
|
||||
DCHECK(Async() || node_queue_.empty());
|
||||
auto last_id = next_node_id_ - 1;
|
||||
VLOG(3) << "Wait for Node: [id " << last_id << "] ";
|
||||
node_done_notifications_.insert(std::make_pair(last_id, &cond));
|
||||
@ -191,25 +172,30 @@ tensorflow::Status EagerExecutor::status() const {
|
||||
return status_;
|
||||
}
|
||||
|
||||
void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
||||
void EagerExecutor::NodeDone(core::RefCountPtr<NodeItem> item,
|
||||
const Status& status) {
|
||||
VLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString()
|
||||
<< " with status: " << status.ToString();
|
||||
std::unique_ptr<NodeItem> current_item;
|
||||
std::vector<std::unique_ptr<NodeItem>> items_to_destroy;
|
||||
DCHECK(item->state != NodeState::kDONE);
|
||||
std::vector<core::RefCountPtr<NodeItem>> items_to_destroy;
|
||||
{
|
||||
mutex_lock l(node_queue_mutex_);
|
||||
auto previous_state = item->state;
|
||||
item->state = NodeState::kDONE;
|
||||
if (!status_.ok()) return;
|
||||
bool need_notification = false;
|
||||
if (!node_queue_.empty() && item == node_queue_.front().get()) {
|
||||
need_notification = unfinished_nodes_.empty();
|
||||
current_item = std::move(node_queue_.front());
|
||||
node_queue_.pop();
|
||||
if (previous_state == NodeState::kPENDING) {
|
||||
if (Async()) {
|
||||
DCHECK(!node_queue_.empty() && item.get() == node_queue_.front().get());
|
||||
need_notification = unfinished_nodes_.empty();
|
||||
node_queue_.pop();
|
||||
} else {
|
||||
need_notification = unfinished_nodes_.empty();
|
||||
}
|
||||
} else {
|
||||
DCHECK(!unfinished_nodes_.empty());
|
||||
need_notification = item->id == unfinished_nodes_.begin()->first;
|
||||
auto erase_result = unfinished_nodes_.erase(item->id);
|
||||
DCHECK_GT(erase_result, 0);
|
||||
current_item = absl::WrapUnique(item);
|
||||
auto result = unfinished_nodes_.erase(item->id);
|
||||
DCHECK_GT(result, 0);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
need_notification = true;
|
||||
@ -217,7 +203,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
||||
// We remove any pending ops so that we don't try to execute them if
|
||||
// ClearError is called.
|
||||
errors::AppendToMessage(&status_,
|
||||
". Encountered when executing an operation using "
|
||||
"Encountered when executing an operation using "
|
||||
"EagerExecutor. This error cancels all future "
|
||||
"operations and poisons their output tensors.");
|
||||
while (!node_queue_.empty()) {
|
||||
@ -225,7 +211,7 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
||||
node_queue_.pop();
|
||||
}
|
||||
for (auto& it : unfinished_nodes_) {
|
||||
items_to_destroy.push_back(absl::WrapUnique(it.second));
|
||||
items_to_destroy.push_back(std::move(it.second));
|
||||
}
|
||||
unfinished_nodes_.clear();
|
||||
}
|
||||
@ -238,6 +224,8 @@ void EagerExecutor::NodeDone(NodeItem* item, const Status& status) {
|
||||
} else {
|
||||
upperbound_id = next_node_id_ - 1;
|
||||
}
|
||||
VLOG(3) << "Notify node done: [id " << item->id << " to " << upperbound_id
|
||||
<< "] ";
|
||||
// Note that we notify all waiting threads in case an error has
|
||||
// occurred. These calling threads are responsible for checking status_
|
||||
// before proceeding.
|
||||
@ -266,7 +254,7 @@ void EagerExecutor::Run() {
|
||||
auto thread_exited_notifier =
|
||||
gtl::MakeCleanup([this] { thread_exited_notification_.Notify(); });
|
||||
while (true) {
|
||||
NodeItem* curr_item_raw;
|
||||
core::RefCountPtr<NodeItem> curr_item;
|
||||
{
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
while (node_queue_.empty() || !status_.ok()) {
|
||||
@ -280,30 +268,39 @@ void EagerExecutor::Run() {
|
||||
// will then contain a nullptr. This can be a problem in
|
||||
// WaitForAllPendingNodes where we get the top EagerNode pointer
|
||||
// and register a notification for its completion.
|
||||
curr_item_raw = node_queue_.front().get();
|
||||
curr_item.reset(node_queue_.front().get());
|
||||
curr_item->Ref();
|
||||
}
|
||||
VLOG(3) << "Running Node: [id " << curr_item_raw->id << "] "
|
||||
<< curr_item_raw->node->DebugString();
|
||||
AsyncEagerNode* async_node_raw = curr_item_raw->node->AsAsync();
|
||||
if (async_node_raw == nullptr) {
|
||||
tensorflow::Status status = curr_item_raw->node->Run();
|
||||
NodeDone(curr_item_raw, status);
|
||||
} else {
|
||||
async_node_raw->RunAsync([this, curr_item_raw](const Status& status) {
|
||||
NodeDone(curr_item_raw, status);
|
||||
});
|
||||
{
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
// If false, NodeDone has been called.
|
||||
if (!node_queue_.empty() &&
|
||||
curr_item_raw == node_queue_.front().get()) {
|
||||
node_queue_.front().release();
|
||||
node_queue_.pop();
|
||||
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(),
|
||||
curr_item_raw->id, curr_item_raw);
|
||||
}
|
||||
}
|
||||
RunItem(std::move(curr_item));
|
||||
}
|
||||
}
|
||||
|
||||
void EagerExecutor::RunItem(core::RefCountPtr<NodeItem> item) {
|
||||
VLOG(3) << "Running Node: [id " << item->id << "] "
|
||||
<< item->node->DebugString();
|
||||
AsyncEagerNode* async_node = item->node->AsAsync();
|
||||
if (async_node == nullptr) {
|
||||
core::RefCountPtr<NodeItem> new_ref(item.get());
|
||||
new_ref->Ref();
|
||||
tensorflow::Status status = item->node->Run();
|
||||
NodeDone(std::move(new_ref), status);
|
||||
} else {
|
||||
auto* new_ref = item.get();
|
||||
new_ref->Ref();
|
||||
async_node->RunAsync([this, new_ref](const Status& status) {
|
||||
core::RefCountPtr<NodeItem> new_item(new_ref);
|
||||
NodeDone(std::move(new_item), status);
|
||||
});
|
||||
}
|
||||
tensorflow::mutex_lock l(node_queue_mutex_);
|
||||
if (item->state == NodeState::kPENDING) {
|
||||
item->state = NodeState::kSCHEDULED;
|
||||
if (!node_queue_.empty() && item.get() == node_queue_.front().get()) {
|
||||
node_queue_.pop();
|
||||
}
|
||||
VLOG(3) << "Add Node: [id " << item->id << "] to unfinished map.";
|
||||
unfinished_nodes_.emplace_hint(unfinished_nodes_.end(), item->id,
|
||||
std::move(item));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
@ -73,12 +75,8 @@ class AsyncEagerNode : public EagerNode {
|
||||
|
||||
AsyncEagerNode* AsAsync() final { return this; }
|
||||
|
||||
// This is non-blocking. It returns the scheduling status.
|
||||
// TODO(fishx): avoid calling this AsyncEagerNode::Run.
|
||||
Status Run() final {
|
||||
std::shared_ptr<Status> status(new Status);
|
||||
RunAsync([status](const Status& s) { status->Update(s); });
|
||||
return *status;
|
||||
return errors::Unimplemented("Don't call AsyncEagerNode::Run().");
|
||||
}
|
||||
};
|
||||
|
||||
@ -105,10 +103,11 @@ class EagerExecutor {
|
||||
|
||||
bool Async() const;
|
||||
|
||||
// Schedules `node` for execution. If an error occurs (e.g. EagerExecutor
|
||||
// has already been shut down), the `node` is not added to this executor
|
||||
// and its Abort() method is called.
|
||||
Status Add(std::unique_ptr<EagerNode> node);
|
||||
// - Async Mode: schedules `node` for execution.
|
||||
// - Sync Mode: inline execute the 'node' directly.
|
||||
// If an error occurs (e.g. EagerExecutor has already been shut down), the
|
||||
// `node` is not added to this executor and its Abort() method is called.
|
||||
Status AddOrExecute(std::unique_ptr<EagerNode> node);
|
||||
|
||||
// Blocks till all currently pending ops are done.
|
||||
// In particular, if EnableAsync() has not beed called, it will not return
|
||||
@ -139,16 +138,23 @@ class EagerExecutor {
|
||||
kShutDown,
|
||||
};
|
||||
|
||||
struct NodeItem {
|
||||
enum class NodeState {
|
||||
kPENDING,
|
||||
kSCHEDULED,
|
||||
kDONE,
|
||||
};
|
||||
|
||||
struct NodeItem : core::RefCounted {
|
||||
// Unique id generated in EagerExecutor::Add(). If item1.id < item2.id, it
|
||||
// means item1.node is added before item2.node.
|
||||
uint64 id;
|
||||
std::unique_ptr<EagerNode> node;
|
||||
NodeState state;
|
||||
};
|
||||
|
||||
const char* StateStringLocked() EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||
|
||||
void NodeDone(NodeItem* item, const Status& status);
|
||||
void NodeDone(core::RefCountPtr<NodeItem> item, const Status& status);
|
||||
|
||||
// Starts execution of pending EagerNodes. This function loops till
|
||||
// thread_done_ is set to true. If any errors are encontered, these are set
|
||||
@ -156,20 +162,13 @@ class EagerExecutor {
|
||||
// `status_` is not ok.
|
||||
void Run();
|
||||
|
||||
void RunItem(core::RefCountPtr<NodeItem> item);
|
||||
|
||||
// The impl of WaitForAllPendingNodes
|
||||
// `lock` is the lock that holds node_queue_mutex_.
|
||||
Status WaitForAllPendingNodesLocked(mutex_lock* lock)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||
|
||||
// If async has been enabled on this executor, just calls
|
||||
// WaitForAllPendingNodes. Else sets the status_ to an error if it does not
|
||||
// already contain one `lock` is the lock that holds node_queue_mutex_.
|
||||
// Precondition: state_ != kActive.
|
||||
void WaitForOrDestroyAllPendingNodes(
|
||||
mutex_lock* lock,
|
||||
std::vector<std::unique_ptr<NodeItem>>* nodes_to_destroy)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_);
|
||||
|
||||
Status WaitImpl(bool wait_all, uint64 node_id);
|
||||
|
||||
std::atomic<uint64> next_node_id_;
|
||||
@ -180,12 +179,12 @@ class EagerExecutor {
|
||||
condition_variable nodes_pending_ GUARDED_BY(node_queue_mutex_);
|
||||
|
||||
// Queue of pending NodeItems. Ordered by NodeItem::id.
|
||||
std::queue<std::unique_ptr<NodeItem>> node_queue_
|
||||
std::queue<core::RefCountPtr<NodeItem>> node_queue_
|
||||
GUARDED_BY(node_queue_mutex_);
|
||||
|
||||
// Owned the NodeItem in it. Ordered by NodeItem::id.
|
||||
std::map<uint64, NodeItem*, std::less<uint64>> unfinished_nodes_
|
||||
GUARDED_BY(node_queue_mutex_);
|
||||
// Ordered by NodeItem::id.
|
||||
std::map<uint64, core::RefCountPtr<NodeItem>, std::less<uint64>>
|
||||
unfinished_nodes_ GUARDED_BY(node_queue_mutex_);
|
||||
|
||||
// `status_` is set based on any errors raised during execution of a
|
||||
// EagerNode. It remains set until ClearError is called.
|
||||
|
@ -362,6 +362,7 @@ Status GetDeviceForInput(const EagerContext* ctx, TensorHandle* tensor_handle,
|
||||
// Use the resource's actual device because it is the device that will
|
||||
// influence partitioning the multi-device function.
|
||||
const Tensor* tensor;
|
||||
// TODO(fishx): Avoid blocking here.
|
||||
TF_RETURN_IF_ERROR(tensor_handle->Tensor(&tensor));
|
||||
const ResourceHandle& handle = tensor->flat<ResourceHandle>()(0);
|
||||
device_name = handle.device();
|
||||
@ -639,7 +640,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// input handles are ready before executing them.
|
||||
// TODO(b/137118203): Consider executing "cheap" kernels inline for
|
||||
// performance.
|
||||
Status s = executor.Async() ? executor.Add(std::move(node)) : node->Run();
|
||||
Status s = executor.AddOrExecute(std::move(node));
|
||||
// Since the operation failed, we need to Unref any outputs that were
|
||||
// allocated.
|
||||
if (!s.ok()) {
|
||||
@ -765,14 +766,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
}
|
||||
|
||||
auto& executor = op->Executor();
|
||||
bool is_async = executor.Async();
|
||||
VLOG(4) << "Execute remote eager op: " << op->Name()
|
||||
<< " (is async?: " << is_async << ").";
|
||||
<< " (is async?: " << executor.Async() << ").";
|
||||
|
||||
std::unique_ptr<EagerNode> node(
|
||||
new eager::RemoteExecuteNode(std::move(request), op_device, eager_client,
|
||||
op->Inputs(), {retvals, num_outputs}));
|
||||
Status s = is_async ? executor.Add(std::move(node)) : node->Run();
|
||||
Status s = executor.AddOrExecute(std::move(node));
|
||||
// Since the operation failed, we need to Unref any outputs that were
|
||||
// allocated.
|
||||
if (!s.ok()) {
|
||||
@ -904,6 +904,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
|
||||
bool op_is_local = op->EagerContext()->IsLocalDeviceName(op->GetDeviceName());
|
||||
|
||||
if (!op->Executor().Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
// TODO(b/141004939): Remove this.
|
||||
op->Executor().ClearError();
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::EagerOperation> out_op;
|
||||
TF_RETURN_IF_ERROR(EagerOpRewriteRegistry::Global()->RunRewrite(
|
||||
EagerOpRewriteRegistry::PRE_EXECUTION, op, &out_op));
|
||||
@ -1044,7 +1050,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
// Note that `h` may not be currently ready. However execution order will
|
||||
// make sure that `h` is ready before the copy is actually done.
|
||||
std::unique_ptr<EagerNode> node(new CopyToDeviceNode(h, *result, dstd, ctx));
|
||||
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
|
||||
Status s = executor->AddOrExecute(std::move(node));
|
||||
// Since the operation failed, we need to Unref any outputs that were
|
||||
// allocated.
|
||||
if (!s.ok()) {
|
||||
@ -1065,6 +1071,12 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
|
||||
bool recver_is_local = device->IsLocal();
|
||||
|
||||
if (!executor->Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
// TODO(b/141004939): Remove this.
|
||||
executor->ClearError();
|
||||
}
|
||||
|
||||
if (sender_is_local && recver_is_local) {
|
||||
return LocalEagerCopyToDevice(h, ctx, executor, device, result);
|
||||
} else {
|
||||
@ -1105,7 +1117,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
}
|
||||
auto node = absl::make_unique<eager::RemoteCopyNode>(
|
||||
ctx, executor, h, result[0], device, recv_op_id);
|
||||
Status s = executor->Async() ? executor->Add(std::move(node)) : node->Run();
|
||||
Status s = executor->AddOrExecute(std::move(node));
|
||||
if (!s.ok()) {
|
||||
result[0]->Unref();
|
||||
}
|
||||
|
@ -296,8 +296,7 @@ Status EagerServiceImpl::Enqueue(const EnqueueRequest* request,
|
||||
item.handle_to_decref());
|
||||
auto node = absl::make_unique<ClientTensorHandleDeleteNode>(
|
||||
context, std::move(handle_to_decref));
|
||||
s = executor.Async() ? context->Context()->Executor().Add(std::move(node))
|
||||
: node->Run();
|
||||
s = context->Context()->Executor().AddOrExecute(std::move(node));
|
||||
} else {
|
||||
s = SendTensor(item.send_tensor(), context->Context());
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
||||
eager_client, ready));
|
||||
auto& executor = ctx->Executor();
|
||||
if (executor.Async()) {
|
||||
Status status = executor.Add(std::move(node));
|
||||
Status status = executor.AddOrExecute(std::move(node));
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
||||
<< status.error_message();
|
||||
@ -56,13 +56,13 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
||||
// This thread may still hold tensorflow::StreamingRPCState::mu_. We need
|
||||
// to send out the destroy request in a new thread to avoid deadlock.
|
||||
auto* released_node = node.release();
|
||||
(*ctx->runner())([released_node] {
|
||||
Status status = released_node->Run();
|
||||
(*ctx->runner())([ctx, released_node] {
|
||||
Status status =
|
||||
ctx->Executor().AddOrExecute(absl::WrapUnique(released_node));
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Unable to destroy remote tensor handles: "
|
||||
<< status.error_message();
|
||||
}
|
||||
delete released_node;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -523,7 +523,7 @@ class CsvDatasetTest(test_base.DatasetTestBase):
|
||||
if context.executing_eagerly():
|
||||
err_spec = errors.InvalidArgumentError, (
|
||||
'Each record default should be at '
|
||||
'most rank 1.')
|
||||
'most rank 1')
|
||||
else:
|
||||
err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
|
||||
|
||||
|
@ -73,7 +73,7 @@ class DecodeCSVOpTest(test.TestCase):
|
||||
if context.executing_eagerly():
|
||||
err_spec = errors.InvalidArgumentError, (
|
||||
"Each record default should be at "
|
||||
"most rank 1.")
|
||||
"most rank 1")
|
||||
else:
|
||||
err_spec = ValueError, "Shape must be at most rank 1 but is rank 2"
|
||||
with self.assertRaisesWithPredicateMatch(*err_spec):
|
||||
|
@ -107,9 +107,10 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
handle = resource_variable_ops.var_handle_op(
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
resource_variable_ops.assign_variable_op(handle, 1)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Trying to read variable with wrong dtype. "
|
||||
"Expected float got int32."):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Trying to read variable with wrong dtype. "
|
||||
"Expected float got int32"):
|
||||
_ = resource_variable_ops.read_variable_op(handle, dtype=dtypes.float32)
|
||||
|
||||
def testEagerInitializedValue(self):
|
||||
@ -195,9 +196,9 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
dtype=dtypes.int32, shape=[1], name="foo")
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([1]))
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Trying to assign variable with wrong "
|
||||
"dtype. Expected int32 got float."):
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError, "Trying to assign variable with wrong "
|
||||
"dtype. Expected int32 got float"):
|
||||
resource_variable_ops.assign_variable_op(
|
||||
handle, constant_op.constant([1.], dtype=dtypes.float32))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user