Consolidate tensor handle data types

The separation between ready and non-ready TensorHandleData classes
caused a lot of mirroring logic to get quite messy, especially when
considering the waiting logic, which was duplicated in the main
TensorHandle class and the various TensorHandleData classes. In
addition, having these classes expose the same API was a bit cumbersome
since most APIs were not supported by the various types.

We instead keep simply two different types, a local and remote one.
Further, we move all the waiting logic out of the TensorHandle and have
it in the TensorHandleData. Since we no longer need to swap out the
tensor_handle_data_ pointer when moving to a ready state, we can replace
it with a variant and save ourselves a heap allocation.

The LocalTensorHandleData is optimized such that if a tensor is provided
it does not require mutex synchronization, for any of the member
operations.

The RemoteTensorHandleData no longer needs to support the delicate dance
of calling ReleaseRemoteTensorHandle(). However, for lazy remotes we set
the EagerContext to nullptr to indicate no deletion upon class
destruction is needed.

Along the way we also do the following performance optimizations:
* Change tensor handle construct to use move semantics. This avoids
  unnecessary Ref counts on the TensorBuffer.
* In sync, do not allocate empty TensorHandle and later call SetTensor.
  Instead, we allocate the return TensorHandles once the kernel has
  executed. This avoid the overhead synchronization when there is no
  need for it.
* Switch mirror maps to store tensor data direct vs using a unique_ptr.
  Also switch to unordered_map.

Additional clean-ups:
* Make TensorHandle APIs consistently take Device pointer as first
  argument.
* Remove CreateLocalTensorHandle function which could be confused with
  the one used for CustomDevice.
* Remove many unused includes.

PiperOrigin-RevId: 298423283
Change-Id: I838736396e9ef81b2de665d6d9a3ad2062070b0c
This commit is contained in:
Gaurav Jain 2020-03-02 12:58:06 -08:00 committed by TensorFlower Gardener
parent 21d0fa9af9
commit f97b7ba2bf
20 changed files with 518 additions and 779 deletions

View File

@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorHandle* ret_handle; tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) { if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle); std::move(t), device, device, context, &ret_handle);
} else { } else {
status->status = tensorflow::TensorHandle::CreateLocalHandle( status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle); std::move(t), custom_device, context, &ret_handle);
} }
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;

View File

@ -140,6 +140,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite", "//tensorflow/core:android_tensorflow_lib_lite",
], ],
"//conditions:default": [ "//conditions:default": [
"@com_google_absl//absl/types:variant",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme",

View File

@ -564,8 +564,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
} }
} }
} }
const DataTypeVector& output_dtypes = kernel->output_dtypes(); int num_outputs = kernel->num_outputs();
const size_t num_outputs = static_cast<int>(output_dtypes.size());
if (num_outputs > *num_retvals) { if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs, return errors::InvalidArgument("Expecting ", num_outputs,
" outputs, but *num_retvals is ", " outputs, but *num_retvals is ",
@ -579,21 +578,19 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
graph_collector = ctx.GetGraphCollector(); graph_collector = ctx.GetGraphCollector();
} }
const bool async = executor.Async(); Status s;
if (executor.Async()) {
const DataTypeVector& output_dtypes = kernel->output_dtypes();
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
async,
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)), /* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(), /* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), /* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i])); output_dtypes[i], &ctx, &retvals[i]));
} }
Status s;
if (async) {
auto node = absl::make_unique<AsyncExecuteNode>( auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, output_dtypes, op->GetCancellationManager(), graph_collector, op->GetCancellationManager(),
absl::Span<TensorHandle*>(retvals, num_outputs)); absl::Span<TensorHandle*>(retvals, num_outputs));
// For async mode, execution order will make sure that all // For async mode, execution order will make sure that all
// input handles are ready before executing them. // input handles are ready before executing them.
@ -601,18 +598,23 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// performance. // performance.
s = executor.AddOrExecute(std::move(node)); s = executor.AddOrExecute(std::move(node));
} else { } else {
for (int i = 0; i < num_outputs; ++i) {
retvals[i] = nullptr;
}
ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel, ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel,
graph_collector, output_dtypes, graph_collector, op->GetCancellationManager(),
op->GetCancellationManager(), {retvals, num_outputs}); {retvals, static_cast<size_t>(num_outputs)});
s = executor.SyncExecute(&node); s = executor.SyncExecute(&node);
} }
// Since the operation failed, we need to Unref any outputs that were // Since the operation failed, we need to Unref any outputs if they were
// allocated. // allocated.
if (!s.ok()) { if (!s.ok()) {
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
if (retvals[i] != nullptr) {
retvals[i]->Unref(); retvals[i]->Unref();
} }
} }
}
return s; return s;
} }
@ -733,12 +735,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
input, input_handle, input_device, *input_device_name, input, input_handle, input_device, *input_device_name,
serialize_resource_dtype_and_shape)); serialize_resource_dtype_and_shape));
if (!input_handle->resource_dtypes_and_shapes().empty()) { if (!input_handle->resource_dtypes_and_shapes().empty()) {
auto tensor_handle_data = TF_RETURN_IF_ERROR(
absl::make_unique<UnshapedRemoteTensorHandleData>( input->AddResourceShapeMirror(op_device, input_handle->op_id(),
input_handle->op_id(), input_handle->output_num(), remote_task, input_handle->output_num(), &ctx));
&ctx);
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
std::move(tensor_handle_data), op_device));
} }
} }
} }
@ -1032,13 +1031,24 @@ Status EagerKernelExecute(
} }
} }
DCHECK_EQ(retvals.size(), outputs.size()); DCHECK_EQ(retvals.size(), outputs.size());
for (int i = 0; i < retvals.size(); ++i) { for (int i = 0; i < retvals.size(); ++i) {
if (retvals[i] == nullptr) {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(outputs[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
&retvals[i]));
} else {
DCHECK_EQ(kernel->device(), retvals[i]->op_device()); DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)), DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device())); absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(retvals[i]->SetTensor( TF_RETURN_IF_ERROR(
std::move(outputs[i]), ctx->CanonicalDevice(kernel->OutputDevice(i)))); retvals[i]->SetTensor(std::move(outputs[i]),
ctx->CanonicalDevice(kernel->OutputDevice(i))));
}
} }
return Status::OK(); return Status::OK();
} }
@ -1069,7 +1079,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, d, dstd, h->resource_device(), h->dtype, ctx, result)); d, dstd, h->resource_device(), h->dtype, ctx, result));
} }
Status s; Status s;
@ -1138,7 +1148,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, /* d= */ d, /* op_device= */ device, /* d= */ d, /* op_device= */ device,
/*resource_device=*/nullptr, h->dtype, ctx, result)); /*resource_device=*/nullptr, h->dtype, ctx, result));
} }
} else { } else {
@ -1156,17 +1166,14 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
device->name()); device->name());
} }
recv_op_id = ctx->RemoteMgr()->NextOpId(); recv_op_id = ctx->RemoteMgr()->NextOpId();
auto tensor_handle_data =
absl::make_unique<UnshapedRemoteTensorHandleData>(recv_op_id, 0,
remote_task, ctx);
if (mirror) { if (mirror) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device)); remote_task, ctx));
h->Ref(); h->Ref();
*result = h; *result = h;
} else { } else {
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(tensor_handle_data), h->dtype, device, ctx, result)); recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
} }
} }

