Experimental API for custom devices in TFE.

Custom devices are an experimental hook into eager op execution, allowing experimentation outside the TensorFlow codebase. These devices do not work in traced code at the moment.

PiperOrigin-RevId: 293615055
Change-Id: I031da213e964caa7d4e11e0f491a3985d034b175
This commit is contained in:
Allen Lavoie 2020-02-06 09:58:02 -08:00 committed by TensorFlower Gardener
parent 91a3741164
commit a4064a389e
20 changed files with 618 additions and 69 deletions

View File

@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
"tf_cuda_library",
@ -289,6 +290,27 @@ tf_cuda_cc_test(
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",
srcs = [
"custom_device_test.cc",
],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "tape",
hdrs = ["tape.h"],
@ -301,7 +323,10 @@ cc_library(
filegroup(
name = "headers",
srcs = ["c_api.h"],
srcs = [
"c_api.h",
"c_api_experimental.h",
],
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -103,7 +103,12 @@ const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
return op_def;
}
bool IsCPU(const tensorflow::Device* d) {
bool IsCPU(
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
if (VariantDeviceIsCustom(variant)) {
return false;
}
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
@ -1009,6 +1014,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
}
tensorflow::Device* d = handle_->op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
@ -1029,9 +1037,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
if (!IsValid(status)) {
return nullptr;
}
tensorflow::Device* d = handle_->device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
if (VariantDeviceIsCustom(handle_->device())) {
return absl::get<tensorflow::CustomDevice*>(handle_->device())
->name()
.c_str();
} else {
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
}
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
@ -1065,6 +1079,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
if (VariantDeviceIsCustom(handle_->device())) {
tensorflow::CustomDevice* custom_device =
absl::get<tensorflow::CustomDevice*>(handle_->device());
tensorflow::TensorHandle* copy;
*status = custom_device->CopyTensorFromDevice(
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", &copy);
if (status->ok()) {
return TensorHandleInterface(copy).Resolve(status);
} else {
return nullptr;
}
}
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
if (handle_->IsRemote()) {
@ -1110,6 +1136,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
return t->data();
}
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
@ -1117,8 +1148,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
"handle.");
return nullptr;
}
if (handle->device() != nullptr) {
status->status = handle->device()->Sync();
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
if (device != nullptr) {
status->status = device->Sync();
if (!status->status.ok()) {
return nullptr;
}
@ -1137,12 +1169,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
const int64_t* dims, int num_dims, void* data, size_t len,
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status) {
tensorflow::Device* device;
tensorflow::Device* device = nullptr;
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
status->status =
context->FindCustomDeviceFromName(device_name, &custom_device);
if (!status->status.ok()) {
deallocator(data, len, deallocator_arg);
return nullptr;
}
}
std::vector<tensorflow::int64> dimvec(num_dims);
for (int i = 0; i < num_dims; ++i) {
@ -1166,8 +1203,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
tensorflow::TensorHandle* ret_handle;
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
if (custom_device == nullptr) {
variant_device = device;
} else {
variant_device = custom_device;
}
status->status = tensorflow::TensorHandle::CreateLocalHandle(
t, device, context, &ret_handle);
t, variant_device, context, &ret_handle);
if (!status->status.ok()) {
return nullptr;
}
@ -1508,8 +1551,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
tensorflow::EagerContext* context = ctx->context;
status->status = context->FindDeviceFromName(device_name, &device);
if (!status->status.ok()) {
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
&handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
}
return nullptr;
}
// Handle tensor handles currently in custom devices
const char* handle_device_name = h->handle->DeviceName(&status->status);
if (!status->status.ok()) {
return nullptr;
}
tensorflow::CustomDevice* dev;
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
return nullptr;
}
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
@ -1648,3 +1725,94 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
}
}
} // namespace tensorflow
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
: device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* tensor,
tensorflow::TensorHandle** result) override {
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TF_Status status;
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
(*result)->Ref();
delete result_handle;
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* tensor,
const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override {
TF_Status status;
tensor->Ref();
TFE_TensorHandle tensor_handle{
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
&tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
(*result)->Ref();
delete result_handle;
return status.status;
}
tensorflow::Status Execute(tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> inputs;
inputs.reserve(op->Inputs().size());
for (int i = 0; i < op->Inputs().size(); ++i) {
op->Inputs()[i]->Ref();
inputs.push_back(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
op->Inputs()[i])});
}
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
// TODO(allenl): figure out how to get attrs from EagerOperation
TF_Status status;
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
outputs[i]->handle.get())
->Handle();
retvals[i]->Ref();
}
}
for (auto inp : inputs) {
delete inp;
}
return status.status;
}
private:
TFE_CustomDevice device_;
void* info_;
string name_;
};
} // namespace
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
}

