Roll-forward: Consolidate tensor handle data types

Fixed remote handles to be ready when they are poisoned

PiperOrigin-RevId: 298519389
Change-Id: Icd693a354639622705ff08d253acfbbd40013bc7
This commit is contained in:
Gaurav Jain 2020-03-02 21:22:02 -08:00
parent 6c212dc330
commit ca879f6889
20 changed files with 519 additions and 779 deletions

View File

@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorHandle* ret_handle; tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) { if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle); std::move(t), device, device, context, &ret_handle);
} else { } else {
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle); std::move(t), custom_device, context, &ret_handle);
} }
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;

View File

@ -140,6 +140,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"@com_google_absl//absl/types:variant",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",

View File

@ -564,8 +564,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
} }
} }
} }
const DataTypeVector& output_dtypes = kernel->output_dtypes(); int num_outputs = kernel->num_outputs();
const size_t num_outputs = static_cast<int>(output_dtypes.size());
if (num_outputs > *num_retvals) { if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs, return errors::InvalidArgument("Expecting ", num_outputs,
" outputs, but *num_retvals is ", " outputs, but *num_retvals is ",
@ -579,21 +578,19 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
graph_collector = ctx.GetGraphCollector(); graph_collector = ctx.GetGraphCollector();
} }
const bool async = executor.Async();
for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
async,
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i]));
}
Status s; Status s;
if (async) { if (executor.Async()) {
const DataTypeVector& output_dtypes = kernel->output_dtypes();
for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i]));
}
auto node = absl::make_unique<AsyncExecuteNode>( auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, output_dtypes, op->GetCancellationManager(), graph_collector, op->GetCancellationManager(),
absl::Span<TensorHandle*>(retvals, num_outputs)); absl::Span<TensorHandle*>(retvals, num_outputs));
// For async mode, execution order will make sure that all // For async mode, execution order will make sure that all
// input handles are ready before executing them. // input handles are ready before executing them.
@ -601,16 +598,21 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// performance. // performance.
s = executor.AddOrExecute(std::move(node)); s = executor.AddOrExecute(std::move(node));
} else { } else {
for (int i = 0; i < num_outputs; ++i) {
retvals[i] = nullptr;
}
ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel, ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel,
graph_collector, output_dtypes, graph_collector, op->GetCancellationManager(),
op->GetCancellationManager(), {retvals, num_outputs}); {retvals, static_cast<size_t>(num_outputs)});
s = executor.SyncExecute(&node); s = executor.SyncExecute(&node);
} }
// Since the operation failed, we need to Unref any outputs that were // Since the operation failed, we need to Unref any outputs if they were
// allocated. // allocated.
if (!s.ok()) { if (!s.ok()) {
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
retvals[i]->Unref(); if (retvals[i] != nullptr) {
retvals[i]->Unref();
}
} }
} }
@ -733,12 +735,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
input, input_handle, input_device, *input_device_name, input, input_handle, input_device, *input_device_name,
serialize_resource_dtype_and_shape)); serialize_resource_dtype_and_shape));
if (!input_handle->resource_dtypes_and_shapes().empty()) { if (!input_handle->resource_dtypes_and_shapes().empty()) {
auto tensor_handle_data = TF_RETURN_IF_ERROR(
absl::make_unique<UnshapedRemoteTensorHandleData>( input->AddResourceShapeMirror(op_device, input_handle->op_id(),
input_handle->op_id(), input_handle->output_num(), remote_task, input_handle->output_num(), &ctx));
&ctx);
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
std::move(tensor_handle_data), op_device));
} }
} }
} }
@ -1032,13 +1031,24 @@ Status EagerKernelExecute(
} }
} }
DCHECK_EQ(retvals.size(), outputs.size()); DCHECK_EQ(retvals.size(), outputs.size());
for (int i = 0; i < retvals.size(); ++i) {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(retvals[i]->SetTensor( for (int i = 0; i < retvals.size(); ++i) {
std::move(outputs[i]), ctx->CanonicalDevice(kernel->OutputDevice(i)))); if (retvals[i] == nullptr) {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(outputs[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
&retvals[i]));
} else {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(
retvals[i]->SetTensor(std::move(outputs[i]),
ctx->CanonicalDevice(kernel->OutputDevice(i))));
}
} }
return Status::OK(); return Status::OK();
} }
@ -1069,7 +1079,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, d, dstd, h->resource_device(), h->dtype, ctx, result)); d, dstd, h->resource_device(), h->dtype, ctx, result));
} }
Status s; Status s;
@ -1138,7 +1148,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, /* d= */ d, /* op_device= */ device, /* d= */ d, /* op_device= */ device,
/*resource_device=*/nullptr, h->dtype, ctx, result)); /*resource_device=*/nullptr, h->dtype, ctx, result));
} }
} else { } else {
@ -1156,17 +1166,14 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
device->name()); device->name());
} }
recv_op_id = ctx->RemoteMgr()->NextOpId(); recv_op_id = ctx->RemoteMgr()->NextOpId();
auto tensor_handle_data =
absl::make_unique<UnshapedRemoteTensorHandleData>(recv_op_id, 0,
remote_task, ctx);
if (mirror) { if (mirror) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device)); remote_task, ctx));
h->Ref(); h->Ref();
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(tensor_handle_data), h->dtype, device, ctx, result)); recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
} }
} }

View File

@ -32,7 +32,7 @@ Status ExecuteNodeArgs::Init(
for (int i = 0; i < n_inputs; ++i) { for (int i = 0; i < n_inputs; ++i) {
TensorHandle* in = op_inputs_flat[i]; TensorHandle* in = op_inputs_flat[i];
Device* d = kernel->InputDevice(i); Device* d = kernel->InputDevice(i);
Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d)); Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
if (!s.ok()) { if (!s.ok()) {
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
uint64 context_view_id = ctx->GetContextViewId(); uint64 context_view_id = ctx->GetContextViewId();

View File

@ -77,7 +77,7 @@ class ExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params, const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
const core::RefCountPtr<KernelAndDevice>& kernel, const core::RefCountPtr<KernelAndDevice>& kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes, GraphCollector* graph_collector,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals) absl::Span<TensorHandle*> retvals)
: EagerNode(), : EagerNode(),
@ -130,7 +130,7 @@ class AsyncExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params, const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
core::RefCountPtr<KernelAndDevice> kernel, core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes, GraphCollector* graph_collector,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals) absl::Span<TensorHandle*> retvals)
: EagerNode(), : EagerNode(),

View File