View File

@ -32,7 +32,7 @@ Status ExecuteNodeArgs::Init(
for (int i = 0; i < n_inputs; ++i) { for (int i = 0; i < n_inputs; ++i) {
TensorHandle* in = op_inputs_flat[i]; TensorHandle* in = op_inputs_flat[i];
Device* d = kernel->InputDevice(i); Device* d = kernel->InputDevice(i);
Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d)); Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
if (!s.ok()) { if (!s.ok()) {
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
uint64 context_view_id = ctx->GetContextViewId(); uint64 context_view_id = ctx->GetContextViewId();

View File

@ -77,7 +77,7 @@ class ExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params, const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
const core::RefCountPtr<KernelAndDevice>& kernel, const core::RefCountPtr<KernelAndDevice>& kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes, GraphCollector* graph_collector,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals) absl::Span<TensorHandle*> retvals)
: EagerNode(), : EagerNode(),
@ -130,7 +130,7 @@ class AsyncExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs, EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params, const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
core::RefCountPtr<KernelAndDevice> kernel, core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes, GraphCollector* graph_collector,
CancellationManager* cancellation_manager, CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals) absl::Span<TensorHandle*> retvals)
: EagerNode(), : EagerNode(),

View File

@ -67,6 +67,7 @@ class EagerKernelArgs : public FunctionArgsInterface {
~EagerKernelArgs() override{}; ~EagerKernelArgs() override{};
bool HasRemoteInputs() const override { return false; }; bool HasRemoteInputs() const override { return false; };
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
Status GetLocalArg(const int index, Tensor* val) const override; Status GetLocalArg(const int index, Tensor* val) const override;

View File

@ -20,39 +20,31 @@ limitations under the License.
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/errors.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
@ -64,7 +56,7 @@ const int32 kInvalidOutputNum = -1;
} // namespace } // namespace
void TensorHandle::SetResourceHandleDtypeAndShape( void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) { std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
} }
@ -86,250 +78,191 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
profiler::TraceMe activity( profiler::TraceMe activity(
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady", "TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
*result = handle_dtypes_and_shapes_; *result = handle_dtypes_and_shapes_;
return Status::OK(); return Status::OK();
} }
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h) { TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr // TODO(b/136608821): Move away from nullptr
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr), tensorflow::Tensor tensor = t;
return CreateLocalHandle(std::move(tensor),
/*d=*/nullptr,
/*op_device=*/nullptr, /*op_device=*/nullptr,
/*ctx=*/nullptr, h); /*ctx=*/nullptr, h);
} }
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d, Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
EagerContext* ctx, TensorHandle** h) {
return CreateLocalHandle(t, d, d, ctx, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
Device* op_device, EagerContext* ctx, Device* op_device, EagerContext* ctx,
TensorHandle** h) { TensorHandle** h) {
if (t.dtype() != DT_RESOURCE) { return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
t.dtype(), d, op_device, ctx);
} else {
const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0);
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
resource_handle, d, op_device, ctx);
} }
return Status::OK(); Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
} Device* op_device,
Device* resource_device,
Status TensorHandle::CreateLocalHandle(const class Tensor& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h) { EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), t.dtype(), if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
d, ctx); *h = new TensorHandle(std::move(t), d, op_device, ctx);
} else {
*h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
}
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
DataType dtype, Device* d, Device* op_device, EagerContext* ctx, TensorHandle** h) {
EagerContext* ctx) *h = new TensorHandle(std::move(t), d, ctx);
: dtype(dtype),
return Status::OK();
}
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx)
: dtype(t.dtype()),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(nullptr), resource_device_(resource_device),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
const ResourceHandle& resource_handle, Device* d, EagerContext* ctx)
Device* op_device, EagerContext* ctx)
: dtype(DT_RESOURCE), : dtype(DT_RESOURCE),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d), device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(GetResourceDevice(resource_handle, ctx)), resource_device_(
#if !defined(IS_MOBILE_PLATFORM) GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), handle_dtypes_and_shapes_(
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()), t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
tensor_handle_data_(std::move(t)) { data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
DataType dtype, CustomDevice* d, EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(t.dtype()),
device_(d), device_(d),
op_device_(nullptr), op_device_(nullptr),
resource_device_(nullptr), resource_device_(nullptr),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
tensor_handle_data_(std::move(t)) {
// TODO(allenl): Figure out a better op_device story for custom devices, // TODO(allenl): Figure out a better op_device story for custom devices,
// since always setting it to CPU=nullptr doesn't make much sense. // since always setting it to CPU=nullptr doesn't make much sense.
DVLOG(3) << "Creating Local TensorHandle: " << this DVLOG(3) << "Creating Local TensorHandle: " << this
<< " custom device: " << VariantDeviceDebugString(device_); << " custom device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
} }
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d, Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* op_device,
Device* resource_device, Device* resource_device,
DataType dtype, EagerContext* ctx, DataType dtype, EagerContext* ctx,
TensorHandle** h) { TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async, *h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
d, op_device, resource_device, dtype, ctx);
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, TensorHandle::TensorHandle(Device* d, Device* op_device,
bool async, Device* d, Device* op_device,
Device* resource_device, DataType dtype, Device* resource_device, DataType dtype,
EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_((d == ctx->HostCPU()) ? nullptr : d), device_((d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device), op_device_(op_device),
resource_device_(resource_device), resource_device_(resource_device),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx), ctx_(ctx),
is_remote_(false),
is_async_(async),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(!async), data_(absl::in_place_type<LocalTensorHandleData>) {
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating empty Local TensorHandle: " << this DVLOG(3) << "Creating empty Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::CreateRemoteHandle( Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
return Status::OK();
}
Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task, const string& remote_task,
DataType dtype, Device* d, DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx,
EagerContext* ctx, TensorHandle** h) { TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<RemoteTensorHandleData>( *h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
op_id, output_num, shape, remote_task, ctx),
dtype, d, resource_device, ctx);
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, TensorHandle::TensorHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d, Device* resource_device, const string& remote_task, DataType dtype, Device* d,
EagerContext* ctx) EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_(d), device_(d),
op_device_(d), op_device_(d),
resource_device_(resource_device), resource_device_(dtype == DT_RESOURCE ? d : nullptr),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
ctx_(ctx), ctx_(ctx),
is_remote_(true),
is_async_(false),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(true), data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
tensor_handle_data_(std::move(t)) { remote_task, ctx) {
DVLOG(3) << "Creating Remote TensorHandle: " << this DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
Status TensorHandle::CreateUnshapedRemoteHandle( Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype, DataType dtype, Device* d,
Device* d, EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(std::move(t), dtype, d, ctx);
return Status::OK();
}
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* device,
EagerContext* ctx, EagerContext* ctx,
TensorHandle** h) { TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>( *h = new TensorHandle(op_id, output_num, dtype, d, ctx);
op_id, output_num, remote_task, ctx),
dtype, device, ctx);
return Status::OK(); return Status::OK();
} }
TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t, TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,
DataType dtype, Device* device, EagerContext* ctx) Device* d, EagerContext* ctx)
: dtype(dtype), : dtype(dtype),
device_(device), device_(d),
op_device_(device), op_device_(d),
resource_device_(dtype == DT_RESOURCE ? device : nullptr), resource_device_(dtype == DT_RESOURCE ? d : nullptr),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
ctx_(ctx), ctx_(ctx),
is_remote_(true),
is_async_(true),
implicit_mirroring_(true), implicit_mirroring_(true),
is_ready_(false), data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
tensor_handle_data_(std::move(t)) { ctx->GetContextViewId()) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_); << " device: " << VariantDeviceDebugString(device_);
} }
#endif #endif
bool TensorHandle::IsReady() const { bool TensorHandle::IsReady() const {
// Avoid mutex acquisition for local sync handles return absl::visit([](auto& data) { return data.IsReady(); }, data_);
if (!is_async_ && !is_remote_) {
return true;
} }
tf_shared_lock l(mu_); bool TensorHandle::IsRemote() const {
return is_ready_; #if !defined(IS_MOBILE_PLATFORM)
} return data_.index() == 1;
#else
Status TensorHandle::WaitReady(const char* caller) const { return false;
if (!IsReady()) { #endif
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
profiler::TraceMeLevel::kInfo);
tf_shared_lock l(mu_);
mu_.Await(Condition(&is_ready_));
}
return is_poisoned_;
} }
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
DVLOG(3) << "Tensor on TensorHandle: " << this; DVLOG(3) << "Tensor on TensorHandle: " << this;
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor")); if (IsRemote()) {
return tensor_handle_data_->Tensor(t); return errors::Internal("Invalid Tensor call on remote handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Tensor(t);
} }
Status TensorHandle::TensorFromDevice(const Device* d, Status TensorHandle::TensorFromDevice(const Device* d,
@ -337,12 +270,12 @@ Status TensorHandle::TensorFromDevice(const Device* d,
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
if (is_remote_) { if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this); return errors::Internal("Invalid Tensor call on remote handle: ", this);
} }
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice")); auto& data = absl::get<LocalTensorHandleData>(data_);
return tensor_handle_data_->Tensor(t); return data.Tensor(t);
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -352,25 +285,21 @@ Status TensorHandle::TensorFromDevice(const Device* d,
" in Tensor call to handle: ", this); " in Tensor call to handle: ", this);
} }
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.Tensor(t);
TF_RETURN_IF_ERROR(
mirror.first->WaitReady("TensorHandle::TensorFromDevice"));
} }
return mirror.second->Tensor(t); Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
} DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
if (is_remote_) { if (IsRemote()) {
return errors::Internal("Invalid TensorValue call on remote handle: ", return errors::Internal("Invalid TensorValue call on remote handle: ",
this); this);
} }
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue")); auto& data = absl::get<LocalTensorHandleData>(data_);
return tensor_handle_data_->TensorValue(t); return data.TensorValue(t);
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -380,13 +309,8 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
" in TensorValue call to handle: ", this); " in TensorValue call to handle: ", this);
} }
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.TensorValue(t);
TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue"));
}
return mirror.second->TensorValue(t);
} }
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU( TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
@ -405,8 +329,8 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
DCHECK(fill); DCHECK(fill);
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape")); return absl::visit([shape](auto& data) { return data.Shape(shape); },
return tensor_handle_data_->Shape(shape); data_);
} }
} }
@ -480,8 +404,8 @@ Status TensorHandle::NumDims(int* num_dims) const {
*num_dims = inference_shape_.dims(); *num_dims = inference_shape_.dims();
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims")); return absl::visit(
return tensor_handle_data_->NumDims(num_dims); [num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
} }
} }
@ -492,8 +416,9 @@ Status TensorHandle::Dim(int dim_index, int64* dim) const {
*dim = inference_shape_.dim_size(dim_index); *dim = inference_shape_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim")); return absl::visit(
return tensor_handle_data_->Dim(dim_index, dim); [dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
data_);
} }
} }
@ -503,8 +428,9 @@ Status TensorHandle::NumElements(int64* num_elements) const {
*num_elements = inference_shape_.num_elements(); *num_elements = inference_shape_.num_elements();
return Status::OK(); return Status::OK();
} else { } else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements")); return absl::visit(
return tensor_handle_data_->NumElements(num_elements); [num_elements](auto& data) { return data.NumElements(num_elements); },
data_);
} }
} }
@ -512,7 +438,8 @@ Status TensorHandle::Unprotect(const Device* d) {
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
return tensor_handle_data_->Unprotect(); auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Unprotect();
} }
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -524,11 +451,7 @@ Status TensorHandle::Unprotect(const Device* d) {
// Check if the handle is non-empty // Check if the handle is non-empty
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second == nullptr) { return mirror.Unprotect();
return errors::Internal("Attempted to unprotect an empty mirror");
}
return mirror.second->Unprotect();
} }
bool TensorHandle::HasLocalMirror(const Device* d) const { bool TensorHandle::HasLocalMirror(const Device* d) const {
@ -551,8 +474,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
return errors::Internal("Attempted to duplicate a local mirror."); return errors::Internal("Attempted to duplicate a local mirror.");
} }
local_mirrors_[d] = local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr); std::forward_as_tuple());
return Status::OK(); return Status::OK();
} }
@ -567,15 +490,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
*op_id = mirror->second->op_id(); *op_id = mirror->second.op_id();
*output_num = mirror->second->output_num(); *output_num = mirror->second.output_num();
return Status::OK();
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
*op_id = unshaped_mirror->second->op_id();
*output_num = unshaped_mirror->second->output_num();
return Status::OK(); return Status::OK();
} }
@ -583,14 +499,14 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
"Could not find remote mirror for specified device"); "Could not find remote mirror for specified device");
} }
if (remote_op_id_ == kInvalidOpId || if (!IsRemote()) {
remote_output_num_ == kInvalidOutputNum) { return errors::InvalidArgument("Primary device is not remote");
return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
", output_num:", remote_output_num_,
") is not set.");
} }
*op_id = remote_op_id_;
*output_num = remote_output_num_; auto& data = absl::get<RemoteTensorHandleData>(data_);
*op_id = data.op_id();
*output_num = data.output_num();
return Status::OK(); return Status::OK();
} }
@ -603,16 +519,7 @@ bool TensorHandle::HasRemoteMirror(const Device* d,
auto mirror = remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
// Check if mirror is stale // Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) { if (mirror->second.context_view_id() != context_view_id) {
return false;
}
return true;
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
// Check if mirror is stale
if (unshaped_mirror->second->context_view_id() != context_view_id) {
return false; return false;
} }
return true; return true;
@ -630,7 +537,7 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
auto mirror = resource_shape_mirrors_.find(d->name()); auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) { if (mirror != resource_shape_mirrors_.end()) {
// Check if mirror is stale // Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) { if (mirror->second.context_view_id() != context_view_id) {
return false; return false;
} }
return true; return true;
@ -638,45 +545,39 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
return false; return false;
} }
Status TensorHandle::AddUnshapedRemoteMirror( Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { int output_num,
const string& remote_task,
EagerContext* ctx) {
DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
<< " device: " << d << " " << d->name(); << " device: " << d << " " << d->name() << " op_id: " << op_id
<< " output_num: " << output_num;
mutex_lock l(mu_); mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name()); auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) { if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == t->context_view_id()) { if (remote_mirror->second.context_view_id() == ctx->GetContextId()) {
return errors::Internal("Attempted to duplicate a remote mirror."); return errors::Internal("Attempted to duplicate a remote mirror.");
} }
// Remove stale mirror // Remove stale mirror
remote_mirrors_.erase(remote_mirror); remote_mirrors_.erase(remote_mirror);
} }
auto unshaped_remote_mirror = unshaped_remote_mirrors_.find(d->name()); remote_mirrors_.emplace(
if (unshaped_remote_mirror != unshaped_remote_mirrors_.end()) { std::piecewise_construct, std::forward_as_tuple(d->name()),
if (unshaped_remote_mirror->second->context_view_id() == std::forward_as_tuple(op_id, output_num, remote_task, ctx));
t->context_view_id()) {
return errors::Internal(
"Attempted to duplicate an unshaped remote mirror.");
}
// Remove stale mirror
unshaped_remote_mirrors_.erase(unshaped_remote_mirror);
}
unshaped_remote_mirrors_[d->name()] = std::move(t);
return Status::OK(); return Status::OK();
} }
Status TensorHandle::AddResourceShapeMirror( Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) { int output_num, EagerContext* ctx) {
DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this; DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
mutex_lock l(mu_); mutex_lock l(mu_);
auto mirror = resource_shape_mirrors_.find(d->name()); auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) { if (mirror != resource_shape_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) { if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
return errors::Internal( return errors::Internal(
"Attempted to duplicate a resource shape mirror."); "Attempted to duplicate a resource shape mirror.");
} }
@ -684,26 +585,9 @@ Status TensorHandle::AddResourceShapeMirror(
resource_shape_mirrors_.erase(mirror); resource_shape_mirrors_.erase(mirror);
} }
resource_shape_mirrors_[d->name()] = std::move(t); resource_shape_mirrors_.emplace(
std::piecewise_construct, std::forward_as_tuple(d->name()),
return Status::OK(); std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId()));
}
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
const Device* d) {
DVLOG(3) << "AddRemoteMirror on TensorHandle: " << this << " device: " << d;
mutex_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) {
return errors::Internal("Attempted to duplicate a remote mirror.");
}
// Remove stale mirror
remote_mirrors_.erase(mirror);
}
remote_mirrors_[d->name()] = std::move(t);
return Status::OK(); return Status::OK();
} }
@ -717,53 +601,24 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
mutex_lock l(mu_); mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name()); auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) { if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == context_view_id) { auto& mirror = remote_mirror->second;
return errors::Internal( if (mirror.context_view_id() == context_view_id) {
"Attempted to set remote shape for existing mirror."); return mirror.SetShape(shape);
} }
remote_mirrors_.erase(remote_mirror); remote_mirrors_.erase(remote_mirror);
} }
auto elem = unshaped_remote_mirrors_.find(d->name());
if (elem == unshaped_remote_mirrors_.end()) {
return errors::Internal(
"Attempted to set remote shape for non-waiting mirror.");
}
if (elem->second->context_view_id() != context_view_id) {
unshaped_remote_mirrors_.erase(elem);
return errors::Internal(
"Attempted to set remote shape for a stale mirror.");
}
auto& data = elem->second;
data->ReleaseRemoteTensorHandle();
remote_mirrors_[d->name()] = absl::make_unique<RemoteTensorHandleData>(
data->op_id(), data->output_num(), shape, data->remote_task(),
&data->ctx());
unshaped_remote_mirrors_.erase(elem);
return Status::OK(); return Status::OK();
} }
DCHECK(is_remote_) << "SetRemoteShape is only called on remote handles."; DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
UnshapedRemoteTensorHandleData* p = auto& data = absl::get<RemoteTensorHandleData>(data_);
reinterpret_cast<UnshapedRemoteTensorHandleData*>( if (data.context_view_id() != context_view_id) {
tensor_handle_data_.get());
if (p->context_view_id() != context_view_id) {
return errors::Internal("Attempted to set remote shape for an old handle."); return errors::Internal("Attempted to set remote shape for an old handle.");
} }
p->ReleaseRemoteTensorHandle(); return data.SetShape(shape);
tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
remote_op_id_, remote_output_num_, shape, p->remote_task(), ctx_);
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
return Status::OK();
} }
void TensorHandle::PoisonRemote(Status status, const Device* d, void TensorHandle::PoisonRemote(Status status, const Device* d,
@ -772,18 +627,16 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
<< " " << d->name(); << " " << d->name();
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady()) DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
<< "PoisonRemote can only be called on non-ready handle: " << this;
is_poisoned_ = status; auto& data = absl::get<RemoteTensorHandleData>(data_);
mutex_lock l(mu_); data.Poison(status);
is_ready_ = true;
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto mirror = unshaped_remote_mirrors_.find(d->name()); auto mirror = remote_mirrors_.find(d->name());
if (mirror != unshaped_remote_mirrors_.end()) { if (mirror != remote_mirrors_.end()) {
if (mirror->second->context_view_id() == context_view_id) { if (mirror->second.context_view_id() == context_view_id) {
mirror->second->Poison(status); mirror->second.Poison(status);
} }
} }
} }
@ -798,9 +651,9 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
} }
mutex_lock l(mu_); mutex_lock l(mu_);
auto elem = local_mirrors_.insert(std::make_pair( auto elem =
d, std::make_pair(nullptr, local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::make_unique<LocalTensorHandleData>(tensor)))); std::forward_as_tuple(std::move(tensor)));
if (!elem.second) { if (!elem.second) {
return errors::Internal("Attempted to set tensor for existing mirror."); return errors::Internal("Attempted to set tensor for existing mirror.");
} }
@ -808,24 +661,18 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
return Status::OK(); return Status::OK();
} }
Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) { if (d == absl::get<Device*>(device_)) {
DCHECK(!is_remote_) << "SetTensor is not called on remote handles."; DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
DCHECK(!is_async_ || !IsReady())
<< "SetTensor is only called on non-ready handles.";
if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) { if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = tensor.flat<class ResourceHandle>()(0); auto& resource_handle = t.flat<class ResourceHandle>()(0);
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
} }
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor); auto& data = absl::get<LocalTensorHandleData>(data_);
if (is_async_) { return data.SetTensor(std::move(t));
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
}
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d); auto elem = local_mirrors_.find(d);
@ -835,12 +682,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
} }
auto& mirror = elem->second; auto& mirror = elem->second;
if (mirror.second != nullptr) { return mirror.SetTensor(std::move(t));
return errors::Internal("Attempted to set tensor for existing mirror.");
}
mirror.second = absl::make_unique<LocalTensorHandleData>(tensor);
mirror.first->SetReady();
} }
return Status::OK(); return Status::OK();
@ -850,12 +692,10 @@ void TensorHandle::Poison(Status status, const Device* d) {
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) { if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady()) DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
<< "Poison can only be called on non-ready handle: " << this;
is_poisoned_ = status; auto& data = absl::get<LocalTensorHandleData>(data_);
mutex_lock l(mu_); data.Poison(status);
is_ready_ = true;
} else { } else {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d); auto elem = local_mirrors_.find(d);
@ -864,9 +704,7 @@ void TensorHandle::Poison(Status status, const Device* d) {
<< " device: " << d; << " device: " << d;
auto& mirror = elem->second; auto& mirror = elem->second;
DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror."; mirror.Poison(status);
mirror.first->Poison(status);
} }
} }
@ -977,8 +815,11 @@ string TensorHandle::DebugString() const {
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull; !VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to // Consider supporting non-CPU tensors and CPU tensors with a device_ set to
// non-NULL if needed. // non-NULL if needed.
strings::StrAppend(&out, ", Tensor: ", strings::StrAppend(
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n"); &out, ", Tensor: ",
is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
: "?",
"\n");
return out; return out;
} }

