Internal change
PiperOrigin-RevId: 298470987 Change-Id: I190a37e62a7419de541a2ded4aadeb1194de7bb3
This commit is contained in:
parent
f97b7ba2bf
commit
6c212dc330
@ -1213,10 +1213,10 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
tensorflow::TensorHandle* ret_handle;
|
tensorflow::TensorHandle* ret_handle;
|
||||||
if (custom_device == nullptr) {
|
if (custom_device == nullptr) {
|
||||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
std::move(t), device, device, context, &ret_handle);
|
t, device, context, &ret_handle);
|
||||||
} else {
|
} else {
|
||||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
std::move(t), custom_device, context, &ret_handle);
|
t, custom_device, context, &ret_handle);
|
||||||
}
|
}
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -140,7 +140,6 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:android_tensorflow_lib_lite",
|
"//tensorflow/core:android_tensorflow_lib_lite",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"@com_google_absl//absl/types:variant",
|
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
|
@ -564,7 +564,8 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int num_outputs = kernel->num_outputs();
|
const DataTypeVector& output_dtypes = kernel->output_dtypes();
|
||||||
|
const size_t num_outputs = static_cast<int>(output_dtypes.size());
|
||||||
if (num_outputs > *num_retvals) {
|
if (num_outputs > *num_retvals) {
|
||||||
return errors::InvalidArgument("Expecting ", num_outputs,
|
return errors::InvalidArgument("Expecting ", num_outputs,
|
||||||
" outputs, but *num_retvals is ",
|
" outputs, but *num_retvals is ",
|
||||||
@ -578,19 +579,21 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
graph_collector = ctx.GetGraphCollector();
|
graph_collector = ctx.GetGraphCollector();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool async = executor.Async();
|
||||||
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||||
|
async,
|
||||||
|
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
|
||||||
|
/* op_device= */ kernel->device(),
|
||||||
|
/* resource_device= */ kernel->OutputResourceDevice(i),
|
||||||
|
output_dtypes[i], &ctx, &retvals[i]));
|
||||||
|
}
|
||||||
|
|
||||||
Status s;
|
Status s;
|
||||||
if (executor.Async()) {
|
if (async) {
|
||||||
const DataTypeVector& output_dtypes = kernel->output_dtypes();
|
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
|
||||||
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
|
|
||||||
/* op_device= */ kernel->device(),
|
|
||||||
/* resource_device= */ kernel->OutputResourceDevice(i),
|
|
||||||
output_dtypes[i], &ctx, &retvals[i]));
|
|
||||||
}
|
|
||||||
auto node = absl::make_unique<AsyncExecuteNode>(
|
auto node = absl::make_unique<AsyncExecuteNode>(
|
||||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||||
graph_collector, op->GetCancellationManager(),
|
graph_collector, output_dtypes, op->GetCancellationManager(),
|
||||||
absl::Span<TensorHandle*>(retvals, num_outputs));
|
absl::Span<TensorHandle*>(retvals, num_outputs));
|
||||||
// For async mode, execution order will make sure that all
|
// For async mode, execution order will make sure that all
|
||||||
// input handles are ready before executing them.
|
// input handles are ready before executing them.
|
||||||
@ -598,21 +601,16 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
// performance.
|
// performance.
|
||||||
s = executor.AddOrExecute(std::move(node));
|
s = executor.AddOrExecute(std::move(node));
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
|
||||||
retvals[i] = nullptr;
|
|
||||||
}
|
|
||||||
ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel,
|
ExecuteNode node(&ctx, op->Inputs(), op->remote_func_params(), kernel,
|
||||||
graph_collector, op->GetCancellationManager(),
|
graph_collector, output_dtypes,
|
||||||
{retvals, static_cast<size_t>(num_outputs)});
|
op->GetCancellationManager(), {retvals, num_outputs});
|
||||||
s = executor.SyncExecute(&node);
|
s = executor.SyncExecute(&node);
|
||||||
}
|
}
|
||||||
// Since the operation failed, we need to Unref any outputs if they were
|
// Since the operation failed, we need to Unref any outputs that were
|
||||||
// allocated.
|
// allocated.
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
for (int i = 0; i < num_outputs; ++i) {
|
for (int i = 0; i < num_outputs; ++i) {
|
||||||
if (retvals[i] != nullptr) {
|
retvals[i]->Unref();
|
||||||
retvals[i]->Unref();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -735,9 +733,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
input, input_handle, input_device, *input_device_name,
|
input, input_handle, input_device, *input_device_name,
|
||||||
serialize_resource_dtype_and_shape));
|
serialize_resource_dtype_and_shape));
|
||||||
if (!input_handle->resource_dtypes_and_shapes().empty()) {
|
if (!input_handle->resource_dtypes_and_shapes().empty()) {
|
||||||
TF_RETURN_IF_ERROR(
|
auto tensor_handle_data =
|
||||||
input->AddResourceShapeMirror(op_device, input_handle->op_id(),
|
absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||||
input_handle->output_num(), &ctx));
|
input_handle->op_id(), input_handle->output_num(), remote_task,
|
||||||
|
&ctx);
|
||||||
|
TF_RETURN_IF_ERROR(input->AddResourceShapeMirror(
|
||||||
|
std::move(tensor_handle_data), op_device));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1031,24 +1032,13 @@ Status EagerKernelExecute(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
DCHECK_EQ(retvals.size(), outputs.size());
|
DCHECK_EQ(retvals.size(), outputs.size());
|
||||||
|
|
||||||
for (int i = 0; i < retvals.size(); ++i) {
|
for (int i = 0; i < retvals.size(); ++i) {
|
||||||
if (retvals[i] == nullptr) {
|
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||||
std::move(outputs[i]),
|
absl::get<Device*>(retvals[i]->device()));
|
||||||
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
|
||||||
/* op_device= */ kernel->device(),
|
|
||||||
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
|
|
||||||
&retvals[i]));
|
|
||||||
} else {
|
|
||||||
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
|
|
||||||
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
|
||||||
absl::get<Device*>(retvals[i]->device()));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(retvals[i]->SetTensor(
|
||||||
retvals[i]->SetTensor(std::move(outputs[i]),
|
std::move(outputs[i]), ctx->CanonicalDevice(kernel->OutputDevice(i))));
|
||||||
ctx->CanonicalDevice(kernel->OutputDevice(i))));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1079,7 +1069,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
*result = h;
|
*result = h;
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||||
d, dstd, h->resource_device(), h->dtype, ctx, result));
|
true, d, dstd, h->resource_device(), h->dtype, ctx, result));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status s;
|
Status s;
|
||||||
@ -1148,7 +1138,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
*result = h;
|
*result = h;
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||||
/* d= */ d, /* op_device= */ device,
|
true, /* d= */ d, /* op_device= */ device,
|
||||||
/*resource_device=*/nullptr, h->dtype, ctx, result));
|
/*resource_device=*/nullptr, h->dtype, ctx, result));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -1166,14 +1156,17 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
|||||||
device->name());
|
device->name());
|
||||||
}
|
}
|
||||||
recv_op_id = ctx->RemoteMgr()->NextOpId();
|
recv_op_id = ctx->RemoteMgr()->NextOpId();
|
||||||
|
auto tensor_handle_data =
|
||||||
|
absl::make_unique<UnshapedRemoteTensorHandleData>(recv_op_id, 0,
|
||||||
|
remote_task, ctx);
|
||||||
if (mirror) {
|
if (mirror) {
|
||||||
TF_RETURN_IF_ERROR(h->AddUnshapedRemoteMirror(device, recv_op_id, 0,
|
TF_RETURN_IF_ERROR(
|
||||||
remote_task, ctx));
|
h->AddUnshapedRemoteMirror(std::move(tensor_handle_data), device));
|
||||||
h->Ref();
|
h->Ref();
|
||||||
*result = h;
|
*result = h;
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
|
std::move(tensor_handle_data), h->dtype, device, ctx, result));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ Status ExecuteNodeArgs::Init(
|
|||||||
for (int i = 0; i < n_inputs; ++i) {
|
for (int i = 0; i < n_inputs; ++i) {
|
||||||
TensorHandle* in = op_inputs_flat[i];
|
TensorHandle* in = op_inputs_flat[i];
|
||||||
Device* d = kernel->InputDevice(i);
|
Device* d = kernel->InputDevice(i);
|
||||||
Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]);
|
Status s = in->TensorValue(&tensor_args_flat[i], ctx->CanonicalDevice(d));
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
uint64 context_view_id = ctx->GetContextViewId();
|
uint64 context_view_id = ctx->GetContextViewId();
|
||||||
|
@ -77,7 +77,7 @@ class ExecuteNode : public EagerNode {
|
|||||||
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
|
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
|
||||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
const core::RefCountPtr<KernelAndDevice>& kernel,
|
const core::RefCountPtr<KernelAndDevice>& kernel,
|
||||||
GraphCollector* graph_collector,
|
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
|
||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
absl::Span<TensorHandle*> retvals)
|
absl::Span<TensorHandle*> retvals)
|
||||||
: EagerNode(),
|
: EagerNode(),
|
||||||
@ -130,7 +130,7 @@ class AsyncExecuteNode : public EagerNode {
|
|||||||
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
|
EagerContext* ctx, const absl::InlinedVector<TensorHandle*, 4>& inputs,
|
||||||
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
const absl::optional<EagerRemoteFunctionParams>& remote_func_params,
|
||||||
core::RefCountPtr<KernelAndDevice> kernel,
|
core::RefCountPtr<KernelAndDevice> kernel,
|
||||||
GraphCollector* graph_collector,
|
GraphCollector* graph_collector, const DataTypeVector& output_dtypes,
|
||||||
CancellationManager* cancellation_manager,
|
CancellationManager* cancellation_manager,
|
||||||
absl::Span<TensorHandle*> retvals)
|
absl::Span<TensorHandle*> retvals)
|
||||||
: EagerNode(),
|
: EagerNode(),
|
||||||
|
@ -67,7 +67,6 @@ class EagerKernelArgs : public FunctionArgsInterface {
|
|||||||
~EagerKernelArgs() override{};
|
~EagerKernelArgs() override{};
|
||||||
|
|
||||||
bool HasRemoteInputs() const override { return false; };
|
bool HasRemoteInputs() const override { return false; };
|
||||||
TensorValue* MutableInput(int i) { return &tensor_args_[i]; }
|
|
||||||
|
|
||||||
Status GetLocalArg(const int index, Tensor* val) const override;
|
Status GetLocalArg(const int index, Tensor* val) const override;
|
||||||
|
|
||||||
|
@ -20,31 +20,39 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
#include "tensorflow/core/common_runtime/copy_tensor.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||||
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
|
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
#include "tensorflow/core/framework/resource_var.h"
|
#include "tensorflow/core/framework/resource_var.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -56,7 +64,7 @@ const int32 kInvalidOutputNum = -1;
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void TensorHandle::SetResourceHandleDtypeAndShape(
|
void TensorHandle::SetResourceHandleDtypeAndShape(
|
||||||
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes) {
|
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
|
||||||
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
|
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,191 +86,250 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
|
|||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
|
"TensorHandle::GetResourceHandleDtypesAndShapes WaitReady",
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
data.WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
|
WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
|
||||||
|
|
||||||
*result = handle_dtypes_and_shapes_;
|
*result = handle_dtypes_and_shapes_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
|
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
|
||||||
TensorHandle** h) {
|
TensorHandle** h) {
|
||||||
// TODO(b/136608821): Move away from nullptr
|
// TODO(b/136608821): Move away from nullptr
|
||||||
tensorflow::Tensor tensor = t;
|
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr),
|
||||||
return CreateLocalHandle(std::move(tensor),
|
|
||||||
/*d=*/nullptr,
|
|
||||||
/*op_device=*/nullptr,
|
/*op_device=*/nullptr,
|
||||||
/*ctx=*/nullptr, h);
|
/*ctx=*/nullptr, h);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
|
||||||
Device* op_device, EagerContext* ctx,
|
EagerContext* ctx, TensorHandle** h) {
|
||||||
TensorHandle** h) {
|
return CreateLocalHandle(t, d, d, ctx, h);
|
||||||
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
|
||||||
Device* op_device,
|
Device* op_device, EagerContext* ctx,
|
||||||
Device* resource_device,
|
TensorHandle** h) {
|
||||||
EagerContext* ctx, TensorHandle** h) {
|
if (t.dtype() != DT_RESOURCE) {
|
||||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
|
||||||
*h = new TensorHandle(std::move(t), d, op_device, ctx);
|
t.dtype(), d, op_device, ctx);
|
||||||
} else {
|
} else {
|
||||||
*h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
|
const ResourceHandle& resource_handle = t.flat<class ResourceHandle>()(0);
|
||||||
|
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t),
|
||||||
|
resource_handle, d, op_device, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
Status TensorHandle::CreateLocalHandle(const class Tensor& t, CustomDevice* d,
|
||||||
EagerContext* ctx, TensorHandle** h) {
|
EagerContext* ctx, TensorHandle** h) {
|
||||||
*h = new TensorHandle(std::move(t), d, ctx);
|
*h = new TensorHandle(absl::make_unique<LocalTensorHandleData>(t), t.dtype(),
|
||||||
|
d, ctx);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||||
Device* resource_device, EagerContext* ctx)
|
DataType dtype, Device* d, Device* op_device,
|
||||||
: dtype(t.dtype()),
|
EagerContext* ctx)
|
||||||
|
: dtype(dtype),
|
||||||
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
|
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
|
||||||
op_device_(op_device),
|
op_device_(op_device),
|
||||||
resource_device_(resource_device),
|
resource_device_(nullptr),
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
remote_op_id_(kInvalidOpId),
|
||||||
|
remote_output_num_(kInvalidOutputNum),
|
||||||
|
#endif
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(false),
|
||||||
|
is_async_(false),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
is_ready_(true),
|
||||||
|
tensor_handle_data_(std::move(t)) {
|
||||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_)
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
<< " tensor: " << t.DeviceSafeDebugString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||||
EagerContext* ctx)
|
const ResourceHandle& resource_handle, Device* d,
|
||||||
|
Device* op_device, EagerContext* ctx)
|
||||||
: dtype(DT_RESOURCE),
|
: dtype(DT_RESOURCE),
|
||||||
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
|
device_((!ctx || d == ctx->HostCPU()) ? nullptr : d),
|
||||||
op_device_(op_device),
|
op_device_(op_device),
|
||||||
resource_device_(
|
resource_device_(GetResourceDevice(resource_handle, ctx)),
|
||||||
GetResourceDevice(t.flat<class ResourceHandle>()(0), ctx)),
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
remote_op_id_(kInvalidOpId),
|
||||||
|
remote_output_num_(kInvalidOutputNum),
|
||||||
|
#endif
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(false),
|
||||||
|
is_async_(false),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
handle_dtypes_and_shapes_(
|
is_ready_(true),
|
||||||
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
|
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
|
||||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
tensor_handle_data_(std::move(t)) {
|
||||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_)
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
<< " tensor: " << t.DeviceSafeDebugString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||||
EagerContext* ctx)
|
DataType dtype, CustomDevice* d, EagerContext* ctx)
|
||||||
: dtype(t.dtype()),
|
: dtype(dtype),
|
||||||
device_(d),
|
device_(d),
|
||||||
op_device_(nullptr),
|
op_device_(nullptr),
|
||||||
resource_device_(nullptr),
|
resource_device_(nullptr),
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
remote_op_id_(kInvalidOpId),
|
||||||
|
remote_output_num_(kInvalidOutputNum),
|
||||||
|
#endif
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(false),
|
||||||
|
is_async_(false),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
is_ready_(true),
|
||||||
|
tensor_handle_data_(std::move(t)) {
|
||||||
// TODO(allenl): Figure out a better op_device story for custom devices,
|
// TODO(allenl): Figure out a better op_device story for custom devices,
|
||||||
// since always setting it to CPU=nullptr doesn't make much sense.
|
// since always setting it to CPU=nullptr doesn't make much sense.
|
||||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||||
<< " custom device: " << VariantDeviceDebugString(device_)
|
<< " custom device: " << VariantDeviceDebugString(device_);
|
||||||
<< " tensor: " << t.DeviceSafeDebugString();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
|
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
|
||||||
|
Device* op_device,
|
||||||
Device* resource_device,
|
Device* resource_device,
|
||||||
DataType dtype, EagerContext* ctx,
|
DataType dtype, EagerContext* ctx,
|
||||||
TensorHandle** h) {
|
TensorHandle** h) {
|
||||||
*h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
|
*h = new TensorHandle(absl::make_unique<EmptyLocalTensorHandleData>(), async,
|
||||||
|
d, op_device, resource_device, dtype, ctx);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(Device* d, Device* op_device,
|
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
|
||||||
|
bool async, Device* d, Device* op_device,
|
||||||
Device* resource_device, DataType dtype,
|
Device* resource_device, DataType dtype,
|
||||||
EagerContext* ctx)
|
EagerContext* ctx)
|
||||||
: dtype(dtype),
|
: dtype(dtype),
|
||||||
device_((d == ctx->HostCPU()) ? nullptr : d),
|
device_((d == ctx->HostCPU()) ? nullptr : d),
|
||||||
op_device_(op_device),
|
op_device_(op_device),
|
||||||
resource_device_(resource_device),
|
resource_device_(resource_device),
|
||||||
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
remote_op_id_(kInvalidOpId),
|
||||||
|
remote_output_num_(kInvalidOutputNum),
|
||||||
|
#endif
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(false),
|
||||||
|
is_async_(async),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
data_(absl::in_place_type<LocalTensorHandleData>) {
|
is_ready_(!async),
|
||||||
|
tensor_handle_data_(std::move(t)) {
|
||||||
DVLOG(3) << "Creating empty Local TensorHandle: " << this
|
DVLOG(3) << "Creating empty Local TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_);
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
Status TensorHandle::CreateRemoteHandle(
|
||||||
const string& remote_task,
|
std::unique_ptr<RemoteTensorHandleData> t, DataType dtype, Device* d,
|
||||||
DataType dtype, Device* d,
|
Device* resource_device, EagerContext* ctx, TensorHandle** h) {
|
||||||
EagerContext* ctx,
|
*h = new TensorHandle(std::move(t), dtype, d, resource_device, ctx);
|
||||||
TensorHandle** h) {
|
|
||||||
*h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
Status TensorHandle::CreateRemoteHandle(int64 op_id, int output_num,
|
||||||
const string& remote_task, DataType dtype, Device* d,
|
const TensorShape& shape,
|
||||||
|
const string& remote_task,
|
||||||
|
DataType dtype, Device* d,
|
||||||
|
Device* resource_device,
|
||||||
|
EagerContext* ctx, TensorHandle** h) {
|
||||||
|
*h = new TensorHandle(absl::make_unique<RemoteTensorHandleData>(
|
||||||
|
op_id, output_num, shape, remote_task, ctx),
|
||||||
|
dtype, d, resource_device, ctx);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||||
|
DataType dtype, Device* d, Device* resource_device,
|
||||||
EagerContext* ctx)
|
EagerContext* ctx)
|
||||||
: dtype(dtype),
|
: dtype(dtype),
|
||||||
device_(d),
|
device_(d),
|
||||||
op_device_(d),
|
op_device_(d),
|
||||||
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
resource_device_(resource_device),
|
||||||
|
remote_op_id_(t->op_id()),
|
||||||
|
remote_output_num_(t->output_num()),
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(true),
|
||||||
|
is_async_(false),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
is_ready_(true),
|
||||||
remote_task, ctx) {
|
tensor_handle_data_(std::move(t)) {
|
||||||
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
DVLOG(3) << "Creating Remote TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_);
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
Status TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
DataType dtype, Device* d,
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
|
||||||
EagerContext* ctx,
|
Device* d, EagerContext* ctx, TensorHandle** h) {
|
||||||
TensorHandle** h) {
|
*h = new TensorHandle(std::move(t), dtype, d, ctx);
|
||||||
*h = new TensorHandle(op_id, output_num, dtype, d, ctx);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,
|
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||||
Device* d, EagerContext* ctx)
|
const string& remote_task,
|
||||||
|
DataType dtype, Device* device,
|
||||||
|
EagerContext* ctx,
|
||||||
|
TensorHandle** h) {
|
||||||
|
*h = new TensorHandle(absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||||
|
op_id, output_num, remote_task, ctx),
|
||||||
|
dtype, device, ctx);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||||
|
DataType dtype, Device* device, EagerContext* ctx)
|
||||||
: dtype(dtype),
|
: dtype(dtype),
|
||||||
device_(d),
|
device_(device),
|
||||||
op_device_(d),
|
op_device_(device),
|
||||||
resource_device_(dtype == DT_RESOURCE ? d : nullptr),
|
resource_device_(dtype == DT_RESOURCE ? device : nullptr),
|
||||||
|
remote_op_id_(t->op_id()),
|
||||||
|
remote_output_num_(t->output_num()),
|
||||||
ctx_(ctx),
|
ctx_(ctx),
|
||||||
|
is_remote_(true),
|
||||||
|
is_async_(true),
|
||||||
implicit_mirroring_(true),
|
implicit_mirroring_(true),
|
||||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
is_ready_(false),
|
||||||
ctx->GetContextViewId()) {
|
tensor_handle_data_(std::move(t)) {
|
||||||
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
|
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||||
<< " device: " << VariantDeviceDebugString(device_);
|
<< " device: " << VariantDeviceDebugString(device_);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
bool TensorHandle::IsReady() const {
|
bool TensorHandle::IsReady() const {
|
||||||
return absl::visit([](auto& data) { return data.IsReady(); }, data_);
|
// Avoid mutex acquisition for local sync handles
|
||||||
|
if (!is_async_ && !is_remote_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
return is_ready_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TensorHandle::IsRemote() const {
|
Status TensorHandle::WaitReady(const char* caller) const {
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
if (!IsReady()) {
|
||||||
return data_.index() == 1;
|
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
|
||||||
#else
|
profiler::TraceMeLevel::kInfo);
|
||||||
return false;
|
tf_shared_lock l(mu_);
|
||||||
#endif
|
mu_.Await(Condition(&is_ready_));
|
||||||
|
}
|
||||||
|
return is_poisoned_;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
|
Status TensorHandle::Tensor(const tensorflow::Tensor** t) const {
|
||||||
DVLOG(3) << "Tensor on TensorHandle: " << this;
|
DVLOG(3) << "Tensor on TensorHandle: " << this;
|
||||||
|
|
||||||
if (IsRemote()) {
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Tensor"));
|
||||||
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
return tensor_handle_data_->Tensor(t);
|
||||||
}
|
|
||||||
|
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
|
||||||
return data.Tensor(t);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::TensorFromDevice(const Device* d,
|
Status TensorHandle::TensorFromDevice(const Device* d,
|
||||||
@ -270,12 +337,12 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
|||||||
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
|
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
|
||||||
|
|
||||||
if (d == absl::get<Device*>(device_)) {
|
if (d == absl::get<Device*>(device_)) {
|
||||||
if (IsRemote()) {
|
if (is_remote_) {
|
||||||
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
return errors::Internal("Invalid Tensor call on remote handle: ", this);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorFromDevice"));
|
||||||
return data.Tensor(t);
|
return tensor_handle_data_->Tensor(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
@ -285,21 +352,25 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
|||||||
" in Tensor call to handle: ", this);
|
" in Tensor call to handle: ", this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the handle is non-empty, else wait.
|
||||||
auto& mirror = elem->second;
|
auto& mirror = elem->second;
|
||||||
return mirror.Tensor(t);
|
if (mirror.second == nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
mirror.first->WaitReady("TensorHandle::TensorFromDevice"));
|
||||||
|
}
|
||||||
|
|
||||||
|
return mirror.second->Tensor(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
Status TensorHandle::TensorValue(tensorflow::TensorValue* t, const Device* d) {
|
||||||
DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
|
|
||||||
|
|
||||||
if (d == absl::get<Device*>(device_)) {
|
if (d == absl::get<Device*>(device_)) {
|
||||||
if (IsRemote()) {
|
if (is_remote_) {
|
||||||
return errors::Internal("Invalid TensorValue call on remote handle: ",
|
return errors::Internal("Invalid TensorValue call on remote handle: ",
|
||||||
this);
|
this);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::TensorValue"));
|
||||||
return data.TensorValue(t);
|
return tensor_handle_data_->TensorValue(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
@ -309,8 +380,13 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
|||||||
" in TensorValue call to handle: ", this);
|
" in TensorValue call to handle: ", this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the handle is non-empty, else wait.
|
||||||
auto& mirror = elem->second;
|
auto& mirror = elem->second;
|
||||||
return mirror.TensorValue(t);
|
if (mirror.second == nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(mirror.first->WaitReady("TensorHandle::TensorValue"));
|
||||||
|
}
|
||||||
|
|
||||||
|
return mirror.second->TensorValue(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
|
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
|
||||||
@ -329,8 +405,8 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
|
|||||||
DCHECK(fill);
|
DCHECK(fill);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return absl::visit([shape](auto& data) { return data.Shape(shape); },
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Shape"));
|
||||||
data_);
|
return tensor_handle_data_->Shape(shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -404,8 +480,8 @@ Status TensorHandle::NumDims(int* num_dims) const {
|
|||||||
*num_dims = inference_shape_.dims();
|
*num_dims = inference_shape_.dims();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return absl::visit(
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumDims"));
|
||||||
[num_dims](auto& data) { return data.NumDims(num_dims); }, data_);
|
return tensor_handle_data_->NumDims(num_dims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,9 +492,8 @@ Status TensorHandle::Dim(int dim_index, int64* dim) const {
|
|||||||
*dim = inference_shape_.dim_size(dim_index);
|
*dim = inference_shape_.dim_size(dim_index);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return absl::visit(
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::Dim"));
|
||||||
[dim_index, dim](auto& data) { return data.Dim(dim_index, dim); },
|
return tensor_handle_data_->Dim(dim_index, dim);
|
||||||
data_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,9 +503,8 @@ Status TensorHandle::NumElements(int64* num_elements) const {
|
|||||||
*num_elements = inference_shape_.num_elements();
|
*num_elements = inference_shape_.num_elements();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else {
|
} else {
|
||||||
return absl::visit(
|
TF_RETURN_IF_ERROR(WaitReady("TensorHandle::NumElements"));
|
||||||
[num_elements](auto& data) { return data.NumElements(num_elements); },
|
return tensor_handle_data_->NumElements(num_elements);
|
||||||
data_);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,8 +512,7 @@ Status TensorHandle::Unprotect(const Device* d) {
|
|||||||
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
|
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
|
||||||
|
|
||||||
if (d == absl::get<Device*>(device_)) {
|
if (d == absl::get<Device*>(device_)) {
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
return tensor_handle_data_->Unprotect();
|
||||||
return data.Unprotect();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
@ -451,7 +524,11 @@ Status TensorHandle::Unprotect(const Device* d) {
|
|||||||
|
|
||||||
// Check if the handle is non-empty
|
// Check if the handle is non-empty
|
||||||
auto& mirror = elem->second;
|
auto& mirror = elem->second;
|
||||||
return mirror.Unprotect();
|
if (mirror.second == nullptr) {
|
||||||
|
return errors::Internal("Attempted to unprotect an empty mirror");
|
||||||
|
}
|
||||||
|
|
||||||
|
return mirror.second->Unprotect();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TensorHandle::HasLocalMirror(const Device* d) const {
|
bool TensorHandle::HasLocalMirror(const Device* d) const {
|
||||||
@ -474,8 +551,8 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
|
|||||||
return errors::Internal("Attempted to duplicate a local mirror.");
|
return errors::Internal("Attempted to duplicate a local mirror.");
|
||||||
}
|
}
|
||||||
|
|
||||||
local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
|
local_mirrors_[d] =
|
||||||
std::forward_as_tuple());
|
std::make_pair(std::make_unique<EmptyLocalTensorHandleData>(), nullptr);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -490,8 +567,15 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
|
|||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
auto mirror = remote_mirrors_.find(d->name());
|
auto mirror = remote_mirrors_.find(d->name());
|
||||||
if (mirror != remote_mirrors_.end()) {
|
if (mirror != remote_mirrors_.end()) {
|
||||||
*op_id = mirror->second.op_id();
|
*op_id = mirror->second->op_id();
|
||||||
*output_num = mirror->second.output_num();
|
*output_num = mirror->second->output_num();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
|
||||||
|
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
|
||||||
|
*op_id = unshaped_mirror->second->op_id();
|
||||||
|
*output_num = unshaped_mirror->second->output_num();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -499,14 +583,14 @@ Status TensorHandle::RemoteAddress(const Device* d, int64* op_id,
|
|||||||
"Could not find remote mirror for specified device");
|
"Could not find remote mirror for specified device");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!IsRemote()) {
|
if (remote_op_id_ == kInvalidOpId ||
|
||||||
return errors::InvalidArgument("Primary device is not remote");
|
remote_output_num_ == kInvalidOutputNum) {
|
||||||
|
return errors::InvalidArgument("Remote handle (op_id:", remote_op_id_,
|
||||||
|
", output_num:", remote_output_num_,
|
||||||
|
") is not set.");
|
||||||
}
|
}
|
||||||
|
*op_id = remote_op_id_;
|
||||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
*output_num = remote_output_num_;
|
||||||
*op_id = data.op_id();
|
|
||||||
*output_num = data.output_num();
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -519,7 +603,16 @@ bool TensorHandle::HasRemoteMirror(const Device* d,
|
|||||||
auto mirror = remote_mirrors_.find(d->name());
|
auto mirror = remote_mirrors_.find(d->name());
|
||||||
if (mirror != remote_mirrors_.end()) {
|
if (mirror != remote_mirrors_.end()) {
|
||||||
// Check if mirror is stale
|
// Check if mirror is stale
|
||||||
if (mirror->second.context_view_id() != context_view_id) {
|
if (mirror->second->context_view_id() != context_view_id) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto unshaped_mirror = unshaped_remote_mirrors_.find(d->name());
|
||||||
|
if (unshaped_mirror != unshaped_remote_mirrors_.end()) {
|
||||||
|
// Check if mirror is stale
|
||||||
|
if (unshaped_mirror->second->context_view_id() != context_view_id) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -537,7 +630,7 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
|
|||||||
auto mirror = resource_shape_mirrors_.find(d->name());
|
auto mirror = resource_shape_mirrors_.find(d->name());
|
||||||
if (mirror != resource_shape_mirrors_.end()) {
|
if (mirror != resource_shape_mirrors_.end()) {
|
||||||
// Check if mirror is stale
|
// Check if mirror is stale
|
||||||
if (mirror->second.context_view_id() != context_view_id) {
|
if (mirror->second->context_view_id() != context_view_id) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -545,39 +638,45 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64 op_id,
|
Status TensorHandle::AddUnshapedRemoteMirror(
|
||||||
int output_num,
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) {
|
||||||
const string& remote_task,
|
|
||||||
EagerContext* ctx) {
|
|
||||||
DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
|
DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this
|
||||||
<< " device: " << d << " " << d->name() << " op_id: " << op_id
|
<< " device: " << d << " " << d->name();
|
||||||
<< " output_num: " << output_num;
|
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto remote_mirror = remote_mirrors_.find(d->name());
|
auto remote_mirror = remote_mirrors_.find(d->name());
|
||||||
if (remote_mirror != remote_mirrors_.end()) {
|
if (remote_mirror != remote_mirrors_.end()) {
|
||||||
if (remote_mirror->second.context_view_id() == ctx->GetContextId()) {
|
if (remote_mirror->second->context_view_id() == t->context_view_id()) {
|
||||||
return errors::Internal("Attempted to duplicate a remote mirror.");
|
return errors::Internal("Attempted to duplicate a remote mirror.");
|
||||||
}
|
}
|
||||||
// Remove stale mirror
|
// Remove stale mirror
|
||||||
remote_mirrors_.erase(remote_mirror);
|
remote_mirrors_.erase(remote_mirror);
|
||||||
}
|
}
|
||||||
|
|
||||||
remote_mirrors_.emplace(
|
auto unshaped_remote_mirror = unshaped_remote_mirrors_.find(d->name());
|
||||||
std::piecewise_construct, std::forward_as_tuple(d->name()),
|
if (unshaped_remote_mirror != unshaped_remote_mirrors_.end()) {
|
||||||
std::forward_as_tuple(op_id, output_num, remote_task, ctx));
|
if (unshaped_remote_mirror->second->context_view_id() ==
|
||||||
|
t->context_view_id()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Attempted to duplicate an unshaped remote mirror.");
|
||||||
|
}
|
||||||
|
// Remove stale mirror
|
||||||
|
unshaped_remote_mirrors_.erase(unshaped_remote_mirror);
|
||||||
|
}
|
||||||
|
|
||||||
|
unshaped_remote_mirrors_[d->name()] = std::move(t);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
|
Status TensorHandle::AddResourceShapeMirror(
|
||||||
int output_num, EagerContext* ctx) {
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d) {
|
||||||
DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
|
DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this;
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto mirror = resource_shape_mirrors_.find(d->name());
|
auto mirror = resource_shape_mirrors_.find(d->name());
|
||||||
if (mirror != resource_shape_mirrors_.end()) {
|
if (mirror != resource_shape_mirrors_.end()) {
|
||||||
if (mirror->second.context_view_id() == ctx->GetContextViewId()) {
|
if (mirror->second->context_view_id() == t->context_view_id()) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Attempted to duplicate a resource shape mirror.");
|
"Attempted to duplicate a resource shape mirror.");
|
||||||
}
|
}
|
||||||
@ -585,9 +684,26 @@ Status TensorHandle::AddResourceShapeMirror(const Device* d, int64 op_id,
|
|||||||
resource_shape_mirrors_.erase(mirror);
|
resource_shape_mirrors_.erase(mirror);
|
||||||
}
|
}
|
||||||
|
|
||||||
resource_shape_mirrors_.emplace(
|
resource_shape_mirrors_[d->name()] = std::move(t);
|
||||||
std::piecewise_construct, std::forward_as_tuple(d->name()),
|
|
||||||
std::forward_as_tuple(op_id, output_num, ctx->GetContextViewId()));
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
|
||||||
|
const Device* d) {
|
||||||
|
DVLOG(3) << "AddRemoteMirror on TensorHandle: " << this << " device: " << d;
|
||||||
|
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
auto mirror = remote_mirrors_.find(d->name());
|
||||||
|
if (mirror != remote_mirrors_.end()) {
|
||||||
|
if (mirror->second->context_view_id() == t->context_view_id()) {
|
||||||
|
return errors::Internal("Attempted to duplicate a remote mirror.");
|
||||||
|
}
|
||||||
|
// Remove stale mirror
|
||||||
|
remote_mirrors_.erase(mirror);
|
||||||
|
}
|
||||||
|
|
||||||
|
remote_mirrors_[d->name()] = std::move(t);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -601,24 +717,53 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d,
|
|||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto remote_mirror = remote_mirrors_.find(d->name());
|
auto remote_mirror = remote_mirrors_.find(d->name());
|
||||||
if (remote_mirror != remote_mirrors_.end()) {
|
if (remote_mirror != remote_mirrors_.end()) {
|
||||||
auto& mirror = remote_mirror->second;
|
if (remote_mirror->second->context_view_id() == context_view_id) {
|
||||||
if (mirror.context_view_id() == context_view_id) {
|
return errors::Internal(
|
||||||
return mirror.SetShape(shape);
|
"Attempted to set remote shape for existing mirror.");
|
||||||
}
|
}
|
||||||
remote_mirrors_.erase(remote_mirror);
|
remote_mirrors_.erase(remote_mirror);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto elem = unshaped_remote_mirrors_.find(d->name());
|
||||||
|
if (elem == unshaped_remote_mirrors_.end()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Attempted to set remote shape for non-waiting mirror.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (elem->second->context_view_id() != context_view_id) {
|
||||||
|
unshaped_remote_mirrors_.erase(elem);
|
||||||
|
return errors::Internal(
|
||||||
|
"Attempted to set remote shape for a stale mirror.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& data = elem->second;
|
||||||
|
data->ReleaseRemoteTensorHandle();
|
||||||
|
remote_mirrors_[d->name()] = absl::make_unique<RemoteTensorHandleData>(
|
||||||
|
data->op_id(), data->output_num(), shape, data->remote_task(),
|
||||||
|
&data->ctx());
|
||||||
|
unshaped_remote_mirrors_.erase(elem);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DCHECK(IsRemote()) << "SetRemoteShape is only called on remote handles.";
|
DCHECK(is_remote_) << "SetRemoteShape is only called on remote handles.";
|
||||||
|
DCHECK(!IsReady()) << "SetRemoteShape is only called on non-ready handles.";
|
||||||
|
|
||||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
UnshapedRemoteTensorHandleData* p =
|
||||||
if (data.context_view_id() != context_view_id) {
|
reinterpret_cast<UnshapedRemoteTensorHandleData*>(
|
||||||
|
tensor_handle_data_.get());
|
||||||
|
if (p->context_view_id() != context_view_id) {
|
||||||
return errors::Internal("Attempted to set remote shape for an old handle.");
|
return errors::Internal("Attempted to set remote shape for an old handle.");
|
||||||
}
|
}
|
||||||
|
|
||||||
return data.SetShape(shape);
|
p->ReleaseRemoteTensorHandle();
|
||||||
|
tensor_handle_data_ = absl::make_unique<RemoteTensorHandleData>(
|
||||||
|
remote_op_id_, remote_output_num_, shape, p->remote_task(), ctx_);
|
||||||
|
is_poisoned_ = Status::OK();
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
is_ready_ = true;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void TensorHandle::PoisonRemote(Status status, const Device* d,
|
void TensorHandle::PoisonRemote(Status status, const Device* d,
|
||||||
@ -627,16 +772,18 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
|
|||||||
<< " " << d->name();
|
<< " " << d->name();
|
||||||
|
|
||||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||||
DCHECK(IsRemote()) << "Poison can only be on remote handles: " << this;
|
DCHECK(!is_async_ || !IsReady())
|
||||||
|
<< "PoisonRemote can only be called on non-ready handle: " << this;
|
||||||
|
|
||||||
auto& data = absl::get<RemoteTensorHandleData>(data_);
|
is_poisoned_ = status;
|
||||||
data.Poison(status);
|
mutex_lock l(mu_);
|
||||||
|
is_ready_ = true;
|
||||||
} else {
|
} else {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
auto mirror = remote_mirrors_.find(d->name());
|
auto mirror = unshaped_remote_mirrors_.find(d->name());
|
||||||
if (mirror != remote_mirrors_.end()) {
|
if (mirror != unshaped_remote_mirrors_.end()) {
|
||||||
if (mirror->second.context_view_id() == context_view_id) {
|
if (mirror->second->context_view_id() == context_view_id) {
|
||||||
mirror->second.Poison(status);
|
mirror->second->Poison(status);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -651,9 +798,9 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
auto elem =
|
auto elem = local_mirrors_.insert(std::make_pair(
|
||||||
local_mirrors_.emplace(std::piecewise_construct, std::forward_as_tuple(d),
|
d, std::make_pair(nullptr,
|
||||||
std::forward_as_tuple(std::move(tensor)));
|
std::make_unique<LocalTensorHandleData>(tensor))));
|
||||||
if (!elem.second) {
|
if (!elem.second) {
|
||||||
return errors::Internal("Attempted to set tensor for existing mirror.");
|
return errors::Internal("Attempted to set tensor for existing mirror.");
|
||||||
}
|
}
|
||||||
@ -661,18 +808,24 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
|
||||||
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
|
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
|
||||||
|
|
||||||
if (d == absl::get<Device*>(device_)) {
|
if (d == absl::get<Device*>(device_)) {
|
||||||
DCHECK(!IsRemote()) << "SetTensor is not called on remote handles.";
|
DCHECK(!is_remote_) << "SetTensor is not called on remote handles.";
|
||||||
|
DCHECK(!is_async_ || !IsReady())
|
||||||
|
<< "SetTensor is only called on non-ready handles.";
|
||||||
|
|
||||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
|
||||||
auto& resource_handle = t.flat<class ResourceHandle>()(0);
|
auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
|
||||||
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
|
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
|
||||||
}
|
}
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
|
||||||
return data.SetTensor(std::move(t));
|
if (is_async_) {
|
||||||
|
is_poisoned_ = Status::OK();
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
is_ready_ = true;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
auto elem = local_mirrors_.find(d);
|
auto elem = local_mirrors_.find(d);
|
||||||
@ -682,7 +835,12 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto& mirror = elem->second;
|
auto& mirror = elem->second;
|
||||||
return mirror.SetTensor(std::move(t));
|
if (mirror.second != nullptr) {
|
||||||
|
return errors::Internal("Attempted to set tensor for existing mirror.");
|
||||||
|
}
|
||||||
|
|
||||||
|
mirror.second = absl::make_unique<LocalTensorHandleData>(tensor);
|
||||||
|
mirror.first->SetReady();
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -692,10 +850,12 @@ void TensorHandle::Poison(Status status, const Device* d) {
|
|||||||
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
|
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
|
||||||
|
|
||||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||||
DCHECK(!IsRemote()) << "Poison can only be on local handles: " << this;
|
DCHECK(!is_async_ || !IsReady())
|
||||||
|
<< "Poison can only be called on non-ready handle: " << this;
|
||||||
|
|
||||||
auto& data = absl::get<LocalTensorHandleData>(data_);
|
is_poisoned_ = status;
|
||||||
data.Poison(status);
|
mutex_lock l(mu_);
|
||||||
|
is_ready_ = true;
|
||||||
} else {
|
} else {
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
auto elem = local_mirrors_.find(d);
|
auto elem = local_mirrors_.find(d);
|
||||||
@ -704,7 +864,9 @@ void TensorHandle::Poison(Status status, const Device* d) {
|
|||||||
<< " device: " << d;
|
<< " device: " << d;
|
||||||
|
|
||||||
auto& mirror = elem->second;
|
auto& mirror = elem->second;
|
||||||
mirror.Poison(status);
|
DCHECK(mirror.second == nullptr) << "Attempted to poison existing mirror.";
|
||||||
|
|
||||||
|
mirror.first->Poison(status);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -815,11 +977,8 @@ string TensorHandle::DebugString() const {
|
|||||||
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
|
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
|
||||||
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
|
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
|
||||||
// non-NULL if needed.
|
// non-NULL if needed.
|
||||||
strings::StrAppend(
|
strings::StrAppend(&out, ", Tensor: ",
|
||||||
&out, ", Tensor: ",
|
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n");
|
||||||
is_cpu ? absl::visit([](auto& data) { return data.DebugString(); }, data_)
|
|
||||||
: "?",
|
|
||||||
"\n");
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,10 +17,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
@ -32,20 +32,28 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||||
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
|
#include "tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h"
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/fingerprint.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/platform/notification.h"
|
||||||
#include "tensorflow/core/platform/thread_annotations.h"
|
#include "tensorflow/core/platform/thread_annotations.h"
|
||||||
|
#include "tensorflow/core/public/session_options.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -59,45 +67,56 @@ class TensorHandle : public core::RefCounted {
|
|||||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||||
|
|
||||||
// TensorHandle for dtype != DT_RESOURCE
|
// TensorHandle for dtype != DT_RESOURCE
|
||||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
|
||||||
Device* resource_device, EagerContext* ctx);
|
Device* d, Device* op_device, EagerContext* ctx);
|
||||||
// TensorHandle for dtype == DT_RESOURCE
|
// TensorHandle for dtype == DT_RESOURCE
|
||||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||||
EagerContext* ctx);
|
const ResourceHandle& resource_handle, Device* d,
|
||||||
TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
|
Device* op_device, EagerContext* ctx);
|
||||||
TensorHandle(Device* d, Device* op_device, Device* resource_device,
|
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
|
||||||
|
CustomDevice* d, EagerContext* ctx);
|
||||||
|
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
|
||||||
|
Device* d, Device* op_device, Device* resource_device,
|
||||||
DataType dtype, EagerContext* ctx);
|
DataType dtype, EagerContext* ctx);
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
TensorHandle(int64 op_id, int32 output_num, const string& remote_task,
|
TensorHandle(std::unique_ptr<RemoteTensorHandleData> t, DataType dtype,
|
||||||
|
Device* d, Device* resource_device, EagerContext* ctx);
|
||||||
|
TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||||
DataType dtype, Device* device, EagerContext* ctx);
|
DataType dtype, Device* device, EagerContext* ctx);
|
||||||
TensorHandle(int64 op_id, int32 output_num, DataType dtype, Device* device,
|
|
||||||
EagerContext* ctx);
|
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// TensorHandle with no assigned device
|
// TensorHandle with no assigned device
|
||||||
static Status CreateLocalHandle(const tensorflow::Tensor& t,
|
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
|
||||||
TensorHandle** h);
|
// TensorHandle with device == op_device
|
||||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
static Status CreateLocalHandle(const class Tensor& t, Device* d,
|
||||||
|
EagerContext* ctx, TensorHandle** h);
|
||||||
|
static Status CreateLocalHandle(const class Tensor& t, Device* d,
|
||||||
Device* op_device, EagerContext* ctx,
|
Device* op_device, EagerContext* ctx,
|
||||||
TensorHandle** h);
|
TensorHandle** h);
|
||||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
static Status CreateLocalHandle(const class Tensor& t, CustomDevice* d,
|
||||||
Device* op_device, Device* resource_device,
|
|
||||||
EagerContext* ctx, TensorHandle** h);
|
EagerContext* ctx, TensorHandle** h);
|
||||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
|
||||||
EagerContext* ctx, TensorHandle** h);
|
|
||||||
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
|
|
||||||
Device* resource_device, DataType dtype,
|
Device* resource_device, DataType dtype,
|
||||||
EagerContext* ctx, TensorHandle** h);
|
EagerContext* ctx, TensorHandle** h);
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
|
static Status CreateRemoteHandle(int64 op_id, int output_num,
|
||||||
|
const TensorShape& shape,
|
||||||
|
const string& remote_task, DataType dtype,
|
||||||
|
Device* d, Device* resource_device,
|
||||||
|
EagerContext* ctx, TensorHandle** h);
|
||||||
|
static Status CreateRemoteHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||||
|
DataType dtype, Device* d,
|
||||||
|
Device* resource_device, EagerContext* ctx,
|
||||||
|
TensorHandle** h);
|
||||||
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||||
const string& remote_task,
|
const string& remote_task,
|
||||||
DataType dtype, Device* d,
|
DataType dtype, Device* device,
|
||||||
EagerContext* ctx, TensorHandle** h);
|
EagerContext* ctx, TensorHandle** h);
|
||||||
static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
static Status CreateUnshapedRemoteHandle(
|
||||||
DataType dtype, Device* d,
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, DataType dtype,
|
||||||
EagerContext* ctx, TensorHandle** h);
|
Device* device, EagerContext* ctx, TensorHandle** h);
|
||||||
#endif // IS_MOBILE_PLATFORM
|
#endif // IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
|
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
|
||||||
@ -112,7 +131,7 @@ class TensorHandle : public core::RefCounted {
|
|||||||
// Return the TensorValue from the specified device which could be either the
|
// Return the TensorValue from the specified device which could be either the
|
||||||
// default device or a local mirror. The device pointer should be nullptr if
|
// default device or a local mirror. The device pointer should be nullptr if
|
||||||
// requesting the HostCPU.
|
// requesting the HostCPU.
|
||||||
Status TensorValue(const Device* d, tensorflow::TensorValue* t);
|
Status TensorValue(tensorflow::TensorValue* t, const Device* d);
|
||||||
|
|
||||||
VariantDevice device() const { return device_; }
|
VariantDevice device() const { return device_; }
|
||||||
Device* op_device() const { return op_device_; }
|
Device* op_device() const { return op_device_; }
|
||||||
@ -142,10 +161,12 @@ class TensorHandle : public core::RefCounted {
|
|||||||
bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
|
bool HasRemoteMirror(const Device* d, uint64 context_view_id) const;
|
||||||
bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
|
bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const;
|
||||||
|
|
||||||
Status AddUnshapedRemoteMirror(const Device* d, int64 op_id, int output_num,
|
Status AddUnshapedRemoteMirror(
|
||||||
const string& remote_task, EagerContext* ctx);
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
|
||||||
Status AddResourceShapeMirror(const Device* d, int64 op_id, int output_num,
|
Status AddRemoteMirror(std::unique_ptr<RemoteTensorHandleData> t,
|
||||||
EagerContext* ctx);
|
const Device* d);
|
||||||
|
Status AddResourceShapeMirror(
|
||||||
|
std::unique_ptr<UnshapedRemoteTensorHandleData> t, const Device* d);
|
||||||
|
|
||||||
// Return the op_id and output num if the handle refers to a remote tensor.
|
// Return the op_id and output num if the handle refers to a remote tensor.
|
||||||
Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const;
|
Status RemoteAddress(const Device* d, int64* op_id, int32* output_num) const;
|
||||||
@ -191,12 +212,14 @@ class TensorHandle : public core::RefCounted {
|
|||||||
Status CopyInferenceShape(TensorHandle* other);
|
Status CopyInferenceShape(TensorHandle* other);
|
||||||
|
|
||||||
// Warning: can return nullptr for CPU tensors.
|
// Warning: can return nullptr for CPU tensors.
|
||||||
|
// TODO(b/136608821): Move away from nullptr
|
||||||
EagerContext* Context() { return ctx_; }
|
EagerContext* Context() { return ctx_; }
|
||||||
|
|
||||||
// dtype for the handle. It must be the same as t.dtype() once the handle is
|
// dtype for the handle. It must be the same as t.dtype() once the handle is
|
||||||
// ready.
|
// ready.
|
||||||
const DataType dtype;
|
const DataType dtype;
|
||||||
|
|
||||||
|
// TODO(b/136608821): Move away from nullptr
|
||||||
bool OnHostCPU() const {
|
bool OnHostCPU() const {
|
||||||
return (
|
return (
|
||||||
device_.index() == 0 &&
|
device_.index() == 0 &&
|
||||||
@ -204,14 +227,14 @@ class TensorHandle : public core::RefCounted {
|
|||||||
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
|
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsRemote() const;
|
bool IsRemote() const { return is_remote_; }
|
||||||
void EnableImplicitMirroring() { implicit_mirroring_ = true; }
|
void EnableImplicitMirroring() { implicit_mirroring_ = true; }
|
||||||
bool ImplicitMirroring() const { return implicit_mirroring_; }
|
bool ImplicitMirroring() const { return implicit_mirroring_; }
|
||||||
|
|
||||||
string DebugString() const;
|
string DebugString() const;
|
||||||
|
|
||||||
void SetResourceHandleDtypeAndShape(
|
void SetResourceHandleDtypeAndShape(
|
||||||
std::vector<DtypeAndPartialTensorShape>&& dtypes_and_shapes);
|
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
|
||||||
|
|
||||||
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
|
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
|
||||||
// return data types and shapes of the underlying resource.
|
// return data types and shapes of the underlying resource.
|
||||||
@ -225,6 +248,19 @@ class TensorHandle : public core::RefCounted {
|
|||||||
// with a ready version of the tensor handle data.
|
// with a ready version of the tensor handle data.
|
||||||
bool IsReady() const;
|
bool IsReady() const;
|
||||||
|
|
||||||
|
// If the contents of the Tensor pointed to by this handle is yet to be
|
||||||
|
// computed by a EagerNode, this function will block till that computation is
|
||||||
|
// done and the handle is "ready".
|
||||||
|
Status WaitReady(const char* caller) const;
|
||||||
|
|
||||||
|
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
|
||||||
|
// This was expedient, but perhaps worth revisiting ('device_' should always
|
||||||
|
// be a valid pointer?)
|
||||||
|
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
|
||||||
|
// provided with the appropriate TFE_Context.
|
||||||
|
//
|
||||||
|
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
|
||||||
|
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
|
||||||
VariantDevice const device_;
|
VariantDevice const device_;
|
||||||
|
|
||||||
// Device in which the op producing this tensor was executed. Equals to
|
// Device in which the op producing this tensor was executed. Equals to
|
||||||
@ -239,33 +275,47 @@ class TensorHandle : public core::RefCounted {
|
|||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
|
|
||||||
// Map of local mirrors. This can include both ready and non-ready mirrors.
|
// Map of local mirrors. In sync mode the EmptyLocalTensorHandleData is
|
||||||
std::unordered_map<const tensorflow::Device*, LocalTensorHandleData>
|
// nullptr. In async mode, we use the EmptyLocalTensorHandleData to manage
|
||||||
|
// waiting clients. Once the EmptyLocalTensorHandleData is "ready" only the
|
||||||
|
// LocalTensorHandleData should be used.
|
||||||
|
std::map<const tensorflow::Device*,
|
||||||
|
std::pair<std::unique_ptr<EmptyLocalTensorHandleData>,
|
||||||
|
std::unique_ptr<LocalTensorHandleData>>>
|
||||||
local_mirrors_ GUARDED_BY(mu_);
|
local_mirrors_ GUARDED_BY(mu_);
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
|
// TODO(yujingzhang): Remove resource_shape_mirrors_ once scalable per-replica
|
||||||
// variable is ready, since we could get the shape locally without remote copy
|
// variable is ready, since we could get the shape locally without remote copy
|
||||||
// then.
|
// then.
|
||||||
std::unordered_map<string, RemoteTensorHandleData> resource_shape_mirrors_
|
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
|
||||||
GUARDED_BY(mu_);
|
resource_shape_mirrors_ GUARDED_BY(mu_);
|
||||||
|
// TODO(gjn): Unshaped remote mirrors are not expected to be long-lived.
|
||||||
|
// Consider replacing the unshaped_remote_mirrors_ map with something more
|
||||||
|
// efficient.
|
||||||
|
std::map<string, std::unique_ptr<UnshapedRemoteTensorHandleData>>
|
||||||
|
unshaped_remote_mirrors_ GUARDED_BY(mu_);
|
||||||
// TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
|
// TODO(gjn): Is std::map the most optimal choice here? Perhaps this should be
|
||||||
// a fixed size map.
|
// a fixed size map.
|
||||||
std::unordered_map<string, RemoteTensorHandleData> remote_mirrors_
|
std::map<string, std::unique_ptr<RemoteTensorHandleData>> remote_mirrors_
|
||||||
GUARDED_BY(mu_);
|
GUARDED_BY(mu_);
|
||||||
|
|
||||||
|
// IDs required when this class is representing a remote tensor handle.
|
||||||
|
int64 remote_op_id_;
|
||||||
|
int32 remote_output_num_;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// `ctx` is only guaranteed to be set if the handle is not "ready". This is
|
// `ctx` is only guaranteed to be set if the handle is not "ready". This is
|
||||||
// typically true when the handle was produced during async execution.
|
// typically true when the handle was produced during async execution.
|
||||||
// `ctx` object is not owned and should outlive this handle.
|
// `ctx` object is not owned and should outlive this handle.
|
||||||
//
|
|
||||||
// TODO(b/150614042): Reference count EagerContext to ensure that 'device_' of
|
|
||||||
// a TensorHandle does not outlive the EagerContext from which it came?
|
|
||||||
EagerContext* const ctx_;
|
EagerContext* const ctx_;
|
||||||
|
|
||||||
// Does not need synchronization because it can be accessed only after
|
// Does not need synchronization because it can be accessed only after
|
||||||
// WaitReady() has returned. At that point, is_poisoned_ is immutable.
|
// WaitReady() has returned. At that point, is_poisoned_ is immutable.
|
||||||
Status is_poisoned_;
|
Status is_poisoned_;
|
||||||
|
const bool is_remote_;
|
||||||
|
const bool is_async_;
|
||||||
bool implicit_mirroring_;
|
bool implicit_mirroring_;
|
||||||
|
bool is_ready_ GUARDED_BY(mu_);
|
||||||
|
|
||||||
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
|
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
|
||||||
// refers to a remote resource handle, we store data types and shapes for
|
// refers to a remote resource handle, we store data types and shapes for
|
||||||
@ -273,12 +323,8 @@ class TensorHandle : public core::RefCounted {
|
|||||||
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
|
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
|
||||||
|
|
||||||
// Does not need synchronization because it can be accessed only after
|
// Does not need synchronization because it can be accessed only after
|
||||||
// WaitReady() has returned. At that point, data_ is immutable.
|
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
std::unique_ptr<TensorHandleData> tensor_handle_data_;
|
||||||
absl::variant<LocalTensorHandleData, RemoteTensorHandleData> data_;
|
|
||||||
#else
|
|
||||||
absl::variant<LocalTensorHandleData> data_;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
PartialTensorShape inference_shape_;
|
PartialTensorShape inference_shape_;
|
||||||
};
|
};
|
||||||
|
@ -23,16 +23,12 @@ namespace tensorflow {
|
|||||||
class Status;
|
class Status;
|
||||||
|
|
||||||
Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
|
Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("Tensor"));
|
|
||||||
|
|
||||||
*t = &tensor_;
|
*t = &tensor_;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("TensorValue"));
|
|
||||||
|
|
||||||
tensorflow::Tensor& tensor = tensor_;
|
tensorflow::Tensor& tensor = tensor_;
|
||||||
*t = tensorflow::TensorValue(&tensor);
|
*t = tensorflow::TensorValue(&tensor);
|
||||||
|
|
||||||
@ -40,96 +36,103 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::Shape(TensorShape* shape) const {
|
Status LocalTensorHandleData::Shape(TensorShape* shape) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("Shape"));
|
|
||||||
|
|
||||||
*shape = tensor_.shape();
|
*shape = tensor_.shape();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::NumDims(int* num_dims) const {
|
Status LocalTensorHandleData::NumDims(int* num_dims) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
|
|
||||||
|
|
||||||
*num_dims = tensor_.dims();
|
*num_dims = tensor_.dims();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const {
|
Status LocalTensorHandleData::Dim(int dim_index, int64* dim) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("Dim"));
|
|
||||||
|
|
||||||
*dim = tensor_.dim_size(dim_index);
|
*dim = tensor_.dim_size(dim_index);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::NumElements(int64* num_elements) const {
|
Status LocalTensorHandleData::NumElements(int64* num_elements) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
|
|
||||||
|
|
||||||
*num_elements = tensor_.NumElements();
|
*num_elements = tensor_.NumElements();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::Unprotect() {
|
Status LocalTensorHandleData::Unprotect() {
|
||||||
if (!IsReady()) {
|
|
||||||
return errors::Internal("Cannot unprotect a non-ready tensor");
|
|
||||||
}
|
|
||||||
|
|
||||||
forwarding_protection_tensor_ = tensorflow::Tensor();
|
forwarding_protection_tensor_ = tensorflow::Tensor();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) {
|
Status EmptyLocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
|
||||||
DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles.";
|
return errors::Unavailable(
|
||||||
|
"Unable to get a tensor for an empty handle. "
|
||||||
tensor_ = std::move(t);
|
"Please wait until it is ready");
|
||||||
// Create copy of original tensor to avoid forwarding
|
|
||||||
forwarding_protection_tensor_ = tensor_;
|
|
||||||
|
|
||||||
auto& state = absl::get<BlockingControl>(ctrl_);
|
|
||||||
state.SetReady();
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string LocalTensorHandleData::DebugString() const {
|
Status EmptyLocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
||||||
if (IsReady()) {
|
return errors::Unavailable(
|
||||||
return tensor_.DeviceSafeDebugString();
|
"Unable to get a tensor for an empty handle. "
|
||||||
} else {
|
"Please wait until it is ready");
|
||||||
return "LocalTensorHandleData";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalTensorHandleData::BlockingControl::SetReady() {
|
Status EmptyLocalTensorHandleData::Shape(TensorShape* shape) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an empty handle. "
|
||||||
|
"Please wait until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EmptyLocalTensorHandleData::NumDims(int* num_dims) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an empty handle. "
|
||||||
|
"Please wait until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EmptyLocalTensorHandleData::Dim(int dim_index, int64* dim) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an empty handle. "
|
||||||
|
"Please wait until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EmptyLocalTensorHandleData::NumElements(int64* num_elements) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an empty handle. "
|
||||||
|
"Please wait until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EmptyLocalTensorHandleData::Unprotect() {
|
||||||
|
return errors::Unavailable("Unable to unprotect an empty handle.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EmptyLocalTensorHandleData::IsReady() const {
|
||||||
|
tf_shared_lock l(mu_);
|
||||||
|
return is_ready_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EmptyLocalTensorHandleData::SetReady() {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status LocalTensorHandleData::BlockingControl::WaitReady(
|
Status EmptyLocalTensorHandleData::WaitReady(const char* caller) const {
|
||||||
const char* caller) const {
|
if (!IsReady()) {
|
||||||
tf_shared_lock l(mu_);
|
profiler::TraceMe activity(absl::StrCat(caller, " WaitReady"),
|
||||||
if (!is_ready_) {
|
profiler::TraceMeLevel::kInfo);
|
||||||
profiler::TraceMe activity(
|
tf_shared_lock l(mu_);
|
||||||
[caller] { return absl::StrCat(caller, " WaitReady"); },
|
|
||||||
|
|
||||||
profiler::TraceMeLevel::kInfo);
|
|
||||||
DVLOG(3) << "WaitReady: " << caller << " " << this;
|
|
||||||
mu_.Await(Condition(&is_ready_));
|
mu_.Await(Condition(&is_ready_));
|
||||||
}
|
}
|
||||||
|
|
||||||
return is_poisoned_;
|
return is_poisoned_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LocalTensorHandleData::BlockingControl::Poison(Status status) {
|
void EmptyLocalTensorHandleData::Poison(Status status) {
|
||||||
mutex_lock l(mu_);
|
|
||||||
if (is_ready_) {
|
|
||||||
LOG(ERROR) << "Poison can only be called on non-ready handle: " << this;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
is_poisoned_ = status;
|
is_poisoned_ = status;
|
||||||
|
mutex_lock l(mu_);
|
||||||
is_ready_ = true;
|
is_ready_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string EmptyLocalTensorHandleData::DebugString() const {
|
||||||
|
return "EmptyLocalTensorHandleData";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -15,50 +15,52 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
|
||||||
|
|
||||||
#include "absl/types/variant.h"
|
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Local Tensor Handle: Handle to a Tensor present on the local host.
|
class TensorHandleData {
|
||||||
class LocalTensorHandleData {
|
|
||||||
public:
|
public:
|
||||||
LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {}
|
virtual ~TensorHandleData() {}
|
||||||
explicit LocalTensorHandleData(tensorflow::Tensor&& t)
|
|
||||||
: tensor_(std::move(t)),
|
// Different tensor handles support a set of these calls. In some cases these
|
||||||
forwarding_protection_tensor_(tensor_),
|
// are resolved with a Tensor or TensorShape. Typically if the handle is not
|
||||||
ctrl_(absl::in_place_type<NonBlockingControl>) {}
|
// ready, none of these are supported operations.
|
||||||
|
virtual Status Tensor(const tensorflow::Tensor** t) const = 0;
|
||||||
|
virtual Status TensorValue(tensorflow::TensorValue* t) = 0;
|
||||||
|
virtual Status Shape(TensorShape* shape) const = 0;
|
||||||
|
virtual Status NumDims(int* num_dims) const = 0;
|
||||||
|
virtual Status Dim(int dim_index, int64* dim) const = 0;
|
||||||
|
virtual Status NumElements(int64* num_elements) const = 0;
|
||||||
|
// Allow the backing Tensor to be available for buffer reuse during op
|
||||||
|
// execution.
|
||||||
|
virtual Status Unprotect() = 0;
|
||||||
|
|
||||||
|
virtual string DebugString() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Local Tensor Handle: Handle to a Tensor present on the local host.
|
||||||
|
class LocalTensorHandleData : public TensorHandleData {
|
||||||
|
public:
|
||||||
|
explicit LocalTensorHandleData(const tensorflow::Tensor& t)
|
||||||
|
: tensor_(t), forwarding_protection_tensor_(t) {}
|
||||||
|
~LocalTensorHandleData() override {}
|
||||||
|
|
||||||
// A local tensor handle should be able to satisfy all of these requests.
|
// A local tensor handle should be able to satisfy all of these requests.
|
||||||
Status Tensor(const tensorflow::Tensor** t) const;
|
Status Tensor(const tensorflow::Tensor** t) const override;
|
||||||
Status TensorValue(tensorflow::TensorValue* t);
|
Status TensorValue(tensorflow::TensorValue* t) override;
|
||||||
Status Shape(TensorShape* shape) const;
|
Status Shape(TensorShape* shape) const override;
|
||||||
Status NumDims(int* num_dims) const;
|
Status NumDims(int* num_dims) const override;
|
||||||
Status Dim(int dim_index, int64* dim) const;
|
Status Dim(int dim_index, int64* dim) const override;
|
||||||
Status NumElements(int64* num_elements) const;
|
Status NumElements(int64* num_elements) const override;
|
||||||
Status Unprotect();
|
Status Unprotect() override;
|
||||||
|
|
||||||
bool IsReady() const {
|
string DebugString() const override {
|
||||||
return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_);
|
return tensor_.DeviceSafeDebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status WaitReady(const char* caller) const {
|
|
||||||
return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
|
|
||||||
ctrl_);
|
|
||||||
}
|
|
||||||
void Poison(Status status) {
|
|
||||||
return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_);
|
|
||||||
}
|
|
||||||
Status IsPoisoned() const {
|
|
||||||
return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status SetTensor(tensorflow::Tensor&& t);
|
|
||||||
|
|
||||||
string DebugString() const;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
tensorflow::Tensor tensor_;
|
tensorflow::Tensor tensor_;
|
||||||
// TensorHandle has its own reference counting which is distinct from the
|
// TensorHandle has its own reference counting which is distinct from the
|
||||||
@ -68,41 +70,37 @@ class LocalTensorHandleData {
|
|||||||
// forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
|
// forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
|
||||||
// release this Tensor to allow forwarding.
|
// release this Tensor to allow forwarding.
|
||||||
tensorflow::Tensor forwarding_protection_tensor_;
|
tensorflow::Tensor forwarding_protection_tensor_;
|
||||||
|
};
|
||||||
|
|
||||||
// We distinguish between ready and empty tensors with the ctrl_ variant.
|
// Empty Local Tensor Handle: Once the execution is complete this is replaced by
|
||||||
// which contains 2 implementations of the waiting logic. The
|
// a local tensor handle.
|
||||||
// NonBlockingControl is a simple no-op class whereas the BlockingControl
|
class EmptyLocalTensorHandleData : public TensorHandleData {
|
||||||
// actually uses a mutex. By using a variant we avoid the overhead of
|
public:
|
||||||
// constructing and destructing the mutex for ready local tensors.
|
EmptyLocalTensorHandleData() {}
|
||||||
class NonBlockingControl {
|
~EmptyLocalTensorHandleData() override {}
|
||||||
public:
|
|
||||||
bool IsReady() const { return true; }
|
|
||||||
Status WaitReady(const char* caller) const { return Status::OK(); }
|
|
||||||
void Poison(Status status) {}
|
|
||||||
Status IsPoisoned() const { return Status::OK(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class BlockingControl {
|
// Empty tensor handles are not ready and hence cannot satisfy any of these
|
||||||
public:
|
// requests.
|
||||||
bool IsReady() const {
|
Status Tensor(const tensorflow::Tensor** t) const override;
|
||||||
tf_shared_lock l(mu_);
|
Status TensorValue(tensorflow::TensorValue* t) override;
|
||||||
return is_ready_;
|
Status Shape(TensorShape* shape) const override;
|
||||||
}
|
Status NumDims(int* num_dims) const override;
|
||||||
void SetReady();
|
Status Dim(int dim_index, int64* dim) const override;
|
||||||
Status WaitReady(const char* caller) const;
|
Status NumElements(int64* num_elements) const override;
|
||||||
void Poison(Status status);
|
Status Unprotect() override;
|
||||||
Status IsPoisoned() const {
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return is_poisoned_;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
bool IsReady() const;
|
||||||
mutable mutex mu_;
|
void SetReady();
|
||||||
bool is_ready_ GUARDED_BY(mu_);
|
Status WaitReady(const char* caller) const;
|
||||||
Status is_poisoned_ GUARDED_BY(mu_);
|
void Poison(Status status);
|
||||||
};
|
Status IsPoisoned() const { return is_poisoned_; }
|
||||||
|
|
||||||
absl::variant<NonBlockingControl, BlockingControl> ctrl_;
|
string DebugString() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable mutex mu_;
|
||||||
|
bool is_ready_ GUARDED_BY(mu_);
|
||||||
|
Status is_poisoned_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -39,13 +39,12 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
|||||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||||
&device_mgr, false, nullptr, nullptr, nullptr);
|
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||||
TensorHandle* sync_th;
|
TensorHandle* sync_th;
|
||||||
EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
|
EXPECT_TRUE(
|
||||||
ctx, &sync_th)
|
TensorHandle::CreateLocalHandle(t, ctx->HostCPU(), ctx, &sync_th).ok());
|
||||||
.ok());
|
|
||||||
TensorHandle* async_th;
|
TensorHandle* async_th;
|
||||||
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
|
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(true, nullptr, nullptr,
|
||||||
DataType::DT_UINT16, ctx,
|
nullptr, DataType::DT_UINT16,
|
||||||
&async_th)
|
ctx, &async_th)
|
||||||
.ok());
|
.ok());
|
||||||
|
|
||||||
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());
|
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());
|
||||||
|
@ -190,10 +190,8 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":destroy_tensor_handle_node",
|
":destroy_tensor_handle_node",
|
||||||
":eager_client",
|
":eager_client",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/common_runtime/eager:context",
|
"//tensorflow/core/common_runtime/eager:tensor_handle_data",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -530,8 +530,7 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TensorHandle* tensor_handle = nullptr;
|
TensorHandle* tensor_handle = nullptr;
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(tensor, &tensor_handle));
|
||||||
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
|
|
||||||
TensorHandle* copied_handle = nullptr;
|
TensorHandle* copied_handle = nullptr;
|
||||||
Device* device;
|
Device* device;
|
||||||
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
|
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
|
||||||
|
@ -101,10 +101,12 @@ Status RemoteCopyNode::RunLocalSend(EagerOperation* op) {
|
|||||||
core::RefCountPtr<KernelAndDevice> kernel;
|
core::RefCountPtr<KernelAndDevice> kernel;
|
||||||
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
TF_RETURN_IF_ERROR(CreateUncachedKernelAndDeviceOp(op, &kernel));
|
||||||
|
|
||||||
EagerKernelArgs args(1);
|
gtl::InlinedVector<TensorValue, 4> input_vector(1);
|
||||||
Device* d = ctx_->CanonicalDevice(absl::get<Device*>(op->Device()));
|
TF_RETURN_IF_ERROR(src_->TensorValue(
|
||||||
TF_RETURN_IF_ERROR(src_->TensorValue(d, args.MutableInput(0)));
|
&input_vector[0],
|
||||||
|
ctx_->CanonicalDevice(absl::get<Device*>(op->Device()))));
|
||||||
|
|
||||||
|
EagerKernelArgs args(std::move(input_vector));
|
||||||
return kernel->Run(args, /*outputs=*/nullptr,
|
return kernel->Run(args, /*outputs=*/nullptr,
|
||||||
/*cancellation_manager=*/nullptr,
|
/*cancellation_manager=*/nullptr,
|
||||||
/*remote_func_params=*/absl::nullopt);
|
/*remote_func_params=*/absl::nullopt);
|
||||||
|
@ -162,8 +162,16 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
|
|||||||
in.op_device().empty() ? in.device() : in.op_device();
|
in.op_device().empty() ? in.device() : in.op_device();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
|
string remote_task;
|
||||||
in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
|
if (!DeviceNameUtils::GetTaskName(device->parsed_name(), &remote_task)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Unable to find remote task corresponding to device ", device_name);
|
||||||
|
}
|
||||||
|
auto remote_handle_data = absl::make_unique<UnshapedRemoteTensorHandleData>(
|
||||||
|
in.op_id(), in.output_num(), remote_task, parent_);
|
||||||
|
remote_handle_data->ReleaseRemoteTensorHandle();
|
||||||
|
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
||||||
|
std::move(remote_handle_data), in.dtype(), device, parent_, out));
|
||||||
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
|
||||||
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
|
||||||
&dtypes_and_shapes)
|
&dtypes_and_shapes)
|
||||||
|
@ -71,12 +71,14 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
|
|||||||
Tensor t(DT_FLOAT, TensorShape({0}));
|
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||||
|
|
||||||
TensorHandle* handle;
|
TensorHandle* handle;
|
||||||
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
|
TF_ASSERT_OK(
|
||||||
local_device_, ctx_, &handle));
|
TensorHandle::CreateLocalHandle(t, local_device_, ctx_, &handle));
|
||||||
const uint64 op_id = 2;
|
const uint64 op_id = 2;
|
||||||
const int output_num = 3;
|
const int output_num = 3;
|
||||||
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
|
auto tensor_handle_data = absl::make_unique<RemoteTensorHandleData>(
|
||||||
output_num, "", ctx_));
|
op_id, output_num, t.shape(), /*remote_task=*/"", ctx_);
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
handle->AddRemoteMirror(std::move(tensor_handle_data), remote_device_));
|
||||||
RemoteTensorHandle remote_handle;
|
RemoteTensorHandle remote_handle;
|
||||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||||
handle, &remote_handle, remote_device_, remote_device_->name()));
|
handle, &remote_handle, remote_device_, remote_device_->name()));
|
||||||
@ -88,13 +90,14 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
|
|||||||
|
|
||||||
TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
||||||
RemoteMgr remote_mgr(false, ctx_);
|
RemoteMgr remote_mgr(false, ctx_);
|
||||||
|
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||||
|
|
||||||
const uint64 op_id = 3;
|
const uint64 op_id = 3;
|
||||||
const int output_num = 1;
|
const int output_num = 1;
|
||||||
TensorHandle* handle;
|
TensorHandle* handle;
|
||||||
TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
|
TF_ASSERT_OK(TensorHandle::CreateRemoteHandle(
|
||||||
op_id, output_num,
|
op_id, output_num, t.shape(), /*remote_task=*/"", DT_FLOAT,
|
||||||
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
|
remote_device_, /*resource_device=*/nullptr, ctx_, &handle));
|
||||||
RemoteTensorHandle remote_handle;
|
RemoteTensorHandle remote_handle;
|
||||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||||
handle, &remote_handle, remote_device_, remote_device_->name()));
|
handle, &remote_handle, remote_device_, remote_device_->name()));
|
||||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -85,103 +84,66 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
||||||
uint64 context_view_id)
|
const TensorShape& shape,
|
||||||
: is_ready_(false),
|
|
||||||
op_id_(op_id),
|
|
||||||
output_num_(output_num),
|
|
||||||
context_view_id_(context_view_id),
|
|
||||||
ctx_(nullptr) {
|
|
||||||
DCHECK(op_id_ >= 0 && output_num_ >= 0)
|
|
||||||
<< "Op ID and output num should be >= 0. Op ID: " << op_id
|
|
||||||
<< ", Output num: " << output_num;
|
|
||||||
}
|
|
||||||
|
|
||||||
RemoteTensorHandleData::RemoteTensorHandleData(int64 op_id, int output_num,
|
|
||||||
const string& remote_task,
|
const string& remote_task,
|
||||||
EagerContext* ctx)
|
EagerContext* ctx)
|
||||||
: is_ready_(false),
|
: op_id_(op_id),
|
||||||
op_id_(op_id),
|
|
||||||
output_num_(output_num),
|
output_num_(output_num),
|
||||||
|
shape_(shape),
|
||||||
remote_task_(remote_task),
|
remote_task_(remote_task),
|
||||||
context_id_(ctx->GetContextId()),
|
context_id_(ctx->GetContextId()),
|
||||||
context_view_id_(ctx->GetContextViewId()),
|
context_view_id_(ctx->GetContextViewId()),
|
||||||
ctx_(ctx) {
|
ctx_(*ctx) {
|
||||||
DCHECK(op_id_ >= 0 && output_num_ >= 0)
|
DCHECK(op_id_ >= 0 && output_num_ >= 0)
|
||||||
<< "Op ID and output num should be >= 0. Op ID: " << op_id
|
<< "Op ID and output num should be >= 0. Op ID: " << op_id
|
||||||
<< ", Output num: " << output_num;
|
<< ", Output num: " << output_num;
|
||||||
ctx_->Ref();
|
ctx_.Ref();
|
||||||
}
|
}
|
||||||
|
|
||||||
RemoteTensorHandleData::~RemoteTensorHandleData() {
|
RemoteTensorHandleData::~RemoteTensorHandleData() {
|
||||||
if (ctx_) {
|
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_,
|
||||||
DestroyRemoteTensorHandle(ctx_, remote_task_, context_id_, op_id_,
|
output_num_, /*ready=*/true);
|
||||||
output_num_, /*ready=*/true);
|
ctx_.Unref();
|
||||||
ctx_->Unref();
|
}
|
||||||
}
|
|
||||||
|
Status RemoteTensorHandleData::Tensor(const tensorflow::Tensor** t) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get a tensor for a remote device. Please copy the tensor "
|
||||||
|
"handle to a local device using TFE_TensorHandleCopyToDevice");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get a tensor for a remote device. Please copy the tensor "
|
||||||
|
"handle to a local device using TFE_TensorHandleCopyToDevice");
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
|
Status RemoteTensorHandleData::Shape(TensorShape* shape) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("Shape"));
|
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
*shape = shape_;
|
*shape = shape_;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::NumDims(int* num_dims) const {
|
Status RemoteTensorHandleData::NumDims(int* num_dims) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("NumDims"));
|
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
*num_dims = shape_.dims();
|
*num_dims = shape_.dims();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
|
Status RemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("Dim"));
|
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
*dim = shape_.dim_size(dim_index);
|
*dim = shape_.dim_size(dim_index);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
|
Status RemoteTensorHandleData::NumElements(int64* num_elements) const {
|
||||||
TF_RETURN_IF_ERROR(WaitReady("NumElements"));
|
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
*num_elements = shape_.num_elements();
|
*num_elements = shape_.num_elements();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RemoteTensorHandleData::IsReady() const {
|
Status RemoteTensorHandleData::Unprotect() {
|
||||||
tf_shared_lock l(mu_);
|
return errors::Unavailable("Unable to unprotect a remote handle.");
|
||||||
return is_ready_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void RemoteTensorHandleData::Poison(Status status) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
is_poisoned_ = status;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RemoteTensorHandleData::IsPoisoned() const {
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return is_poisoned_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RemoteTensorHandleData::SetShape(const TensorShape& shape) {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
if (is_ready_) {
|
|
||||||
return errors::Internal("SetShape is only called on non-ready handles.");
|
|
||||||
}
|
|
||||||
|
|
||||||
shape_ = shape;
|
|
||||||
is_poisoned_ = Status::OK();
|
|
||||||
is_ready_ = true;
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string RemoteTensorHandleData::DebugString() const {
|
string RemoteTensorHandleData::DebugString() const {
|
||||||
@ -189,20 +151,73 @@ string RemoteTensorHandleData::DebugString() const {
|
|||||||
" output_num: ", output_num_);
|
" output_num: ", output_num_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoteTensorHandleData::WaitReady(const char* caller) const {
|
UnshapedRemoteTensorHandleData::UnshapedRemoteTensorHandleData(
|
||||||
if (ctx_ == nullptr) {
|
int64 op_id, int32 output_num, const string& remote_task, EagerContext* ctx)
|
||||||
return errors::Internal("Cannot wait on lazy remote handle");
|
: op_id_(op_id),
|
||||||
}
|
output_num_(output_num),
|
||||||
|
delete_remote_tensor_(true),
|
||||||
|
remote_task_(remote_task),
|
||||||
|
context_id_(ctx->GetContextId()),
|
||||||
|
context_view_id_(ctx->GetContextViewId()),
|
||||||
|
ctx_(*ctx) {
|
||||||
|
DCHECK(op_id_ >= 0 && output_num_ >= 0)
|
||||||
|
<< "Op ID and output num should be >= 0. Op ID: " << op_id
|
||||||
|
<< ", Output num: " << output_num;
|
||||||
|
ctx_.Ref();
|
||||||
|
}
|
||||||
|
|
||||||
tf_shared_lock l(mu_);
|
UnshapedRemoteTensorHandleData::~UnshapedRemoteTensorHandleData() {
|
||||||
if (!is_ready_) {
|
if (delete_remote_tensor_) {
|
||||||
profiler::TraceMe activity(
|
DestroyRemoteTensorHandle(&ctx_, remote_task_, context_id_, op_id_,
|
||||||
[caller] { return absl::StrCat(caller, " WaitReady"); },
|
output_num_, /*ready=*/false);
|
||||||
profiler::TraceMeLevel::kInfo);
|
|
||||||
DVLOG(3) << "WaitReady: " << caller << " " << this;
|
|
||||||
mu_.Await(Condition(&is_ready_));
|
|
||||||
}
|
}
|
||||||
return is_poisoned_;
|
ctx_.Unref();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::Tensor(
|
||||||
|
const tensorflow::Tensor** t) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get a tensor for a remote handle. Please copy the tensor "
|
||||||
|
"handle to a local device using TFE_TensorHandleCopyToDevice");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::TensorValue(tensorflow::TensorValue* t) {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get a tensor for a remote handle. Please copy the tensor "
|
||||||
|
"handle to a local device using TFE_TensorHandleCopyToDevice");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::Shape(TensorShape* shape) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an async remote handle. Please wait "
|
||||||
|
"until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::NumDims(int* num_dims) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an async remote handle. Please wait "
|
||||||
|
"until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::Dim(int dim_index, int64* dim) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an async remote handle. Please wait "
|
||||||
|
"until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::NumElements(int64* num_elements) const {
|
||||||
|
return errors::Unavailable(
|
||||||
|
"Unable to get shape information for an async remote handle. Please wait "
|
||||||
|
"until it is ready");
|
||||||
|
}
|
||||||
|
|
||||||
|
Status UnshapedRemoteTensorHandleData::Unprotect() {
|
||||||
|
return errors::Unavailable("Unable to unprotect a remote handle.");
|
||||||
|
}
|
||||||
|
|
||||||
|
string UnshapedRemoteTensorHandleData::DebugString() const {
|
||||||
|
return strings::StrCat("UnshapedRemoteTensorHandleDat:", " op_id: ", op_id_,
|
||||||
|
" output_num: ", output_num_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -15,56 +15,97 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_TENSOR_HANDLE_DATA_H_
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle_data.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only
|
// Remote Tensor Handle: A handle to a Tensor on a remote host. Note that only
|
||||||
// the shape is known.
|
// the shape is known.
|
||||||
class RemoteTensorHandleData {
|
class RemoteTensorHandleData : public TensorHandleData {
|
||||||
public:
|
public:
|
||||||
// Constructor for lazy remote handles
|
RemoteTensorHandleData(int64 op_id, int output_num, const TensorShape& shape,
|
||||||
RemoteTensorHandleData(int64 op_id, int output_num, uint64 context_view_id);
|
const string& remote_task, EagerContext* ctx);
|
||||||
// Constructor for unshaped remote handles
|
~RemoteTensorHandleData() override;
|
||||||
RemoteTensorHandleData(int64 op_id, int output_num, const string& remote_task,
|
|
||||||
EagerContext* ctx);
|
|
||||||
~RemoteTensorHandleData();
|
|
||||||
|
|
||||||
// A remote tensor handle does not have a Tensor object, hence it can only
|
// A remote tensor handle does not have a Tensor object, hence it can only
|
||||||
// support the shape requests.
|
// support the shape requests.
|
||||||
Status Shape(TensorShape* shape) const;
|
Status Tensor(const tensorflow::Tensor** t) const override;
|
||||||
Status NumDims(int* num_dims) const;
|
Status TensorValue(tensorflow::TensorValue* t) override;
|
||||||
Status Dim(int dim_index, int64* dim) const;
|
Status Shape(TensorShape* shape) const override;
|
||||||
Status NumElements(int64* num_elements) const;
|
Status NumDims(int* num_dims) const override;
|
||||||
|
Status Dim(int dim_index, int64* dim) const override;
|
||||||
|
Status NumElements(int64* num_elements) const override;
|
||||||
|
Status Unprotect() override;
|
||||||
|
EagerContext& ctx() const { return ctx_; }
|
||||||
|
|
||||||
bool IsReady() const;
|
string DebugString() const override;
|
||||||
Status SetShape(const TensorShape& shape);
|
|
||||||
void Poison(Status status);
|
|
||||||
Status IsPoisoned() const;
|
|
||||||
|
|
||||||
string DebugString() const;
|
|
||||||
|
|
||||||
int64 op_id() const { return op_id_; }
|
int64 op_id() const { return op_id_; }
|
||||||
int32 output_num() const { return output_num_; }
|
int32 output_num() const { return output_num_; }
|
||||||
|
uint64 context_id() const { return context_id_; }
|
||||||
uint64 context_view_id() const { return context_view_id_; }
|
uint64 context_view_id() const { return context_view_id_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status WaitReady(const char* caller) const;
|
|
||||||
|
|
||||||
mutable mutex mu_;
|
|
||||||
bool is_ready_ GUARDED_BY(mu_);
|
|
||||||
Status is_poisoned_ GUARDED_BY(mu_);
|
|
||||||
TensorShape shape_ GUARDED_BY(mu_);
|
|
||||||
|
|
||||||
// IDs required when this class is representing a remote tensor handle.
|
// IDs required when this class is representing a remote tensor handle.
|
||||||
const int64 op_id_;
|
const int64 op_id_;
|
||||||
const int32 output_num_;
|
const int32 output_num_;
|
||||||
|
const TensorShape shape_;
|
||||||
string remote_task_;
|
string remote_task_;
|
||||||
uint64 context_id_;
|
uint64 context_id_;
|
||||||
uint64 context_view_id_;
|
uint64 context_view_id_;
|
||||||
EagerContext* ctx_;
|
EagerContext& ctx_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Async Remote Tensor Handle: A handle to a Tensor on a remote host. Once the
|
||||||
|
// shape has been computed this is replaced with a remote tensor handle.
|
||||||
|
class UnshapedRemoteTensorHandleData : public TensorHandleData {
|
||||||
|
public:
|
||||||
|
UnshapedRemoteTensorHandleData(int64 op_id, int32 output_num,
|
||||||
|
const string& remote_task, EagerContext* ctx);
|
||||||
|
~UnshapedRemoteTensorHandleData() override;
|
||||||
|
|
||||||
|
// Unshaped remote tensor handles are not ready and hence cannot satisfy any
|
||||||
|
// of these requests.
|
||||||
|
Status Tensor(const tensorflow::Tensor** t) const override;
|
||||||
|
Status TensorValue(tensorflow::TensorValue* t) override;
|
||||||
|
Status Shape(TensorShape* shape) const override;
|
||||||
|
Status NumDims(int* num_dims) const override;
|
||||||
|
Status Dim(int dim_index, int64* dim) const override;
|
||||||
|
Status NumElements(int64* num_elements) const override;
|
||||||
|
Status Unprotect() override;
|
||||||
|
|
||||||
|
void Poison(Status status) { is_poisoned_ = status; }
|
||||||
|
Status IsPoisoned() const { return is_poisoned_; }
|
||||||
|
|
||||||
|
string DebugString() const override;
|
||||||
|
|
||||||
|
int64 op_id() const { return op_id_; }
|
||||||
|
int32 output_num() const { return output_num_; }
|
||||||
|
string remote_task() const { return remote_task_; }
|
||||||
|
uint64 context_id() const { return context_id_; }
|
||||||
|
uint64 context_view_id() const { return context_view_id_; }
|
||||||
|
EagerContext& ctx() const { return ctx_; }
|
||||||
|
|
||||||
|
// When constructed, UnshapedRemoteTensorHandleData owns the remote
|
||||||
|
// TensorHandle and should delete it by issuing an RPC. Once the remote
|
||||||
|
// shape has been learned, the ownership is transferred to
|
||||||
|
// RemoteTensorHandleData. This method must be called to let `this` know
|
||||||
|
// that it no longer owns the remote handle.
|
||||||
|
// TODO(iga): Add a factory method here that will create a new
|
||||||
|
// RemoteTensorHandleData from this and transfer ownership in the process.
|
||||||
|
void ReleaseRemoteTensorHandle() { delete_remote_tensor_ = false; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status is_poisoned_;
|
||||||
|
// IDs required when this class is representing a remote tensor handle.
|
||||||
|
const int64 op_id_;
|
||||||
|
const int32 output_num_;
|
||||||
|
bool delete_remote_tensor_;
|
||||||
|
string remote_task_;
|
||||||
|
uint64 context_id_;
|
||||||
|
uint64 context_view_id_;
|
||||||
|
EagerContext& ctx_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -91,11 +91,11 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
|
|||||||
Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
|
Device* device = IsCPUDevice(call->device) ? nullptr : call->device;
|
||||||
for (int64 i = 0; i < n; ++i) {
|
for (int64 i = 0; i < n; ++i) {
|
||||||
PyObject* arg = nullptr;
|
PyObject* arg = nullptr;
|
||||||
|
const Tensor& t = call->ins[i];
|
||||||
if (call->eager) {
|
if (call->eager) {
|
||||||
TensorHandle* handle;
|
TensorHandle* handle;
|
||||||
Tensor t = call->ins[i];
|
|
||||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||||
std::move(t), ctx->CanonicalDevice(device), nullptr, ctx, &handle));
|
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle));
|
||||||
arg = EagerTensorFromHandle(new TFE_TensorHandle{
|
arg = EagerTensorFromHandle(new TFE_TensorHandle{
|
||||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)});
|
std::make_unique<tensorflow::TensorHandleInterface>(handle)});
|
||||||
if (arg == nullptr) {
|
if (arg == nullptr) {
|
||||||
@ -103,7 +103,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
|
|||||||
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
return errors::Internal("Unable to procure EagerTensor from Tensor.");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Status s = TensorToNdarray(call->ins[i], &arg);
|
Status s = TensorToNdarray(t, &arg);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
Py_DECREF(lst);
|
Py_DECREF(lst);
|
||||||
return s;
|
return s;
|
||||||
|
@ -277,8 +277,7 @@ struct Converter {
|
|||||||
|
|
||||||
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
|
static Status Convert(TFE_Context* ctx, PyObject* obj, ConverterState* state,
|
||||||
TFE_TensorHandle** h, const char** error) {
|
TFE_TensorHandle** h, const char** error) {
|
||||||
// TODO(josh11b): Allocator & attributes
|
/* TODO(josh11b): Allocator & attributes? */
|
||||||
// TODO(gjn): Use optimized scalar constructors when possible.
|
|
||||||
Tensor result(ConverterTraits<T>::kTypeEnum,
|
Tensor result(ConverterTraits<T>::kTypeEnum,
|
||||||
TensorShape(state->inferred_shape));
|
TensorShape(state->inferred_shape));
|
||||||
if (state->inferred_shape.empty()) { /* Scalar case */
|
if (state->inferred_shape.empty()) { /* Scalar case */
|
||||||
@ -295,7 +294,7 @@ struct Converter {
|
|||||||
}
|
}
|
||||||
tensorflow::TensorHandle* handle = nullptr;
|
tensorflow::TensorHandle* handle = nullptr;
|
||||||
auto status = tensorflow::TensorHandle::CreateLocalHandle(
|
auto status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
std::move(result), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
||||||
ctx->context, &handle);
|
ctx->context, &handle);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
@ -611,8 +610,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
|||||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||||
if (cppstatus.ok()) {
|
if (cppstatus.ok()) {
|
||||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context,
|
||||||
ctx->context, &handle);
|
&handle);
|
||||||
}
|
}
|
||||||
if (!cppstatus.ok()) {
|
if (!cppstatus.ok()) {
|
||||||
PyErr_SetString(PyExc_ValueError,
|
PyErr_SetString(PyExc_ValueError,
|
||||||
@ -806,10 +805,10 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
|
|||||||
case DT_INVALID: // Only occurs for empty tensors.
|
case DT_INVALID: // Only occurs for empty tensors.
|
||||||
{
|
{
|
||||||
tensorflow::TensorHandle* h = nullptr;
|
tensorflow::TensorHandle* h = nullptr;
|
||||||
Tensor t(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||||
TensorShape(state.inferred_shape));
|
TensorShape(state.inferred_shape));
|
||||||
status = tensorflow::TensorHandle::CreateLocalHandle(
|
status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||||
std::move(t), /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
||||||
ctx->context, &h);
|
ctx->context, &h);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
||||||
|
Loading…
Reference in New Issue
Block a user