Consolidate tensor handle data types

The separation between ready and non-ready TensorHandleData classes
caused a lot of mirroring logic to get quite messy, especially when
considering the waiting logic, which was duplicated in the main
TensorHandle class and the various TensorHandleData classes. In
addition, having these classes expose the same API was a bit cumbersome
since most APIs were not supported by the various types.

We instead keep simply two different types, a local and remote one.
Further, we move all the waiting logic out of the TensorHandle and have
it in the TensorHandleData. Since we no longer need to swap out the
tensor_handle_data_ pointer when moving to a ready state, we can replace
it with a variant and save ourselves a heap allocation.

The LocalTensorHandleData is optimized such that if a tensor is provided
it does not require mutex synchronization, for any of the member
operations.

The RemoteTensorHandleData no longer needs to support the delicate dance
of calling ReleaseRemoteTensorHandle(). However, for lazy remotes we set
the EagerContext to nullptr to indicate no deletion upon class
destruction is needed.

Along the way we also do the following performance optimizations:
* Change tensor handle construct to use move semantics. This avoids
  unnecessary Ref counts on the TensorBuffer.
* In sync, do not allocate empty TensorHandle and later call SetTensor.
  Instead, we allocate the return TensorHandles once the kernel has
  executed. This avoid the overhead synchronization when there is no
  need for it.
* Switch mirror maps to store tensor data direct vs using a unique_ptr.
  Also switch to unordered_map.

Additional clean-ups:
* Make TensorHandle APIs consistently take Device pointer as first
  argument.
* Remove CreateLocalTensorHandle function which could be confused with
  the one used for CustomDevice.
* Remove many unused includes.

PiperOrigin-RevId: 298423283
Change-Id: I838736396e9ef81b2de665d6d9a3ad2062070b0c
This commit is contained in:
Gaurav Jain 2020-03-02 12:58:06 -08:00 committed by TensorFlower Gardener
parent 21d0fa9af9
commit f97b7ba2bf
20 changed files with 518 additions and 779 deletions

View File

@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
std::move(t), device, device, context, &ret_handle);
} else {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, custom_device, context, &ret_handle);
std::move(t), custom_device, context, &ret_handle);
}
if (!status->status.ok()) {
return nullptr;

View File

@ -140,6 +140,7 @@ tf_cuda_library(
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"@com_google_absl//absl/types:variant",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme",

View File

@ -564,8 +564,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
}
}
}
const DataTypeVector& output_dtypes = kernel->output_dtypes();
const size_t num_outputs = static_cast<int>(output_dtypes.size());
int num_outputs = kernel->num_outputs();
if (num_outputs > *num_retvals) {
return errors::InvalidArgument("Expecting ", num_outputs,
" outputs, but *num_retvals is ",
@ -579,21 +578,19 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
graph_collector = ctx.GetGraphCollector();
}
const bool async = executor.Async();
for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
async,
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i]));
}
Status s;
if (async) {
if (executor.Async()) {
const DataTypeVector& output_dtypes = kernel->output_dtypes();
for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i]));
}
auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
graph_collector, output_dtypes, op->GetCancellationManager(),
graph_collector, op->GetCancellationManager(),
absl::Span<TensorHandle*>(retvals, num_outputs));
// For async mode, execution order will make sure that all
// input handles are ready before executing them.
@ -601,16 +598,21 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// performance.
s = executor.AddOrExecute(std::move(node));
} else {
for (int i = 0; i < num_outputs; ++i) {
retvals[i] = nullptr;
}
ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel,
graph_collector, output_dtypes,
op->GetCancellationManager(), {retvals, num_outputs});
graph_collector, op->GetCancellationManager(),
{retvals, static_cast<size_t>(num_outputs)});
s = executor.SyncExecute(&node);
}
// Since the operation failed, we need to Unref any outputs that were
// Since the operation failed, we need to Unref any outputs if they were
// allocated.
if (!s.ok()) {
for (int i = 0; i < num_outputs; ++i) {
retvals[i]->Unref();
if (retvals[i] != nullptr) {
retvals[i]->Unref();
}
}
}
@ -733,12 +735,9 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
input, input_handle, input_device, *input_device_name,
serialize_resource_dtype_and_shape));
if (!input_handle->resource_dtypes_and_shapes().empty()) {
auto tensor_handle_data =
absl::make_unique<UnshapedRemoteTensorHandleData>(
input_handle->op_id(), input_handle->output_num(), remote_task,
&ctx);
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
std::move(tensor_handle_data), op_device));
TF_RETURN_IF_ERROR(
input->AddResourceShapeMirror(op_device, input_handle->op_id(),
input_handle->output_num(), &ctx));
}
}
}
@ -1032,13 +1031,24 @@ Status EagerKernelExecute(
}
}
DCHECK_EQ(retvals.size(), outputs.size());
for (int i = 0; i < retvals.size(); ++i) {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
std::move(outputs[i]), ctx->CanonicalDevice(kernel->OutputDevice(i))));
for (int i = 0; i < retvals.size(); ++i) {
if (retvals[i] == nullptr) {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(outputs[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
&retvals[i]));
} else {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(
retvals[i]->SetTensor(std::move(outputs[i]),
ctx->CanonicalDevice(kernel->OutputDevice(i))));
}
}
return Status::OK();
}
@ -1069,7 +1079,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, d, dstd, h->resource_device(), h->dtype, ctx, result));
d, dstd, h->resource_device(), h->dtype, ctx, result));
}
Status s;
@ -1138,7 +1148,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
true, /* d= */ d, /* op_device= */ device,
/* d= */ d, /* op_device= */ device,
/*resource_device=*/nullptr, h->dtype, ctx, result));
}
} else {
@ -1156,17 +1166,14 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
device->name());
}
recv_op_id = ctx->RemoteMgr()->NextOpId();
auto tensor_handle_data =
absl::make_unique<UnshapedRemoteTensorHandleData>(recv_op_id, 0,
remote_task, ctx);
if (mirror) {
TF_RETURN_IF_ERROR(
h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device));
TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
remote_task, ctx));
h->Ref();
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(tensor_handle_data), h->dtype, device, ctx, result));
recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
}
}