View File

@ -17,10 +17,10 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <map>
#include <memory> #include <memory>
#include <queue> #include <queue>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
// clang-format off // clang-format off
@ -32,28 +32,20 @@ limitations under the License.
#include "absl/types/variant.h" #include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h" #include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
@ -67,56 +59,45 @@ class TensorHandle : public core::RefCounted {
using VariantDevice = absl::variant<Device*, CustomDevice*>; using VariantDevice = absl::variant<Device*, CustomDevice*>;
// TensorHandle for dtype != DT_RESOURCE // TensorHandle for dtype != DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* d, Device* op_device, EagerContext* ctx); Device* resource_device, EagerContext* ctx);
// TensorHandle for dtype == DT_RESOURCE // TensorHandle for dtype == DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
const ResourceHandle& resource_handle, Device* d, EagerContext* ctx);
Device* op_device, EagerContext* ctx); TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype, TensorHandle(Device* d, Device* op_device, Device* resource_device,
CustomDevice* d, EagerContext* ctx);
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
Device* d, Device* op_device, Device* resource_device,
DataType dtype, EagerContext* ctx); DataType dtype, EagerContext* ctx);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
Device* d, Device* resource_device, EagerContext* ctx);
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
DataType dtype, Device* device, EagerContext* ctx); DataType dtype, Device* device, EagerContext* ctx);
TensorHandle(int64 op_id, int32 output_num, DataType dtype, Device* device,
EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
public: public:
// TensorHandle with no assigned device // TensorHandle with no assigned device
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h); static Status CreateLocalHandle(const tensorflow::Tensor& t,
// TensorHandle with device == op_device TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d, static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d,
Device* op_device, EagerContext* ctx, Device* op_device, EagerContext* ctx,
TensorHandle** h); TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, CustomDevice* d, static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, Device* resource_device,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device, static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device, DataType dtype, Device* resource_device, DataType dtype,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
static Status CreateRemoteHandle(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task, DataType dtype,
Device* d, Device* resource_device,
EagerContext* ctx, TensorHandle** h);
static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num, static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task, const string& remote_task,
DataType dtype, Device* device, DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h); EagerContext* ctx, TensorHandle** h);
static Status CreateUnshapedRemoteHandle(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
Device* device, EagerContext* ctx, TensorHandle** h);
#endif // IS_MOBILE_PLATFORM #endif // IS_MOBILE_PLATFORM
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; } ~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
@ -131,7 +112,7 @@ class TensorHandle : public core::RefCounted {
// Return the TensorValue from the specified device which could be either the // Return the TensorValue from the specified device which could be either the
// default device or a local mirror. The device pointer should be nullptr if // default device or a local mirror. The device pointer should be nullptr if
// requesting the HostCPU. // requesting the HostCPU.
Status TensorValue(tensorflow::TensorValue* t, const Device* d); Status TensorValue(const Device* d, tensorflow::TensorValue* t);
VariantDevice device() const { return device_; } VariantDevice device() const { return device_; }
Device* op_device() const { return op_device_; } Device* op_device() const { return op_device_; }
@ -161,12 +142,10 @@ class TensorHandle : public core::RefCounted {
bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
Status AddUnshapedRemoteMirror( Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d); const string& remote_task, EagerContext* ctx);
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t, Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
const Device* d); EagerContext* ctx);
Status AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
// Return the op_id and output num if the handle refers to a remote tensor. // Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const; Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const;
@ -212,14 +191,12 @@ class TensorHandle : public core::RefCounted {
Status CopyInferenceShape(TensorHandle* other); Status CopyInferenceShape(TensorHandle* other);
// Warning: can return nullptr for CPU tensors. // Warning: can return nullptr for CPU tensors.
// TODO(b/136608821): Move away from nullptr
EagerContext* Context() { return ctx_; } EagerContext* Context() { return ctx_; }
// dtype for the handle. It must be the same as t.dtype() once the handle is // dtype for the handle. It must be the same as t.dtype() once the handle is
// ready. // ready.
const DataType dtype; const DataType dtype;
// TODO(b/136608821): Move away from nullptr
bool OnHostCPU() const { bool OnHostCPU() const {
return ( return (
device_.index() == 0 && device_.index() == 0 &&
@ -227,14 +204,14 @@ class TensorHandle : public core::RefCounted {
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_)))); (ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
} }
bool IsRemote() const { return is_remote_; } bool IsRemote() const;
void EnableImplicitMirroring() { implicit_mirroring_ = true; } void EnableImplicitMirroring() { implicit_mirroring_ = true; }
bool ImplicitMirroring() const { return implicit_mirroring_; } bool ImplicitMirroring() const { return implicit_mirroring_; }
string DebugString() const; string DebugString() const;
void SetResourceHandleDtypeAndShape( void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes); std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle, // If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types and shapes of the underlying resource. // return data types and shapes of the underlying resource.
@ -248,19 +225,6 @@ class TensorHandle : public core::RefCounted {
// with a ready version of the tensor handle data. // with a ready version of the tensor handle data.
bool IsReady() const; bool IsReady() const;
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that computation is
// done and the handle is "ready".
Status WaitReady(const char* caller) const;
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
// This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
// provided with the appropriate TFE_Context.
//
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
VariantDevice const device_; VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to // Device in which the op producing this tensor was executed. Equals to
@ -275,47 +239,33 @@ class TensorHandle : public core::RefCounted {
mutable mutex mu_; mutable mutex mu_;
// Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is // Map of local mirrors. This can include both ready and non-ready mirrors.
// nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
// waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the
// LocalTensorHandleData should be used.
std::map<const tensorflow::Device*,
std::pair<std::unique_ptr<EmptyLocalTensorHandleData>,
std::unique_ptr<LocalTensorHandleData>>>
local_mirrors_ GUARDED_BY(mu_); local_mirrors_ GUARDED_BY(mu_);
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica // TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
// variable is ready, since we could get the shape locally without remote copy // variable is ready, since we could get the shape locally without remote copy
// then. // then.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>> std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
resource_shape_mirrors_ GUARDED_BY(mu_); GUARDED_BY(mu_);
// TODO(gjn): Unshaped remote mirrors are not expected to be long-lived.
// Consider replacing the unshaped_remote_mirrors_ map with something more
// efficient.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
unshaped_remote_mirrors_ GUARDED_BY(mu_);
// TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be // TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
// a fixed size map. // a fixed size map.
std::map<string, std::unique_ptr<RemoteTensorHandleData>> remote_mirrors_ std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
GUARDED_BY(mu_); GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle.
int64 remote_op_id_;
int32 remote_output_num_;
#endif #endif
// `ctx` is only guaranteed to be set if the handle is not "ready". This is // `ctx` is only guaranteed to be set if the handle is not "ready". This is
// typically true when the handle was produced during async execution. // typically true when the handle was produced during async execution.
// `ctx` object is not owned and should outlive this handle. // `ctx` object is not owned and should outlive this handle.
//
// TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
// a TensorHandle does not outlive the EagerContext from which it came?
EagerContext* const ctx_; EagerContext* const ctx_;
// Does not need synchronization because it can be accessed only after // Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, is_poisoned_ is immutable. // WaitReady() has returned. At that point, is_poisoned_ is immutable.
Status is_poisoned_; Status is_poisoned_;
const bool is_remote_;
const bool is_async_;
bool implicit_mirroring_; bool implicit_mirroring_;
bool is_ready_ GUARDED_BY(mu_);
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types and shapes for // refers to a remote resource handle, we store data types and shapes for
@ -323,8 +273,12 @@ class TensorHandle : public core::RefCounted {
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_; std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// Does not need synchronization because it can be accessed only after // Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable. // WaitReady() has returned. At that point, data_ is immutable.
std::unique_ptr<TensorHandleData> tensor_handle_data_; #if !defined(IS_MOBILE_PLATFORM)
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
#else
absl::variant<LocalTensorHandleData> data_;
#endif
PartialTensorShape inference_shape_; PartialTensorShape inference_shape_;
}; };

View File

@ -23,12 +23,16 @@ namespace tensorflow {
class Status; class Status;
Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
TF_RETURN_IF_ERROR(WaitReady("Tensor"));
*t = &tensor_; *t = &tensor_;
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
TF_RETURN_IF_ERROR(WaitReady("TensorValue"));
tensorflow::Tensor& tensor = tensor_; tensorflow::Tensor& tensor = tensor_;
*t = tensorflow::TensorValue(&tensor); *t = tensorflow::TensorValue(&tensor);
@ -36,103 +40,96 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
} }
Status LocalTensorHandleData::Shape(TensorShape* shape) const { Status LocalTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
*shape = tensor_.shape(); *shape = tensor_.shape();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::NumDims(int* num_dims) const { Status LocalTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
*num_dims = tensor_.dims(); *num_dims = tensor_.dims();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const { Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
*dim = tensor_.dim_size(dim_index); *dim = tensor_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::NumElements(int64* num_elements) const { Status LocalTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
*num_elements = tensor_.NumElements(); *num_elements = tensor_.NumElements();
return Status::OK(); return Status::OK();
} }
Status LocalTensorHandleData::Unprotect() { Status LocalTensorHandleData::Unprotect() {
if (!IsReady()) {
return errors::Internal("Cannot unprotect a non-ready tensor");
}
forwarding_protection_tensor_ = tensorflow::Tensor(); forwarding_protection_tensor_ = tensorflow::Tensor();
return Status::OK(); return Status::OK();
} }
Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) {
return errors::Unavailable( DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles.";
"Unable to get a tensor for an empty handle. "
"Please wait until it is ready"); tensor_ = std::move(t);
// Create copy of original tensor to avoid forwarding
forwarding_protection_tensor_ = tensor_;
auto& state = absl::get<BlockingControl>(ctrl_);
state.SetReady();
return Status::OK();
} }
Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { string LocalTensorHandleData::DebugString() const {
return errors::Unavailable( if (IsReady()) {
"Unable to get a tensor for an empty handle. " return tensor_.DeviceSafeDebugString();
"Please wait until it is ready"); } else {
return "LocalTensorHandleData";
}
} }
Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const { void LocalTensorHandleData::BlockingControl::SetReady() {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect an empty handle.");
}
bool EmptyLocalTensorHandleData::IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void EmptyLocalTensorHandleData::SetReady() {
mutex_lock l(mu_); mutex_lock l(mu_);
is_ready_ = true; is_ready_ = true;
} }
Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const { Status LocalTensorHandleData::BlockingControl::WaitReady(
if (!IsReady()) { const char* caller) const {
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
profiler::TraceMeLevel::kInfo);
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
if (!is_ready_) {
profiler::TraceMe activity(
[caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_)); mu_.Await(Condition(&is_ready_));
} }
return is_poisoned_; return is_poisoned_;
} }
void EmptyLocalTensorHandleData::Poison(Status status) { void LocalTensorHandleData::BlockingControl::Poison(Status status) {
is_poisoned_ = status;
mutex_lock l(mu_); mutex_lock l(mu_);
if (is_ready_) {
LOG(ERROR) << "Poison can only be called on non-ready handle: " << this;
return;
}
is_poisoned_ = status;
is_ready_ = true; is_ready_ = true;
} }
string EmptyLocalTensorHandleData::DebugString() const {
return "EmptyLocalTensorHandleData";
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,52 +15,50 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
class TensorHandleData {
public:
virtual ~TensorHandleData() {}
// Different tensor handles support a set of these calls. In some cases these
// are resolved with a Tensor or TensorShape. Typically if the handle is not
// ready, none of these are supported operations.
virtual Status Tensor(const tensorflow::Tensor** t) const = 0;
virtual Status TensorValue(tensorflow::TensorValue* t) = 0;
virtual Status Shape(TensorShape* shape) const = 0;
virtual Status NumDims(int* num_dims) const = 0;
virtual Status Dim(int dim_index, int64* dim) const = 0;
virtual Status NumElements(int64* num_elements) const = 0;
// Allow the backing Tensor to be available for buffer reuse during op
// execution.
virtual Status Unprotect() = 0;
virtual string DebugString() const = 0;
};
// Local Tensor Handle: Handle to a Tensor present on the local host. // Local Tensor Handle: Handle to a Tensor present on the local host.
class LocalTensorHandleData : public TensorHandleData { class LocalTensorHandleData {
public: public:
explicit LocalTensorHandleData(const tensorflow::Tensor& t) LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {}
: tensor_(t), forwarding_protection_tensor_(t) {} explicit LocalTensorHandleData(tensorflow::Tensor&& t)
~LocalTensorHandleData() override {} : tensor_(std::move(t)),
forwarding_protection_tensor_(tensor_),
ctrl_(absl::in_place_type<NonBlockingControl>) {}
// A local tensor handle should be able to satisfy all of these requests. // A local tensor handle should be able to satisfy all of these requests.
Status Tensor(const tensorflow::Tensor** t) const override; Status Tensor(const tensorflow::Tensor** t) const;
Status TensorValue(tensorflow::TensorValue* t) override; Status TensorValue(tensorflow::TensorValue* t);
Status Shape(TensorShape* shape) const override; Status Shape(TensorShape* shape) const;
Status NumDims(int* num_dims) const override; Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const override; Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const override; Status NumElements(int64* num_elements) const;
Status Unprotect() override; Status Unprotect();
string DebugString() const override { bool IsReady() const {
return tensor_.DeviceSafeDebugString(); return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_);
} }
Status WaitReady(const char* caller) const {
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
ctrl_);
}
void Poison(Status status) {
return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_);
}
Status IsPoisoned() const {
return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_);
}
Status SetTensor(tensorflow::Tensor&& t);
string DebugString() const;
private: private:
tensorflow::Tensor tensor_; tensorflow::Tensor tensor_;
// TensorHandle has its own reference counting which is distinct from the // TensorHandle has its own reference counting which is distinct from the
@ -70,37 +68,41 @@ class LocalTensorHandleData : public TensorHandleData {
// forwarding_protection_tensor_ Tensor. When Unprotect() is called, we // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
// release this Tensor to allow forwarding. // release this Tensor to allow forwarding.
tensorflow::Tensor forwarding_protection_tensor_; tensorflow::Tensor forwarding_protection_tensor_;
// We distinguish between ready and empty tensors with the ctrl_ variant.
// which contains 2 implementations of the waiting logic. The
// NonBlockingControl is a simple no-op class whereas the BlockingControl
// actually uses a mutex. By using a variant we avoid the overhead of
// constructing and destructing the mutex for ready local tensors.
class NonBlockingControl {
public:
bool IsReady() const { return true; }
Status WaitReady(const char* caller) const { return Status::OK(); }
void Poison(Status status) {}
Status IsPoisoned() const { return Status::OK(); }
}; };
// Empty Local Tensor Handle: Once the execution is complete this is replaced by class BlockingControl {
// a local tensor handle.
class EmptyLocalTensorHandleData : public TensorHandleData {
public: public:
EmptyLocalTensorHandleData() {} bool IsReady() const {
~EmptyLocalTensorHandleData() override {} tf_shared_lock l(mu_);
return is_ready_;
// Empty tensor handles are not ready and hence cannot satisfy any of these }
// requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
bool IsReady() const;
void SetReady(); void SetReady();
Status WaitReady(const char* caller) const; Status WaitReady(const char* caller) const;
void Poison(Status status); void Poison(Status status);
Status IsPoisoned() const { return is_poisoned_; } Status IsPoisoned() const {
tf_shared_lock l(mu_);
string DebugString() const override; return is_poisoned_;
}
private: private:
mutable mutex mu_; mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_); bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_; Status is_poisoned_ GUARDED_BY(mu_);
};
absl::variant<NonBlockingControl, BlockingControl> ctrl_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,12 +39,13 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr); &device_mgr, false, nullptr, nullptr, nullptr);
TensorHandle* sync_th; TensorHandle* sync_th;
EXPECT_TRUE( EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
TensorHandle::CreateLocalHandle(t, ctx->HostCPU(), ctx, &sync_th).ok()); ctx, &sync_th)
.ok());
TensorHandle* async_th; TensorHandle* async_th;
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr, EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
nullptr, DataType::DT_UINT16, DataType::DT_UINT16, ctx,
ctx, &async_th) &async_th)
.ok()); .ok());
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok()); EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());