View File

@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
}
#ifdef TENSORFLOW_EAGER_USE_XLA
tensorflow::Device* device = handle_->device();
tensorflow::Device* device = absl::get<Device*>(handle_->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
tensorflow::XlaDevice* xla_device =

View File

@ -463,6 +463,57 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
TF_Buffer* buf);
#define TFE_CUSTOM_DEVICE_VERSION 0
// Struct to be filled in
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info);
// Method to execute an operation.
// TODO(allenl) figure out a generic way of passing attrs here
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
// Method to delete a device.
void (*delete_device)(void* device_info);
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.
//
// Eager operations may be placed on this device, e.g. `with
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
//
// The custom device defines copy operations for moving TensorHandles on and
// off, and an an execution operation for named operations. Often execution will
// simply wrap op execution on one or more physical devices.
//
// device_info is an opaque caller-defined type stored with the custom device
// which is passed to the functions referenced in the TFE_CustomDevice struct
// `device` (execute, delete_device, etc.). It can for example contain the
// names of wrapped devices.
//
// There are currently no graph semantics implemented for registered custom
// devices, so executing tf.functions which contain operations placed on custom
// devices will fail.
//
// This API is highly experimental, and in particular is expected to change when
// it starts supporting operations with attributes and when tf.function support
// is added.
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// A simple logging device to test custom device registration.
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/platform/test.h"
namespace {
struct LoggingDevice {
TFE_Context* ctx;
tensorflow::string device_name;
tensorflow::string underlying_device;
// Set to true whenever a TensorHandle is copied onto the device
bool* arrived_flag;
};
struct LoggedTensor {
TFE_TensorHandle* tensor;
LoggedTensor() = delete;
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
};
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
delete reinterpret_cast<LoggedTensor*>(data);
}
TFE_TensorHandle* MakeLoggedTensorHandle(
TFE_Context* ctx, const tensorflow::string& logging_device_name,
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
}
auto dtype = TFE_TensorHandleDataType(t->tensor);
return TFE_NewTensorHandleFromDeviceMemory(
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
}
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
TF_Status* status, void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
tensor, dev->ctx, dev->underlying_device.c_str(), status);
if (TF_GetCode(status) != TF_OK) return nullptr;
auto dst = std::make_unique<LoggedTensor>(t);
*(dev->arrived_flag) = true;
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
status);
}
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
const char* target_device_name,
TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_INTERNAL,
"Trying to copy a tensor out of a logging device.");
return nullptr;
}
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
const char* operation_name, int* num_outputs,
TFE_TensorHandle** outputs, TF_Status* s,
void* device_info) {
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
for (int j = 0; j < num_inputs; ++j) {
TFE_TensorHandle* input = inputs[j];
const char* input_device = TFE_TensorHandleDeviceName(input, s);
if (TF_GetCode(s) != TF_OK) return;
if (dev->device_name == input_device) {
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
TFE_TensorHandleDevicePointer(input, s));
if (TF_GetCode(s) != TF_OK) return;
TFE_OpAddInput(op, t->tensor, s);
} else {
TFE_OpAddInput(op, input, s);
}
if (TF_GetCode(s) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
TFE_Execute(op, op_outputs.data(), num_outputs, s);
TFE_DeleteOp(op);
if (TF_GetCode(s) != TF_OK) return;
std::vector<TFE_TensorHandle*> unwrapped_outputs;
for (auto* handle : op_outputs) {
unwrapped_outputs.push_back(handle);
}
for (int i = 0; i < *num_outputs; ++i) {
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
std::move(logged_tensor), s);
}
}
void DeleteLoggingDevice(void* device_info) {
delete reinterpret_cast<LoggingDevice*>(device_info);
}
void RegisterLoggingDevice(TFE_Context* context, const char* name,
bool* arrived_flag) {
TFE_CustomDevice custom_device;
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
custom_device.delete_device = &DeleteLoggingDevice;
custom_device.execute = &LoggingDeviceExecute;
LoggingDevice* device = new LoggingDevice;
device->ctx = context;
device->arrived_flag = arrived_flag;
device->device_name = name;
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
TFE_RegisterCustomDevice(context, custom_device, name, device);
}
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* context = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
bool arrived = false;
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
RegisterLoggingDevice(context, name, &arrived);
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
ASSERT_FALSE(arrived);
TFE_TensorHandle* hdevice =
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
ASSERT_TRUE(arrived);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteTensorHandle(hdevice);
TFE_DeleteContext(context);
}
} // namespace