View File

@ -32,7 +32,7 @@ Status ExecuteNodeArgs::Init(
for (int i = 0; i < n_inputs; ++i) {
TensorHandle* in = op_inputs_flat[i];
Device* d = kernel->InputDevice(i);
Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d));
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
if (!s.ok()) {
#if !defined(IS_MOBILE_PLATFORM)
uint64 context_view_id = ctx->GetContextViewId();

View File

@ -77,7 +77,7 @@ class ExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
const core::RefCountPtr<KernelAndDevice>& kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
GraphCollector* graph_collector,
CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals)
: EagerNode(),
@ -130,7 +130,7 @@ class AsyncExecuteNode : public EagerNode {
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
core::RefCountPtr<KernelAndDevice> kernel,
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
GraphCollector* graph_collector,
CancellationManager* cancellation_manager,
absl::Span<TensorHandle*> retvals)
: EagerNode(),

View File

@ -67,6 +67,7 @@ class EagerKernelArgs : public FunctionArgsInterface {
~EagerKernelArgs() override{};
bool HasRemoteInputs() const override { return false; };
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
Status GetLocalArg(const int index, Tensor* val) const override;

View File

@ -20,39 +20,31 @@ limitations under the License.
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/substitute.h"
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/errors.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@ -64,7 +56,7 @@ const int32 kInvalidOutputNum = -1;
} // namespace
void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
}
@ -86,250 +78,191 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
profiler::TraceMe activity(
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
profiler::TraceMeLevel::kInfo);
auto& data = absl::get<LocalTensorHandleData>(data_);
TF_RETURN_IF_ERROR(
WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
*result = handle_dtypes_and_shapes_;
return Status::OK();
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr),
tensorflow::Tensor tensor = t;
return CreateLocalHandle(std::move(tensor),
/*d=*/nullptr,
/*op_device=*/nullptr,
/*ctx=*/nullptr, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
EagerContext* ctx, TensorHandle** h) {
return CreateLocalHandle(t, d, d, ctx, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, EagerContext* ctx,
TensorHandle** h) {
if (t.dtype() != DT_RESOURCE) {
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
t.dtype(), d, op_device, ctx);
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
}
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
Device* resource_device,
EagerContext* ctx, TensorHandle** h) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
*h = new TensorHandle(std::move(t), d, op_device, ctx);
} else {
const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0);
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
resource_handle, d, op_device, ctx);
*h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
}
return Status::OK();
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, CustomDevice* d,
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), t.dtype(),
d, ctx);
*h = new TensorHandle(std::move(t), d, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
DataType dtype, Device* d, Device* op_device,
EagerContext* ctx)
: dtype(dtype),
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx)
: dtype(t.dtype()),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device),
resource_device_(nullptr),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
resource_device_(resource_device),
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
const ResourceHandle& resource_handle, Device* d,
Device* op_device, EagerContext* ctx)
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
EagerContext* ctx)
: dtype(DT_RESOURCE),
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device),
resource_device_(GetResourceDevice(resource_handle, ctx)),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
resource_device_(
GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true),
is_ready_(true),
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
tensor_handle_data_(std::move(t)) {
handle_dtypes_and_shapes_(
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
DataType dtype, CustomDevice* d, EagerContext* ctx)
: dtype(dtype),
TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx)
: dtype(t.dtype()),
device_(d),
op_device_(nullptr),
resource_device_(nullptr),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx),
is_remote_(false),
is_async_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
// TODO(allenl): Figure out a better op_device story for custom devices,
// since always setting it to CPU=nullptr doesn't make much sense.
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " custom device: " << VariantDeviceDebugString(device_);
<< " custom device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
Device* op_device,
Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
DataType dtype, EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async,
d, op_device, resource_device, dtype, ctx);
*h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
bool async, Device* d, Device* op_device,
TensorHandle::TensorHandle(Device* d, Device* op_device,
Device* resource_device, DataType dtype,
EagerContext* ctx)
: dtype(dtype),
device_((d == ctx->HostCPU()) ? nullptr : d),
op_device_(op_device),
resource_device_(resource_device),
#if !defined(IS_MOBILE_PLATFORM)
remote_op_id_(kInvalidOpId),
remote_output_num_(kInvalidOutputNum),
#endif
ctx_(ctx),
is_remote_(false),
is_async_(async),
implicit_mirroring_(true),
is_ready_(!async),
tensor_handle_data_(std::move(t)) {
data_(absl::in_place_type<LocalTensorHandleData>) {
DVLOG(3) << "Creating empty Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
#if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::CreateRemoteHandle(
std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* d,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
return Status::OK();
}
Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task,
DataType dtype, Device* d,
Device* resource_device,
EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<RemoteTensorHandleData>(
op_id, output_num, shape, remote_task, ctx),
dtype, d, resource_device, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
DataType dtype, Device* d, Device* resource_device,
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
const string& remote_task, DataType dtype, Device* d,
EagerContext* ctx)
: dtype(dtype),
device_(d),
op_device_(d),
resource_device_(resource_device),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
ctx_(ctx),
is_remote_(true),
is_async_(false),
implicit_mirroring_(true),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Remote TensorHandle: " << this
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
remote_task, ctx) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
Status TensorHandle::CreateUnshapedRemoteHandle(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
Device* d, EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(std::move(t), dtype, d, ctx);
Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, dtype, d, ctx);
return Status::OK();
}
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* device,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
op_id, output_num, remote_task, ctx),
dtype, device, ctx);
return Status::OK();
}
TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
DataType dtype, Device* device, EagerContext* ctx)
TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,
Device* d, EagerContext* ctx)
: dtype(dtype),
device_(device),
op_device_(device),
resource_device_(dtype == DT_RESOURCE ? device : nullptr),
remote_op_id_(t->op_id()),
remote_output_num_(t->output_num()),
device_(d),
op_device_(d),
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
ctx_(ctx),
is_remote_(true),
is_async_(true),
implicit_mirroring_(true),
is_ready_(false),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
ctx->GetContextViewId()) {
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
#endif
bool TensorHandle::IsReady() const {
// Avoid mutex acquisition for local sync handles
if (!is_async_ && !is_remote_) {
return true;
}
tf_shared_lock l(mu_);
return is_ready_;
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
}
Status TensorHandle::WaitReady(const char* caller) const {
if (!IsReady()) {
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
profiler::TraceMeLevel::kInfo);
tf_shared_lock l(mu_);
mu_.Await(Condition(&is_ready_));
}
return is_poisoned_;
bool TensorHandle::IsRemote() const {
#if !defined(IS_MOBILE_PLATFORM)
return data_.index() == 1;
#else
return false;
#endif
}
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
DVLOG(3) << "Tensor on TensorHandle: " << this;
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor"));
return tensor_handle_data_->Tensor(t);
if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this);
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Tensor(t);
}
Status TensorHandle::TensorFromDevice(const Device* d,
@ -337,12 +270,12 @@ Status TensorHandle::TensorFromDevice(const Device* d,
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (is_remote_) {
if (IsRemote()) {
return errors::Internal("Invalid Tensor call on remote handle: ", this);
}
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice"));
return tensor_handle_data_->Tensor(t);
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Tensor(t);
}
tf_shared_lock l(mu_);
@ -352,25 +285,21 @@ Status TensorHandle::TensorFromDevice(const Device* d,
" in Tensor call to handle: ", this);
}
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second;
if (mirror.second == nullptr) {
TF_RETURN_IF_ERROR(
mirror.first->WaitReady("TensorHandle::TensorFromDevice"));
}
return mirror.second->Tensor(t);
return mirror.Tensor(t);
}
Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (is_remote_) {
if (IsRemote()) {
return errors::Internal("Invalid TensorValue call on remote handle: ",
this);
}
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
return tensor_handle_data_->TensorValue(t);
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.TensorValue(t);
}
tf_shared_lock l(mu_);
@ -380,13 +309,8 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
" in TensorValue call to handle: ", this);
}
// Check if the handle is non-empty, else wait.
auto& mirror = elem->second;
if (mirror.second == nullptr) {
TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue"));
}
return mirror.second->TensorValue(t);
return mirror.TensorValue(t);
}
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
@ -405,8 +329,8 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
DCHECK(fill);
return Status::OK();
} else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape"));
return tensor_handle_data_->Shape(shape);
return absl::visit([shape](auto& data) { return data.Shape(shape); },
data_);
}
}
@ -480,8 +404,8 @@ Status TensorHandle::NumDims(int* num_dims) const {
*num_dims = inference_shape_.dims();
return Status::OK();
} else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims"));
return tensor_handle_data_->NumDims(num_dims);
return absl::visit(
[num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
}
}
@ -492,8 +416,9 @@ Status TensorHandle::Dim(int dim_index, int64* dim) const {
*dim = inference_shape_.dim_size(dim_index);
return Status::OK();
} else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim"));
return tensor_handle_data_->Dim(dim_index, dim);
return absl::visit(
[dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
data_);
}
}
@ -503,8 +428,9 @@ Status TensorHandle::NumElements(int64* num_elements) const {
*num_elements = inference_shape_.num_elements();
return Status::OK();
} else {
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements"));
return tensor_handle_data_->NumElements(num_elements);
return absl::visit(
[num_elements](auto& data) { return data.NumElements(num_elements); },
data_);
}
}
@ -512,7 +438,8 @@ Status TensorHandle::Unprotect(const Device* d) {
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
return tensor_handle_data_->Unprotect();
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.Unprotect();
}
tf_shared_lock l(mu_);
@ -524,11 +451,7 @@ Status TensorHandle::Unprotect(const Device* d) {
// Check if the handle is non-empty
auto& mirror = elem->second;
if (mirror.second == nullptr) {
return errors::Internal("Attempted to unprotect an empty mirror");
}
return mirror.second->Unprotect();
return mirror.Unprotect();
}
bool TensorHandle::HasLocalMirror(const Device* d) const {
@ -551,8 +474,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
return errors::Internal("Attempted to duplicate a local mirror.");
}
local_mirrors_[d] =
std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr);
local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::forward_as_tuple());
return Status::OK();
}
@ -567,15 +490,8 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
*op_id = mirror->second->op_id();
*output_num = mirror->second->output_num();
return Status::OK();
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
*op_id = unshaped_mirror->second->op_id();
*output_num = unshaped_mirror->second->output_num();
*op_id = mirror->second.op_id();
*output_num = mirror->second.output_num();
return Status::OK();
}
@ -583,14 +499,14 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
"Could not find remote mirror for specified device");
}
if (remote_op_id_ == kInvalidOpId ||
remote_output_num_ == kInvalidOutputNum) {
return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
", output_num:", remote_output_num_,
") is not set.");
if (!IsRemote()) {
return errors::InvalidArgument("Primary device is not remote");
}
*op_id = remote_op_id_;
*output_num = remote_output_num_;
auto& data = absl::get<RemoteTensorHandleData>(data_);
*op_id = data.op_id();
*output_num = data.output_num();
return Status::OK();
}
@ -603,16 +519,7 @@ bool TensorHandle::HasRemoteMirror(const Device* d,
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
// Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) {
return false;
}
return true;
}
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
// Check if mirror is stale
if (unshaped_mirror->second->context_view_id() != context_view_id) {
if (mirror->second.context_view_id() != context_view_id) {
return false;
}
return true;
@ -630,7 +537,7 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) {
// Check if mirror is stale
if (mirror->second->context_view_id() != context_view_id) {
if (mirror->second.context_view_id() != context_view_id) {
return false;
}
return true;
@ -638,45 +545,39 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
return false;
}
Status TensorHandle::AddUnshapedRemoteMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) {
Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
int output_num,
const string& remote_task,
EagerContext* ctx) {
DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
<< " device: " << d << " " << d->name();
<< " device: " << d << " " << d->name() << " op_id: " << op_id
<< " output_num: " << output_num;
mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == t->context_view_id()) {
if (remote_mirror->second.context_view_id() == ctx->GetContextId()) {
return errors::Internal("Attempted to duplicate a remote mirror.");
}
// Remove stale mirror
remote_mirrors_.erase(remote_mirror);
}
auto unshaped_remote_mirror = unshaped_remote_mirrors_.find(d->name());
if (unshaped_remote_mirror != unshaped_remote_mirrors_.end()) {
if (unshaped_remote_mirror->second->context_view_id() ==
t->context_view_id()) {
return errors::Internal(
"Attempted to duplicate an unshaped remote mirror.");
}
// Remove stale mirror
unshaped_remote_mirrors_.erase(unshaped_remote_mirror);
}
unshaped_remote_mirrors_[d->name()] = std::move(t);
remote_mirrors_.emplace(
std::piecewise_construct, std::forward_as_tuple(d->name()),
std::forward_as_tuple(op_id, output_num, remote_task, ctx));
return Status::OK();
}
Status TensorHandle::AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) {
Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
int output_num, EagerContext* ctx) {
DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
mutex_lock l(mu_);
auto mirror = resource_shape_mirrors_.find(d->name());
if (mirror != resource_shape_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) {
if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
return errors::Internal(
"Attempted to duplicate a resource shape mirror.");
}
@ -684,26 +585,9 @@ Status TensorHandle::AddResourceShapeMirror(
resource_shape_mirrors_.erase(mirror);
}
resource_shape_mirrors_[d->name()] = std::move(t);
return Status::OK();
}
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
const Device* d) {
DVLOG(3) << "AddRemoteMirror on TensorHandle: " << this << " device: " << d;
mutex_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
if (mirror->second->context_view_id() == t->context_view_id()) {
return errors::Internal("Attempted to duplicate a remote mirror.");
}
// Remove stale mirror
remote_mirrors_.erase(mirror);
}
remote_mirrors_[d->name()] = std::move(t);
resource_shape_mirrors_.emplace(
std::piecewise_construct, std::forward_as_tuple(d->name()),
std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId()));
return Status::OK();
}
@ -717,53 +601,24 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
mutex_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror != remote_mirrors_.end()) {
if (remote_mirror->second->context_view_id() == context_view_id) {
return errors::Internal(
"Attempted to set remote shape for existing mirror.");
auto& mirror = remote_mirror->second;
if (mirror.context_view_id() == context_view_id) {
return mirror.SetShape(shape);
}
remote_mirrors_.erase(remote_mirror);
}
auto elem = unshaped_remote_mirrors_.find(d->name());
if (elem == unshaped_remote_mirrors_.end()) {
return errors::Internal(
"Attempted to set remote shape for non-waiting mirror.");
}
if (elem->second->context_view_id() != context_view_id) {
unshaped_remote_mirrors_.erase(elem);
return errors::Internal(
"Attempted to set remote shape for a stale mirror.");
}
auto& data = elem->second;
data->ReleaseRemoteTensorHandle();
remote_mirrors_[d->name()] = absl::make_unique<RemoteTensorHandleData>(
data->op_id(), data->output_num(), shape, data->remote_task(),
&data->ctx());
unshaped_remote_mirrors_.erase(elem);
return Status::OK();
}
DCHECK(is_remote_) << "SetRemoteShape is only called on remote handles.";
DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
UnshapedRemoteTensorHandleData* p =
reinterpret_cast<UnshapedRemoteTensorHandleData*>(
tensor_handle_data_.get());
if (p->context_view_id() != context_view_id) {
auto& data = absl::get<RemoteTensorHandleData>(data_);
if (data.context_view_id() != context_view_id) {
return errors::Internal("Attempted to set remote shape for an old handle.");
}
p->ReleaseRemoteTensorHandle();
tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
remote_op_id_, remote_output_num_, shape, p->remote_task(), ctx_);
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
return Status::OK();
return data.SetShape(shape);
}
void TensorHandle::PoisonRemote(Status status, const Device* d,
@ -772,18 +627,16 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
<< " " << d->name();
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady())
<< "PoisonRemote can only be called on non-ready handle: " << this;
DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
is_poisoned_ = status;
mutex_lock l(mu_);
is_ready_ = true;
auto& data = absl::get<RemoteTensorHandleData>(data_);
data.Poison(status);
} else {
tf_shared_lock l(mu_);
auto mirror = unshaped_remote_mirrors_.find(d->name());
if (mirror != unshaped_remote_mirrors_.end()) {
if (mirror->second->context_view_id() == context_view_id) {
mirror->second->Poison(status);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
if (mirror->second.context_view_id() == context_view_id) {
mirror->second.Poison(status);
}
}
}
@ -798,9 +651,9 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
}
mutex_lock l(mu_);
auto elem = local_mirrors_.insert(std::make_pair(
d, std::make_pair(nullptr,
std::make_unique<LocalTensorHandleData>(tensor))));
auto elem =
local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
std::forward_as_tuple(std::move(tensor)));
if (!elem.second) {
return errors::Internal("Attempted to set tensor for existing mirror.");
}
@ -808,24 +661,18 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
return Status::OK();
}
Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
DCHECK(!is_async_ || !IsReady())
<< "SetTensor is only called on non-ready handles.";
DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
auto& resource_handle = t.flat<class ResourceHandle>()(0);
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
}
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
if (is_async_) {
is_poisoned_ = Status::OK();
mutex_lock l(mu_);
is_ready_ = true;
}
auto& data = absl::get<LocalTensorHandleData>(data_);
return data.SetTensor(std::move(t));
} else {
tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d);
@ -835,12 +682,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
}
auto& mirror = elem->second;
if (mirror.second != nullptr) {
return errors::Internal("Attempted to set tensor for existing mirror.");
}
mirror.second = absl::make_unique<LocalTensorHandleData>(tensor);
mirror.first->SetReady();
return mirror.SetTensor(std::move(t));
}
return Status::OK();
@ -850,12 +692,10 @@ void TensorHandle::Poison(Status status, const Device* d) {
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
DCHECK(!is_async_ || !IsReady())
<< "Poison can only be called on non-ready handle: " << this;
DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
is_poisoned_ = status;
mutex_lock l(mu_);
is_ready_ = true;
auto& data = absl::get<LocalTensorHandleData>(data_);
data.Poison(status);
} else {
tf_shared_lock l(mu_);
auto elem = local_mirrors_.find(d);
@ -864,9 +704,7 @@ void TensorHandle::Poison(Status status, const Device* d) {
<< " device: " << d;
auto& mirror = elem->second;
DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror.";
mirror.first->Poison(status);
mirror.Poison(status);
}
}
@ -977,8 +815,11 @@ string TensorHandle::DebugString() const {
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
// non-NULL if needed.
strings::StrAppend(&out, ", Tensor: ",
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n");
strings::StrAppend(
&out, ", Tensor: ",
is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
: "?",
"\n");
return out;
}

View File

@ -17,10 +17,10 @@ limitations under the License.
#include <algorithm>
#include <cstddef>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <vector>
// clang-format off
@ -32,28 +32,20 @@ limitations under the License.
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
#endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
@ -67,56 +59,45 @@ class TensorHandle : public core::RefCounted {
using VariantDevice = absl::variant<Device*, CustomDevice*>;
// TensorHandle for dtype != DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
Device* d, Device* op_device, EagerContext* ctx);
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx);
// TensorHandle for dtype == DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
const ResourceHandle& resource_handle, Device* d,
Device* op_device, EagerContext* ctx);
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
CustomDevice* d, EagerContext* ctx);
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
Device* d, Device* op_device, Device* resource_device,
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
EagerContext* ctx);
TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
TensorHandle(Device* d, Device* op_device, Device* resource_device,
DataType dtype, EagerContext* ctx);
#if !defined(IS_MOBILE_PLATFORM)
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype,
Device* d, Device* resource_device, EagerContext* ctx);
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
DataType dtype, Device* device, EagerContext* ctx);
TensorHandle(int64 op_id, int32 output_num, DataType dtype, Device* device,
EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM
public:
// TensorHandle with no assigned device
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
// TensorHandle with device == op_device
static Status CreateLocalHandle(const class Tensor& t, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d,
static Status CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h);
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, CustomDevice* d,
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, Device* resource_device,
EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device, DataType dtype,
EagerContext* ctx, TensorHandle** h);
#if !defined(IS_MOBILE_PLATFORM)
static Status CreateRemoteHandle(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task, DataType dtype,
Device* d, Device* resource_device,
EagerContext* ctx, TensorHandle** h);
static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
DataType dtype, Device* d,
Device* resource_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* device,
DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateUnshapedRemoteHandle(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
Device* device, EagerContext* ctx, TensorHandle** h);
static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h);
#endif // IS_MOBILE_PLATFORM
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
@ -131,7 +112,7 @@ class TensorHandle : public core::RefCounted {
// Return the TensorValue from the specified device which could be either the
// default device or a local mirror. The device pointer should be nullptr if
// requesting the HostCPU.
Status TensorValue(tensorflow::TensorValue* t, const Device* d);
Status TensorValue(const Device* d, tensorflow::TensorValue* t);
VariantDevice device() const { return device_; }
Device* op_device() const { return op_device_; }
@ -161,12 +142,10 @@ class TensorHandle : public core::RefCounted {
bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
Status AddUnshapedRemoteMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
const Device* d);
Status AddResourceShapeMirror(
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
const string& remote_task, EagerContext* ctx);
Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
EagerContext* ctx);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const;
@ -212,14 +191,12 @@ class TensorHandle : public core::RefCounted {
Status CopyInferenceShape(TensorHandle* other);
// Warning: can return nullptr for CPU tensors.
// TODO(b/136608821): Move away from nullptr
EagerContext* Context() { return ctx_; }
// dtype for the handle. It must be the same as t.dtype() once the handle is
// ready.
const DataType dtype;
// TODO(b/136608821): Move away from nullptr
bool OnHostCPU() const {
return (
device_.index() == 0 &&
@ -227,14 +204,14 @@ class TensorHandle : public core::RefCounted {
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
}
bool IsRemote() const { return is_remote_; }
bool IsRemote() const;
void EnableImplicitMirroring() { implicit_mirroring_ = true; }
bool ImplicitMirroring() const { return implicit_mirroring_; }
string DebugString() const;
void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types and shapes of the underlying resource.
@ -248,19 +225,6 @@ class TensorHandle : public core::RefCounted {
// with a ready version of the tensor handle data.
bool IsReady() const;
// If the contents of the Tensor pointed to by this handle is yet to be
// computed by a EagerNode, this function will block till that computation is
// done and the handle is "ready".
Status WaitReady(const char* caller) const;
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
// This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
// provided with the appropriate TFE_Context.
//
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to
@ -275,47 +239,33 @@ class TensorHandle : public core::RefCounted {
mutable mutex mu_;
// Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is
// nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage
// waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the
// LocalTensorHandleData should be used.
std::map<const tensorflow::Device*,
std::pair<std::unique_ptr<EmptyLocalTensorHandleData>,
std::unique_ptr<LocalTensorHandleData>>>
// Map of local mirrors. This can include both ready and non-ready mirrors.
std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
local_mirrors_ GUARDED_BY(mu_);
#if !defined(IS_MOBILE_PLATFORM)
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
// variable is ready, since we could get the shape locally without remote copy
// then.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
resource_shape_mirrors_ GUARDED_BY(mu_);
// TODO(gjn): Unshaped remote mirrors are not expected to be long-lived.
// Consider replacing the unshaped_remote_mirrors_ map with something more
// efficient.
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
unshaped_remote_mirrors_ GUARDED_BY(mu_);
std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
GUARDED_BY(mu_);
// TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
// a fixed size map.
std::map<string, std::unique_ptr<RemoteTensorHandleData>> remote_mirrors_
std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle.
int64 remote_op_id_;
int32 remote_output_num_;
#endif
// `ctx` is only guaranteed to be set if the handle is not "ready". This is
// typically true when the handle was produced during async execution.
// `ctx` object is not owned and should outlive this handle.
//
// TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
// a TensorHandle does not outlive the EagerContext from which it came?
EagerContext* const ctx_;
// Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, is_poisoned_ is immutable.
Status is_poisoned_;
const bool is_remote_;
const bool is_async_;
bool implicit_mirroring_;
bool is_ready_ GUARDED_BY(mu_);
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types and shapes for
@ -323,8 +273,12 @@ class TensorHandle : public core::RefCounted {
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
std::unique_ptr<TensorHandleData> tensor_handle_data_;
// WaitReady() has returned. At that point, data_ is immutable.
#if !defined(IS_MOBILE_PLATFORM)
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
#else
absl::variant<LocalTensorHandleData> data_;
#endif
PartialTensorShape inference_shape_;
};

View File

@ -23,12 +23,16 @@ namespace tensorflow {
class Status;
Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
TF_RETURN_IF_ERROR(WaitReady("Tensor"));
*t = &tensor_;
return Status::OK();
}
Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
TF_RETURN_IF_ERROR(WaitReady("TensorValue"));
tensorflow::Tensor& tensor = tensor_;
*t = tensorflow::TensorValue(&tensor);
@ -36,103 +40,96 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
}
Status LocalTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
*shape = tensor_.shape();
return Status::OK();
}
Status LocalTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
*num_dims = tensor_.dims();
return Status::OK();
}
Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
*dim = tensor_.dim_size(dim_index);
return Status::OK();
}
Status LocalTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
*num_elements = tensor_.NumElements();
return Status::OK();
}
Status LocalTensorHandleData::Unprotect() {
if (!IsReady()) {
return errors::Internal("Cannot unprotect a non-ready tensor");
}
forwarding_protection_tensor_ = tensorflow::Tensor();
return Status::OK();
}
Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for an empty handle. "
"Please wait until it is ready");
Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) {
DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles.";
tensor_ = std::move(t);
// Create copy of original tensor to avoid forwarding
forwarding_protection_tensor_ = tensor_;
auto& state = absl::get<BlockingControl>(ctrl_);
state.SetReady();
return Status::OK();
}
Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for an empty handle. "
"Please wait until it is ready");
string LocalTensorHandleData::DebugString() const {
if (IsReady()) {
return tensor_.DeviceSafeDebugString();
} else {
return "LocalTensorHandleData";
}
}
Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an empty handle. "
"Please wait until it is ready");
}
Status EmptyLocalTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect an empty handle.");
}
bool EmptyLocalTensorHandleData::IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void EmptyLocalTensorHandleData::SetReady() {
void LocalTensorHandleData::BlockingControl::SetReady() {
mutex_lock l(mu_);
is_ready_ = true;
}
Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const {
if (!IsReady()) {
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
profiler::TraceMeLevel::kInfo);
tf_shared_lock l(mu_);
Status LocalTensorHandleData::BlockingControl::WaitReady(
const char* caller) const {
tf_shared_lock l(mu_);
if (!is_ready_) {
profiler::TraceMe activity(
[caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_));
}
return is_poisoned_;
}
void EmptyLocalTensorHandleData::Poison(Status status) {
is_poisoned_ = status;
void LocalTensorHandleData::BlockingControl::Poison(Status status) {
mutex_lock l(mu_);
if (is_ready_) {
LOG(ERROR) << "Poison can only be called on non-ready handle: " << this;
return;
}
is_poisoned_ = status;
is_ready_ = true;
}
string EmptyLocalTensorHandleData::DebugString() const {
return "EmptyLocalTensorHandleData";
}
} // namespace tensorflow

View File

@ -15,52 +15,50 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class TensorHandleData {
public:
virtual ~TensorHandleData() {}
// Different tensor handles support a set of these calls. In some cases these
// are resolved with a Tensor or TensorShape. Typically if the handle is not
// ready, none of these are supported operations.
virtual Status Tensor(const tensorflow::Tensor** t) const = 0;
virtual Status TensorValue(tensorflow::TensorValue* t) = 0;
virtual Status Shape(TensorShape* shape) const = 0;
virtual Status NumDims(int* num_dims) const = 0;
virtual Status Dim(int dim_index, int64* dim) const = 0;
virtual Status NumElements(int64* num_elements) const = 0;
// Allow the backing Tensor to be available for buffer reuse during op
// execution.
virtual Status Unprotect() = 0;
virtual string DebugString() const = 0;
};
// Local Tensor Handle: Handle to a Tensor present on the local host.
class LocalTensorHandleData : public TensorHandleData {
class LocalTensorHandleData {
public:
explicit LocalTensorHandleData(const tensorflow::Tensor& t)
: tensor_(t), forwarding_protection_tensor_(t) {}
~LocalTensorHandleData() override {}
LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {}
explicit LocalTensorHandleData(tensorflow::Tensor&& t)
: tensor_(std::move(t)),
forwarding_protection_tensor_(tensor_),
ctrl_(absl::in_place_type<NonBlockingControl>) {}
// A local tensor handle should be able to satisfy all of these requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
Status Tensor(const tensorflow::Tensor** t) const;
Status TensorValue(tensorflow::TensorValue* t);
Status Shape(TensorShape* shape) const;
Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const;
Status Unprotect();
string DebugString() const override {
return tensor_.DeviceSafeDebugString();
bool IsReady() const {
return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_);
}
Status WaitReady(const char* caller) const {
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
ctrl_);
}
void Poison(Status status) {
return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_);
}
Status IsPoisoned() const {
return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_);
}
Status SetTensor(tensorflow::Tensor&& t);
string DebugString() const;
private:
tensorflow::Tensor tensor_;
// TensorHandle has its own reference counting which is distinct from the
@ -70,37 +68,41 @@ class LocalTensorHandleData : public TensorHandleData {
// forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
// release this Tensor to allow forwarding.
tensorflow::Tensor forwarding_protection_tensor_;
};
// Empty Local Tensor Handle: Once the execution is complete this is replaced by
// a local tensor handle.
class EmptyLocalTensorHandleData : public TensorHandleData {
public:
EmptyLocalTensorHandleData() {}
~EmptyLocalTensorHandleData() override {}
// We distinguish between ready and empty tensors with the ctrl_ variant.
// which contains 2 implementations of the waiting logic. The
// NonBlockingControl is a simple no-op class whereas the BlockingControl
// actually uses a mutex. By using a variant we avoid the overhead of
// constructing and destructing the mutex for ready local tensors.
class NonBlockingControl {
public:
bool IsReady() const { return true; }
Status WaitReady(const char* caller) const { return Status::OK(); }
void Poison(Status status) {}
Status IsPoisoned() const { return Status::OK(); }
};
// Empty tensor handles are not ready and hence cannot satisfy any of these
// requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
class BlockingControl {
public:
bool IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void SetReady();
Status WaitReady(const char* caller) const;
void Poison(Status status);
Status IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
bool IsReady() const;
void SetReady();
Status WaitReady(const char* caller) const;
void Poison(Status status);
Status IsPoisoned() const { return is_poisoned_; }
private:
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_ GUARDED_BY(mu_);
};
string DebugString() const override;
private:
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_;
absl::variant<NonBlockingControl, BlockingControl> ctrl_;
};
} // namespace tensorflow