View File

@ -190,8 +190,10 @@ cc_library(
deps = [ deps = [
":destroy_tensor_handle_node", ":destroy_tensor_handle_node",
":eager_client", ":eager_client",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:tensor_handle_data", "//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/profiler/lib:traceme",
], ],
) )

View File

@ -530,7 +530,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
} }
TensorHandle* tensor_handle = nullptr; TensorHandle* tensor_handle = nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle)); TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
TensorHandle* copied_handle = nullptr; TensorHandle* copied_handle = nullptr;
Device* device; Device* device;
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName( TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(

View File

@ -101,12 +101,10 @@ Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
core::RefCountPtr<KernelAndDevice> kernel; core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel)); TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector(1); EagerKernelArgs args(1);
TF_RETURN_IF_ERROR(src_->TensorValue( Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
&input_vector[0], TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
ctx_->CanonicalDevice(absl::get<Device*>(op->Device()))));
EagerKernelArgs args(std::move(input_vector));
return kernel->Run(args, /*outputs=*/nullptr, return kernel->Run(args, /*outputs=*/nullptr,
/*cancellation_manager=*/nullptr, /*cancellation_manager=*/nullptr,
/*remote_func_params=*/absl::nullopt); /*remote_func_params=*/absl::nullopt);

View File

@ -162,16 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
in.op_device().empty() ? in.device() : in.op_device(); in.op_device().empty() ? in.device() : in.op_device();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
parent_->FindDeviceFromName(device_name.c_str(), &device)); parent_->FindDeviceFromName(device_name.c_str(), &device));
string remote_task; TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) { in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
return errors::InvalidArgument(
"Unable to find remote task corresponding to device ", device_name);
}
auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
in.op_id(), in.output_num(), remote_task, parent_);
remote_handle_data->ReleaseRemoteTensorHandle();
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(remote_handle_data), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes; std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes) &dtypes_and_shapes)