@ -67,6 +67,7 @@ class EagerKernelArgs : public FunctionArgsInterface {
~EagerKernelArgs() override{}; ~EagerKernelArgs() override{};
bool HasRemoteInputs() const override { return false; }; bool HasRemoteInputs() const override { return false; };
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
Status GetLocalArg(const int index, Tensor* val) const override; Status GetLocalArg(const int index, Tensor* val) const override;

View File

@ -20,39 +20,31 @@ limitations under the License.
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/errors.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
@ -64,7 +56,7 @@ const int32 kInvalidOutputNum = -1;
} // namespace } // namespace
void TensorHandle::SetResourceHandleDtypeAndShape( void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) { std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
} }
@ -86,250 +78,191 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
profiler::TraceMe activity( profiler::TraceMe activity(
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady", "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
*result = handle_dtypes_and_shapes_; *result = handle_dtypes_and_shapes_;
return Status::OK(); return Status::OK();
} }
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h) { TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr // TODO(b/136608821): Move away from nullptr
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr), tensorflow::Tensor tensor = t;
return CreateLocalHandle(std::move(tensor),
/*d=*/nullptr,
/*op_device=*/nullptr, /*op_device=*/nullptr,
/*ctx=*/nullptr, h); /*ctx=*/nullptr, h);
} }
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d, Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
EagerContext* ctx, TensorHandle** h) {
return CreateLocalHandle(t, d, d, ctx, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
Device* op_device, EagerContext* ctx, Device* op_device, EagerContext* ctx,
TensorHandle** h) { TensorHandle** h) {
if (t.dtype() != DT_RESOURCE) { return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), }
t.dtype(), d, op_device, ctx);
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
Device* resource_device,
EagerContext* ctx, TensorHandle** h) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
*h = new TensorHandle(std::move(t), d, op_device, ctx);
} else { } else {
const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0); *h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
resource_handle, d, op_device, ctx);
} }
return Status::OK(); return Status::OK();
} }
Status TensorHandle::CreateLocalHandle(const class Tensor& t, CustomDevice* d, Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h) { EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), t.dtype(), *h = new TensorHandle(std::move(t), d, ctx);
d, ctx);
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
DataType dtype, Device* d, Device* op_device, Device* resource_device, EagerContext* ctx)
EagerContext* ctx) : dtype(t.dtype()),
: dtype(dtype),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(nullptr), resource_device_(resource_device),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
const ResourceHandle& resource_handle, Device* d, EagerContext* ctx)
Device* op_device, EagerContext* ctx)
: dtype(DT_RESOURCE), : dtype(DT_RESOURCE),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(GetResourceDevice(resource_handle, ctx)), resource_device_(
#if !defined(IS_MOBILE_PLATFORM) GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), handle_dtypes_and_shapes_(
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()), t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
tensor_handle_data_(std::move(t)) { data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
DataType dtype, CustomDevice* d, EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(t.dtype()),
device_(d), device_(d),
op_device_(nullptr), op_device_(nullptr),
resource_device_(nullptr), resource_device_(nullptr),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
tensor_handle_data_(std::move(t)) {
// TODO(allenl): Figure out a better op_device story for custom devices, // TODO(allenl): Figure out a better op_device story for custom devices,
// since always setting it to CPU=nullptr doesn't make much sense. // since always setting it to CPU=nullptr doesn't make much sense.
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " custom device: " << VariantDeviceDebugString(device_); << " custom device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d, Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* op_device,
Device* resource_device, Device* resource_device,
DataType dtype, EagerContext* ctx, DataType dtype, EagerContext* ctx,
TensorHandle** h) { TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async, *h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
d, op_device, resource_device, dtype, ctx);
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, TensorHandle::TensorHandle(Device* d, Device* op_device,
bool async, Device* d, Device* op_device,
Device* resource_device, DataType dtype, Device* resource_device, DataType dtype,
EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_((d == ctx->HostCPU()) ? nullptr : d), device_((d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(resource_device), resource_device_(resource_device),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(async),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(!async), data_(absl::in_place_type<LocalTensorHandleData>) {
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating empty Local TensorHandle: " << this DVLOG(3) << "Creating empty Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::CreateRemoteHandle( Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d, const string& remote_task,
Device* resource_device, EagerContext* ctx, TensorHandle** h) { DataType dtype, Device* d,
*h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx); EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
return Status::OK(); return Status::OK();
} }
Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num, TensorHandle::TensorHandle(int64 op_id, int32 output_num,
const TensorShape& shape, const string& remote_task, DataType dtype, Device* d,
const string& remote_task,
DataType dtype, Device* d,
Device* resource_device,
EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<RemoteTensorHandleData>(
op_id, output_num, shape, remote_task, ctx),
dtype, d, resource_device, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
DataType dtype, Device* d, Device* resource_device,
EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_(d), device_(d),
op_device_(d), op_device_(d),
resource_device_(resource_device), resource_device_(dtype == DT_RESOURCE ? d : nullptr),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
ctx_(ctx), ctx_(ctx),
is_remote_(true),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
tensor_handle_data_(std::move(t)) { remote_task, ctx) {
DVLOG(3) << "Creating Remote TensorHandle: " << this DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
Status TensorHandle::CreateUnshapedRemoteHandle( Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, DataType dtype, Device* d,
Device* d, EagerContext* ctx, TensorHandle** h) { EagerContext* ctx,
*h = new TensorHandle(std::move(t), dtype, d, ctx); TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, dtype, d, ctx);
return Status::OK(); return Status::OK();
} }
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,
const string& remote_task, Device* d, EagerContext* ctx)
DataType dtype, Device* device,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
op_id, output_num, remote_task, ctx),
dtype, device, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
DataType dtype, Device* device, EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_(device), device_(d),
op_device_(device), op_device_(d),
resource_device_(dtype == DT_RESOURCE ? device : nullptr), resource_device_(dtype == DT_RESOURCE ? d : nullptr),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
ctx_(ctx), ctx_(ctx),
is_remote_(true),
is_async_(true),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(false), data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
tensor_handle_data_(std::move(t)) { ctx->GetContextViewId()) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
#endif #endif
bool TensorHandle::IsReady() const { bool TensorHandle::IsReady() const {
// Avoid mutex acquisition for local sync handles return absl::visit([](auto& data) { return data.IsReady(); }, data_);
if (!is_async_ && !is_remote_) {
return true;
}
tf_shared_lock l(mu_);
return is_ready_;
} }
Status TensorHandle::WaitReady(const char* caller) const { bool TensorHandle::IsRemote() const {
if (!IsReady()) { #if !defined(IS_MOBILE_PLATFORM)
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), return data_.index() == 1;
profiler::TraceMeLevel::kInfo); #else
tf_shared_lock l(mu_); return false;
mu_.Await(Condition(&is_ready_)); #endif
}
return is_poisoned_;
} }
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
DVLOG(3) << "Tensor on TensorHandle: " << this; DVLOG(3) << "Tensor on TensorHandle: " << this;
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor")); if (IsRemote()) {
return tensor_handle_data_->Tensor(t); return errors::Internal("Invalid Tensor call on remote handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Tensor(t);
} }
Status TensorHandle::TensorFromDevice(const Device* d, Status TensorHandle::TensorFromDevice(const Device* d,
@ -337,12 +270,12 @@ Status TensorHandle::TensorFromDevice(const Device* d,
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
if (is_remote_) { if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this); return errors::Internal("Invalid Tensor call on remote handle: ", this);
} }
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice")); auto& data = absl::get<LocalTensorHandleData>(data_);
return tensor_handle_data_->Tensor(t); return data.Tensor(t);
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -352,25 +285,21 @@ Status TensorHandle::TensorFromDevice(const Device* d,
" in Tensor call to handle: ", this); " in Tensor call to handle: ", this);
} }
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.Tensor(t);
TF_RETURN_IF_ERROR(
mirror.first->WaitReady("TensorHandle::TensorFromDevice"));
}
return mirror.second->Tensor(t);
} }
Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) { Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
if (is_remote_) { if (IsRemote()) {
return errors::Internal("Invalid TensorValue call on remote handle: ", return errors::Internal("Invalid TensorValue call on remote handle: ",
this); this);
} }
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue")); auto& data = absl::get<LocalTensorHandleData>(data_);
return tensor_handle_data_->TensorValue(t); return data.TensorValue(t);
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -380,13 +309,8 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
" in TensorValue call to handle: ", this); " in TensorValue call to handle: ", this);
} }
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.TensorValue(t);
TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue"));
}
return mirror.second->TensorValue(t);
} }
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
@ -405,8 +329,8 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
DCHECK(fill); DCHECK(fill);
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape")); return absl::visit([shape](auto& data) { return data.Shape(shape); },
return tensor_handle_data_->Shape(shape); data_);
} }
} }
@ -480,8 +404,8 @@ Status TensorHandle::NumDims(int* num_dims) const {
*num_dims = inference_shape_.dims(); *num_dims = inference_shape_.dims();
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims")); return absl::visit(
return tensor_handle_data_->NumDims(num_dims); [num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
} }
} }
@ -492,8 +416,9 @@ Status TensorHandle::Dim(int dim_index, int64* dim) const {
*dim = inference_shape_.dim_size(dim_index); *dim = inference_shape_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim")); return absl::visit(
return tensor_handle_data_->Dim(dim_index, dim); [dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
data_);
} }
} }
@ -503,8 +428,9 @@ Status TensorHandle::NumElements(int64* num_elements) const {
*num_elements = inference_shape_.num_elements(); *num_elements = inference_shape_.num_elements();
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements")); return absl::visit(
return tensor_handle_data_->NumElements(num_elements); [num_elements](auto& data) { return data.NumElements(num_elements); },
data_);
} }
} }
@ -512,7 +438,8 @@ Status TensorHandle::Unprotect(const Device* d) {
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
return tensor_handle_data_->Unprotect(); auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Unprotect();
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -524,11 +451,7 @@ Status TensorHandle::Unprotect(const Device* d) {
// Check if the handle is non-empty // Check if the handle is non-empty
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.Unprotect();
return errors::Internal("Attempted to unprotect an empty mirror");
}
return mirror.second->Unprotect();
} }
bool TensorHandle::HasLocalMirror(const Device* d) const { bool TensorHandle::HasLocalMirror(const Device* d) const {
@ -551,8 +474,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
return errors::Internal("Attempted to duplicate a local mirror."); return errors::Internal("Attempted to duplicate a local mirror.");
} }
local_mirrors_[d] = local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr); std::forward_as_tuple());
return Status::OK(); return Status::OK();
} }
@ -567,15 +490,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
*op_id = mirror->second->op_id(); *op_id = mirror->second.op_id();
*output_num = mirror->second->output_num(); *output_num = mirror->second.output_num();
return Status::OK();
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
*op_id = unshaped_mirror->second->op_id();
*output_num = unshaped_mirror->second->output_num();
return Status::OK(); return Status::OK();
} }
@ -583,14 +499,14 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
"Could not find remote mirror for specified device"); "Could not find remote mirror for specified device");
} }
if (remote_op_id_ == kInvalidOpId || if (!IsRemote()) {
remote_output_num_ == kInvalidOutputNum) { return errors::InvalidArgument("Primary device is not remote");
return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
", output_num:", remote_output_num_,
") is not set.");
} }
*op_id = remote_op_id_;
*output_num = remote_output_num_; auto& data = absl::get<RemoteTensorHandleData>(data_);
*op_id = data.op_id();
*output_num = data.output_num();
return Status::OK(); return Status::OK();
} }
@ -603,16 +519,7 @@ bool TensorHandle::HasRemoteMirror(const Device* d,
auto mirror = remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
// Check if mirror is stale // Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) { if (mirror->second.context_view_id() != context_view_id) {
return false;
}
return true;
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
// Check if mirror is stale
if (unshaped_mirror->second->context_view_id() != context_view_id) {
return false; return false;
} }
return true; return true;
@ -630,7 +537,7 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
auto mirror = resource_shape_mirrors_.find(d->name()); auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) { if (mirror != resource_shape_mirrors_.end()) {
// Check if mirror is stale // Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) { if (mirror->second.context_view_id() != context_view_id) {
return false; return false;
} }
return true; return true;
@ -638,45 +545,39 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
return false; return false;
} }
Status TensorHandle::AddUnshapedRemoteMirror( Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { int output_num,
const string& remote_task,
EagerContext* ctx) {
DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
<< " device: " << d << " " << d->name(); << " device: " << d << " " << d->name() << " op_id: " << op_id
<< " output_num: " << output_num;
mutex_lock l(mu_); mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name()); auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) { if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == t->context_view_id()) { if (remote_mirror->second.context_view_id() == ctx->GetContextId()) {
return errors::Internal("Attempted to duplicate a remote mirror."); return errors::Internal("Attempted to duplicate a remote mirror.");
} }
// Remove stale mirror // Remove stale mirror
remote_mirrors_.erase(remote_mirror); remote_mirrors_.erase(remote_mirror);
} }
auto unshaped_remote_mirror = unshaped_remote_mirrors_.find(d->name()); remote_mirrors_.emplace(
if (unshaped_remote_mirror != unshaped_remote_mirrors_.end()) { std::piecewise_construct, std::forward_as_tuple(d->name()),
if (unshaped_remote_mirror->second->context_view_id() == std::forward_as_tuple(op_id, output_num, remote_task, ctx));
t->context_view_id()) {
return errors::Internal(
"Attempted to duplicate an unshaped remote mirror.");
}
// Remove stale mirror
unshaped_remote_mirrors_.erase(unshaped_remote_mirror);
}
unshaped_remote_mirrors_[d->name()] = std::move(t);
return Status::OK(); return Status::OK();
} }
Status TensorHandle::AddResourceShapeMirror( Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { int output_num, EagerContext* ctx) {
DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this; DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
mutex_lock l(mu_); mutex_lock l(mu_);
auto mirror = resource_shape_mirrors_.find(d->name()); auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) { if (mirror != resource_shape_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) { if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
return errors::Internal( return errors::Internal(
"Attempted to duplicate a resource shape mirror."); "Attempted to duplicate a resource shape mirror.");
} }
@ -684,26 +585,9 @@ Status TensorHandle::AddResourceShapeMirror(
resource_shape_mirrors_.erase(mirror); resource_shape_mirrors_.erase(mirror);
} }
resource_shape_mirrors_[d->name()] = std::move(t); resource_shape_mirrors_.emplace(
std::piecewise_construct, std::forward_as_tuple(d->name()),
return Status::OK(); std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId()));
}
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
const Device* d) {
DVLOG(3) << "AddRemoteMirror on TensorHandle: " << this << " device: " << d;
mutex_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) {
return errors::Internal("Attempted to duplicate a remote mirror.");
}
// Remove stale mirror
remote_mirrors_.erase(mirror);
}
remote_mirrors_[d->name()] = std::move(t);
return Status::OK(); return Status::OK();
} }
@ -717,53 +601,24 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
mutex_lock l(mu_); mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name()); auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) { if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == context_view_id) { auto& mirror = remote_mirror->second;
return errors::Internal( if (mirror.context_view_id() == context_view_id) {
"Attempted to set remote shape for existing mirror."); return mirror.SetShape(shape);
} }
remote_mirrors_.erase(remote_mirror); remote_mirrors_.erase(remote_mirror);
} }
auto elem = unshaped_remote_mirrors_.find(d->name());
if (elem == unshaped_remote_mirrors_.end()) {
return errors::Internal(
"Attempted to set remote shape for non-waiting mirror.");
}
if (elem->second->context_view_id() != context_view_id) {
unshaped_remote_mirrors_.erase(elem);
return errors::Internal(
"Attempted to set remote shape for a stale mirror.");
}
auto& data = elem->second;
data->ReleaseRemoteTensorHandle();
remote_mirrors_[d->name()] = absl::make_unique<RemoteTensorHandleData>(
data->op_id(), data->output_num(), shape, data->remote_task(),
&data->ctx());
unshaped_remote_mirrors_.erase(elem);
return Status::OK(); return Status::OK();
} }
DCHECK(is_remote_) << "SetRemoteShape is only called on remote handles."; DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
UnshapedRemoteTensorHandleData* p = auto& data = absl::get<RemoteTensorHandleData>(data_);
reinterpret_cast<UnshapedRemoteTensorHandleData*>( if (data.context_view_id() != context_view_id) {
tensor_handle_data_.get());
if (p->context_view_id() != context_view_id) {
return errors::Internal("Attempted to set remote shape for an old handle."); return errors::Internal("Attempted to set remote shape for an old handle.");
} }
p->ReleaseRemoteTensorHandle(); return data.SetShape(shape);
tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
remote_op_id_, remote_output_num_, shape, p->remote_task(), ctx_);
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
return Status::OK();
} }
void TensorHandle::PoisonRemote(Status status, const Device* d, void TensorHandle::PoisonRemote(Status status, const Device* d,
@ -772,18 +627,16 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
<< " " << d->name(); << " " << d->name();
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady()) DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
<< "PoisonRemote can only be called on non-ready handle: " << this;
is_poisoned_ = status; auto& data = absl::get<RemoteTensorHandleData>(data_);
mutex_lock l(mu_); data.Poison(status);
is_ready_ = true;
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto mirror = unshaped_remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != unshaped_remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
if (mirror->second->context_view_id() == context_view_id) { if (mirror->second.context_view_id() == context_view_id) {
mirror->second->Poison(status); mirror->second.Poison(status);
} }
} }
} }
@ -798,9 +651,9 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
} }
mutex_lock l(mu_); mutex_lock l(mu_);
auto elem = local_mirrors_.insert(std::make_pair( auto elem =
d, std::make_pair(nullptr, local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::make_unique<LocalTensorHandleData>(tensor)))); std::forward_as_tuple(std::move(tensor)));
if (!elem.second) { if (!elem.second) {
return errors::Internal("Attempted to set tensor for existing mirror."); return errors::Internal("Attempted to set tensor for existing mirror.");
} }
@ -808,24 +661,18 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
return Status::OK(); return Status::OK();
} }
Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
DCHECK(!is_remote_) << "SetTensor is not called on remote handles."; DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
DCHECK(!is_async_ || !IsReady())
<< "SetTensor is only called on non-ready handles.";
if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = tensor.flat<class ResourceHandle>()(0); auto& resource_handle = t.flat<class ResourceHandle>()(0);
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
} }
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor); auto& data = absl::get<LocalTensorHandleData>(data_);
if (is_async_) { return data.SetTensor(std::move(t));
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
}
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d); auto elem = local_mirrors_.find(d);
@ -835,12 +682,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
} }
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second != nullptr) { return mirror.SetTensor(std::move(t));
return errors::Internal("Attempted to set tensor for existing mirror.");
}
mirror.second = absl::make_unique<LocalTensorHandleData>(tensor);
mirror.first->SetReady();
} }
return Status::OK(); return Status::OK();
@ -850,12 +692,10 @@ void TensorHandle::Poison(Status status, const Device* d) {
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady()) DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
<< "Poison can only be called on non-ready handle: " << this;
is_poisoned_ = status; auto& data = absl::get<LocalTensorHandleData>(data_);
mutex_lock l(mu_); data.Poison(status);
is_ready_ = true;
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d); auto elem = local_mirrors_.find(d);
@ -864,9 +704,7 @@ void TensorHandle::Poison(Status status, const Device* d) {
<< " device: " << d; << " device: " << d;
auto& mirror = elem->second; auto& mirror = elem->second;
DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror."; mirror.Poison(status);
mirror.first->Poison(status);
} }
} }
@ -977,8 +815,11 @@ string TensorHandle::DebugString() const {
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull; !VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to // Consider supporting non-CPU tensors and CPU tensors with a device_ set to
// non-NULL if needed. // non-NULL if needed.
strings::StrAppend(&out, ", Tensor: ", strings::StrAppend(
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n"); &out, ", Tensor: ",
is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
: "?",
"\n");
return out; return out;
} }