View File

@ -39,12 +39,13 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr);
TensorHandle* sync_th;
EXPECT_TRUE(
TensorHandle::CreateLocalHandle(t, ctx->HostCPU(), ctx, &sync_th).ok());
EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
ctx, &sync_th)
.ok());
TensorHandle* async_th;
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr,
nullptr, DataType::DT_UINT16,
ctx, &async_th)
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
DataType::DT_UINT16, ctx,
&async_th)
.ok());
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());

View File

@ -190,8 +190,10 @@ cc_library(
deps = [
":destroy_tensor_handle_node",
":eager_client",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:tensor_handle_data",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/profiler/lib:traceme",
],
)

View File

@ -530,7 +530,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
}
TensorHandle* tensor_handle = nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle));
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
TensorHandle* copied_handle = nullptr;
Device* device;
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(

View File

@ -101,12 +101,10 @@ Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
core::RefCountPtr<KernelAndDevice> kernel;
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
gtl::InlinedVector<TensorValue, 4> input_vector(1);
TF_RETURN_IF_ERROR(src_->TensorValue(
&input_vector[0],
ctx_->CanonicalDevice(absl::get<Device*>(op->Device()))));
EagerKernelArgs args(1);
Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
EagerKernelArgs args(std::move(input_vector));
return kernel->Run(args, /*outputs=*/nullptr,
/*cancellation_manager=*/nullptr,
/*remote_func_params=*/absl::nullopt);

View File

@ -162,16 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
in.op_device().empty() ? in.device() : in.op_device();
TF_RETURN_IF_ERROR(
parent_->FindDeviceFromName(device_name.c_str(), &device));
string remote_task;
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
return errors::InvalidArgument(
"Unable to find remote task corresponding to device ", device_name);
}
auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
in.op_id(), in.output_num(), remote_task, parent_);
remote_handle_data->ReleaseRemoteTensorHandle();
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(remote_handle_data), in.dtype(), device, parent_, out));
TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes)

