diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index b045ed5b701..65f37f3021f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::TensorHandle* ret_handle; if (custom_device == nullptr) { status->status = tensorflow::TensorHandle::CreateLocalHandle( - t, device, context, &ret_handle); + std::move(t), device, device, context, &ret_handle); } else { status->status = tensorflow::TensorHandle::CreateLocalHandle( - t, custom_device, context, &ret_handle); + std::move(t), custom_device, context, &ret_handle); } if (!status->status.ok()) { return nullptr; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 4ec077e556e..6d7b00fa64e 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -140,6 +140,7 @@ tf_cuda_library( "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ + "@com_google_absl//absl/types:variant", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:traceme", diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index eaf88e4797f..01815b4dee2 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -564,8 +564,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, } } } - const DataTypeVector& output_dtypes = kernel->output_dtypes(); - const size_t num_outputs = static_cast<int>(output_dtypes.size()); + int num_outputs = kernel->num_outputs(); if (num_outputs > *num_retvals) { return errors::InvalidArgument("Expecting ", num_outputs, " outputs, but *num_retvals is ", @@ -579,21 +578,19 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, graph_collector = ctx.GetGraphCollector(); } - const bool async = executor.Async(); - for (int i = 0; i < num_outputs; ++i) { - TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( - async, - /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)), - /* op_device= */ kernel->device(), - /* resource_device= */ kernel->OutputResourceDevice(i), - output_dtypes[i], &ctx, &retvals[i])); - } - Status s; - if (async) { + if (executor.Async()) { + const DataTypeVector& output_dtypes = kernel->output_dtypes(); + for (int i = 0; i < num_outputs; ++i) { + TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( + /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)), + /* op_device= */ kernel->device(), + /* resource_device= */ kernel->OutputResourceDevice(i), + output_dtypes[i], &ctx, &retvals[i])); + } auto node = absl::make_unique<AsyncExecuteNode>( &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), - graph_collector, output_dtypes, op->GetCancellationManager(), + graph_collector, op->GetCancellationManager(), absl::Span<TensorHandle*>(retvals, num_outputs)); // For async mode, execution order will make sure that all // input handles are ready before executing them. @@ -601,16 +598,21 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, // performance. s = executor.AddOrExecute(std::move(node)); } else { + for (int i = 0; i < num_outputs; ++i) { + retvals[i] = nullptr; + } ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel, - graph_collector, output_dtypes, - op->GetCancellationManager(), {retvals, num_outputs}); + graph_collector, op->GetCancellationManager(), + {retvals, static_cast<size_t>(num_outputs)}); s = executor.SyncExecute(&node); } - // Since the operation failed, we need to Unref any outputs that were + // Since the operation failed, we need to Unref any outputs if they were // allocated. if (!s.ok()) { for (int i = 0; i < num_outputs; ++i) { - retvals[i]->Unref(); + if (retvals[i] != nullptr) { + retvals[i]->Unref(); + } } } @@ -733,12 +735,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, input, input_handle, input_device, *input_device_name, serialize_resource_dtype_and_shape)); if (!input_handle->resource_dtypes_and_shapes().empty()) { - auto tensor_handle_data = - absl::make_unique<UnshapedRemoteTensorHandleData>( - input_handle->op_id(), input_handle->output_num(), remote_task, - &ctx); - TF_RETURN_IF_ERROR(input->AddResourceShapeMirror( - std::move(tensor_handle_data), op_device)); + TF_RETURN_IF_ERROR( + input->AddResourceShapeMirror(op_device, input_handle->op_id(), + input_handle->output_num(), &ctx)); } } } @@ -1032,13 +1031,24 @@ Status EagerKernelExecute( } } DCHECK_EQ(retvals.size(), outputs.size()); - for (int i = 0; i < retvals.size(); ++i) { - DCHECK_EQ(kernel->device(), retvals[i]->op_device()); - DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), - absl::get<Device*>(retvals[i]->device())); - TF_RETURN_IF_ERROR(retvals[i]->SetTensor( - std::move(outputs[i]), ctx->CanonicalDevice(kernel->OutputDevice(i)))); + for (int i = 0; i < retvals.size(); ++i) { + if (retvals[i] == nullptr) { + TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( + std::move(outputs[i]), + /* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)), + /* op_device= */ kernel->device(), + /* resource_device= */ kernel->OutputResourceDevice(i), ctx, + &retvals[i])); + } else { + DCHECK_EQ(kernel->device(), retvals[i]->op_device()); + DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), + absl::get<Device*>(retvals[i]->device())); + + TF_RETURN_IF_ERROR( + retvals[i]->SetTensor(std::move(outputs[i]), + ctx->CanonicalDevice(kernel->OutputDevice(i)))); + } } return Status::OK(); } @@ -1069,7 +1079,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, *result = h; } else { TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( - true, d, dstd, h->resource_device(), h->dtype, ctx, result)); + d, dstd, h->resource_device(), h->dtype, ctx, result)); } Status s; @@ -1138,7 +1148,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, *result = h; } else { TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( - true, /* d= */ d, /* op_device= */ device, + /* d= */ d, /* op_device= */ device, /*resource_device=*/nullptr, h->dtype, ctx, result)); } } else { @@ -1156,17 +1166,14 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, device->name()); } recv_op_id = ctx->RemoteMgr()->NextOpId(); - auto tensor_handle_data = - absl::make_unique<UnshapedRemoteTensorHandleData>(recv_op_id, 0, - remote_task, ctx); if (mirror) { - TF_RETURN_IF_ERROR( - h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device)); + TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0, + remote_task, ctx)); h->Ref(); *result = h; } else { TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( - std::move(tensor_handle_data), h->dtype, device, ctx, result)); + recv_op_id, 0, remote_task, h->dtype, device, ctx, result)); } } diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index d523dc20084..5ced006fb9e 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -32,7 +32,7 @@ Status ExecuteNodeArgs::Init( for (int i = 0; i < n_inputs; ++i) { TensorHandle* in = op_inputs_flat[i]; Device* d = kernel->InputDevice(i); - Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d)); + Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]); if (!s.ok()) { #if !defined(IS_MOBILE_PLATFORM) uint64 context_view_id = ctx->GetContextViewId(); diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index ed1bd956179..be6e4009896 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -77,7 +77,7 @@ class ExecuteNode : public EagerNode { EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, const absl::optional<EagerRemoteFunctionParams>& remote_func_params, const core::RefCountPtr<KernelAndDevice>& kernel, - GraphCollector* graph_collector, const DataTypeVector& output_dtypes, + GraphCollector* graph_collector, CancellationManager* cancellation_manager, absl::Span<TensorHandle*> retvals) : EagerNode(), @@ -130,7 +130,7 @@ class AsyncExecuteNode : public EagerNode { EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, const absl::optional<EagerRemoteFunctionParams>& remote_func_params, core::RefCountPtr<KernelAndDevice> kernel, - GraphCollector* graph_collector, const DataTypeVector& output_dtypes, + GraphCollector* graph_collector, CancellationManager* cancellation_manager, absl::Span<TensorHandle*> retvals) : EagerNode(), diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index b8e688057f4..abbf840784b 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -67,6 +67,7 @@ class EagerKernelArgs : public FunctionArgsInterface { ~EagerKernelArgs() override{}; bool HasRemoteInputs() const override { return false; }; + TensorValue* MutableInput(int i) { return &tensor_args_[i]; } Status GetLocalArg(const int index, Tensor* val) const override; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 47a7125ced8..e7e2fb7b197 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -20,39 +20,31 @@ limitations under the License. #include <memory> #include <queue> #include <string> +#include <utility> #include <vector> -#include "absl/strings/substitute.h" #include "absl/types/variant.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" -#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/errors.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #endif // IS_MOBILE_PLATFORM -#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -64,7 +56,7 @@ const int32 kInvalidOutputNum = -1; } // namespace void TensorHandle::SetResourceHandleDtypeAndShape( - std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) { + std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) { handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); } @@ -86,250 +78,191 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes( profiler::TraceMe activity( "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady", profiler::TraceMeLevel::kInfo); + auto& data = absl::get<LocalTensorHandleData>(data_); TF_RETURN_IF_ERROR( - WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); + data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); *result = handle_dtypes_and_shapes_; return Status::OK(); } -Status TensorHandle::CreateLocalHandle(const class Tensor& t, +Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t, TensorHandle** h) { // TODO(b/136608821): Move away from nullptr - return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr), + tensorflow::Tensor tensor = t; + return CreateLocalHandle(std::move(tensor), + /*d=*/nullptr, /*op_device=*/nullptr, /*ctx=*/nullptr, h); } -Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d, - EagerContext* ctx, TensorHandle** h) { - return CreateLocalHandle(t, d, d, ctx, h); -} - -Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d, +Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, EagerContext* ctx, TensorHandle** h) { - if (t.dtype() != DT_RESOURCE) { - *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), - t.dtype(), d, op_device, ctx); + return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h); +} + +Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, + Device* resource_device, + EagerContext* ctx, TensorHandle** h) { + if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { + *h = new TensorHandle(std::move(t), d, op_device, ctx); } else { - const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0); - *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), - resource_handle, d, op_device, ctx); + *h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx); } return Status::OK(); } -Status TensorHandle::CreateLocalHandle(const class Tensor& t, CustomDevice* d, +Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), t.dtype(), - d, ctx); + *h = new TensorHandle(std::move(t), d, ctx); return Status::OK(); } -TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, - DataType dtype, Device* d, Device* op_device, - EagerContext* ctx) - : dtype(dtype), +TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + Device* resource_device, EagerContext* ctx) + : dtype(t.dtype()), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), - resource_device_(nullptr), -#if !defined(IS_MOBILE_PLATFORM) - remote_op_id_(kInvalidOpId), - remote_output_num_(kInvalidOutputNum), -#endif + resource_device_(resource_device), ctx_(ctx), - is_remote_(false), - is_async_(false), implicit_mirroring_(true), - is_ready_(true), - tensor_handle_data_(std::move(t)) { + data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << VariantDeviceDebugString(device_) + << " tensor: " << t.DeviceSafeDebugString(); } -TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, - const ResourceHandle& resource_handle, Device* d, - Device* op_device, EagerContext* ctx) +TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + EagerContext* ctx) : dtype(DT_RESOURCE), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), - resource_device_(GetResourceDevice(resource_handle, ctx)), -#if !defined(IS_MOBILE_PLATFORM) - remote_op_id_(kInvalidOpId), - remote_output_num_(kInvalidOutputNum), -#endif + resource_device_( + GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)), ctx_(ctx), - is_remote_(false), - is_async_(false), implicit_mirroring_(true), - is_ready_(true), - handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()), - tensor_handle_data_(std::move(t)) { + handle_dtypes_and_shapes_( + t.flat<class ResourceHandle>()(0).dtypes_and_shapes()), + data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << VariantDeviceDebugString(device_) + << " tensor: " << t.DeviceSafeDebugString(); } -TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, - DataType dtype, CustomDevice* d, EagerContext* ctx) - : dtype(dtype), +TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, + EagerContext* ctx) + : dtype(t.dtype()), device_(d), op_device_(nullptr), resource_device_(nullptr), -#if !defined(IS_MOBILE_PLATFORM) - remote_op_id_(kInvalidOpId), - remote_output_num_(kInvalidOutputNum), -#endif ctx_(ctx), - is_remote_(false), - is_async_(false), implicit_mirroring_(true), - is_ready_(true), - tensor_handle_data_(std::move(t)) { + data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) { // TODO(allenl): Figure out a better op_device story for custom devices, // since always setting it to CPU=nullptr doesn't make much sense. DVLOG(3) << "Creating Local TensorHandle: " << this - << " custom device: " << VariantDeviceDebugString(device_); + << " custom device: " << VariantDeviceDebugString(device_) + << " tensor: " << t.DeviceSafeDebugString(); } -Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d, - Device* op_device, +Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async, - d, op_device, resource_device, dtype, ctx); + *h = new TensorHandle(d, op_device, resource_device, dtype, ctx); return Status::OK(); } -TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, - bool async, Device* d, Device* op_device, +TensorHandle::TensorHandle(Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx) : dtype(dtype), device_((d == ctx->HostCPU()) ? nullptr : d), op_device_(op_device), resource_device_(resource_device), -#if !defined(IS_MOBILE_PLATFORM) - remote_op_id_(kInvalidOpId), - remote_output_num_(kInvalidOutputNum), -#endif ctx_(ctx), - is_remote_(false), - is_async_(async), implicit_mirroring_(true), - is_ready_(!async), - tensor_handle_data_(std::move(t)) { + data_(absl::in_place_type<LocalTensorHandleData>) { DVLOG(3) << "Creating empty Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_); } #if !defined(IS_MOBILE_PLATFORM) -Status TensorHandle::CreateRemoteHandle( - std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d, - Device* resource_device, EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx); +Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, + const string& remote_task, + DataType dtype, Device* d, + EagerContext* ctx, + TensorHandle** h) { + *h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx); return Status::OK(); } -Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num, - const TensorShape& shape, - const string& remote_task, - DataType dtype, Device* d, - Device* resource_device, - EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(absl::make_unique<RemoteTensorHandleData>( - op_id, output_num, shape, remote_task, ctx), - dtype, d, resource_device, ctx); - return Status::OK(); -} - -TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, - DataType dtype, Device* d, Device* resource_device, +TensorHandle::TensorHandle(int64 op_id, int32 output_num, + const string& remote_task, DataType dtype, Device* d, EagerContext* ctx) : dtype(dtype), device_(d), op_device_(d), - resource_device_(resource_device), - remote_op_id_(t->op_id()), - remote_output_num_(t->output_num()), + resource_device_(dtype == DT_RESOURCE ? d : nullptr), ctx_(ctx), - is_remote_(true), - is_async_(false), implicit_mirroring_(true), - is_ready_(true), - tensor_handle_data_(std::move(t)) { - DVLOG(3) << "Creating Remote TensorHandle: " << this + data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num, + remote_task, ctx) { + DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_); } -Status TensorHandle::CreateUnshapedRemoteHandle( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, - Device* d, EagerContext* ctx, TensorHandle** h) { - *h = new TensorHandle(std::move(t), dtype, d, ctx); +Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num, + DataType dtype, Device* d, + EagerContext* ctx, + TensorHandle** h) { + *h = new TensorHandle(op_id, output_num, dtype, d, ctx); return Status::OK(); } -Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, - const string& remote_task, - DataType dtype, Device* device, - EagerContext* ctx, - TensorHandle** h) { - *h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>( - op_id, output_num, remote_task, ctx), - dtype, device, ctx); - return Status::OK(); -} - -TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t, - DataType dtype, Device* device, EagerContext* ctx) +TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype, + Device* d, EagerContext* ctx) : dtype(dtype), - device_(device), - op_device_(device), - resource_device_(dtype == DT_RESOURCE ? device : nullptr), - remote_op_id_(t->op_id()), - remote_output_num_(t->output_num()), + device_(d), + op_device_(d), + resource_device_(dtype == DT_RESOURCE ? d : nullptr), ctx_(ctx), - is_remote_(true), - is_async_(true), implicit_mirroring_(true), - is_ready_(false), - tensor_handle_data_(std::move(t)) { - DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this + data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num, + ctx->GetContextViewId()) { + DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_); } #endif bool TensorHandle::IsReady() const { - // Avoid mutex acquisition for local sync handles - if (!is_async_ && !is_remote_) { - return true; - } - - tf_shared_lock l(mu_); - return is_ready_; + return absl::visit([](auto& data) { return data.IsReady(); }, data_); } -Status TensorHandle::WaitReady(const char* caller) const { - if (!IsReady()) { - profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), - profiler::TraceMeLevel::kInfo); - tf_shared_lock l(mu_); - mu_.Await(Condition(&is_ready_)); - } - return is_poisoned_; +bool TensorHandle::IsRemote() const { +#if !defined(IS_MOBILE_PLATFORM) + return data_.index() == 1; +#else + return false; +#endif } Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { DVLOG(3) << "Tensor on TensorHandle: " << this; - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor")); - return tensor_handle_data_->Tensor(t); + if (IsRemote()) { + return errors::Internal("Invalid Tensor call on remote handle: ", this); + } + + auto& data = absl::get<LocalTensorHandleData>(data_); + return data.Tensor(t); } Status TensorHandle::TensorFromDevice(const Device* d, @@ -337,12 +270,12 @@ Status TensorHandle::TensorFromDevice(const Device* d, DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; if (d == absl::get<Device*>(device_)) { - if (is_remote_) { + if (IsRemote()) { return errors::Internal("Invalid Tensor call on remote handle: ", this); } - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice")); - return tensor_handle_data_->Tensor(t); + auto& data = absl::get<LocalTensorHandleData>(data_); + return data.Tensor(t); } tf_shared_lock l(mu_); @@ -352,25 +285,21 @@ Status TensorHandle::TensorFromDevice(const Device* d, " in Tensor call to handle: ", this); } - // Check if the handle is non-empty, else wait. auto& mirror = elem->second; - if (mirror.second == nullptr) { - TF_RETURN_IF_ERROR( - mirror.first->WaitReady("TensorHandle::TensorFromDevice")); - } - - return mirror.second->Tensor(t); + return mirror.Tensor(t); } -Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { +Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { + DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d; + if (d == absl::get<Device*>(device_)) { - if (is_remote_) { + if (IsRemote()) { return errors::Internal("Invalid TensorValue call on remote handle: ", this); } - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue")); - return tensor_handle_data_->TensorValue(t); + auto& data = absl::get<LocalTensorHandleData>(data_); + return data.TensorValue(t); } tf_shared_lock l(mu_); @@ -380,13 +309,8 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { " in TensorValue call to handle: ", this); } - // Check if the handle is non-empty, else wait. auto& mirror = elem->second; - if (mirror.second == nullptr) { - TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue")); - } - - return mirror.second->TensorValue(t); + return mirror.TensorValue(t); } TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( @@ -405,8 +329,8 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { DCHECK(fill); return Status::OK(); } else { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape")); - return tensor_handle_data_->Shape(shape); + return absl::visit([shape](auto& data) { return data.Shape(shape); }, + data_); } } @@ -480,8 +404,8 @@ Status TensorHandle::NumDims(int* num_dims) const { *num_dims = inference_shape_.dims(); return Status::OK(); } else { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims")); - return tensor_handle_data_->NumDims(num_dims); + return absl::visit( + [num_dims](auto& data) { return data.NumDims(num_dims); }, data_); } } @@ -492,8 +416,9 @@ Status TensorHandle::Dim(int dim_index, int64* dim) const { *dim = inference_shape_.dim_size(dim_index); return Status::OK(); } else { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim")); - return tensor_handle_data_->Dim(dim_index, dim); + return absl::visit( + [dim_index, dim](auto& data) { return data.Dim(dim_index, dim); }, + data_); } } @@ -503,8 +428,9 @@ Status TensorHandle::NumElements(int64* num_elements) const { *num_elements = inference_shape_.num_elements(); return Status::OK(); } else { - TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements")); - return tensor_handle_data_->NumElements(num_elements); + return absl::visit( + [num_elements](auto& data) { return data.NumElements(num_elements); }, + data_); } } @@ -512,7 +438,8 @@ Status TensorHandle::Unprotect(const Device* d) { DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; if (d == absl::get<Device*>(device_)) { - return tensor_handle_data_->Unprotect(); + auto& data = absl::get<LocalTensorHandleData>(data_); + return data.Unprotect(); } tf_shared_lock l(mu_); @@ -524,11 +451,7 @@ Status TensorHandle::Unprotect(const Device* d) { // Check if the handle is non-empty auto& mirror = elem->second; - if (mirror.second == nullptr) { - return errors::Internal("Attempted to unprotect an empty mirror"); - } - - return mirror.second->Unprotect(); + return mirror.Unprotect(); } bool TensorHandle::HasLocalMirror(const Device* d) const { @@ -551,8 +474,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) { return errors::Internal("Attempted to duplicate a local mirror."); } - local_mirrors_[d] = - std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr); + local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d), + std::forward_as_tuple()); return Status::OK(); } @@ -567,15 +490,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id, tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d->name()); if (mirror != remote_mirrors_.end()) { - *op_id = mirror->second->op_id(); - *output_num = mirror->second->output_num(); - return Status::OK(); - } - - auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name()); - if (unshaped_mirror != unshaped_remote_mirrors_.end()) { - *op_id = unshaped_mirror->second->op_id(); - *output_num = unshaped_mirror->second->output_num(); + *op_id = mirror->second.op_id(); + *output_num = mirror->second.output_num(); return Status::OK(); } @@ -583,14 +499,14 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id, "Could not find remote mirror for specified device"); } - if (remote_op_id_ == kInvalidOpId || - remote_output_num_ == kInvalidOutputNum) { - return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_, - ", output_num:", remote_output_num_, - ") is not set."); + if (!IsRemote()) { + return errors::InvalidArgument("Primary device is not remote"); } - *op_id = remote_op_id_; - *output_num = remote_output_num_; + + auto& data = absl::get<RemoteTensorHandleData>(data_); + *op_id = data.op_id(); + *output_num = data.output_num(); + return Status::OK(); } @@ -603,16 +519,7 @@ bool TensorHandle::HasRemoteMirror(const Device* d, auto mirror = remote_mirrors_.find(d->name()); if (mirror != remote_mirrors_.end()) { // Check if mirror is stale - if (mirror->second->context_view_id() != context_view_id) { - return false; - } - return true; - } - - auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name()); - if (unshaped_mirror != unshaped_remote_mirrors_.end()) { - // Check if mirror is stale - if (unshaped_mirror->second->context_view_id() != context_view_id) { + if (mirror->second.context_view_id() != context_view_id) { return false; } return true; @@ -630,7 +537,7 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d, auto mirror = resource_shape_mirrors_.find(d->name()); if (mirror != resource_shape_mirrors_.end()) { // Check if mirror is stale - if (mirror->second->context_view_id() != context_view_id) { + if (mirror->second.context_view_id() != context_view_id) { return false; } return true; @@ -638,45 +545,39 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d, return false; } -Status TensorHandle::AddUnshapedRemoteMirror( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { +Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id, + int output_num, + const string& remote_task, + EagerContext* ctx) { DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this - << " device: " << d << " " << d->name(); + << " device: " << d << " " << d->name() << " op_id: " << op_id + << " output_num: " << output_num; mutex_lock l(mu_); auto remote_mirror = remote_mirrors_.find(d->name()); if (remote_mirror != remote_mirrors_.end()) { - if (remote_mirror->second->context_view_id() == t->context_view_id()) { + if (remote_mirror->second.context_view_id() == ctx->GetContextId()) { return errors::Internal("Attempted to duplicate a remote mirror."); } // Remove stale mirror remote_mirrors_.erase(remote_mirror); } - auto unshaped_remote_mirror = unshaped_remote_mirrors_.find(d->name()); - if (unshaped_remote_mirror != unshaped_remote_mirrors_.end()) { - if (unshaped_remote_mirror->second->context_view_id() == - t->context_view_id()) { - return errors::Internal( - "Attempted to duplicate an unshaped remote mirror."); - } - // Remove stale mirror - unshaped_remote_mirrors_.erase(unshaped_remote_mirror); - } - - unshaped_remote_mirrors_[d->name()] = std::move(t); + remote_mirrors_.emplace( + std::piecewise_construct, std::forward_as_tuple(d->name()), + std::forward_as_tuple(op_id, output_num, remote_task, ctx)); return Status::OK(); } -Status TensorHandle::AddResourceShapeMirror( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { +Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id, + int output_num, EagerContext* ctx) { DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this; mutex_lock l(mu_); auto mirror = resource_shape_mirrors_.find(d->name()); if (mirror != resource_shape_mirrors_.end()) { - if (mirror->second->context_view_id() == t->context_view_id()) { + if (mirror->second.context_view_id() == ctx->GetContextViewId()) { return errors::Internal( "Attempted to duplicate a resource shape mirror."); } @@ -684,26 +585,9 @@ Status TensorHandle::AddResourceShapeMirror( resource_shape_mirrors_.erase(mirror); } - resource_shape_mirrors_[d->name()] = std::move(t); - - return Status::OK(); -} - -Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, - const Device* d) { - DVLOG(3) << "AddRemoteMirror on TensorHandle: " << this << " device: " << d; - - mutex_lock l(mu_); - auto mirror = remote_mirrors_.find(d->name()); - if (mirror != remote_mirrors_.end()) { - if (mirror->second->context_view_id() == t->context_view_id()) { - return errors::Internal("Attempted to duplicate a remote mirror."); - } - // Remove stale mirror - remote_mirrors_.erase(mirror); - } - - remote_mirrors_[d->name()] = std::move(t); + resource_shape_mirrors_.emplace( + std::piecewise_construct, std::forward_as_tuple(d->name()), + std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId())); return Status::OK(); } @@ -717,53 +601,24 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d, mutex_lock l(mu_); auto remote_mirror = remote_mirrors_.find(d->name()); if (remote_mirror != remote_mirrors_.end()) { - if (remote_mirror->second->context_view_id() == context_view_id) { - return errors::Internal( - "Attempted to set remote shape for existing mirror."); + auto& mirror = remote_mirror->second; + if (mirror.context_view_id() == context_view_id) { + return mirror.SetShape(shape); } remote_mirrors_.erase(remote_mirror); } - auto elem = unshaped_remote_mirrors_.find(d->name()); - if (elem == unshaped_remote_mirrors_.end()) { - return errors::Internal( - "Attempted to set remote shape for non-waiting mirror."); - } - - if (elem->second->context_view_id() != context_view_id) { - unshaped_remote_mirrors_.erase(elem); - return errors::Internal( - "Attempted to set remote shape for a stale mirror."); - } - - auto& data = elem->second; - data->ReleaseRemoteTensorHandle(); - remote_mirrors_[d->name()] = absl::make_unique<RemoteTensorHandleData>( - data->op_id(), data->output_num(), shape, data->remote_task(), - &data->ctx()); - unshaped_remote_mirrors_.erase(elem); - return Status::OK(); } - DCHECK(is_remote_) << "SetRemoteShape is only called on remote handles."; - DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles."; + DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles."; - UnshapedRemoteTensorHandleData* p = - reinterpret_cast<UnshapedRemoteTensorHandleData*>( - tensor_handle_data_.get()); - if (p->context_view_id() != context_view_id) { + auto& data = absl::get<RemoteTensorHandleData>(data_); + if (data.context_view_id() != context_view_id) { return errors::Internal("Attempted to set remote shape for an old handle."); } - p->ReleaseRemoteTensorHandle(); - tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>( - remote_op_id_, remote_output_num_, shape, p->remote_task(), ctx_); - is_poisoned_ = Status::OK(); - mutex_lock l(mu_); - is_ready_ = true; - - return Status::OK(); + return data.SetShape(shape); } void TensorHandle::PoisonRemote(Status status, const Device* d, @@ -772,18 +627,16 @@ void TensorHandle::PoisonRemote(Status status, const Device* d, << " " << d->name(); if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { - DCHECK(!is_async_ || !IsReady()) - << "PoisonRemote can only be called on non-ready handle: " << this; + DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this; - is_poisoned_ = status; - mutex_lock l(mu_); - is_ready_ = true; + auto& data = absl::get<RemoteTensorHandleData>(data_); + data.Poison(status); } else { tf_shared_lock l(mu_); - auto mirror = unshaped_remote_mirrors_.find(d->name()); - if (mirror != unshaped_remote_mirrors_.end()) { - if (mirror->second->context_view_id() == context_view_id) { - mirror->second->Poison(status); + auto mirror = remote_mirrors_.find(d->name()); + if (mirror != remote_mirrors_.end()) { + if (mirror->second.context_view_id() == context_view_id) { + mirror->second.Poison(status); } } } @@ -798,9 +651,9 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, } mutex_lock l(mu_); - auto elem = local_mirrors_.insert(std::make_pair( - d, std::make_pair(nullptr, - std::make_unique<LocalTensorHandleData>(tensor)))); + auto elem = + local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d), + std::forward_as_tuple(std::move(tensor))); if (!elem.second) { return errors::Internal("Attempted to set tensor for existing mirror."); } @@ -808,24 +661,18 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, return Status::OK(); } -Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { +Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; if (d == absl::get<Device*>(device_)) { - DCHECK(!is_remote_) << "SetTensor is not called on remote handles."; - DCHECK(!is_async_ || !IsReady()) - << "SetTensor is only called on non-ready handles."; + DCHECK(!IsRemote()) << "SetTensor is not called on remote handles."; - if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) { - auto& resource_handle = tensor.flat<class ResourceHandle>()(0); + if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { + auto& resource_handle = t.flat<class ResourceHandle>()(0); handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); } - tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor); - if (is_async_) { - is_poisoned_ = Status::OK(); - mutex_lock l(mu_); - is_ready_ = true; - } + auto& data = absl::get<LocalTensorHandleData>(data_); + return data.SetTensor(std::move(t)); } else { tf_shared_lock l(mu_); auto elem = local_mirrors_.find(d); @@ -835,12 +682,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { } auto& mirror = elem->second; - if (mirror.second != nullptr) { - return errors::Internal("Attempted to set tensor for existing mirror."); - } - - mirror.second = absl::make_unique<LocalTensorHandleData>(tensor); - mirror.first->SetReady(); + return mirror.SetTensor(std::move(t)); } return Status::OK(); @@ -850,12 +692,10 @@ void TensorHandle::Poison(Status status, const Device* d) { DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { - DCHECK(!is_async_ || !IsReady()) - << "Poison can only be called on non-ready handle: " << this; + DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this; - is_poisoned_ = status; - mutex_lock l(mu_); - is_ready_ = true; + auto& data = absl::get<LocalTensorHandleData>(data_); + data.Poison(status); } else { tf_shared_lock l(mu_); auto elem = local_mirrors_.find(d); @@ -864,9 +704,7 @@ void TensorHandle::Poison(Status status, const Device* d) { << " device: " << d; auto& mirror = elem->second; - DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror."; - - mirror.first->Poison(status); + mirror.Poison(status); } } @@ -977,8 +815,11 @@ string TensorHandle::DebugString() const { !VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull; // Consider supporting non-CPU tensors and CPU tensors with a device_ set to // non-NULL if needed. - strings::StrAppend(&out, ", Tensor: ", - is_cpu ? tensor_handle_data_->DebugString() : "?", "\n"); + strings::StrAppend( + &out, ", Tensor: ", + is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_) + : "?", + "\n"); return out; } diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 1e5b741cbb7..b3bf5ac22db 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -17,10 +17,10 @@ limitations under the License. #include <algorithm> #include <cstddef> -#include <map> #include <memory> #include <queue> #include <string> +#include <unordered_map> #include <vector> // clang-format off @@ -32,28 +32,20 @@ limitations under the License. #include "absl/types/variant.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #endif // IS_MOBILE_PLATFORM -#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { @@ -67,56 +59,45 @@ class TensorHandle : public core::RefCounted { using VariantDevice = absl::variant<Device*, CustomDevice*>; // TensorHandle for dtype != DT_RESOURCE - TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, - Device* d, Device* op_device, EagerContext* ctx); + TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + Device* resource_device, EagerContext* ctx); // TensorHandle for dtype == DT_RESOURCE - TensorHandle(std::unique_ptr<LocalTensorHandleData> t, - const ResourceHandle& resource_handle, Device* d, - Device* op_device, EagerContext* ctx); - TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, - CustomDevice* d, EagerContext* ctx); - TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async, - Device* d, Device* op_device, Device* resource_device, + TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, + EagerContext* ctx); + TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx); + TensorHandle(Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx); #if !defined(IS_MOBILE_PLATFORM) - TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, - Device* d, Device* resource_device, EagerContext* ctx); - TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t, + TensorHandle(int64 op_id, int32 output_num, const string& remote_task, DataType dtype, Device* device, EagerContext* ctx); + TensorHandle(int64 op_id, int32 output_num, DataType dtype, Device* device, + EagerContext* ctx); #endif // IS_MOBILE_PLATFORM public: // TensorHandle with no assigned device - static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h); - // TensorHandle with device == op_device - static Status CreateLocalHandle(const class Tensor& t, Device* d, - EagerContext* ctx, TensorHandle** h); - static Status CreateLocalHandle(const class Tensor& t, Device* d, + static Status CreateLocalHandle(const tensorflow::Tensor& t, + TensorHandle** h); + static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, EagerContext* ctx, TensorHandle** h); - static Status CreateLocalHandle(const class Tensor& t, CustomDevice* d, + static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d, + Device* op_device, Device* resource_device, EagerContext* ctx, TensorHandle** h); - static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device, + static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d, + EagerContext* ctx, TensorHandle** h); + static Status CreateEmptyLocalHandle(Device* d, Device* op_device, Device* resource_device, DataType dtype, EagerContext* ctx, TensorHandle** h); #if !defined(IS_MOBILE_PLATFORM) - static Status CreateRemoteHandle(int64 op_id, int output_num, - const TensorShape& shape, - const string& remote_task, DataType dtype, - Device* d, Device* resource_device, - EagerContext* ctx, TensorHandle** h); - static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t, - DataType dtype, Device* d, - Device* resource_device, EagerContext* ctx, - TensorHandle** h); static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, const string& remote_task, - DataType dtype, Device* device, + DataType dtype, Device* d, EagerContext* ctx, TensorHandle** h); - static Status CreateUnshapedRemoteHandle( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, - Device* device, EagerContext* ctx, TensorHandle** h); + static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num, + DataType dtype, Device* d, + EagerContext* ctx, TensorHandle** h); #endif // IS_MOBILE_PLATFORM ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; } @@ -131,7 +112,7 @@ class TensorHandle : public core::RefCounted { // Return the TensorValue from the specified device which could be either the // default device or a local mirror. The device pointer should be nullptr if // requesting the HostCPU. - Status TensorValue(tensorflow::TensorValue* t, const Device* d); + Status TensorValue(const Device* d, tensorflow::TensorValue* t); VariantDevice device() const { return device_; } Device* op_device() const { return op_device_; } @@ -161,12 +142,10 @@ class TensorHandle : public core::RefCounted { bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; - Status AddUnshapedRemoteMirror( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d); - Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, - const Device* d); - Status AddResourceShapeMirror( - std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d); + Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num, + const string& remote_task, EagerContext* ctx); + Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num, + EagerContext* ctx); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const; @@ -212,14 +191,12 @@ class TensorHandle : public core::RefCounted { Status CopyInferenceShape(TensorHandle* other); // Warning: can return nullptr for CPU tensors. - // TODO(b/136608821): Move away from nullptr EagerContext* Context() { return ctx_; } // dtype for the handle. It must be the same as t.dtype() once the handle is // ready. const DataType dtype; - // TODO(b/136608821): Move away from nullptr bool OnHostCPU() const { return ( device_.index() == 0 && @@ -227,14 +204,14 @@ class TensorHandle : public core::RefCounted { (ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_)))); } - bool IsRemote() const { return is_remote_; } + bool IsRemote() const; void EnableImplicitMirroring() { implicit_mirroring_ = true; } bool ImplicitMirroring() const { return implicit_mirroring_; } string DebugString() const; void SetResourceHandleDtypeAndShape( - std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes); + std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes); // If this TensorHandle is 1) a local tensor, and 2) a resource handle, // return data types and shapes of the underlying resource. @@ -248,19 +225,6 @@ class TensorHandle : public core::RefCounted { // with a ready version of the tensor handle data. bool IsReady() const; - // If the contents of the Tensor pointed to by this handle is yet to be - // computed by a EagerNode, this function will block till that computation is - // done and the handle is "ready". - Status WaitReady(const char* caller) const; - - // TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0 - // This was expedient, but perhaps worth revisiting ('device_' should always - // be a valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? VariantDevice const device_; // Device in which the op producing this tensor was executed. Equals to @@ -275,47 +239,33 @@ class TensorHandle : public core::RefCounted { mutable mutex mu_; - // Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is - // nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage - // waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the - // LocalTensorHandleData should be used. - std::map<const tensorflow::Device*, - std::pair<std::unique_ptr<EmptyLocalTensorHandleData>, - std::unique_ptr<LocalTensorHandleData>>> + // Map of local mirrors. This can include both ready and non-ready mirrors. + std::unordered_map<const tensorflow::Device*, LocalTensorHandleData> local_mirrors_ GUARDED_BY(mu_); #if !defined(IS_MOBILE_PLATFORM) // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica // variable is ready, since we could get the shape locally without remote copy // then. - std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>> - resource_shape_mirrors_ GUARDED_BY(mu_); - // TODO(gjn): Unshaped remote mirrors are not expected to be long-lived. - // Consider replacing the unshaped_remote_mirrors_ map with something more - // efficient. - std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>> - unshaped_remote_mirrors_ GUARDED_BY(mu_); + std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_ + GUARDED_BY(mu_); // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be // a fixed size map. - std::map<string, std::unique_ptr<RemoteTensorHandleData>> remote_mirrors_ + std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_ GUARDED_BY(mu_); - - // IDs required when this class is representing a remote tensor handle. - int64 remote_op_id_; - int32 remote_output_num_; #endif // `ctx` is only guaranteed to be set if the handle is not "ready". This is // typically true when the handle was produced during async execution. // `ctx` object is not owned and should outlive this handle. + // + // TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of + // a TensorHandle does not outlive the EagerContext from which it came? EagerContext* const ctx_; // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, is_poisoned_ is immutable. Status is_poisoned_; - const bool is_remote_; - const bool is_async_; bool implicit_mirroring_; - bool is_ready_ GUARDED_BY(mu_); // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // refers to a remote resource handle, we store data types and shapes for @@ -323,8 +273,12 @@ class TensorHandle : public core::RefCounted { std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_; // Does not need synchronization because it can be accessed only after - // WaitReady() has returned. At that point, tensor_handle_data_ is immutable. - std::unique_ptr<TensorHandleData> tensor_handle_data_; + // WaitReady() has returned. At that point, data_ is immutable. +#if !defined(IS_MOBILE_PLATFORM) + absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_; +#else + absl::variant<LocalTensorHandleData> data_; +#endif PartialTensorShape inference_shape_; }; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc index 690c14e2ffd..f0b06cf983c 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc @@ -23,12 +23,16 @@ namespace tensorflow { class Status; Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { + TF_RETURN_IF_ERROR(WaitReady("Tensor")); + *t = &tensor_; return Status::OK(); } Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { + TF_RETURN_IF_ERROR(WaitReady("TensorValue")); + tensorflow::Tensor& tensor = tensor_; *t = tensorflow::TensorValue(&tensor); @@ -36,103 +40,96 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { } Status LocalTensorHandleData::Shape(TensorShape* shape) const { + TF_RETURN_IF_ERROR(WaitReady("Shape")); + *shape = tensor_.shape(); return Status::OK(); } Status LocalTensorHandleData::NumDims(int* num_dims) const { + TF_RETURN_IF_ERROR(WaitReady("NumDims")); + *num_dims = tensor_.dims(); return Status::OK(); } Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const { + TF_RETURN_IF_ERROR(WaitReady("Dim")); + *dim = tensor_.dim_size(dim_index); return Status::OK(); } Status LocalTensorHandleData::NumElements(int64* num_elements) const { + TF_RETURN_IF_ERROR(WaitReady("NumElements")); + *num_elements = tensor_.NumElements(); return Status::OK(); } Status LocalTensorHandleData::Unprotect() { + if (!IsReady()) { + return errors::Internal("Cannot unprotect a non-ready tensor"); + } + forwarding_protection_tensor_ = tensorflow::Tensor(); return Status::OK(); } -Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { - return errors::Unavailable( - "Unable to get a tensor for an empty handle. " - "Please wait until it is ready"); +Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) { + DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles."; + + tensor_ = std::move(t); + // Create copy of original tensor to avoid forwarding + forwarding_protection_tensor_ = tensor_; + + auto& state = absl::get<BlockingControl>(ctrl_); + state.SetReady(); + + return Status::OK(); } -Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { - return errors::Unavailable( - "Unable to get a tensor for an empty handle. " - "Please wait until it is ready"); +string LocalTensorHandleData::DebugString() const { + if (IsReady()) { + return tensor_.DeviceSafeDebugString(); + } else { + return "LocalTensorHandleData"; + } } -Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const { - return errors::Unavailable( - "Unable to get shape information for an empty handle. " - "Please wait until it is ready"); -} - -Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const { - return errors::Unavailable( - "Unable to get shape information for an empty handle. " - "Please wait until it is ready"); -} - -Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const { - return errors::Unavailable( - "Unable to get shape information for an empty handle. " - "Please wait until it is ready"); -} - -Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const { - return errors::Unavailable( - "Unable to get shape information for an empty handle. " - "Please wait until it is ready"); -} - -Status EmptyLocalTensorHandleData::Unprotect() { - return errors::Unavailable("Unable to unprotect an empty handle."); -} - -bool EmptyLocalTensorHandleData::IsReady() const { - tf_shared_lock l(mu_); - return is_ready_; -} - -void EmptyLocalTensorHandleData::SetReady() { +void LocalTensorHandleData::BlockingControl::SetReady() { mutex_lock l(mu_); is_ready_ = true; } -Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const { - if (!IsReady()) { - profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), - profiler::TraceMeLevel::kInfo); - tf_shared_lock l(mu_); +Status LocalTensorHandleData::BlockingControl::WaitReady( + const char* caller) const { + tf_shared_lock l(mu_); + if (!is_ready_) { + profiler::TraceMe activity( + [caller] { return absl::StrCat(caller, " WaitReady"); }, + + profiler::TraceMeLevel::kInfo); + DVLOG(3) << "WaitReady: " << caller << " " << this; mu_.Await(Condition(&is_ready_)); } + return is_poisoned_; } -void EmptyLocalTensorHandleData::Poison(Status status) { - is_poisoned_ = status; +void LocalTensorHandleData::BlockingControl::Poison(Status status) { mutex_lock l(mu_); + if (is_ready_) { + LOG(ERROR) << "Poison can only be called on non-ready handle: " << this; + return; + } + is_poisoned_ = status; is_ready_ = true; } -string EmptyLocalTensorHandleData::DebugString() const { - return "EmptyLocalTensorHandleData"; -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h index 3a791d94315..bcf38d5b695 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h @@ -15,52 +15,50 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ +#include "absl/types/variant.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -class TensorHandleData { - public: - virtual ~TensorHandleData() {} - - // Different tensor handles support a set of these calls. In some cases these - // are resolved with a Tensor or TensorShape. Typically if the handle is not - // ready, none of these are supported operations. - virtual Status Tensor(const tensorflow::Tensor** t) const = 0; - virtual Status TensorValue(tensorflow::TensorValue* t) = 0; - virtual Status Shape(TensorShape* shape) const = 0; - virtual Status NumDims(int* num_dims) const = 0; - virtual Status Dim(int dim_index, int64* dim) const = 0; - virtual Status NumElements(int64* num_elements) const = 0; - // Allow the backing Tensor to be available for buffer reuse during op - // execution. - virtual Status Unprotect() = 0; - - virtual string DebugString() const = 0; -}; - // Local Tensor Handle: Handle to a Tensor present on the local host. -class LocalTensorHandleData : public TensorHandleData { +class LocalTensorHandleData { public: - explicit LocalTensorHandleData(const tensorflow::Tensor& t) - : tensor_(t), forwarding_protection_tensor_(t) {} - ~LocalTensorHandleData() override {} + LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {} + explicit LocalTensorHandleData(tensorflow::Tensor&& t) + : tensor_(std::move(t)), + forwarding_protection_tensor_(tensor_), + ctrl_(absl::in_place_type<NonBlockingControl>) {} // A local tensor handle should be able to satisfy all of these requests. - Status Tensor(const tensorflow::Tensor** t) const override; - Status TensorValue(tensorflow::TensorValue* t) override; - Status Shape(TensorShape* shape) const override; - Status NumDims(int* num_dims) const override; - Status Dim(int dim_index, int64* dim) const override; - Status NumElements(int64* num_elements) const override; - Status Unprotect() override; + Status Tensor(const tensorflow::Tensor** t) const; + Status TensorValue(tensorflow::TensorValue* t); + Status Shape(TensorShape* shape) const; + Status NumDims(int* num_dims) const; + Status Dim(int dim_index, int64* dim) const; + Status NumElements(int64* num_elements) const; + Status Unprotect(); - string DebugString() const override { - return tensor_.DeviceSafeDebugString(); + bool IsReady() const { + return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_); } + Status WaitReady(const char* caller) const { + return absl::visit([caller](auto& data) { return data.WaitReady(caller); }, + ctrl_); + } + void Poison(Status status) { + return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_); + } + Status IsPoisoned() const { + return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_); + } + + Status SetTensor(tensorflow::Tensor&& t); + + string DebugString() const; + private: tensorflow::Tensor tensor_; // TensorHandle has its own reference counting which is distinct from the @@ -70,37 +68,41 @@ class LocalTensorHandleData : public TensorHandleData { // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we // release this Tensor to allow forwarding. tensorflow::Tensor forwarding_protection_tensor_; -}; -// Empty Local Tensor Handle: Once the execution is complete this is replaced by -// a local tensor handle. -class EmptyLocalTensorHandleData : public TensorHandleData { - public: - EmptyLocalTensorHandleData() {} - ~EmptyLocalTensorHandleData() override {} + // We distinguish between ready and empty tensors with the ctrl_ variant. + // which contains 2 implementations of the waiting logic. The + // NonBlockingControl is a simple no-op class whereas the BlockingControl + // actually uses a mutex. By using a variant we avoid the overhead of + // constructing and destructing the mutex for ready local tensors. + class NonBlockingControl { + public: + bool IsReady() const { return true; } + Status WaitReady(const char* caller) const { return Status::OK(); } + void Poison(Status status) {} + Status IsPoisoned() const { return Status::OK(); } + }; - // Empty tensor handles are not ready and hence cannot satisfy any of these - // requests. - Status Tensor(const tensorflow::Tensor** t) const override; - Status TensorValue(tensorflow::TensorValue* t) override; - Status Shape(TensorShape* shape) const override; - Status NumDims(int* num_dims) const override; - Status Dim(int dim_index, int64* dim) const override; - Status NumElements(int64* num_elements) const override; - Status Unprotect() override; + class BlockingControl { + public: + bool IsReady() const { + tf_shared_lock l(mu_); + return is_ready_; + } + void SetReady(); + Status WaitReady(const char* caller) const; + void Poison(Status status); + Status IsPoisoned() const { + tf_shared_lock l(mu_); + return is_poisoned_; + } - bool IsReady() const; - void SetReady(); - Status WaitReady(const char* caller) const; - void Poison(Status status); - Status IsPoisoned() const { return is_poisoned_; } + private: + mutable mutex mu_; + bool is_ready_ GUARDED_BY(mu_); + Status is_poisoned_ GUARDED_BY(mu_); + }; - string DebugString() const override; - - private: - mutable mutex mu_; - bool is_ready_ GUARDED_BY(mu_); - Status is_poisoned_; + absl::variant<NonBlockingControl, BlockingControl> ctrl_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 155155a2763..6c62334281c 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -39,12 +39,13 @@ TEST(TensorHandle_ShapeTest, AsyncShape) { tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, &device_mgr, false, nullptr, nullptr, nullptr); TensorHandle* sync_th; - EXPECT_TRUE( - TensorHandle::CreateLocalHandle(t, ctx->HostCPU(), ctx, &sync_th).ok()); + EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, + ctx, &sync_th) + .ok()); TensorHandle* async_th; - EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr, - nullptr, DataType::DT_UINT16, - ctx, &async_th) + EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr, + DataType::DT_UINT16, ctx, + &async_th) .ok()); EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok()); diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 831dfcd8aef..411e3d3afaa 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -190,8 +190,10 @@ cc_library( deps = [ ":destroy_tensor_handle_node", ":eager_client", + "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/common_runtime/eager:tensor_handle_data", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/profiler/lib:traceme", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 4f60b488144..c023d5ebe48 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -530,7 +530,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, } TensorHandle* tensor_handle = nullptr; - TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); + TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( + std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle)); TensorHandle* copied_handle = nullptr; Device* device; TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index 7bb001ef853..32a88e0ba8d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -101,12 +101,10 @@ Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { core::RefCountPtr<KernelAndDevice> kernel; TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); - gtl::InlinedVector<TensorValue, 4> input_vector(1); - TF_RETURN_IF_ERROR(src_->TensorValue( - &input_vector[0], - ctx_->CanonicalDevice(absl::get<Device*>(op->Device())))); + EagerKernelArgs args(1); + Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device())); + TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0))); - EagerKernelArgs args(std::move(input_vector)); return kernel->Run(args, /*outputs=*/nullptr, /*cancellation_manager=*/nullptr, /*remote_func_params=*/absl::nullopt); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index f2e6735c6b7..51c1e763021 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -162,16 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, in.op_device().empty() ? in.device() : in.op_device(); TF_RETURN_IF_ERROR( parent_->FindDeviceFromName(device_name.c_str(), &device)); - string remote_task; - if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) { - return errors::InvalidArgument( - "Unable to find remote task corresponding to device ", device_name); - } - auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>( - in.op_id(), in.output_num(), remote_task, parent_); - remote_handle_data->ReleaseRemoteTensorHandle(); - TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( - std::move(remote_handle_data), in.dtype(), device, parent_, out)); + TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle( + in.op_id(), in.output_num(), in.dtype(), device, parent_, out)); std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), &dtypes_and_shapes) diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 3d2eb7f57cd..90213c978ed 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -71,14 +71,12 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) { Tensor t(DT_FLOAT, TensorShape({0})); TensorHandle* handle; - TF_ASSERT_OK( - TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle)); + TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_, + local_device_, ctx_, &handle)); const uint64 op_id = 2; const int output_num = 3; - auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>( - op_id, output_num, t.shape(), /*remote_task=*/"", ctx_); - TF_ASSERT_OK( - handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_)); + TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id, + output_num, "", ctx_)); RemoteTensorHandle remote_handle; TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( handle, &remote_handle, remote_device_, remote_device_->name())); @@ -90,14 +88,13 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) { TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) { RemoteMgr remote_mgr(false, ctx_); - Tensor t(DT_FLOAT, TensorShape({0})); const uint64 op_id = 3; const int output_num = 1; TensorHandle* handle; - TF_ASSERT_OK(TensorHandle::CreateRemoteHandle( - op_id, output_num, t.shape(), /*remote_task=*/"", DT_FLOAT, - remote_device_, /*resource_device=*/nullptr, ctx_, &handle)); + TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle( + op_id, output_num, + /*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle)); RemoteTensorHandle remote_handle; TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( handle, &remote_handle, remote_device_, remote_device_->name())); diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 704bef5a253..b5119406a91 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { @@ -84,66 +85,103 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task, } // namespace RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, - const TensorShape& shape, - const string& remote_task, - EagerContext* ctx) - : op_id_(op_id), + uint64 context_view_id) + : is_ready_(false), + op_id_(op_id), output_num_(output_num), - shape_(shape), - remote_task_(remote_task), - context_id_(ctx->GetContextId()), - context_view_id_(ctx->GetContextViewId()), - ctx_(*ctx) { + context_view_id_(context_view_id), + ctx_(nullptr) { DCHECK(op_id_ >= 0 && output_num_ >= 0) << "Op ID and output num should be >= 0. Op ID: " << op_id << ", Output num: " << output_num; - ctx_.Ref(); +} + +RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, + const string& remote_task, + EagerContext* ctx) + : is_ready_(false), + op_id_(op_id), + output_num_(output_num), + remote_task_(remote_task), + context_id_(ctx->GetContextId()), + context_view_id_(ctx->GetContextViewId()), + ctx_(ctx) { + DCHECK(op_id_ >= 0 && output_num_ >= 0) + << "Op ID and output num should be >= 0. Op ID: " << op_id + << ", Output num: " << output_num; + ctx_->Ref(); } RemoteTensorHandleData::~RemoteTensorHandleData() { - DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_, - output_num_, /*ready=*/true); - ctx_.Unref(); -} - -Status RemoteTensorHandleData::Tensor(const tensorflow::Tensor** t) const { - return errors::Unavailable( - "Unable to get a tensor for a remote device. Please copy the tensor " - "handle to a local device using TFE_TensorHandleCopyToDevice"); -} - -Status RemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) { - return errors::Unavailable( - "Unable to get a tensor for a remote device. Please copy the tensor " - "handle to a local device using TFE_TensorHandleCopyToDevice"); + if (ctx_) { + DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_, + output_num_, /*ready=*/true); + ctx_->Unref(); + } } Status RemoteTensorHandleData::Shape(TensorShape* shape) const { + TF_RETURN_IF_ERROR(WaitReady("Shape")); + + tf_shared_lock l(mu_); *shape = shape_; return Status::OK(); } Status RemoteTensorHandleData::NumDims(int* num_dims) const { + TF_RETURN_IF_ERROR(WaitReady("NumDims")); + + tf_shared_lock l(mu_); *num_dims = shape_.dims(); return Status::OK(); } Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const { + TF_RETURN_IF_ERROR(WaitReady("Dim")); + + tf_shared_lock l(mu_); *dim = shape_.dim_size(dim_index); return Status::OK(); } Status RemoteTensorHandleData::NumElements(int64* num_elements) const { + TF_RETURN_IF_ERROR(WaitReady("NumElements")); + + tf_shared_lock l(mu_); *num_elements = shape_.num_elements(); return Status::OK(); } -Status RemoteTensorHandleData::Unprotect() { - return errors::Unavailable("Unable to unprotect a remote handle."); +bool RemoteTensorHandleData::IsReady() const { + tf_shared_lock l(mu_); + return is_ready_; +} + +void RemoteTensorHandleData::Poison(Status status) { + mutex_lock l(mu_); + is_poisoned_ = status; +} + +Status RemoteTensorHandleData::IsPoisoned() const { + tf_shared_lock l(mu_); + return is_poisoned_; +} + +Status RemoteTensorHandleData::SetShape(const TensorShape& shape) { + mutex_lock l(mu_); + if (is_ready_) { + return errors::Internal("SetShape is only called on non-ready handles."); + } + + shape_ = shape; + is_poisoned_ = Status::OK(); + is_ready_ = true; + + return Status::OK(); } string RemoteTensorHandleData::DebugString() const { @@ -151,73 +189,20 @@ string RemoteTensorHandleData::DebugString() const { " output_num: ", output_num_); } -UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData( - int64 op_id, int32 output_num, const string& remote_task, EagerContext* ctx) - : op_id_(op_id), - output_num_(output_num), - delete_remote_tensor_(true), - remote_task_(remote_task), - context_id_(ctx->GetContextId()), - context_view_id_(ctx->GetContextViewId()), - ctx_(*ctx) { - DCHECK(op_id_ >= 0 && output_num_ >= 0) - << "Op ID and output num should be >= 0. Op ID: " << op_id - << ", Output num: " << output_num; - ctx_.Ref(); -} - -UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() { - if (delete_remote_tensor_) { - DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_, - output_num_, /*ready=*/false); +Status RemoteTensorHandleData::WaitReady(const char* caller) const { + if (ctx_ == nullptr) { + return errors::Internal("Cannot wait on lazy remote handle"); } - ctx_.Unref(); -} -Status UnshapedRemoteTensorHandleData::Tensor( - const tensorflow::Tensor** t) const { - return errors::Unavailable( - "Unable to get a tensor for a remote handle. Please copy the tensor " - "handle to a local device using TFE_TensorHandleCopyToDevice"); -} - -Status UnshapedRemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) { - return errors::Unavailable( - "Unable to get a tensor for a remote handle. Please copy the tensor " - "handle to a local device using TFE_TensorHandleCopyToDevice"); -} - -Status UnshapedRemoteTensorHandleData::Shape(TensorShape* shape) const { - return errors::Unavailable( - "Unable to get shape information for an async remote handle. Please wait " - "until it is ready"); -} - -Status UnshapedRemoteTensorHandleData::NumDims(int* num_dims) const { - return errors::Unavailable( - "Unable to get shape information for an async remote handle. Please wait " - "until it is ready"); -} - -Status UnshapedRemoteTensorHandleData::Dim(int dim_index, int64* dim) const { - return errors::Unavailable( - "Unable to get shape information for an async remote handle. Please wait " - "until it is ready"); -} - -Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const { - return errors::Unavailable( - "Unable to get shape information for an async remote handle. Please wait " - "until it is ready"); -} - -Status UnshapedRemoteTensorHandleData::Unprotect() { - return errors::Unavailable("Unable to unprotect a remote handle."); -} - -string UnshapedRemoteTensorHandleData::DebugString() const { - return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_, - " output_num: ", output_num_); + tf_shared_lock l(mu_); + if (!is_ready_) { + profiler::TraceMe activity( + [caller] { return absl::StrCat(caller, " WaitReady"); }, + profiler::TraceMeLevel::kInfo); + DVLOG(3) << "WaitReady: " << caller << " " << this; + mu_.Await(Condition(&is_ready_)); + } + return is_poisoned_; } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h index 34e3ec5f83d..6c3e060d934 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h @@ -15,97 +15,56 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ -#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" -#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { // Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only // the shape is known. -class RemoteTensorHandleData : public TensorHandleData { +class RemoteTensorHandleData { public: - RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape, - const string& remote_task, EagerContext* ctx); - ~RemoteTensorHandleData() override; + // Constructor for lazy remote handles + RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id); + // Constructor for unshaped remote handles + RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task, + EagerContext* ctx); + ~RemoteTensorHandleData(); // A remote tensor handle does not have a Tensor object, hence it can only // support the shape requests. - Status Tensor(const tensorflow::Tensor** t) const override; - Status TensorValue(tensorflow::TensorValue* t) override; - Status Shape(TensorShape* shape) const override; - Status NumDims(int* num_dims) const override; - Status Dim(int dim_index, int64* dim) const override; - Status NumElements(int64* num_elements) const override; - Status Unprotect() override; - EagerContext& ctx() const { return ctx_; } + Status Shape(TensorShape* shape) const; + Status NumDims(int* num_dims) const; + Status Dim(int dim_index, int64* dim) const; + Status NumElements(int64* num_elements) const; - string DebugString() const override; + bool IsReady() const; + Status SetShape(const TensorShape& shape); + void Poison(Status status); + Status IsPoisoned() const; + + string DebugString() const; int64 op_id() const { return op_id_; } int32 output_num() const { return output_num_; } - uint64 context_id() const { return context_id_; } uint64 context_view_id() const { return context_view_id_; } private: + Status WaitReady(const char* caller) const; + + mutable mutex mu_; + bool is_ready_ GUARDED_BY(mu_); + Status is_poisoned_ GUARDED_BY(mu_); + TensorShape shape_ GUARDED_BY(mu_); + // IDs required when this class is representing a remote tensor handle. const int64 op_id_; const int32 output_num_; - const TensorShape shape_; string remote_task_; uint64 context_id_; uint64 context_view_id_; - EagerContext& ctx_; -}; - -// Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the -// shape has been computed this is replaced with a remote tensor handle. -class UnshapedRemoteTensorHandleData : public TensorHandleData { - public: - UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num, - const string& remote_task, EagerContext* ctx); - ~UnshapedRemoteTensorHandleData() override; - - // Unshaped remote tensor handles are not ready and hence cannot satisfy any - // of these requests. - Status Tensor(const tensorflow::Tensor** t) const override; - Status TensorValue(tensorflow::TensorValue* t) override; - Status Shape(TensorShape* shape) const override; - Status NumDims(int* num_dims) const override; - Status Dim(int dim_index, int64* dim) const override; - Status NumElements(int64* num_elements) const override; - Status Unprotect() override; - - void Poison(Status status) { is_poisoned_ = status; } - Status IsPoisoned() const { return is_poisoned_; } - - string DebugString() const override; - - int64 op_id() const { return op_id_; } - int32 output_num() const { return output_num_; } - string remote_task() const { return remote_task_; } - uint64 context_id() const { return context_id_; } - uint64 context_view_id() const { return context_view_id_; } - EagerContext& ctx() const { return ctx_; } - - // When constructed, UnshapedRemoteTensorHandleData owns the remote - // TensorHandle and should delete it by issuing an RPC. Once the remote - // shape has been learned, the ownership is transferred to - // RemoteTensorHandleData. This method must be called to let `this` know - // that it no longer owns the remote handle. - // TODO(iga): Add a factory method here that will create a new - // RemoteTensorHandleData from this and transfer ownership in the process. - void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; } - - private: - Status is_poisoned_; - // IDs required when this class is representing a remote tensor handle. - const int64 op_id_; - const int32 output_num_; - bool delete_remote_tensor_; - string remote_task_; - uint64 context_id_; - uint64 context_view_id_; - EagerContext& ctx_; + EagerContext* ctx_; }; } // namespace tensorflow diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 281e9d61a4e..afa039e091b 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -91,11 +91,11 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) { Device* device = IsCPUDevice(call->device) ? nullptr : call->device; for (int64 i = 0; i < n; ++i) { PyObject* arg = nullptr; - const Tensor& t = call->ins[i]; if (call->eager) { TensorHandle* handle; + Tensor t = call->ins[i]; TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( - t, ctx->CanonicalDevice(device), nullptr, ctx, &handle)); + std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle)); arg = EagerTensorFromHandle(new TFE_TensorHandle{ std::make_unique<tensorflow::TensorHandleInterface>(handle)}); if (arg == nullptr) { @@ -103,7 +103,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) { return errors::Internal("Unable to procure EagerTensor from Tensor."); } } else { - Status s = TensorToNdarray(t, &arg); + Status s = TensorToNdarray(call->ins[i], &arg); if (!s.ok()) { Py_DECREF(lst); return s; diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 75a7b840db9..e81102847ea 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -277,7 +277,8 @@ struct Converter { static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, TFE_TensorHandle** h, const char** error) { - /* TODO(josh11b): Allocator & attributes? */ + // TODO(josh11b): Allocator & attributes + // TODO(gjn): Use optimized scalar constructors when possible. Tensor result(ConverterTraits<T>::kTypeEnum, TensorShape(state->inferred_shape)); if (state->inferred_shape.empty()) { /* Scalar case */ @@ -294,7 +295,7 @@ struct Converter { } tensorflow::TensorHandle* handle = nullptr; auto status = tensorflow::TensorHandle::CreateLocalHandle( - result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, + std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context, &handle); if (!status.ok()) { return status; @@ -610,8 +611,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) { auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); if (cppstatus.ok()) { cppstatus = tensorflow::TensorHandle::CreateLocalHandle( - t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context, - &handle); + std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, + ctx->context, &handle); } if (!cppstatus.ok()) { PyErr_SetString(PyExc_ValueError, @@ -805,10 +806,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, case DT_INVALID: // Only occurs for empty tensors. { tensorflow::TensorHandle* h = nullptr; - Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, - TensorShape(state.inferred_shape)); + Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, + TensorShape(state.inferred_shape)); status = tensorflow::TensorHandle::CreateLocalHandle( - tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, + std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context, &h); if (!status.ok()) { PyErr_SetString(PyExc_ValueError, status.error_message().c_str());