View File

@ -17,10 +17,10 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <map>
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
// clang-format off // clang-format off
@ -32,28 +32,20 @@ limitations under the License.
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
@ -67,56 +59,45 @@ class TensorHandle : public core::RefCounted {
using VariantDevice = absl::variant<Device*, CustomDevice*>; using VariantDevice = absl::variant<Device*, CustomDevice*>;
// TensorHandle for dtype != DT_RESOURCE // TensorHandle for dtype != DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* d, Device* op_device, EagerContext* ctx); Device* resource_device, EagerContext* ctx);
// TensorHandle for dtype == DT_RESOURCE // TensorHandle for dtype == DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
const ResourceHandle& resource_handle, Device* d, EagerContext* ctx);
Device* op_device, EagerContext* ctx); TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, TensorHandle(Device* d, Device* op_device, Device* resource_device,
CustomDevice* d, EagerContext* ctx);
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
Device* d, Device* op_device, Device* resource_device,
DataType dtype, EagerContext* ctx); DataType dtype, EagerContext* ctx);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
Device* d, Device* resource_device, EagerContext* ctx);
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
DataType dtype, Device* device, EagerContext* ctx); DataType dtype, Device* device, EagerContext* ctx);
TensorHandle(int64 op_id, int32 output_num, DataType dtype, Device* device,
EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
public: public:
// TensorHandle with no assigned device // TensorHandle with no assigned device
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h); static Status CreateLocalHandle(const tensorflow::Tensor& t,
// TensorHandle with device == op_device TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d, static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d,
Device* op_device, EagerContext* ctx, Device* op_device, EagerContext* ctx,
TensorHandle** h); TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, CustomDevice* d, static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, Device* resource_device,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device, static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device, DataType dtype, Device* resource_device, DataType dtype,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
static Status CreateRemoteHandle(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task, DataType dtype,
Device* d, Device* resource_device,
EagerContext* ctx, TensorHandle** h);
static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task, const string& remote_task,
DataType dtype, Device* device, DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
static Status CreateUnshapedRemoteHandle( static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, DataType dtype, Device* d,
Device* device, EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; } ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
@ -131,7 +112,7 @@ class TensorHandle : public core::RefCounted {
// Return the TensorValue from the specified device which could be either the // Return the TensorValue from the specified device which could be either the
// default device or a local mirror. The device pointer should be nullptr if // default device or a local mirror. The device pointer should be nullptr if
// requesting the HostCPU. // requesting the HostCPU.
Status TensorValue(tensorflow::TensorValue* t, const Device* d); Status TensorValue(const Device* d, tensorflow::TensorValue* t);
VariantDevice device() const { return device_; } VariantDevice device() const { return device_; }
Device* op_device() const { return op_device_; } Device* op_device() const { return op_device_; }
@ -161,12 +142,10 @@ class TensorHandle : public core::RefCounted {
bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
Status AddUnshapedRemoteMirror( Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d); const string& remote_task, EagerContext* ctx);
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
const Device* d); EagerContext* ctx);
Status AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
// Return the op_id and output num if the handle refers to a remote tensor. // Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const; Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const;
@ -212,14 +191,12 @@ class TensorHandle : public core::RefCounted {
Status CopyInferenceShape(TensorHandle* other); Status CopyInferenceShape(TensorHandle* other);
// Warning: can return nullptr for CPU tensors. // Warning: can return nullptr for CPU tensors.
// TODO(b/136608821): Move away from nullptr
EagerContext* Context() { return ctx_; } EagerContext* Context() { return ctx_; }
// dtype for the handle. It must be the same as t.dtype() once the handle is // dtype for the handle. It must be the same as t.dtype() once the handle is
// ready. // ready.
const DataType dtype; const DataType dtype;
// TODO(b/136608821): Move away from nullptr
bool OnHostCPU() const { bool OnHostCPU() const {
return ( return (
device_.index() == 0 && device_.index() == 0 &&
@ -227,14 +204,14 @@ class TensorHandle : public core::RefCounted {
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_)))); (ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
} }
bool IsRemote() const { return is_remote_; } bool IsRemote() const;
void EnableImplicitMirroring() { implicit_mirroring_ = true; } void EnableImplicitMirroring() { implicit_mirroring_ = true; }
bool ImplicitMirroring() const { return implicit_mirroring_; } bool ImplicitMirroring() const { return implicit_mirroring_; }
string DebugString() const; string DebugString() const;
void SetResourceHandleDtypeAndShape( void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes); std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle, // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types and shapes of the underlying resource. // return data types and shapes of the underlying resource.
@ -248,19 +225,6 @@ class TensorHandle : public core::RefCounted {
// with a ready version of the tensor handle data. // with a ready version of the tensor handle data.
bool IsReady() const; bool IsReady() const;
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that computation is
// done and the handle is "ready".
Status WaitReady(const char* caller) const;
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
// This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
// provided with the appropriate TFE_Context.
//
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
VariantDevice const device_; VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to // Device in which the op producing this tensor was executed. Equals to
@ -275,47 +239,33 @@ class TensorHandle : public core::RefCounted {
mutable mutex mu_; mutable mutex mu_;
// Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is // Map of local mirrors. This can include both ready and non-ready mirrors.
// nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
// waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the
// LocalTensorHandleData should be used.
std::map<const tensorflow::Device*,
std::pair<std::unique_ptr<EmptyLocalTensorHandleData>,
std::unique_ptr<LocalTensorHandleData>>>
local_mirrors_ GUARDED_BY(mu_); local_mirrors_ GUARDED_BY(mu_);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
// variable is ready, since we could get the shape locally without remote copy // variable is ready, since we could get the shape locally without remote copy
// then. // then.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>> std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
resource_shape_mirrors_ GUARDED_BY(mu_); GUARDED_BY(mu_);
// TODO(gjn): Unshaped remote mirrors are not expected to be long-lived.
// Consider replacing the unshaped_remote_mirrors_ map with something more
// efficient.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
unshaped_remote_mirrors_ GUARDED_BY(mu_);
// TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
// a fixed size map. // a fixed size map.
std::map<string, std::unique_ptr<RemoteTensorHandleData>> remote_mirrors_ std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
GUARDED_BY(mu_); GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle.
int64 remote_op_id_;
int32 remote_output_num_;
#endif #endif
// `ctx` is only guaranteed to be set if the handle is not "ready". This is // `ctx` is only guaranteed to be set if the handle is not "ready". This is
// typically true when the handle was produced during async execution. // typically true when the handle was produced during async execution.
// `ctx` object is not owned and should outlive this handle. // `ctx` object is not owned and should outlive this handle.
//
// TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
// a TensorHandle does not outlive the EagerContext from which it came?
EagerContext* const ctx_; EagerContext* const ctx_;
// Does not need synchronization because it can be accessed only after // Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, is_poisoned_ is immutable. // WaitReady() has returned. At that point, is_poisoned_ is immutable.
Status is_poisoned_; Status is_poisoned_;
const bool is_remote_;
const bool is_async_;
bool implicit_mirroring_; bool implicit_mirroring_;
bool is_ready_ GUARDED_BY(mu_);
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types and shapes for // refers to a remote resource handle, we store data types and shapes for
@ -323,8 +273,12 @@ class TensorHandle : public core::RefCounted {
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_; std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// Does not need synchronization because it can be accessed only after // Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable. // WaitReady() has returned. At that point, data_ is immutable.
std::unique_ptr<TensorHandleData> tensor_handle_data_; #if !defined(IS_MOBILE_PLATFORM)
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
#else
absl::variant<LocalTensorHandleData> data_;
#endif
PartialTensorShape inference_shape_; PartialTensorShape inference_shape_;
}; };

View File

@ -23,12 +23,16 @@ namespace tensorflow {
class Status; class Status;
Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
TF_RETURN_IF_ERROR(WaitReady("Tensor"));
*t = &tensor_; *t = &tensor_;
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
TF_RETURN_IF_ERROR(WaitReady("TensorValue"));
tensorflow::Tensor& tensor = tensor_; tensorflow::Tensor& tensor = tensor_;
*t = tensorflow::TensorValue(&tensor); *t = tensorflow::TensorValue(&tensor);
@ -36,103 +40,96 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
} }
Status LocalTensorHandleData::Shape(TensorShape* shape) const { Status LocalTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
*shape = tensor_.shape(); *shape = tensor_.shape();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::NumDims(int* num_dims) const { Status LocalTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
*num_dims = tensor_.dims(); *num_dims = tensor_.dims();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const { Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
*dim = tensor_.dim_size(dim_index); *dim = tensor_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::NumElements(int64* num_elements) const { Status LocalTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
*num_elements = tensor_.NumElements(); *num_elements = tensor_.NumElements();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::Unprotect() { Status LocalTensorHandleData::Unprotect() {
if (!IsReady()) {
return errors::Internal("Cannot unprotect a non-ready tensor");
}
forwarding_protection_tensor_ = tensorflow::Tensor(); forwarding_protection_tensor_ = tensorflow::Tensor();
return Status::OK(); return Status::OK();
} }
Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) {
return errors::Unavailable( DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles.";
"Unable to get a tensor for an empty handle. "
"Please wait until it is ready"); tensor_ = std::move(t);
// Create copy of original tensor to avoid forwarding
forwarding_protection_tensor_ = tensor_;
auto& state = absl::get<BlockingControl>(ctrl_);
state.SetReady();
return Status::OK();
} }
Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { string LocalTensorHandleData::DebugString() const {
return errors::Unavailable( if (IsReady()) {
"Unable to get a tensor for an empty handle. " return tensor_.DeviceSafeDebugString();
"Please wait until it is ready"); } else {
return "LocalTensorHandleData";
}
} }
Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const { void LocalTensorHandleData::BlockingControl::SetReady() {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect an empty handle.");
}
bool EmptyLocalTensorHandleData::IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void EmptyLocalTensorHandleData::SetReady() {
mutex_lock l(mu_); mutex_lock l(mu_);
is_ready_ = true; is_ready_ = true;
} }
Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const { Status LocalTensorHandleData::BlockingControl::WaitReady(
if (!IsReady()) { const char* caller) const {
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"), tf_shared_lock l(mu_);
profiler::TraceMeLevel::kInfo); if (!is_ready_) {
tf_shared_lock l(mu_); profiler::TraceMe activity(
[caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_)); mu_.Await(Condition(&is_ready_));
} }
return is_poisoned_; return is_poisoned_;
} }
void EmptyLocalTensorHandleData::Poison(Status status) { void LocalTensorHandleData::BlockingControl::Poison(Status status) {
is_poisoned_ = status;
mutex_lock l(mu_); mutex_lock l(mu_);
if (is_ready_) {
LOG(ERROR) << "Poison can only be called on non-ready handle: " << this;
return;
}
is_poisoned_ = status;
is_ready_ = true; is_ready_ = true;
} }
string EmptyLocalTensorHandleData::DebugString() const {
return "EmptyLocalTensorHandleData";
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,52 +15,50 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
class TensorHandleData {
public:
virtual ~TensorHandleData() {}
// Different tensor handles support a set of these calls. In some cases these
// are resolved with a Tensor or TensorShape. Typically if the handle is not
// ready, none of these are supported operations.
virtual Status Tensor(const tensorflow::Tensor** t) const = 0;
virtual Status TensorValue(tensorflow::TensorValue* t) = 0;
virtual Status Shape(TensorShape* shape) const = 0;
virtual Status NumDims(int* num_dims) const = 0;
virtual Status Dim(int dim_index, int64* dim) const = 0;
virtual Status NumElements(int64* num_elements) const = 0;
// Allow the backing Tensor to be available for buffer reuse during op
// execution.
virtual Status Unprotect() = 0;
virtual string DebugString() const = 0;
};
// Local Tensor Handle: Handle to a Tensor present on the local host. // Local Tensor Handle: Handle to a Tensor present on the local host.
class LocalTensorHandleData : public TensorHandleData { class LocalTensorHandleData {
public: public:
explicit LocalTensorHandleData(const tensorflow::Tensor& t) LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {}
: tensor_(t), forwarding_protection_tensor_(t) {} explicit LocalTensorHandleData(tensorflow::Tensor&& t)
~LocalTensorHandleData() override {} : tensor_(std::move(t)),
forwarding_protection_tensor_(tensor_),
ctrl_(absl::in_place_type<NonBlockingControl>) {}
// A local tensor handle should be able to satisfy all of these requests. // A local tensor handle should be able to satisfy all of these requests.
Status Tensor(const tensorflow::Tensor** t) const override; Status Tensor(const tensorflow::Tensor** t) const;
Status TensorValue(tensorflow::TensorValue* t) override; Status TensorValue(tensorflow::TensorValue* t);
Status Shape(TensorShape* shape) const override; Status Shape(TensorShape* shape) const;
Status NumDims(int* num_dims) const override; Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const override; Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const override; Status NumElements(int64* num_elements) const;
Status Unprotect() override; Status Unprotect();
string DebugString() const override { bool IsReady() const {
return tensor_.DeviceSafeDebugString(); return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_);
} }
Status WaitReady(const char* caller) const {
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
ctrl_);
}
void Poison(Status status) {
return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_);
}
Status IsPoisoned() const {
return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_);
}
Status SetTensor(tensorflow::Tensor&& t);
string DebugString() const;
private: private:
tensorflow::Tensor tensor_; tensorflow::Tensor tensor_;
// TensorHandle has its own reference counting which is distinct from the // TensorHandle has its own reference counting which is distinct from the
@ -70,37 +68,41 @@ class LocalTensorHandleData : public TensorHandleData {
// forwarding_protection_tensor_ Tensor. When Unprotect() is called, we // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
// release this Tensor to allow forwarding. // release this Tensor to allow forwarding.
tensorflow::Tensor forwarding_protection_tensor_; tensorflow::Tensor forwarding_protection_tensor_;
};
// Empty Local Tensor Handle: Once the execution is complete this is replaced by // We distinguish between ready and empty tensors with the ctrl_ variant.
// a local tensor handle. // which contains 2 implementations of the waiting logic. The
class EmptyLocalTensorHandleData : public TensorHandleData { // NonBlockingControl is a simple no-op class whereas the BlockingControl
public: // actually uses a mutex. By using a variant we avoid the overhead of
EmptyLocalTensorHandleData() {} // constructing and destructing the mutex for ready local tensors.
~EmptyLocalTensorHandleData() override {} class NonBlockingControl {
public:
bool IsReady() const { return true; }
Status WaitReady(const char* caller) const { return Status::OK(); }
void Poison(Status status) {}
Status IsPoisoned() const { return Status::OK(); }
};
// Empty tensor handles are not ready and hence cannot satisfy any of these class BlockingControl {
// requests. public:
Status Tensor(const tensorflow::Tensor** t) const override; bool IsReady() const {
Status TensorValue(tensorflow::TensorValue* t) override; tf_shared_lock l(mu_);
Status Shape(TensorShape* shape) const override; return is_ready_;
Status NumDims(int* num_dims) const override; }
Status Dim(int dim_index, int64* dim) const override; void SetReady();
Status NumElements(int64* num_elements) const override; Status WaitReady(const char* caller) const;
Status Unprotect() override; void Poison(Status status);
Status IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
bool IsReady() const; private:
void SetReady(); mutable mutex mu_;
Status WaitReady(const char* caller) const; bool is_ready_ GUARDED_BY(mu_);
void Poison(Status status); Status is_poisoned_ GUARDED_BY(mu_);
Status IsPoisoned() const { return is_poisoned_; } };
string DebugString() const override; absl::variant<NonBlockingControl, BlockingControl> ctrl_;
private:
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,12 +39,13 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr); &device_mgr, false, nullptr, nullptr, nullptr);
TensorHandle* sync_th; TensorHandle* sync_th;
EXPECT_TRUE( EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
TensorHandle::CreateLocalHandle(t, ctx->HostCPU(), ctx, &sync_th).ok()); ctx, &sync_th)
.ok());
TensorHandle* async_th; TensorHandle* async_th;
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr, EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
nullptr, DataType::DT_UINT16, DataType::DT_UINT16, ctx,
ctx, &async_th) &async_th)
.ok()); .ok());
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok()); EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());