View File

@ -71,14 +71,12 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle;
TF_ASSERT_OK(
TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle));
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
local_device_, ctx_, &handle));
const uint64 op_id = 2;
const int output_num = 3;
auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>(
op_id, output_num, t.shape(), /*remote_task=*/"", ctx_);
TF_ASSERT_OK(
handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_));
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
output_num, "", ctx_));
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name()));
@ -90,14 +88,13 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
const uint64 op_id = 3;
const int output_num = 1;
TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateRemoteHandle(
op_id, output_num, t.shape(), /*remote_task=*/"", DT_FLOAT,
remote_device_, /*resource_device=*/nullptr, ctx_, &handle));
TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num,
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name()));

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
@ -84,66 +85,103 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
} // namespace
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const TensorShape& shape,
const string& remote_task,
EagerContext* ctx)
: op_id_(op_id),
uint64 context_view_id)
: is_ready_(false),
op_id_(op_id),
output_num_(output_num),
shape_(shape),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
context_view_id_(context_view_id),
ctx_(nullptr) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_.Ref();
}
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
const string& remote_task,
EagerContext* ctx)
: is_ready_(false),
op_id_(op_id),
output_num_(output_num),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_->Ref();
}
RemoteTensorHandleData::~RemoteTensorHandleData() {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/true);
ctx_.Unref();
}
Status RemoteTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status RemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for a remote device. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
if (ctx_) {
DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/true);
ctx_->Unref();
}
}
Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
TF_RETURN_IF_ERROR(WaitReady("Shape"));
tf_shared_lock l(mu_);
*shape = shape_;
return Status::OK();
}
Status RemoteTensorHandleData::NumDims(int* num_dims) const {
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
tf_shared_lock l(mu_);
*num_dims = shape_.dims();
return Status::OK();
}
Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
TF_RETURN_IF_ERROR(WaitReady("Dim"));
tf_shared_lock l(mu_);
*dim = shape_.dim_size(dim_index);
return Status::OK();
}
Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
tf_shared_lock l(mu_);
*num_elements = shape_.num_elements();
return Status::OK();
}
Status RemoteTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect a remote handle.");
bool RemoteTensorHandleData::IsReady() const {
tf_shared_lock l(mu_);
return is_ready_;
}
void RemoteTensorHandleData::Poison(Status status) {
mutex_lock l(mu_);
is_poisoned_ = status;
}
Status RemoteTensorHandleData::IsPoisoned() const {
tf_shared_lock l(mu_);
return is_poisoned_;
}
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
mutex_lock l(mu_);
if (is_ready_) {
return errors::Internal("SetShape is only called on non-ready handles.");
}
shape_ = shape;
is_poisoned_ = Status::OK();
is_ready_ = true;
return Status::OK();
}
string RemoteTensorHandleData::DebugString() const {
@ -151,73 +189,20 @@ string RemoteTensorHandleData::DebugString() const {
" output_num: ", output_num_);
}
UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
int64 op_id, int32 output_num, const string& remote_task, EagerContext* ctx)
: op_id_(op_id),
output_num_(output_num),
delete_remote_tensor_(true),
remote_task_(remote_task),
context_id_(ctx->GetContextId()),
context_view_id_(ctx->GetContextViewId()),
ctx_(*ctx) {
DCHECK(op_id_ >= 0 && output_num_ >= 0)
<< "Op ID and output num should be >= 0. Op ID: " << op_id
<< ", Output num: " << output_num;
ctx_.Ref();
}
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
if (delete_remote_tensor_) {
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_,
output_num_, /*ready=*/false);
Status RemoteTensorHandleData::WaitReady(const char* caller) const {
if (ctx_ == nullptr) {
return errors::Internal("Cannot wait on lazy remote handle");
}
ctx_.Unref();
}
Status UnshapedRemoteTensorHandleData::Tensor(
const tensorflow::Tensor** t) const {
return errors::Unavailable(
"Unable to get a tensor for a remote handle. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status UnshapedRemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
return errors::Unavailable(
"Unable to get a tensor for a remote handle. Please copy the tensor "
"handle to a local device using TFE_TensorHandleCopyToDevice");
}
Status UnshapedRemoteTensorHandleData::Shape(TensorShape* shape) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumDims(int* num_dims) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const {
return errors::Unavailable(
"Unable to get shape information for an async remote handle. Please wait "
"until it is ready");
}
Status UnshapedRemoteTensorHandleData::Unprotect() {
return errors::Unavailable("Unable to unprotect a remote handle.");
}
string UnshapedRemoteTensorHandleData::DebugString() const {
return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_,
" output_num: ", output_num_);
tf_shared_lock l(mu_);
if (!is_ready_) {
profiler::TraceMe activity(
[caller] { return absl::StrCat(caller, " WaitReady"); },
profiler::TraceMeLevel::kInfo);
DVLOG(3) << "WaitReady: " << caller << " " << this;
mu_.Await(Condition(&is_ready_));
}
return is_poisoned_;
}
} // namespace tensorflow

