Replace Notification with simple mutex
This avoids the overheads of the Notification object such as mutex acquisition during destruction. PiperOrigin-RevId: 283445581 Change-Id: Ic30ea13186096c23ec775eac13412c6ffe6c9a0a
This commit is contained in:
parent
513f16d55d
commit
250d9bc96b
@ -132,10 +132,9 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(false),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
@ -152,11 +151,10 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(false),
|
||||
is_ready_(true),
|
||||
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
|
||||
@ -185,12 +183,10 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(false),
|
||||
is_async_(async),
|
||||
is_ready_(!async),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Async Local TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
if (!async) {
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -227,11 +223,10 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
is_async_(false),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Remote TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
// Notify immediately since this handle is already ready.
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateUnshapedRemoteHandle(
|
||||
@ -264,21 +259,29 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||
ctx_(ctx),
|
||||
is_remote_(true),
|
||||
is_async_(true),
|
||||
is_ready_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool TensorHandle::IsReady() {
|
||||
return !is_async_ || is_ready_notification_.HasBeenNotified();
|
||||
bool TensorHandle::IsReady() const {
|
||||
// Avoid mutex acquisition for local sync handles
|
||||
if (!is_async_ && !is_remote_) {
|
||||
return true;
|
||||
}
|
||||
|
||||
tf_shared_lock l(mu_);
|
||||
return is_ready_;
|
||||
}
|
||||
|
||||
Status TensorHandle::WaitReady(const char* caller) {
|
||||
if (!IsReady()) {
|
||||
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
is_ready_notification_.WaitForNotification();
|
||||
tf_shared_lock l(mu_);
|
||||
mu_.Await(Condition(&is_ready_));
|
||||
}
|
||||
return is_poisoned_;
|
||||
}
|
||||
@ -537,8 +540,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
|
||||
}
|
||||
|
||||
DCHECK(is_remote_) << "SeRemoteShape is only called on remote handles.";
|
||||
DCHECK(!is_ready_notification_.HasBeenNotified())
|
||||
<< "SetRemoteShape is only called on non-ready handles.";
|
||||
DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
|
||||
|
||||
UnshapedRemoteTensorHandleData* p =
|
||||
reinterpret_cast<UnshapedRemoteTensorHandleData*>(
|
||||
@ -548,7 +550,8 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
|
||||
remote_op_id_, remote_output_num_, shape, remote_task_,
|
||||
remote_context_id_, ctx_);
|
||||
is_poisoned_ = Status::OK();
|
||||
is_ready_notification_.Notify();
|
||||
mutex_lock l(mu_);
|
||||
is_ready_ = true;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -556,7 +559,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
|
||||
|
||||
Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
|
||||
DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
|
||||
DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
|
||||
DCHECK(!is_async_ || !IsReady())
|
||||
<< "SetTensor is only called on non-ready handles.";
|
||||
|
||||
DVLOG(3) << "SetTensor on TensorHandle: " << this;
|
||||
@ -568,21 +571,22 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor) {
|
||||
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
|
||||
if (is_async_) {
|
||||
is_poisoned_ = Status::OK();
|
||||
is_ready_notification_.Notify();
|
||||
mutex_lock l(mu_);
|
||||
is_ready_ = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TensorHandle::Poison(Status status) {
|
||||
DCHECK(!is_async_ || !is_ready_notification_.HasBeenNotified())
|
||||
DCHECK(!is_async_ || !IsReady())
|
||||
<< "Poison(status) can only be called on non-ready handle: " << this;
|
||||
|
||||
DVLOG(3) << "Poison on TensorHandle: " << this;
|
||||
|
||||
is_poisoned_ = status;
|
||||
if (is_async_ || is_remote_) {
|
||||
is_ready_notification_.Notify();
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
is_ready_ = true;
|
||||
}
|
||||
|
||||
Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
|
@ -167,8 +167,6 @@ class TensorHandle : public core::RefCounted {
|
||||
// on a non-ready tensor.
|
||||
void Poison(Status status);
|
||||
|
||||
bool IsReady();
|
||||
|
||||
Status CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
|
||||
tensorflow::Tensor* output);
|
||||
|
||||
@ -207,6 +205,12 @@ class TensorHandle : public core::RefCounted {
|
||||
std::vector<DtypeAndPartialTensorShape>* result);
|
||||
|
||||
private:
|
||||
// The TensorHandleData can either represent a local or remote tensor handle.
|
||||
// Further, it can be in a non-ready state. It would become ready with a call
|
||||
// to either SetTensor or SetRemoteShape which replaces the underlying data
|
||||
// with a ready version of the tensor handle data.
|
||||
bool IsReady() const;
|
||||
|
||||
// If the contents of the Tensor pointed to by this handle is yet to be
|
||||
// computed by a EagerNode, this function will block till that computation is
|
||||
// done and the handle is "ready".
|
||||
@ -232,9 +236,9 @@ class TensorHandle : public core::RefCounted {
|
||||
// backing the resource. Else resource_device_ is nullptr.
|
||||
tensorflow::Device* const resource_device_;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
mutable mutex mu_;
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
|
||||
// variable is ready, since we could get the shape locally without remote copy
|
||||
// then.
|
||||
@ -263,25 +267,18 @@ class TensorHandle : public core::RefCounted {
|
||||
// `ctx` object is not owned and should outlive this handle.
|
||||
EagerContext* const ctx_;
|
||||
|
||||
// Explanation for NOLINT below: absl has clang-tidy macro to rename
|
||||
// 'tensorflow::Notification' to 'absl::Notification'. TF does not use
|
||||
// absl::Notification in open source now, so we can't follow clang-tidy
|
||||
tensorflow::Notification is_ready_notification_; // NOLINT
|
||||
// Does not need synchronization because it can be accessed only after
|
||||
// WaitReady() has returned. At that point, is_poisoned_ is immutable.
|
||||
Status is_poisoned_;
|
||||
const bool is_remote_;
|
||||
const bool is_async_;
|
||||
bool is_ready_ GUARDED_BY(mu_);
|
||||
|
||||
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
|
||||
// refers to a remote resource handle, we store data types and shapes for
|
||||
// the underlying resource.
|
||||
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
|
||||
|
||||
// The TensorHandleData can either represent a local or remote tensor handle.
|
||||
// Further, it can be in a non-ready state. It would become ready with a call
|
||||
// to either SetTensor or SetRemoteShape which replaces the underlying data
|
||||
// with a ready version of the tensor handle data.
|
||||
// Does not need synchronization because it can be accessed only after
|
||||
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
|
||||
std::unique_ptr<TensorHandleData> tensor_handle_data_;
|
||||
|
@ -39,7 +39,6 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
||||
.ok());
|
||||
|
||||
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());
|
||||
EXPECT_FALSE(async_th->IsReady());
|
||||
|
||||
TensorShape sync_shape;
|
||||
TensorShape async_shape;
|
||||
|
Loading…
x
Reference in New Issue
Block a user