View File

@ -71,14 +71,12 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
Tensor t(DT_FLOAT, TensorShape({0})); Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle; TensorHandle* handle;
TF_ASSERT_OK( TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle)); local_device_, ctx_, &handle));
const uint64 op_id = 2; const uint64 op_id = 2;
const int output_num = 3; const int output_num = 3;
auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>( TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
op_id, output_num, t.shape(), /*remote_task=*/"", ctx_); output_num, "", ctx_));
TF_ASSERT_OK(
handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_));
RemoteTensorHandle remote_handle; RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name())); handle, &remote_handle, remote_device_, remote_device_->name()));
@ -90,14 +88,13 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) { TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
RemoteMgr remote_mgr(false, ctx_); RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
const uint64 op_id = 3; const uint64 op_id = 3;
const int output_num = 1; const int output_num = 1;
TensorHandle* handle; TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateRemoteHandle( TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num, t.shape(), /*remote_task=*/"", DT_FLOAT, op_id, output_num,
remote_device_, /*resource_device=*/nullptr, ctx_, &handle)); /*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
RemoteTensorHandle remote_handle; RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle( TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name())); handle, &remote_handle, remote_device_, remote_device_->name()));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow { namespace tensorflow {
@ -84,66 +85,103 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
} // namespace } // namespace
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num, RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const TensorShape& shape, uint64 context_view_id)
const string& remote_task, : is_ready_(false),
EagerContext* ctx) op_id_(op_id),
: op_id_(op_id),
output_num_(output_num), output_num_(output_num),
shape_(shape), context_view_id_(context_view_id),
remote_task_(remote_task), ctx_(nullptr) {
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0) DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id << "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num; << ", Output num: " << output_num;
ctx_.Ref(); }
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const string& remote_task,
EagerContext* ctx)
: is_ready_(false),
op_id_(op_id),
output_num_(output_num),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_->Ref();
} }
RemoteTensorHandleData::~RemoteTensorHandleData() { RemoteTensorHandleData::~RemoteTensorHandleData() {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_, if (ctx_) {
DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/true); output_num_, /*ready=*/true);
ctx_.Unref(); ctx_->Unref();
} }
Status RemoteTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status RemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
} }
Status RemoteTensorHandleData::Shape(TensorShape* shape) const { Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
tf_shared_lock l(mu_);
*shape = shape_; *shape = shape_;
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::NumDims(int* num_dims) const { Status RemoteTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
tf_shared_lock l(mu_);
*num_dims = shape_.dims(); *num_dims = shape_.dims();
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const { Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
tf_shared_lock l(mu_);
*dim = shape_.dim_size(dim_index); *dim = shape_.dim_size(dim_index);
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::NumElements(int64* num_elements) const { Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
tf_shared_lock l(mu_);
*num_elements = shape_.num_elements(); *num_elements = shape_.num_elements();
return Status::OK(); return Status::OK();
} }
Status RemoteTensorHandleData::Unprotect() { bool RemoteTensorHandleData::IsReady() const {
return errors::Unavailable("Unable to unprotect a remote handle."); tf_shared_lock l(mu_);
return is_ready_;
}
void RemoteTensorHandleData::Poison(Status status) {
mutex_lock l(mu_);
is_poisoned_ = status;
}
Status RemoteTensorHandleData::IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
mutex_lock l(mu_);
if (is_ready_) {
return errors::Internal("SetShape is only called on non-ready handles.");
}
shape_ = shape;
is_poisoned_ = Status::OK();
is_ready_ = true;
return Status::OK();
} }
string RemoteTensorHandleData::DebugString() const { string RemoteTensorHandleData::DebugString() const {
@ -151,73 +189,20 @@ string RemoteTensorHandleData::DebugString() const {
" output_num: ", output_num_); " output_num: ", output_num_);
} }
UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData( Status RemoteTensorHandleData::WaitReady(const char* caller) const {
int64 op_id, int32 output_num, const string& remote_task, EagerContext* ctx) if (ctx_ == nullptr) {
: op_id_(op_id), return errors::Internal("Cannot wait on lazy remote handle");
output_num_(output_num),
delete_remote_tensor_(true),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_.Ref();
} }
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() { tf_shared_lock l(mu_);
if (delete_remote_tensor_) { if (!is_ready_) {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_, profiler::TraceMe activity(
output_num_, /*ready=*/false); [caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_));
} }
ctx_.Unref(); return is_poisoned_;
}
Status UnshapedRemoteTensorHandleData::Tensor(
const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for a remote handle. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status UnshapedRemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for a remote handle. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status UnshapedRemoteTensorHandleData::Shape(TensorShape* shape) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect a remote handle.");
}
string UnshapedRemoteTensorHandleData::DebugString() const {
return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_,
" output_num: ", output_num_);
} }
} // namespace tensorflow } // namespace tensorflow