View File

@ -15,97 +15,56 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only
// the shape is known.
class RemoteTensorHandleData : public TensorHandleData {
class RemoteTensorHandleData {
public:
RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape,
const string& remote_task, EagerContext* ctx);
~RemoteTensorHandleData() override;
// Constructor for lazy remote handles
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id);
// Constructor for unshaped remote handles
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
EagerContext* ctx);
~RemoteTensorHandleData();
// A remote tensor handle does not have a Tensor object, hence it can only
// support the shape requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
EagerContext& ctx() const { return ctx_; }
Status Shape(TensorShape* shape) const;
Status NumDims(int* num_dims) const;
Status Dim(int dim_index, int64* dim) const;
Status NumElements(int64* num_elements) const;
string DebugString() const override;
bool IsReady() const;
Status SetShape(const TensorShape& shape);
void Poison(Status status);
Status IsPoisoned() const;
string DebugString() const;
int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; }
private:
Status WaitReady(const char* caller) const;
mutable mutex mu_;
bool is_ready_ GUARDED_BY(mu_);
Status is_poisoned_ GUARDED_BY(mu_);
TensorShape shape_ GUARDED_BY(mu_);
// IDs required when this class is representing a remote tensor handle.
const int64 op_id_;
const int32 output_num_;
const TensorShape shape_;
string remote_task_;
uint64 context_id_;
uint64 context_view_id_;
EagerContext& ctx_;
};
// Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the
// shape has been computed this is replaced with a remote tensor handle.
class UnshapedRemoteTensorHandleData : public TensorHandleData {
public:
UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num,
const string& remote_task, EagerContext* ctx);
~UnshapedRemoteTensorHandleData() override;
// Unshaped remote tensor handles are not ready and hence cannot satisfy any
// of these requests.
Status Tensor(const tensorflow::Tensor** t) const override;
Status TensorValue(tensorflow::TensorValue* t) override;
Status Shape(TensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status Dim(int dim_index, int64* dim) const override;
Status NumElements(int64* num_elements) const override;
Status Unprotect() override;
void Poison(Status status) { is_poisoned_ = status; }
Status IsPoisoned() const { return is_poisoned_; }
string DebugString() const override;
int64 op_id() const { return op_id_; }
int32 output_num() const { return output_num_; }
string remote_task() const { return remote_task_; }
uint64 context_id() const { return context_id_; }
uint64 context_view_id() const { return context_view_id_; }
EagerContext& ctx() const { return ctx_; }
// When constructed, UnshapedRemoteTensorHandleData owns the remote
// TensorHandle and should delete it by issuing an RPC. Once the remote
// shape has been learned, the ownership is transferred to
// RemoteTensorHandleData. This method must be called to let `this` know
// that it no longer owns the remote handle.
// TODO(iga): Add a factory method here that will create a new
// RemoteTensorHandleData from this and transfer ownership in the process.
void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
private:
Status is_poisoned_;
// IDs required when this class is representing a remote tensor handle.
const int64 op_id_;
const int32 output_num_;
bool delete_remote_tensor_;
string remote_task_;
uint64 context_id_;
uint64 context_view_id_;
EagerContext& ctx_;
EagerContext* ctx_;
};
} // namespace tensorflow