View File

@ -190,8 +190,10 @@ cc_library(
deps = [ deps = [
":destroy_tensor_handle_node", ":destroy_tensor_handle_node",
":eager_client", ":eager_client",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:tensor_handle_data", "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/profiler/lib:traceme",
], ],
) )

View File

@ -530,7 +530,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
} }
TensorHandle* tensor_handle = nullptr; TensorHandle* tensor_handle = nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
TensorHandle* copied_handle = nullptr; TensorHandle* copied_handle = nullptr;
Device* device; Device* device;
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(

View File

@ -101,12 +101,10 @@ Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
core::RefCountPtr<KernelAndDevice> kernel; core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector(1); EagerKernelArgs args(1);
TF_RETURN_IF_ERROR(src_->TensorValue( Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
&input_vector[0], TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
ctx_->CanonicalDevice(absl::get<Device*>(op->Device()))));
EagerKernelArgs args(std::move(input_vector));
return kernel->Run(args, /*outputs=*/nullptr, return kernel->Run(args, /*outputs=*/nullptr,
/*cancellation_manager=*/nullptr, /*cancellation_manager=*/nullptr,
/*remote_func_params=*/absl::nullopt); /*remote_func_params=*/absl::nullopt);

View File

@ -162,16 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
in.op_device().empty() ? in.device() : in.op_device(); in.op_device().empty() ? in.device() : in.op_device();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
parent_->FindDeviceFromName(device_name.c_str(), &device)); parent_->FindDeviceFromName(device_name.c_str(), &device));
string remote_task; TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) { in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
return errors::InvalidArgument(
"Unable to find remote task corresponding to device ", device_name);
}
auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
in.op_id(), in.output_num(), remote_task, parent_);
remote_handle_data->ReleaseRemoteTensorHandle();
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(remote_handle_data), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes) &dtypes_and_shapes)