View File

@ -157,6 +157,7 @@ tf_cuda_library(
],
"//conditions:default": [
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",

View File

@ -702,6 +702,21 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
return status;
}
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const {
auto dev_it = custom_devices_.find(device_name);
if (dev_it == custom_devices_.end()) {
return errors::InvalidArgument(device_name, " unknown device.");
}
*dev = dev_it->second.get();
return Status::OK();
}
void EagerContext::RegisterCustomDevice(const string& device_name,
std::unique_ptr<CustomDevice> device) {
custom_devices_[device_name] = std::move(device);
}
bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
if (first == nullptr) first = HostCPU();
if (second == nullptr) second = HostCPU();

View File

@ -106,6 +106,24 @@ class RunMetadataListener {
virtual void BeforeClearRunMetadata() = 0;
};
class TensorHandle;
class EagerOperation;
class CustomDevice {
public:
virtual ~CustomDevice() {}
virtual const string& name() = 0;
virtual Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) = 0;
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
const string& target_device_name,
TensorHandle** result) = 0;
virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
int* num_retvals) = 0;
};
class EagerContext : public core::RefCounted {
public:
static const uint64 kInvalidContextId = 0;
@ -416,6 +434,12 @@ class EagerContext : public core::RefCounted {
Status FindDeviceFromName(const char* device_name, Device** device) const;
Status FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const;
void RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device);
bool OnSameTask(const Device* first, const Device* second) const;
// Gets the CPU device on the task of device.
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
@ -492,6 +516,7 @@ class EagerContext : public core::RefCounted {
std::vector<DeviceType> prioritized_device_type_list_;
Rendezvous* rendezvous_;
std::function<Rendezvous*(const int64)> rendezvous_creator_;
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};

View File

@ -177,7 +177,13 @@ Status ValidateInputTypeAndPlacement(
for (int i = 0; i < n_inputs; ++i) {
TensorHandle* handle = op->Inputs()[i];
Device* expected_device = kernel->InputDevice(i);
Device* handle_device = handle->DeviceOrHostCPU(*ctx);
auto handle_device_variant = handle->DeviceOrHostCPU(*ctx);
if (VariantDeviceIsCustom(handle_device_variant)) {
return errors::Unimplemented(
"Custom devices and remote execution are not yet supported "
"together.");
}
Device* handle_device = absl::get<Device*>(handle_device_variant);
const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
// If the input is already on the right device, then nothing to do.
if (expected_device != handle_device && maybe_copy) {
@ -229,10 +235,14 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
Device** result) {
if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) {
return errors::Unimplemented(
"The kernel cache does not work with custom devices.");
}
Device* cpu_device = ctx.HostCPU();
string device_name;
if (tensor_handle->IsRemote()) {
Device* device = tensor_handle->device();
Device* device = absl::get<Device*>(tensor_handle->device());
device_name = device != nullptr ? device->name() : cpu_device->name();
*result = (device == nullptr ? cpu_device : device);
} else if (tensor_handle->dtype == DT_RESOURCE) {
@ -251,7 +261,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
} else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
*result = cpu_device;
} else {
Device* device = tensor_handle->device();
Device* device = absl::get<Device*>(tensor_handle->device());
device_name = device != nullptr ? device->name() : cpu_device->name();
*result = (device == nullptr ? cpu_device : device);
}
@ -659,8 +669,10 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
!ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
for (int i = 0; i < op->Inputs().size(); i++) {
tensorflow::TensorHandle* input = op->Inputs()[i];
tensorflow::Device* input_device = input->device();
const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
tensorflow::Device* input_device = absl::get<Device*>(input->device());
tensorflow::Device* input_device_or_cpu =
absl::get<Device*>(input->DeviceOrHostCPU(ctx));
const string* input_device_name = &input_device_or_cpu->name();
bool serialize_resource_dtype_and_shape = false;
if (op->Device() != input_device &&
// If the expected and actual devices are on the same task, don't
@ -668,7 +680,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// when the op is executed on the device.
!ctx.OnSameTask(op->Device(), input_device)) {
if (eagerly_copy_function_remote_inputs ||
input->DeviceOrHostCPU(ctx)->IsLocal()) {
input_device_or_cpu->IsLocal()) {
tensorflow::Device* remote_cpu_device;
TF_RETURN_IF_ERROR(
ctx.CPUDeviceOnTask(op->Device(), &remote_cpu_device));
@ -678,7 +690,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// correctly determined after the kernel is selected/instantiated,
// since the op might have its inputs on host memory.
TensorHandle* handle = op->Inputs()[i];
Device* handle_device = handle->DeviceOrHostCPU(ctx);
Device* handle_device =
absl::get<Device*>(handle->DeviceOrHostCPU(ctx));
// If the input is already on the right device, then nothing to do.
if (remote_cpu_device != handle_device) {
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
@ -854,7 +867,12 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// ineligible for CPU pinning.
break;
} else if (all_inputs_eligible_for_cpu_pinning) {
Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
if (VariantDeviceIsCustom(input_device_variant)) {
all_inputs_eligible_for_cpu_pinning = false;
continue;
}
Device* input_device = absl::get<Device*>(input_device_variant);
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(tensor_handle->dtype)
<< " input device = " << input_device->name()
@ -902,6 +920,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
CustomDevice* custom_device;
if (op->EagerContext()
.FindCustomDeviceFromName(op->GetDeviceName(), &custom_device)
.ok()) {
return custom_device->Execute(op, retvals, num_retvals);
}
if (!op->Executor().Async()) {
// In sync mode, always clear error to maintain the same behavior as before.
@ -996,7 +1020,7 @@ Status EagerKernelExecute(
for (int i = 0; i < retvals.size(); ++i) {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
retvals[i]->device());
absl::get<Device*>(retvals[i]->device()));
TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
}
@ -1031,9 +1055,12 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
EagerExecutor* executor, Device* device, bool mirror,
TensorHandle** result) {
Device* send_device = h->DeviceOrHostCPU(*ctx);
bool sender_is_local = send_device->IsLocal();
auto send_device = h->DeviceOrHostCPU(*ctx);
if (VariantDeviceIsCustom(send_device)) {
return errors::Unimplemented(
"Copying a TensorHandle from a custom device is not supported.");
}
bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
bool recver_is_local = device->IsLocal();

View File

@ -71,9 +71,16 @@ Status ExecuteNodeArgs::Init(
serialize_remote_handle_ =
[ctx, &op_inputs](const int i,
eager::RemoteTensorHandle* handle) -> Status {
absl::variant<Device*, CustomDevice*> variant_device =
op_inputs[i]->device();
if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal(
"Custom devices and remote execution are currently not supported "
"together.");
}
Device* device = absl::get<Device*>(variant_device);
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
op_inputs[i], handle, op_inputs[i]->device(),
op_inputs[i]->device()->name());
op_inputs[i], handle, device, device->name());
};
#endif // !IS_MOBILE_PLATFORM
}

View File

@ -95,16 +95,25 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr
return CreateLocalHandle(t, /*d=*/nullptr, /*op_device=*/nullptr,
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr),
/*op_device=*/nullptr,
/*ctx=*/nullptr, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
EagerContext* ctx, TensorHandle** h) {
return CreateLocalHandle(t, d, d, ctx, h);
Device* op_device;
if (VariantDeviceIsCustom(d)) {
// TODO(allenl): Figure out a better op_device story for custom devices,
// since always setting it to CPU=nullptr doesn't make much sense.
op_device = nullptr;
} else {
op_device = absl::get<Device*>(d);
}
return CreateLocalHandle(t, d, op_device, ctx, h);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
Device* op_device, EagerContext* ctx,
TensorHandle** h) {
if (t.dtype() != DT_RESOURCE) {
@ -120,7 +129,7 @@ Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
}
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
DataType dtype, Device* d, Device* op_device,
DataType dtype, VariantDevice d, Device* op_device,
EagerContext* ctx)
: dtype(dtype),
device_(d),
@ -135,12 +144,14 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
is_async_(false),
is_ready_(true),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
const ResourceHandle& resource_handle, Device* d,
Device* op_device, EagerContext* ctx)
const ResourceHandle& resource_handle,
VariantDevice d, Device* op_device,
EagerContext* ctx)
: dtype(DT_RESOURCE),
device_(d),
op_device_(op_device),
@ -155,7 +166,8 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
is_ready_(true),
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
}
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
@ -170,7 +182,7 @@ Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
}
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
bool async, Device* d, Device* op_device,
bool async, VariantDevice d, Device* op_device,
Device* resource_device, DataType dtype,
EagerContext* ctx)
: dtype(dtype),
@ -187,7 +199,7 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
is_ready_(!async),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Async Local TensorHandle: " << this
<< " device: " << device_;
<< " device: " << VariantDeviceDebugString(device_);
}
#if !defined(IS_MOBILE_PLATFORM)
@ -227,7 +239,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
is_ready_(true),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Remote TensorHandle: " << this
<< " device: " << device_;
<< " device: " << VariantDeviceDebugString(device_);
}
Status TensorHandle::CreateUnshapedRemoteHandle(
@ -263,7 +275,7 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
is_ready_(false),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
<< " device: " << device_;
<< " device: " << VariantDeviceDebugString(device_);
}
#endif
@ -297,8 +309,14 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
return tensor_handle_data_->TensorValue(t);
}
Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
return (device_ == nullptr) ? ctx.HostCPU() : device_;
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
const EagerContext& ctx) const {
if (VariantDeviceIsCustom(device_)) {
return device_;
} else {
Device* d = absl::get<Device*>(device_);
return (d == nullptr) ? ctx.HostCPU() : d;
}
}
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
@ -413,7 +431,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
#if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
int32* output_num) const {
if (d != device_) {
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d);
if (mirror != remote_mirrors_.end()) {
@ -517,7 +535,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
tensorflow::Device* d) {
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
if (d != device_) {
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
mutex_lock l(mu_);
if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
return errors::Internal(
@ -593,7 +611,7 @@ void TensorHandle::Poison(Status status) {
Status TensorHandle::CopyToDevice(const EagerContext& ctx,
tensorflow::Device* dstd,
tensorflow::Tensor* output) {
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
bool is_same_device =
@ -655,6 +673,20 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
return status;
}
bool VariantDeviceIsCustom(
absl::variant<Device*, CustomDevice*> variant_device) {
return variant_device.index() != 0;
}
string VariantDeviceDebugString(
absl::variant<Device*, CustomDevice*> variant_device) {
if (VariantDeviceIsCustom(variant_device)) {
return absl::get<CustomDevice*>(variant_device)->name();
} else {
return absl::get<Device*>(variant_device)->DebugString();
}
}
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
if (ctx == nullptr) {
return nullptr;
@ -671,10 +703,14 @@ string TensorHandle::DebugString() const {
DVLOG(1) << "Calling TensorHandle::DebugString() on " << this;
string out;
strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
// Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
string device_debug = VariantDeviceDebugString(device_);
strings::StrAppend(&out, "Device: ", device_debug);
bool is_cpu =
!VariantDeviceIsCustom(device_) && absl::get<Device*>(device_) != nullptr;
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
// non-NULL if needed.
strings::StrAppend(&out, ", Tensor: ",
device_ ? "?" : tensor_handle_data_->DebugString(), "\n");
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n");
return out;
}

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
// clang-format on
#include "absl/types/variant.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/eager/context.h"
@ -43,6 +44,7 @@ limitations under the License.
#endif // IS_MOBILE_PLATFORM
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -60,15 +62,20 @@ namespace tensorflow {
// of the TFE_TensorHandle struct and the python EagerTensor class
// (unrelated to python TensorHandle).
class TensorHandle : public core::RefCounted {
// Custom devices do many of the same things as physical Devices, but have a
// much more restricted interface. We pass around ambiguous pointers since
// TensorHandles may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
// TensorHandle for dtype != DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
Device* d, Device* op_device, EagerContext* ctx);
VariantDevice d, Device* op_device, EagerContext* ctx);
// TensorHandle for dtype == DT_RESOURCE
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
const ResourceHandle& resource_handle, Device* d,
const ResourceHandle& resource_handle, VariantDevice d,
Device* op_device, EagerContext* ctx);
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
Device* d, Device* op_device, Device* resource_device,
VariantDevice d, Device* op_device, Device* resource_device,
DataType dtype, EagerContext* ctx);
#if !defined(IS_MOBILE_PLATFORM)
@ -82,9 +89,9 @@ class TensorHandle : public core::RefCounted {
// TensorHandle with no assigned device
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
// TensorHandle with device == op_device
static Status CreateLocalHandle(const class Tensor& t, Device* d,
static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLocalHandle(const class Tensor& t, Device* d,
static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
Device* op_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
@ -117,11 +124,11 @@ class TensorHandle : public core::RefCounted {
Status TensorValue(tensorflow::TensorValue* t);
Device* device() const { return device_; }
VariantDevice device() const { return device_; }
Device* op_device() const { return op_device_; }
Device* resource_device() const { return resource_device_; }
Device* DeviceOrHostCPU(const EagerContext& ctx) const;
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
Status Shape(tensorflow::TensorShape* shape);
Status NumDims(int* num_dims) const;
@ -188,8 +195,10 @@ class TensorHandle : public core::RefCounted {
// TODO(b/136608821): Move away from nullptr
bool OnHostCPU() const {
return device_ == nullptr ||
(ctx_ != nullptr && ctx_->HostCPU() == device_);
return (
device_.index() == 0 &&
(absl::get<Device*>(device_) == nullptr ||
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
}
bool IsRemote() const { return is_remote_; }
@ -216,7 +225,7 @@ class TensorHandle : public core::RefCounted {
// done and the handle is "ready".
Status WaitReady(const char* caller) const;
// TODO(b/136608821): device_ == nullptr iff Host CPU:0
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
// This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?)
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
@ -224,7 +233,7 @@ class TensorHandle : public core::RefCounted {
//
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
tensorflow::Device* const device_;
VariantDevice const device_;
// Device in which the op producing this tensor was executed. Equals to
// device_ for constant tensors.
@ -286,6 +295,12 @@ class TensorHandle : public core::RefCounted {
PartialTensorShape inference_shape_;
};
// Checks whether a VariantDevice contains a custom device.
bool VariantDeviceIsCustom(absl::variant<Device*, CustomDevice*> device);
// Wraps device->DebugString() or CustomDevice->name().
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device);
// Returns the device backing the resource. Else, returns nullptr.
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);

View File

@ -685,7 +685,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
Device* device = tensor_handle->device();
Device* device = absl::get<Device*>(tensor_handle->device());
EXPECT_EQ(device, nullptr);
auto actual = t->flat<float>();

View File

@ -77,7 +77,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
src_(src),
ctx_(ctx),
executor_(executor),
send_device_(src->DeviceOrHostCPU(*ctx)),
send_device_(absl::get<Device*>(src->DeviceOrHostCPU(*ctx))),
recv_device_(recv_device),
wire_id_(GetUniqueWireID()),
recv_op_id_(recv_op_id),
@ -145,8 +145,8 @@ void RemoteCopyNode::StartSend() {
request.set_context_id(ctx_->GetContextId());
auto* remote_op = request.add_queue()->mutable_operation();
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
src_, remote_op->add_inputs(), src_->device(),
src_->DeviceOrHostCPU(*ctx_)->name());
src_, remote_op->add_inputs(), absl::get<Device*>(src_->device()),
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
if (!status.ok()) {
captured_state_->SetSendStatus(status);
return;

View File

@ -75,8 +75,15 @@ Status RemoteMgr::GetMirroredResourceShape(
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
int64* op_id, int32* output_num) {
// TODO(allenl): Consider supporting remote handles on custom devices.
absl::variant<Device*, CustomDevice*> device = handle->device();
if (VariantDeviceIsCustom(device)) {
return errors::Unimplemented(
"Custom devices and remote execution are currently not supported "
"together.");
}
TF_RETURN_IF_ERROR(
handle->RemoteAddress(handle->device(), op_id, output_num));
handle->RemoteAddress(absl::get<Device*>(device), op_id, output_num));
tensorflow::TensorHandle* h;
TF_RETURN_IF_ERROR(
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));

View File

@ -1244,6 +1244,11 @@ StringPiece Tensor::tensor_data() const {
return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
}
void* Tensor::data() const {
if (buf_ == nullptr) return nullptr; // Don't die for empty tensors
return static_cast<void*>(buf_->data());
}
bool Tensor::SharesBufferWith(const Tensor& b) const {
return buf_ != nullptr && b.buf_ != nullptr &&
buf_->root_buffer() == b.buf_->root_buffer();

View File

@ -601,6 +601,7 @@ class Tensor {
///
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
StringPiece tensor_data() const;
void* data() const;
/// Copy the other tensor into this tensor, reshape it and reinterpret the
/// buffer's datatype. If Status::OK() is returned, the two tensors now share

View File

@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
_VALID_DEVICE_TYPES = {"CPU", "GPU", "TPU"}
_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM"})
# ==============================================================================

View File

@ -95,7 +95,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
if (call->eager) {
TensorHandle* handle;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
t, ctx->CanonicalDevice(device), ctx, &handle));
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle));
arg = EagerTensorFromHandle(new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)});
if (arg == nullptr) {
@ -149,7 +149,11 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
auto handle = down_cast<tensorflow::TensorHandleInterface*>(
EagerTensor_Handle(eager_tensor)->handle.get())
->Handle();
Device* actual_device = handle->device();
if (VariantDeviceIsCustom(handle->device())) {
return errors::Unimplemented(
"Custom devices are currently not supported with PyFuncs.");
}
Device* actual_device = absl::get<Device*>(handle->device());
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
// actual_device may be nullptr, which implies local CPU.
if (expected_device == actual_device) return Status::OK();

View File

@ -294,7 +294,8 @@ struct Converter {
}
tensorflow::TensorHandle* handle = nullptr;
auto status = tensorflow::TensorHandle::CreateLocalHandle(
result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &handle);
if (!status.ok()) {
return status;
}
@ -609,7 +610,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
if (cppstatus.ok()) {
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context,
&handle);
}
if (!cppstatus.ok()) {
PyErr_SetString(PyExc_ValueError,
@ -806,7 +808,8 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
TensorShape(state.inferred_shape));
status = tensorflow::TensorHandle::CreateLocalHandle(
tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h);
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
ctx->context, &h);
if (!status.ok()) {
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
return nullptr;