View File

@ -91,11 +91,11 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
for (int64 i = 0; i < n; ++i) {
PyObject* arg = nullptr;
const Tensor& t = call->ins[i];
if (call->eager) {
TensorHandle* handle;
Tensor t = call->ins[i];
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle));
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle));
arg = EagerTensorFromHandle(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)});
if (arg == nullptr) {
@ -103,7 +103,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
return errors::Internal("Unable to procure EagerTensor from Tensor.");
}
} else {
Status s = TensorToNdarray(t, &arg);
Status s = TensorToNdarray(call->ins[i], &arg);
if (!s.ok()) {
Py_DECREF(lst);
return s;

View File

@ -277,7 +277,8 @@ struct Converter {
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
TFE_TensorHandle** h, const char** error) {
/* TODO(josh11b): Allocator & attributes? */
// TODO(josh11b): Allocator & attributes
// TODO(gjn): Use optimized scalar constructors when possible.
Tensor result(ConverterTraits<T>::kTypeEnum,
TensorShape(state->inferred_shape));
if (state->inferred_shape.empty()) { /* Scalar case */
@ -294,7 +295,7 @@ struct Converter {
}
tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle(
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle);
if (!status.ok()) {
return status;
@ -610,8 +611,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context,
&handle);
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle);
}
if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError,
@ -805,10 +806,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
case DT_INVALID: // Only occurs for empty tensors.
{
tensorflow::TensorHandle* h = nullptr;
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape));
Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle(
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &h);
if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());