View File

@ -15,97 +15,56 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow { namespace tensorflow {
// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only // Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only
// the shape is known. // the shape is known.
class RemoteTensorHandleData : public TensorHandleData { class RemoteTensorHandleData {
public: public:
RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape, // Constructor for lazy remote handles
const string& remote_task, EagerContext* ctx); RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id);
~RemoteTensorHandleData() override; // Constructor for unshaped remote handles
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
EagerContext* ctx);
~RemoteTensorHandleData();
// A remote tensor handle does not have a Tensor object, hence it can only // A remote tensor handle does not have a Tensor object, hence it can only
// support the shape requests. // support the shape requests.
Status Tensor(const tensorflow::Tensor** t) const override; Status Shape(TensorShape* shape) const;
Status TensorValue(tensorflow::TensorValue* t) override; Status NumDims(int* num_dims) const;
Status Shape(TensorShape* shape) const override; Status Dim(int dim_index, int64* dim) const;
Status NumDims(int* num_dims) const override; Status NumElements(int64* num_elements) const;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
EagerContext& ctx() const { return ctx_; }
string DebugString() const override; bool IsReady() const;
Status SetShape(const TensorShape& shape);
void Poison(Status status);
Status IsPoisoned() const;
string DebugString() const;
int64 op_id() const { return op_id_; } int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; } int32 output_num() const { return output_num_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; } uint64 context_view_id() const { return context_view_id_; }
private: private:
Status WaitReady(const char* caller) const;
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_ GUARDED_BY(mu_);
TensorShape shape_ GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle. // IDs required when this class is representing a remote tensor handle.
const int64 op_id_; const int64 op_id_;
const int32 output_num_; const int32 output_num_;
const TensorShape shape_;
string remote_task_; string remote_task_;
uint64 context_id_; uint64 context_id_;
uint64 context_view_id_; uint64 context_view_id_;
EagerContext& ctx_; EagerContext* ctx_;
};
// Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the
// shape has been computed this is replaced with a remote tensor handle.
class UnshapedRemoteTensorHandleData : public TensorHandleData {
public:
UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num,
const string& remote_task, EagerContext* ctx);
~UnshapedRemoteTensorHandleData() override;
// Unshaped remote tensor handles are not ready and hence cannot satisfy any
// of these requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
void Poison(Status status) { is_poisoned_ = status; }
Status IsPoisoned() const { return is_poisoned_; }
string DebugString() const override;
int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; }
string remote_task() const { return remote_task_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; }
EagerContext& ctx() const { return ctx_; }
// When constructed, UnshapedRemoteTensorHandleData owns the remote
// TensorHandle and should delete it by issuing an RPC. Once the remote
// shape has been learned, the ownership is transferred to
// RemoteTensorHandleData. This method must be called to let `this` know
// that it no longer owns the remote handle.
// TODO(iga): Add a factory method here that will create a new
// RemoteTensorHandleData from this and transfer ownership in the process.
void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
private:
Status is_poisoned_;
// IDs required when this class is representing a remote tensor handle.
const int64 op_id_;
const int32 output_num_;
bool delete_remote_tensor_;
string remote_task_;
uint64 context_id_;
uint64 context_view_id_;
EagerContext& ctx_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -91,11 +91,11 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
Device* device = IsCPUDevice(call->device) ? nullptr : call->device; Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
for (int64 i = 0; i < n; ++i) { for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr; PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
if (call->eager) { if (call->eager) {
TensorHandle* handle; TensorHandle* handle;
Tensor t = call->ins[i];
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle( TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle)); std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle));
arg = EagerTensorFromHandle(new TFE_TensorHandle{ arg = EagerTensorFromHandle(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)}); std::make_unique<tensorflow::TensorHandleInterface>(handle)});
if (arg == nullptr) { if (arg == nullptr) {
@ -103,7 +103,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
return errors::Internal("Unable to procure EagerTensor from Tensor."); return errors::Internal("Unable to procure EagerTensor from Tensor.");
} }
} else { } else {
Status s = TensorToNdarray(t, &arg); Status s = TensorToNdarray(call->ins[i], &arg);
if (!s.ok()) { if (!s.ok()) {
Py_DECREF(lst); Py_DECREF(lst);
return s; return s;

View File

@ -277,7 +277,8 @@ struct Converter {
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state, static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
TFE_TensorHandle** h, const char** error) { TFE_TensorHandle** h, const char** error) {
/* TODO(josh11b): Allocator & attributes? */ // TODO(josh11b): Allocator & attributes
// TODO(gjn): Use optimized scalar constructors when possible.
Tensor result(ConverterTraits<T>::kTypeEnum, Tensor result(ConverterTraits<T>::kTypeEnum,
TensorShape(state->inferred_shape)); TensorShape(state->inferred_shape));
if (state->inferred_shape.empty()) { /* Scalar case */ if (state->inferred_shape.empty()) { /* Scalar case */
@ -294,7 +295,7 @@ struct Converter {
} }
tensorflow::TensorHandle* handle = nullptr; tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle( auto status = tensorflow::TensorHandle::CreateLocalHandle(
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle); ctx->context, &handle);
if (!status.ok()) { if (!status.ok()) {
return status; return status;
@ -610,8 +611,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t); auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) { if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle( cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context, std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
&handle); ctx->context, &handle);
} }
if (!cppstatus.ok()) { if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
@ -805,10 +806,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
case DT_INVALID: // Only occurs for empty tensors. case DT_INVALID: // Only occurs for empty tensors.
{ {
tensorflow::TensorHandle* h = nullptr; tensorflow::TensorHandle* h = nullptr;
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype, Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape)); TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle( status = tensorflow::TensorHandle::CreateLocalHandle(
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &h); ctx->context, &h);
if (!status.ok()) { if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str()); PyErr_SetString(PyExc_ValueError, status.error_message().c_str());