View File

@ -71,14 +71,12 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
Tensor t(DT_FLOAT, TensorShape({0})); Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle; TensorHandle* handle;
TF_ASSERT_OK( TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle)); local_device_, ctx_, &handle));
const uint64 op_id = 2; const uint64 op_id = 2;
const int output_num = 3; const int output_num = 3;
auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>( TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
op_id, output_num, t.shape(), /*remote_task=*/"", ctx_); output_num, "", ctx_));
TF_ASSERT_OK(
handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_));
RemoteTensorHandle remote_handle; RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name())); handle, &remote_handle, remote_device_, remote_device_->name()));
@ -90,14 +88,13 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) { TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
RemoteMgr remote_mgr(false, ctx_); RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
const uint64 op_id = 3; const uint64 op_id = 3;
const int output_num = 1; const int output_num = 1;
TensorHandle* handle; TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateRemoteHandle( TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num, t.shape(), /*remote_task=*/"", DT_FLOAT, op_id, output_num,
remote_device_, /*resource_device=*/nullptr, ctx_, &handle)); /*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
RemoteTensorHandle remote_handle; RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name())); handle, &remote_handle, remote_device_, remote_device_->name()));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow { namespace tensorflow {
@ -84,66 +85,104 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
} // namespace } // namespace
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const TensorShape& shape, uint64 context_view_id)
const string& remote_task, : is_ready_(false),
EagerContext* ctx) op_id_(op_id),
: op_id_(op_id),
output_num_(output_num), output_num_(output_num),
shape_(shape), context_view_id_(context_view_id),
remote_task_(remote_task), ctx_(nullptr) {
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0) DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id << "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num; << ", Output num: " << output_num;
ctx_.Ref(); }
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const string& remote_task,
EagerContext* ctx)
: is_ready_(false),
op_id_(op_id),
output_num_(output_num),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_->Ref();
} }
RemoteTensorHandleData::~RemoteTensorHandleData() { RemoteTensorHandleData::~RemoteTensorHandleData() {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_, if (ctx_) {
output_num_, /*ready=*/true); DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
ctx_.Unref(); output_num_, /*ready=*/true);
} ctx_->Unref();
}
Status RemoteTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status RemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
} }
Status RemoteTensorHandleData::Shape(TensorShape* shape) const { Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
tf_shared_lock l(mu_);
*shape = shape_; *shape = shape_;
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::NumDims(int* num_dims) const { Status RemoteTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
tf_shared_lock l(mu_);
*num_dims = shape_.dims(); *num_dims = shape_.dims();
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const { Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
tf_shared_lock l(mu_);
*dim = shape_.dim_size(dim_index); *dim = shape_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::NumElements(int64* num_elements) const { Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
tf_shared_lock l(mu_);
*num_elements = shape_.num_elements(); *num_elements = shape_.num_elements();
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::Unprotect() { bool RemoteTensorHandleData::IsReady() const {
return errors::Unavailable("Unable to unprotect a remote handle."); tf_shared_lock l(mu_);
return is_ready_;
}
void RemoteTensorHandleData::Poison(Status status) {
mutex_lock l(mu_);
is_poisoned_ = status;
is_ready_ = true;
}
Status RemoteTensorHandleData::IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
mutex_lock l(mu_);
if (is_ready_) {
return errors::Internal("SetShape is only called on non-ready handles.");
}
shape_ = shape;
is_poisoned_ = Status::OK();
is_ready_ = true;
return Status::OK();
} }
string RemoteTensorHandleData::DebugString() const { string RemoteTensorHandleData::DebugString() const {
@ -151,73 +190,20 @@ string RemoteTensorHandleData::DebugString() const {
" output_num: ", output_num_); " output_num: ", output_num_);
} }
UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData( Status RemoteTensorHandleData::WaitReady(const char* caller) const {
int64 op_id, int32 output_num, const string& remote_task, EagerContext* ctx) if (ctx_ == nullptr) {
: op_id_(op_id), return errors::Internal("Cannot wait on lazy remote handle");
output_num_(output_num),
delete_remote_tensor_(true),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_.Ref();
}
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
if (delete_remote_tensor_) {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/false);
} }
ctx_.Unref();
}
Status UnshapedRemoteTensorHandleData::Tensor( tf_shared_lock l(mu_);
const tensorflow::Tensor** t) const { if (!is_ready_) {
return errors::Unavailable( profiler::TraceMe activity(
"Unable to get a tensor for a remote handle. Please copy the tensor " [caller] { return absl::StrCat(caller, " WaitReady"); },
"handle to a local device using TFE_TensorHandleCopyToDevice"); profiler::TraceMeLevel::kInfo);
} DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_));
Status UnshapedRemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) { }
return errors::Unavailable( return is_poisoned_;
"Unable to get a tensor for a remote handle. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status UnshapedRemoteTensorHandleData::Shape(TensorShape* shape) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect a remote handle.");
}
string UnshapedRemoteTensorHandleData::DebugString() const {
return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_,
" output_num: ", output_num_);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,97 +15,56 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only // Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only
// the shape is known. // the shape is known.
class RemoteTensorHandleData : public TensorHandleData { class RemoteTensorHandleData {
public: public:
RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape, // Constructor for lazy remote handles
const string& remote_task, EagerContext* ctx); RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id);
~RemoteTensorHandleData() override; // Constructor for unshaped remote handles
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
EagerContext* ctx);
~RemoteTensorHandleData();
// A remote tensor handle does not have a Tensor object, hence it can only // A remote tensor handle does not have a Tensor object, hence it can only
// support the shape requests. // support the shape requests.
Status Tensor(const tensorflow::Tensor** t) const override; Status Shape(TensorShape* shape) const;
Status TensorValue(tensorflow::TensorValue* t) override; Status NumDims(int* num_dims) const;
Status Shape(TensorShape* shape) const override; Status Dim(int dim_index, int64* dim) const;
Status NumDims(int* num_dims) const override; Status NumElements(int64* num_elements) const;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
EagerContext& ctx() const { return ctx_; }
string DebugString() const override; bool IsReady() const;
Status SetShape(const TensorShape& shape);
void Poison(Status status);
Status IsPoisoned() const;
string DebugString() const;
int64 op_id() const { return op_id_; } int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; } int32 output_num() const { return output_num_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; } uint64 context_view_id() const { return context_view_id_; }
private: private:
Status WaitReady(const char* caller) const;
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_ GUARDED_BY(mu_);
TensorShape shape_ GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle. // IDs required when this class is representing a remote tensor handle.
const int64 op_id_; const int64 op_id_;
const int32 output_num_; const int32 output_num_;
const TensorShape shape_;
string remote_task_; string remote_task_;
uint64 context_id_; uint64 context_id_;
uint64 context_view_id_; uint64 context_view_id_;
EagerContext& ctx_; EagerContext* ctx_;
};
// Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the
// shape has been computed this is replaced with a remote tensor handle.
class UnshapedRemoteTensorHandleData : public TensorHandleData {
public:
UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num,
const string& remote_task, EagerContext* ctx);
~UnshapedRemoteTensorHandleData() override;
// Unshaped remote tensor handles are not ready and hence cannot satisfy any
// of these requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
void Poison(Status status) { is_poisoned_ = status; }
Status IsPoisoned() const { return is_poisoned_; }
string DebugString() const override;
int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; }
string remote_task() const { return remote_task_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; }
EagerContext& ctx() const { return ctx_; }
// When constructed, UnshapedRemoteTensorHandleData owns the remote
// TensorHandle and should delete it by issuing an RPC. Once the remote
// shape has been learned, the ownership is transferred to
// RemoteTensorHandleData. This method must be called to let `this` know
// that it no longer owns the remote handle.
// TODO(iga): Add a factory method here that will create a new
// RemoteTensorHandleData from this and transfer ownership in the process.
void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
private:
Status is_poisoned_;
// IDs required when this class is representing a remote tensor handle.
const int64 op_id_;
const int32 output_num_;
bool delete_remote_tensor_;
string remote_task_;
uint64 context_id_;
uint64 context_view_id_;
EagerContext& ctx_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -91,11 +91,11 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
Device* device = IsCPUDevice(call->device) ? nullptr : call->device; Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
for (int64 i = 0; i < n; ++i) { for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr; PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
if (call->eager) { if (call->eager) {
TensorHandle* handle; TensorHandle* handle;
Tensor t = call->ins[i];
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle)); std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle));
arg = EagerTensorFromHandle(new TFE_TensorHandle{ arg = EagerTensorFromHandle(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)}); std::make_unique<tensorflow::TensorHandleInterface>(handle)});
if (arg == nullptr) { if (arg == nullptr) {
@ -103,7 +103,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
return errors::Internal("Unable to procure EagerTensor from Tensor."); return errors::Internal("Unable to procure EagerTensor from Tensor.");
} }
} else { } else {
Status s = TensorToNdarray(t, &arg); Status s = TensorToNdarray(call->ins[i], &arg);
if (!s.ok()) { if (!s.ok()) {
Py_DECREF(lst); Py_DECREF(lst);
return s; return s;

View File

@ -277,7 +277,8 @@ struct Converter {
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
TFE_TensorHandle** h, const char** error) { TFE_TensorHandle** h, const char** error) {
/* TODO(josh11b): Allocator & attributes? */ // TODO(josh11b): Allocator & attributes
// TODO(gjn): Use optimized scalar constructors when possible.
Tensor result(ConverterTraits<T>::kTypeEnum, Tensor result(ConverterTraits<T>::kTypeEnum,
TensorShape(state->inferred_shape)); TensorShape(state->inferred_shape));
if (state->inferred_shape.empty()) { /* Scalar case */ if (state->inferred_shape.empty()) { /* Scalar case */
@ -294,7 +295,7 @@ struct Converter {
} }
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle( auto status = tensorflow::TensorHandle::CreateLocalHandle(
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle); ctx->context, &handle);
if (!status.ok()) { if (!status.ok()) {
return status; return status;
@ -610,8 +611,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) { if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle( cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context, std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
&handle); ctx->context, &handle);
} }
if (!cppstatus.ok()) { if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
@ -805,10 +806,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
case DT_INVALID: // Only occurs for empty tensors. case DT_INVALID: // Only occurs for empty tensors.
{ {
tensorflow::TensorHandle* h = nullptr; tensorflow::TensorHandle* h = nullptr;
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape)); TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle( status = tensorflow::TensorHandle::CreateLocalHandle(
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &h); ctx->context, &h);
if (!status.ok()) { if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); PyErr_SetString(PyExc_ValueError, status.error_message().c_str());