Make TensorHandle wait on SetTensor/SetRemoteShape instead of on eager executor
Before this change async TensorHandles were tightly coupled with EagerExecutor. They knew the id of the node that would produce the Tensor, called executor->WaitForNode(), and relied on the fact that once WaitForNode returns, SetTensor would have been called. As we are making RPCs async, TensorHandles can now be non-ready not just because of local async execution, but because of remote execution as well. We need a way to wait for remote tensor handles to be ready. Instead of introducing a new mechanism, distinct from executor->WaitForNode(), this change eliminates dependence on executor. We add a notification to async tensor handles. Instead of calling executor->WaitForNode, TensorHandle::WaitReady() now simply waits for this notification. The notification can be signaled either by calling SetTensor or SetRemoteShape. TensorHandle does not care about the details of the async mechanism that ends up calling SetTensor or SetRemoteShape. The dangerous aspect of this design is that SetTensor or SetRemoteShape must always be called, even when there is an error. Otherwise, we are likely to deadlock. However, this guarantee is required to support poisoning individual tensor handles instead of the whole executor, as it is done today. So, it is probably a worthwhile complexity to take on. This CL also adds TensorHandle::Poison() method mark the handle as poisoned. Later when some operation tries to use poisoned handles, we will fail that operation. PiperOrigin-RevId: 253321251
This commit is contained in:
parent
0fb4381cfe
commit
81655b9ad3
@ -28,7 +28,7 @@ class CopyToDeviceNode : public EagerNode {
|
||||
CopyToDeviceNode(TensorHandle* src, Device* dstd, EagerContext* ctx)
|
||||
: EagerNode(ctx->NextId()), src_(src), dstd_(dstd), ctx_(ctx) {
|
||||
src_->Ref();
|
||||
status_ = TensorHandle::CreateAsyncLocalHandle(id, dstd_, dstd_, nullptr,
|
||||
status_ = TensorHandle::CreateAsyncLocalHandle(dstd_, dstd_, nullptr,
|
||||
src_->dtype, ctx, &dst_);
|
||||
if (status_.ok()) {
|
||||
dst_->Ref();
|
||||
@ -47,9 +47,13 @@ class CopyToDeviceNode : public EagerNode {
|
||||
return status_;
|
||||
}
|
||||
TensorHandle* temp = nullptr;
|
||||
TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &temp));
|
||||
Status status = src_->CopyToDevice(ctx_, dstd_, &temp);
|
||||
if (!status.ok()) {
|
||||
dst_->Poison(status);
|
||||
return status;
|
||||
}
|
||||
const Tensor* tensor = nullptr;
|
||||
Status status = temp->Tensor(&tensor);
|
||||
status = temp->Tensor(&tensor);
|
||||
// `temp` is a ready handle. So the following call should return OK.
|
||||
TF_DCHECK_OK(status) << status.error_message();
|
||||
DCHECK(tensor);
|
||||
@ -58,6 +62,8 @@ class CopyToDeviceNode : public EagerNode {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Abort(Status status) override { dst_->Poison(status); }
|
||||
|
||||
TensorHandle* dst() { return dst_; }
|
||||
|
||||
private:
|
||||
|
@ -133,11 +133,14 @@ void EagerExecutor::Run() {
|
||||
node_queue_.pop();
|
||||
if (!ok) {
|
||||
status_ = status;
|
||||
// TODO(agarwal): mark all affected handles as corrupted before clearing
|
||||
// this queue.
|
||||
// We remove any pending ops so that we don't try to execute them if
|
||||
// ClearError is called.
|
||||
errors::AppendToMessage(&status,
|
||||
". Encountered when executing an operation using "
|
||||
"EagerExecutor. This error cancels all future "
|
||||
"operations and poisons their output tensors.");
|
||||
for (int i = 0; i < node_queue_.size(); ++i) {
|
||||
node_queue_.front()->Abort(status);
|
||||
delete node_queue_.front();
|
||||
node_queue_.pop();
|
||||
}
|
||||
|
@ -50,6 +50,12 @@ class EagerNode {
|
||||
// execution is done.
|
||||
virtual Status Run() = 0;
|
||||
|
||||
// Called when this node will not be run due to some error contained in
|
||||
// `status`. `status` must not be OK.
|
||||
// For example, if the node would have computed some tensors in the Run(),
|
||||
// it should poison the corresponding tensor handles in this method.
|
||||
virtual void Abort(Status status) = 0;
|
||||
|
||||
// An id unique to the TFE_Context under which this node is created. Allocated
|
||||
// monotonically.
|
||||
const uint64 id;
|
||||
@ -78,6 +84,8 @@ class EagerExecutor {
|
||||
uint64 NextId();
|
||||
|
||||
// Schedules `node` for execution.
|
||||
// Takes ownership of `node`.
|
||||
// TODO(iga): take a unique_ptr instead.
|
||||
// Note that Add must be called in monotonically increasing order of node->id.
|
||||
void Add(EagerNode* node);
|
||||
|
||||
|
@ -641,7 +641,7 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
tensorflow::uint64 id = ctx->NextId();
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateAsyncLocalHandle(
|
||||
id, /* d= */ kernel->OutputDevice(i),
|
||||
/* d= */ kernel->OutputDevice(i),
|
||||
/* op_device= */ kernel->device(),
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i),
|
||||
output_dtypes[i], ctx, &(*retvals)[i]));
|
||||
@ -817,8 +817,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// to copy this tensor to this process, the remote end will know the
|
||||
// correct device of this handle.
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
||||
id, i, remote_node_id, eager_client, context_id, output_dtypes[i],
|
||||
op_device, output_dtypes[i] == DT_RESOURCE ? op_device : nullptr, ctx,
|
||||
id, i, eager_client, context_id, output_dtypes[i], op_device,
|
||||
output_dtypes[i] == DT_RESOURCE ? op_device : nullptr, ctx,
|
||||
&retvals[i]));
|
||||
}
|
||||
|
||||
@ -844,9 +844,13 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
[inputs](const gtl::InlinedVector<TensorHandle*, 2>& retvals,
|
||||
const Status& status,
|
||||
const eager::EnqueueResponse& response) {
|
||||
if (!status.ok()) return;
|
||||
for (int i = 0; i < retvals.size(); i++) {
|
||||
retvals[i]->SetRemoteShape(response.queue_response(0).shape(i));
|
||||
if (status.ok()) {
|
||||
retvals[i]->SetRemoteShape(
|
||||
response.queue_response(0).shape(i));
|
||||
} else {
|
||||
retvals[i]->Poison(status);
|
||||
}
|
||||
retvals[i]->Unref();
|
||||
}
|
||||
for (auto* handle : inputs) {
|
||||
|
@ -63,17 +63,28 @@ class ExecuteNode : public EagerNode {
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::Status Run() override {
|
||||
Status Run() override {
|
||||
const Status status = EagerKernelExecute(
|
||||
ctx_, inputs_, kernel_.get(), maybe_stats_.get(), maybe_step_stats_,
|
||||
graph_collector_, retvals_.begin(), retvals_.size());
|
||||
if (status.ok()) {
|
||||
// If status is ok, EagerKernelExecute would have called SetTensor on
|
||||
// all the output handles.
|
||||
return status;
|
||||
} else {
|
||||
return Status(status.code(),
|
||||
strings::StrCat("Got error, \"", status.error_message(),
|
||||
"\" while executing kernel ",
|
||||
kernel_->kernel()->def().DebugString()));
|
||||
Status s =
|
||||
Status(status.code(),
|
||||
strings::StrCat("Got error, \"", status.error_message(),
|
||||
"\" while executing kernel ",
|
||||
kernel_->kernel()->def().DebugString()));
|
||||
Abort(s);
|
||||
return s;
|
||||
}
|
||||
}
|
||||
|
||||
void Abort(Status status) override {
|
||||
for (auto handle : retvals_) {
|
||||
handle->Poison(status);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,9 +138,10 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
remote_output_num_(-1),
|
||||
#endif
|
||||
ctx_(ctx),
|
||||
is_ready_(true),
|
||||
is_remote_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
@ -155,21 +156,20 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
remote_output_num_(-1),
|
||||
#endif
|
||||
ctx_(ctx),
|
||||
is_ready_(true),
|
||||
is_remote_(false),
|
||||
resource_handle_container_(resource_handle.container()),
|
||||
resource_handle_name_(resource_handle.name()),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateAsyncLocalHandle(uint64 node_id, Device* d,
|
||||
Device* op_device,
|
||||
Status TensorHandle::CreateAsyncLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
DataType dtype, EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
*h = new TensorHandle(
|
||||
absl::make_unique<AsyncLocalTensorHandleData>(node_id, ctx), d, op_device,
|
||||
resource_device, dtype, ctx);
|
||||
*h = new TensorHandle(absl::make_unique<AsyncLocalTensorHandleData>(), d,
|
||||
op_device, resource_device, dtype, ctx);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -187,7 +187,6 @@ TensorHandle::TensorHandle(std::unique_ptr<AsyncLocalTensorHandleData> t,
|
||||
remote_output_num_(-1),
|
||||
#endif
|
||||
ctx_(ctx),
|
||||
is_ready_(false),
|
||||
is_remote_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
}
|
||||
@ -203,7 +202,6 @@ Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
|
||||
absl::make_unique<RemoteTensorHandleData>(op_id, output_num, shape,
|
||||
eager_client, context_id, ctx),
|
||||
dtype, d, resource_device, ctx);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -217,22 +215,22 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
remote_op_id_(t->op_id()),
|
||||
remote_output_num_(t->output_num()),
|
||||
ctx_(ctx),
|
||||
is_ready_(true),
|
||||
is_remote_(true),
|
||||
tensor_handle_data_(std::move(t)) {}
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateUnshapedRemoteHandle(
|
||||
int64 op_id, int32 output_num, uint64 shape_node_id,
|
||||
eager::EagerClient* eager_client, uint64 context_id, DataType dtype,
|
||||
Device* d, Device* resource_device, EagerContext* ctx, TensorHandle** h) {
|
||||
int64 op_id, int32 output_num, eager::EagerClient* eager_client,
|
||||
uint64 context_id, DataType dtype, Device* d, Device* resource_device,
|
||||
EagerContext* ctx, TensorHandle** h) {
|
||||
DCHECK(dtype == DT_RESOURCE ? resource_device != nullptr
|
||||
: resource_device == nullptr);
|
||||
|
||||
*h = new TensorHandle(
|
||||
absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||
op_id, output_num, shape_node_id, eager_client, context_id, ctx),
|
||||
dtype, d, resource_device, ctx);
|
||||
|
||||
*h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||
op_id, output_num, eager_client, context_id, ctx),
|
||||
dtype, d, resource_device, ctx);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -248,7 +246,6 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||
remote_eager_client_(t->eager_client()),
|
||||
remote_context_id_(t->context_id()),
|
||||
ctx_(ctx),
|
||||
is_ready_(false),
|
||||
is_remote_(true),
|
||||
tensor_handle_data_(std::move(t)) {}
|
||||
#endif
|
||||
@ -263,59 +260,47 @@ TensorHandle::TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype)
|
||||
remote_output_num_(-1),
|
||||
#endif
|
||||
ctx_(nullptr),
|
||||
is_ready_(true),
|
||||
is_remote_(false),
|
||||
symbolic_tensor_(new OutputGraphNode(symbolic_tensor)) {
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::WaitReady() {
|
||||
while (true) {
|
||||
{
|
||||
tf_shared_lock l(ready_mutex_);
|
||||
if (is_ready_) return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(tensor_handle_data_->WaitReady());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
is_ready_notification_.WaitForNotification();
|
||||
return is_poisoned_;
|
||||
}
|
||||
|
||||
Status TensorHandle::Tensor(const tensorflow::Tensor** t) {
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->Tensor(t);
|
||||
}
|
||||
|
||||
Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->TensorValue(t);
|
||||
}
|
||||
|
||||
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->Shape(shape);
|
||||
}
|
||||
|
||||
Status TensorHandle::NumDims(int* num_dims) {
|
||||
DCHECK(num_dims != nullptr);
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->NumDims(num_dims);
|
||||
}
|
||||
|
||||
Status TensorHandle::Dim(int dim_index, int64* dim) {
|
||||
DCHECK(dim != nullptr);
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->Dim(dim_index, dim);
|
||||
}
|
||||
|
||||
Status TensorHandle::NumElements(int64* num_elements) {
|
||||
DCHECK(num_elements != nullptr);
|
||||
TF_RETURN_IF_ERROR(WaitReady());
|
||||
|
||||
return tensor_handle_data_->NumElements(num_elements);
|
||||
}
|
||||
|
||||
@ -330,30 +315,41 @@ Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) const {
|
||||
*output_num = remote_output_num_;
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
void TensorHandle::SetTensor(const tensorflow::Tensor& tensor) {
|
||||
mutex_lock l(ready_mutex_);
|
||||
DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
|
||||
DCHECK(!is_ready_) << "SetTensor is only called on non-ready handles.";
|
||||
|
||||
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
|
||||
is_ready_ = true;
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
void TensorHandle::SetRemoteShape(const TensorShape& shape) {
|
||||
mutex_lock l(ready_mutex_);
|
||||
DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
|
||||
DCHECK(!is_ready_) << "SetRemoteShape is only called on non-ready handles.";
|
||||
DCHECK(!is_ready_notification_.HasBeenNotified())
|
||||
<< "SetRemoteShape is only called on non-ready handles.";
|
||||
|
||||
UnshapedRemoteTensorHandleData* p =
|
||||
reinterpret_cast<UnshapedRemoteTensorHandleData*>(
|
||||
tensor_handle_data_.get());
|
||||
p->ReleaseRemoteTensorHandle();
|
||||
tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
|
||||
remote_op_id_, remote_output_num_, shape, remote_eager_client_,
|
||||
remote_context_id_, ctx_);
|
||||
is_ready_ = true;
|
||||
is_poisoned_ = Status::OK();
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
#endif
|
||||
|
||||
void TensorHandle::SetTensor(const tensorflow::Tensor& tensor) {
|
||||
DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
|
||||
DCHECK(!is_ready_notification_.HasBeenNotified())
|
||||
<< "SetTensor is only called on non-ready handles.";
|
||||
|
||||
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
|
||||
is_poisoned_ = Status::OK();
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
void TensorHandle::Poison(Status status) {
|
||||
DCHECK(!is_ready_notification_.HasBeenNotified())
|
||||
<< "Poison(status) can only be called on non-ready handles.";
|
||||
is_poisoned_ = status;
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
TensorHandle** output) {
|
||||
const tensorflow::Tensor* src = nullptr;
|
||||
|
@ -47,6 +47,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
@ -94,8 +95,7 @@ class TensorHandle : public core::RefCounted {
|
||||
static Status CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
Device* op_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateAsyncLocalHandle(uint64 node_id, Device* d,
|
||||
Device* op_device,
|
||||
static Status CreateAsyncLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device, DataType dtype,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -106,7 +106,6 @@ class TensorHandle : public core::RefCounted {
|
||||
Device* resource_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
uint64 shape_node_id,
|
||||
eager::EagerClient* eager_client,
|
||||
uint64 context_id, DataType dtype,
|
||||
Device* d, Device* resource_device,
|
||||
@ -117,7 +116,7 @@ class TensorHandle : public core::RefCounted {
|
||||
TensorHandle(OutputGraphNode symbolic_tensor, DataType dtype);
|
||||
|
||||
~TensorHandle() override {
|
||||
VLOG(1) << "Deleting internal TensorHandle " << this;
|
||||
VLOG(3) << "Deleting internal TensorHandle " << this;
|
||||
}
|
||||
|
||||
Status Tensor(const tensorflow::Tensor** t);
|
||||
@ -142,20 +141,28 @@ class TensorHandle : public core::RefCounted {
|
||||
// transitions the tensor handle from a non-ready to a ready state by
|
||||
// replacing the backing data abstraction to allow for the shape to be
|
||||
// queried.
|
||||
// This method or Poison must be called exactly once for remote tensors that
|
||||
// were created without a known shape.
|
||||
void SetRemoteShape(const TensorShape& shape);
|
||||
#endif
|
||||
|
||||
// Note that this can be called at most once, and only on non-ready handles,
|
||||
// and makes them ready.
|
||||
// Sets the `tensor` for this async non-ready handle making it ready.
|
||||
// This method or Poison must be called exactly once for non-ready async
|
||||
// handles to make them ready.
|
||||
void SetTensor(const tensorflow::Tensor& tensor);
|
||||
|
||||
// Poisons this non-ready handle with an error `status`.
|
||||
// Poisoning means that the handle will become ready and methods trying
|
||||
// to access the actual tensor or shape will return this error `status`.
|
||||
// Exactly one of SetTensor, SetRemoteShape, or Poison methods must be called
|
||||
// on a non-ready tensor.
|
||||
void Poison(Status status);
|
||||
|
||||
Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
TensorHandle** output);
|
||||
|
||||
// Warning: can return nullptr for CPU tensors.
|
||||
EagerContext* Context() {
|
||||
return ctx_;
|
||||
}
|
||||
EagerContext* Context() { return ctx_; }
|
||||
|
||||
// dtype for the handle. It must be the same as t.dtype() once the handle is
|
||||
// ready.
|
||||
@ -217,19 +224,20 @@ class TensorHandle : public core::RefCounted {
|
||||
// `ctx` object is not owned and should outlive this handle.
|
||||
EagerContext* const ctx_;
|
||||
|
||||
bool is_ready_ GUARDED_BY(ready_mutex_);
|
||||
bool is_remote_;
|
||||
// Explanation for NOLINT below: absl has clang-tidy macro to rename
|
||||
// 'tensorflow::Notification' to 'absl::Notification'. TF does not use
|
||||
// absl::Notification in open source now, so we can't follow clang-tidy
|
||||
tensorflow::Notification is_ready_notification_; // NOLINT
|
||||
// Does not need synchronization because it can be accessed only after
|
||||
// WaitReady() has returned. At that point, is_poisoned_ is immutable.
|
||||
Status is_poisoned_;
|
||||
const bool is_remote_;
|
||||
|
||||
// When non-NULL, this tensor handle instance represents a symbolic tensor
|
||||
// (corresponding to a graph node), whose concrete value is to be produced by
|
||||
// executing that graph node.
|
||||
std::unique_ptr<OutputGraphNode> symbolic_tensor_;
|
||||
|
||||
// A TensorHandle may be in a non-ready state because it is being backed by
|
||||
// an async node. We need this mutex to allow clients to block on the
|
||||
// TensorHandle until it is ready.
|
||||
mutable mutex ready_mutex_;
|
||||
|
||||
// If this TensorHandle is 1) a local tensor, and 2) a resource handle, we
|
||||
// we store the container and name to be able to get the data type and shape
|
||||
// in a call to GetResourceVariableDtypeAndShape.
|
||||
@ -240,6 +248,8 @@ class TensorHandle : public core::RefCounted {
|
||||
// Further, it can be in a non-ready state. It would become ready with a call
|
||||
// to either SetTensor or SetRemoteShape which replaces the underlying data
|
||||
// with a ready version of the tensor handle data.
|
||||
// Does not need synchronization because it can be accessed only after
|
||||
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
|
||||
std::unique_ptr<TensorHandleData> tensor_handle_data_;
|
||||
};
|
||||
|
||||
|
@ -58,12 +58,6 @@ Status LocalTensorHandleData::NumElements(int64* num_elements) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AsyncLocalTensorHandleData::AsyncLocalTensorHandleData(uint64 node_id,
|
||||
EagerContext* ctx)
|
||||
: node_id_(node_id), ctx_(ctx) {
|
||||
DCHECK_GT(node_id_, 0);
|
||||
}
|
||||
|
||||
Status AsyncLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
|
||||
return errors::Unavailable(
|
||||
"Unable to get a tensor for an async handle. "
|
||||
@ -100,16 +94,8 @@ Status AsyncLocalTensorHandleData::NumElements(int64* num_elements) const {
|
||||
"Please wait until it is ready");
|
||||
}
|
||||
|
||||
Status AsyncLocalTensorHandleData::WaitReady() {
|
||||
EagerExecutor* executor = nullptr;
|
||||
executor = ctx_->Executor();
|
||||
TF_RETURN_IF_ERROR(executor->WaitFor(node_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string AsyncLocalTensorHandleData::DebugString() const {
|
||||
return strings::StrCat("AsyncLocalTensorHandleData:", " node_id: ", node_id_);
|
||||
return "AsyncLocalTensorHandleData";
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -35,8 +35,6 @@ class TensorHandleData {
|
||||
virtual Status Dim(int dim_index, int64* dim) const = 0;
|
||||
virtual Status NumElements(int64* num_elements) const = 0;
|
||||
|
||||
virtual Status WaitReady() = 0;
|
||||
|
||||
virtual string DebugString() const = 0;
|
||||
};
|
||||
|
||||
@ -54,8 +52,6 @@ class LocalTensorHandleData : public TensorHandleData {
|
||||
Status Dim(int dim_index, int64* dim) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
||||
Status WaitReady() override { return Status::OK(); };
|
||||
|
||||
string DebugString() const override { return tensor_.DebugString(); }
|
||||
|
||||
private:
|
||||
@ -67,7 +63,7 @@ class LocalTensorHandleData : public TensorHandleData {
|
||||
// tensor handle.
|
||||
class AsyncLocalTensorHandleData : public TensorHandleData {
|
||||
public:
|
||||
AsyncLocalTensorHandleData(uint64 node_id, EagerContext* ctx);
|
||||
AsyncLocalTensorHandleData() {}
|
||||
~AsyncLocalTensorHandleData() override {}
|
||||
|
||||
// Async tensor handles are not ready and hence cannot satisfy any of these
|
||||
@ -79,17 +75,7 @@ class AsyncLocalTensorHandleData : public TensorHandleData {
|
||||
Status Dim(int dim_index, int64* dim) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
||||
// If the contents of the Tensor pointed to by this handle is yet to be
|
||||
// computed by an EagerNode, this function will block till that computation is
|
||||
// done and the handle is ready.
|
||||
Status WaitReady() override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
private:
|
||||
// Id for the EagerNode that will compute the value pointed to by this handle.
|
||||
const uint64 node_id_;
|
||||
EagerContext* const ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
"eager_service_impl.h",
|
||||
],
|
||||
deps = [
|
||||
":remote_tensor_handle",
|
||||
"//tensorflow:grpc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
@ -68,7 +69,6 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_wrapper",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
// EnqueueNode is an implementation of EagerNode which enqueues an operation
|
||||
// via RPC in a remote EagerService.
|
||||
// RemoteExecuteNode is an implementation of EagerNode which enqueues an
|
||||
// operation via RPC in a remote EagerService.
|
||||
class RemoteExecuteNode : public tensorflow::EagerNode {
|
||||
public:
|
||||
RemoteExecuteNode(
|
||||
@ -61,6 +61,7 @@ class RemoteExecuteNode : public tensorflow::EagerNode {
|
||||
|
||||
return status;
|
||||
}
|
||||
void Abort(Status status) override {}
|
||||
|
||||
private:
|
||||
std::unique_ptr<EnqueueRequest> request_;
|
||||
|
@ -126,15 +126,14 @@ string RemoteTensorHandleData::DebugString() const {
|
||||
}
|
||||
|
||||
UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
|
||||
int64 op_id, int32 output_num, uint64 shape_node_id,
|
||||
eager::EagerClient* eager_client, uint64 context_id, EagerContext* ctx)
|
||||
int64 op_id, int32 output_num, eager::EagerClient* eager_client,
|
||||
uint64 context_id, EagerContext* ctx)
|
||||
: op_id_(op_id),
|
||||
output_num_(output_num),
|
||||
shape_node_id_(shape_node_id),
|
||||
delete_remote_tensor_(true),
|
||||
eager_client_(eager_client),
|
||||
context_id_(context_id),
|
||||
ctx_(ctx) {
|
||||
DCHECK(shape_node_id > 0) << "Must provide a valid shape_node_id";
|
||||
DCHECK(op_id_ >= 0 && output_num_ >= 0)
|
||||
<< "Op ID and output num should be >= 0. Op ID: " << op_id
|
||||
<< ", Output num: " << output_num;
|
||||
@ -142,11 +141,7 @@ UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
|
||||
}
|
||||
|
||||
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
|
||||
// Only if the ExecuteNode is still pending should we destroy the remote
|
||||
// tensor handle. Otherwise, we expect SetRemoteShape to have caused the
|
||||
// TensorHandle to point to a RemoteTensorHandleData.
|
||||
EagerExecutor* executor = ctx_->Executor();
|
||||
if (executor->IsQueued(shape_node_id_)) {
|
||||
if (delete_remote_tensor_) {
|
||||
DestoryRemoteTensorHandle(ctx_, eager_client_, context_id_, op_id_,
|
||||
output_num_);
|
||||
}
|
||||
@ -189,14 +184,9 @@ Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const {
|
||||
"until it is ready");
|
||||
}
|
||||
|
||||
Status UnshapedRemoteTensorHandleData::WaitReady() {
|
||||
return ctx_->Executor()->WaitFor(shape_node_id_);
|
||||
}
|
||||
|
||||
string UnshapedRemoteTensorHandleData::DebugString() const {
|
||||
return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_,
|
||||
" output_num: ", output_num_,
|
||||
" shape node_id: ", shape_node_id_);
|
||||
" output_num: ", output_num_);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -38,8 +38,6 @@ class RemoteTensorHandleData : public TensorHandleData {
|
||||
Status Dim(int dim_index, int64* dim) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
||||
Status WaitReady() override { return Status::OK(); }
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
int64 op_id() const { return op_id_; }
|
||||
@ -60,7 +58,6 @@ class RemoteTensorHandleData : public TensorHandleData {
|
||||
class UnshapedRemoteTensorHandleData : public TensorHandleData {
|
||||
public:
|
||||
UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num,
|
||||
uint64 shape_node_id,
|
||||
eager::EagerClient* eager_client,
|
||||
uint64 context_id, EagerContext* ctx);
|
||||
~UnshapedRemoteTensorHandleData() override;
|
||||
@ -74,11 +71,6 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData {
|
||||
Status Dim(int dim_index, int64* dim) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
||||
// If the remote TensorShape for this handle is yet to be computed by an
|
||||
// EagerNode, this function will block till that computation is done and the
|
||||
// handle is ready.
|
||||
Status WaitReady() override;
|
||||
|
||||
string DebugString() const override;
|
||||
|
||||
int64 op_id() const { return op_id_; }
|
||||
@ -86,11 +78,20 @@ class UnshapedRemoteTensorHandleData : public TensorHandleData {
|
||||
eager::EagerClient* eager_client() const { return eager_client_; }
|
||||
uint64 context_id() const { return context_id_; }
|
||||
|
||||
// When constructed, UnshapedRemoteTensorHandleData owns the remote
|
||||
// TensorHandle and should delete it by issuing an RPC. Once the remote
|
||||
// shape has been learned, the ownership is transferred to
|
||||
// RemoteTensorHandleData. This method must be called to let `this` know
|
||||
// that it no longer owns the remote handle.
|
||||
// TODO(iga): Add a factory method here that will create a new
|
||||
// RemoteTensorHandleData from this and transfer ownership in the process.
|
||||
void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
|
||||
|
||||
private:
|
||||
// IDs required when this class is representing a remote tensor handle.
|
||||
const int64 op_id_;
|
||||
const int32 output_num_;
|
||||
const uint64 shape_node_id_;
|
||||
bool delete_remote_tensor_;
|
||||
eager::EagerClient* eager_client_;
|
||||
uint64 context_id_;
|
||||
EagerContext* const ctx_;
|
||||
|
Loading…
Reference in New Issue
Block a user