Experimental API for custom devices in TFE.
Custom devices are an experimental hook into eager op execution, allowing experimentation outside the TensorFlow codebase. These devices do not work in traced code at the moment. PiperOrigin-RevId: 293615055 Change-Id: I031da213e964caa7d4e11e0f491a3985d034b175
This commit is contained in:
parent
91a3741164
commit
a4064a389e
@ -2,6 +2,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cuda_library",
|
||||
@ -289,6 +290,27 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"custom_device_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tape",
|
||||
hdrs = ["tape.h"],
|
||||
@ -301,7 +323,10 @@ cc_library(
|
||||
|
||||
filegroup(
|
||||
name = "headers",
|
||||
srcs = ["c_api.h"],
|
||||
srcs = [
|
||||
"c_api.h",
|
||||
"c_api_experimental.h",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
|
@ -103,7 +103,12 @@ const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
|
||||
return op_def;
|
||||
}
|
||||
|
||||
bool IsCPU(const tensorflow::Device* d) {
|
||||
bool IsCPU(
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
return false;
|
||||
}
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
@ -1009,6 +1014,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<CustomDevice*>(handle_->device())->name().c_str();
|
||||
}
|
||||
tensorflow::Device* d = handle_->op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
@ -1029,9 +1037,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* d = handle_->device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
return absl::get<tensorflow::CustomDevice*>(handle_->device())
|
||||
->name()
|
||||
.c_str();
|
||||
} else {
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
}
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
@ -1065,6 +1079,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
|
||||
if (!IsValid(status)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(handle_->device())) {
|
||||
tensorflow::CustomDevice* custom_device =
|
||||
absl::get<tensorflow::CustomDevice*>(handle_->device());
|
||||
tensorflow::TensorHandle* copy;
|
||||
*status = custom_device->CopyTensorFromDevice(
|
||||
handle_, "/job:localhost/task:0/replica:0/device:CPU:0", ©);
|
||||
if (status->ok()) {
|
||||
return TensorHandleInterface(copy).Resolve(status);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(agarwal): move this implementation inside TFE_TensorHandle.
|
||||
if (handle_->IsRemote()) {
|
||||
@ -1110,6 +1136,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
}
|
||||
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -1117,8 +1148,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
"handle.");
|
||||
return nullptr;
|
||||
}
|
||||
if (handle->device() != nullptr) {
|
||||
status->status = handle->device()->Sync();
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
if (device != nullptr) {
|
||||
status->status = device->Sync();
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1137,12 +1169,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg, TF_Status* status) {
|
||||
tensorflow::Device* device;
|
||||
tensorflow::Device* device = nullptr;
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
status->status =
|
||||
context->FindCustomDeviceFromName(device_name, &custom_device);
|
||||
if (!status->status.ok()) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::vector<tensorflow::int64> dimvec(num_dims);
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
@ -1166,8 +1203,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
|
||||
if (custom_device == nullptr) {
|
||||
variant_device = device;
|
||||
} else {
|
||||
variant_device = custom_device;
|
||||
}
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, device, context, &ret_handle);
|
||||
t, variant_device, context, &ret_handle);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -1508,8 +1551,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
tensorflow::EagerContext* context = ctx->context;
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
if (!status->status.ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
&handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = h->handle->DeviceName(&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
@ -1648,3 +1725,94 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
}
|
||||
}
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
|
||||
: device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(&tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* tensor,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
tensor->Ref();
|
||||
TFE_TensorHandle tensor_handle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
&tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> inputs;
|
||||
inputs.reserve(op->Inputs().size());
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
op->Inputs()[i]->Ref();
|
||||
inputs.push_back(new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
op->Inputs()[i])});
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
// TODO(allenl): figure out how to get attrs from EagerOperation
|
||||
TF_Status status;
|
||||
device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
|
||||
num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
outputs[i]->handle.get())
|
||||
->Handle();
|
||||
retvals[i]->Ref();
|
||||
}
|
||||
}
|
||||
for (auto inp : inputs) {
|
||||
delete inp;
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
|
||||
ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
|
||||
}
|
||||
|
||||
#ifdef TENSORFLOW_EAGER_USE_XLA
|
||||
tensorflow::Device* device = handle_->device();
|
||||
tensorflow::Device* device = absl::get<Device*>(handle_->device());
|
||||
|
||||
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
|
||||
tensorflow::XlaDevice* xla_device =
|
||||
|
@ -463,6 +463,57 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
|
||||
TF_Buffer* buf);
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 0
|
||||
|
||||
// Struct to be filled in
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info);
|
||||
|
||||
// Method to execute an operation.
|
||||
// TODO(allenl) figure out a generic way of passing attrs here
|
||||
void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
} TFE_CustomDevice;
|
||||
|
||||
// Registers a custom device for use with eager execution.
|
||||
//
|
||||
// Eager operations may be placed on this device, e.g. `with
|
||||
// tf.device("CUSTOM"):` from Python if `device_name` for this call is
|
||||
// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
|
||||
//
|
||||
// The custom device defines copy operations for moving TensorHandles on and
|
||||
// off, and an an execution operation for named operations. Often execution will
|
||||
// simply wrap op execution on one or more physical devices.
|
||||
//
|
||||
// device_info is an opaque caller-defined type stored with the custom device
|
||||
// which is passed to the functions referenced in the TFE_CustomDevice struct
|
||||
// `device` (execute, delete_device, etc.). It can for example contain the
|
||||
// names of wrapped devices.
|
||||
//
|
||||
// There are currently no graph semantics implemented for registered custom
|
||||
// devices, so executing tf.functions which contain operations placed on custom
|
||||
// devices will fail.
|
||||
//
|
||||
// This API is highly experimental, and in particular is expected to change when
|
||||
// it starts supporting operations with attributes and when tf.function support
|
||||
// is added.
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
159
tensorflow/c/eager/custom_device_test.cc
Normal file
159
tensorflow/c/eager/custom_device_test.cc
Normal file
@ -0,0 +1,159 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// A simple logging device to test custom device registration.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace {
|
||||
|
||||
struct LoggingDevice {
|
||||
TFE_Context* ctx;
|
||||
tensorflow::string device_name;
|
||||
tensorflow::string underlying_device;
|
||||
// Set to true whenever a TensorHandle is copied onto the device
|
||||
bool* arrived_flag;
|
||||
};
|
||||
|
||||
struct LoggedTensor {
|
||||
TFE_TensorHandle* tensor;
|
||||
LoggedTensor() = delete;
|
||||
explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
|
||||
~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
|
||||
};
|
||||
|
||||
void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
|
||||
delete reinterpret_cast<LoggedTensor*>(data);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* MakeLoggedTensorHandle(
|
||||
TFE_Context* ctx, const tensorflow::string& logging_device_name,
|
||||
std::unique_ptr<LoggedTensor> t, TF_Status* status) {
|
||||
std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
}
|
||||
auto dtype = TFE_TensorHandleDataType(t->tensor);
|
||||
return TFE_NewTensorHandleFromDeviceMemory(
|
||||
ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
|
||||
t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
|
||||
TF_Status* status, void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
|
||||
tensor, dev->ctx, dev->underlying_device.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return nullptr;
|
||||
auto dst = std::make_unique<LoggedTensor>(t);
|
||||
*(dev->arrived_flag) = true;
|
||||
return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
|
||||
status);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a logging device.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
|
||||
const char* operation_name, int* num_outputs,
|
||||
TFE_TensorHandle** outputs, TF_Status* s,
|
||||
void* device_info) {
|
||||
LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
|
||||
TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
|
||||
for (int j = 0; j < num_inputs; ++j) {
|
||||
TFE_TensorHandle* input = inputs[j];
|
||||
const char* input_device = TFE_TensorHandleDeviceName(input, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
if (dev->device_name == input_device) {
|
||||
LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
|
||||
TFE_TensorHandleDevicePointer(input, s));
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
TFE_OpAddInput(op, t->tensor, s);
|
||||
} else {
|
||||
TFE_OpAddInput(op, input, s);
|
||||
}
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
|
||||
TFE_Execute(op, op_outputs.data(), num_outputs, s);
|
||||
TFE_DeleteOp(op);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
std::vector<TFE_TensorHandle*> unwrapped_outputs;
|
||||
for (auto* handle : op_outputs) {
|
||||
unwrapped_outputs.push_back(handle);
|
||||
}
|
||||
for (int i = 0; i < *num_outputs; ++i) {
|
||||
auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
|
||||
outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
|
||||
std::move(logged_tensor), s);
|
||||
}
|
||||
}
|
||||
|
||||
void DeleteLoggingDevice(void* device_info) {
|
||||
delete reinterpret_cast<LoggingDevice*>(device_info);
|
||||
}
|
||||
|
||||
void RegisterLoggingDevice(TFE_Context* context, const char* name,
|
||||
bool* arrived_flag) {
|
||||
TFE_CustomDevice custom_device;
|
||||
custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
|
||||
custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
|
||||
custom_device.delete_device = &DeleteLoggingDevice;
|
||||
custom_device.execute = &LoggingDeviceExecute;
|
||||
LoggingDevice* device = new LoggingDevice;
|
||||
device->ctx = context;
|
||||
device->arrived_flag = arrived_flag;
|
||||
device->device_name = name;
|
||||
device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_RegisterCustomDevice(context, custom_device, name, device);
|
||||
}
|
||||
|
||||
TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* context = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
bool arrived = false;
|
||||
const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
RegisterLoggingDevice(context, name, &arrived);
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
ASSERT_FALSE(arrived);
|
||||
TFE_TensorHandle* hdevice =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
|
||||
ASSERT_TRUE(arrived);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteTensorHandle(hdevice);
|
||||
TFE_DeleteContext(context);
|
||||
}
|
||||
|
||||
} // namespace
|
@ -157,6 +157,7 @@ tf_cuda_library(
|
||||
],
|
||||
"//conditions:default": [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -702,6 +702,21 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
|
||||
return status;
|
||||
}
|
||||
|
||||
Status EagerContext::FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const {
|
||||
auto dev_it = custom_devices_.find(device_name);
|
||||
if (dev_it == custom_devices_.end()) {
|
||||
return errors::InvalidArgument(device_name, " unknown device.");
|
||||
}
|
||||
*dev = dev_it->second.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void EagerContext::RegisterCustomDevice(const string& device_name,
|
||||
std::unique_ptr<CustomDevice> device) {
|
||||
custom_devices_[device_name] = std::move(device);
|
||||
}
|
||||
|
||||
bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
|
||||
if (first == nullptr) first = HostCPU();
|
||||
if (second == nullptr) second = HostCPU();
|
||||
|
@ -106,6 +106,24 @@ class RunMetadataListener {
|
||||
virtual void BeforeClearRunMetadata() = 0;
|
||||
};
|
||||
|
||||
class TensorHandle;
|
||||
class EagerOperation;
|
||||
|
||||
class CustomDevice {
|
||||
public:
|
||||
virtual ~CustomDevice() {}
|
||||
virtual const string& name() = 0;
|
||||
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const string& target_device_name,
|
||||
TensorHandle** result) = 0;
|
||||
|
||||
virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals) = 0;
|
||||
};
|
||||
|
||||
class EagerContext : public core::RefCounted {
|
||||
public:
|
||||
static const uint64 kInvalidContextId = 0;
|
||||
@ -416,6 +434,12 @@ class EagerContext : public core::RefCounted {
|
||||
|
||||
Status FindDeviceFromName(const char* device_name, Device** device) const;
|
||||
|
||||
Status FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const;
|
||||
|
||||
void RegisterCustomDevice(const string& name,
|
||||
std::unique_ptr<CustomDevice> device);
|
||||
|
||||
bool OnSameTask(const Device* first, const Device* second) const;
|
||||
// Gets the CPU device on the task of device.
|
||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||
@ -492,6 +516,7 @@ class EagerContext : public core::RefCounted {
|
||||
std::vector<DeviceType> prioritized_device_type_list_;
|
||||
Rendezvous* rendezvous_;
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
||||
|
||||
FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
|
||||
|
||||
|
@ -177,7 +177,13 @@ Status ValidateInputTypeAndPlacement(
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
Device* expected_device = kernel->InputDevice(i);
|
||||
Device* handle_device = handle->DeviceOrHostCPU(*ctx);
|
||||
auto handle_device_variant = handle->DeviceOrHostCPU(*ctx);
|
||||
if (VariantDeviceIsCustom(handle_device_variant)) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices and remote execution are not yet supported "
|
||||
"together.");
|
||||
}
|
||||
Device* handle_device = absl::get<Device*>(handle_device_variant);
|
||||
const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
|
||||
// If the input is already on the right device, then nothing to do.
|
||||
if (expected_device != handle_device && maybe_copy) {
|
||||
@ -229,10 +235,14 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
|
||||
|
||||
Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
|
||||
Device** result) {
|
||||
if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) {
|
||||
return errors::Unimplemented(
|
||||
"The kernel cache does not work with custom devices.");
|
||||
}
|
||||
Device* cpu_device = ctx.HostCPU();
|
||||
string device_name;
|
||||
if (tensor_handle->IsRemote()) {
|
||||
Device* device = tensor_handle->device();
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
device_name = device != nullptr ? device->name() : cpu_device->name();
|
||||
*result = (device == nullptr ? cpu_device : device);
|
||||
} else if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
@ -251,7 +261,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
|
||||
} else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
|
||||
*result = cpu_device;
|
||||
} else {
|
||||
Device* device = tensor_handle->device();
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
device_name = device != nullptr ? device->name() : cpu_device->name();
|
||||
*result = (device == nullptr ? cpu_device : device);
|
||||
}
|
||||
@ -659,8 +669,10 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
!ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
tensorflow::Device* input_device = input->device();
|
||||
const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
|
||||
tensorflow::Device* input_device = absl::get<Device*>(input->device());
|
||||
tensorflow::Device* input_device_or_cpu =
|
||||
absl::get<Device*>(input->DeviceOrHostCPU(ctx));
|
||||
const string* input_device_name = &input_device_or_cpu->name();
|
||||
bool serialize_resource_dtype_and_shape = false;
|
||||
if (op->Device() != input_device &&
|
||||
// If the expected and actual devices are on the same task, don't
|
||||
@ -668,7 +680,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// when the op is executed on the device.
|
||||
!ctx.OnSameTask(op->Device(), input_device)) {
|
||||
if (eagerly_copy_function_remote_inputs ||
|
||||
input->DeviceOrHostCPU(ctx)->IsLocal()) {
|
||||
input_device_or_cpu->IsLocal()) {
|
||||
tensorflow::Device* remote_cpu_device;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx.CPUDeviceOnTask(op->Device(), &remote_cpu_device));
|
||||
@ -678,7 +690,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// correctly determined after the kernel is selected/instantiated,
|
||||
// since the op might have its inputs on host memory.
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
Device* handle_device = handle->DeviceOrHostCPU(ctx);
|
||||
Device* handle_device =
|
||||
absl::get<Device*>(handle->DeviceOrHostCPU(ctx));
|
||||
// If the input is already on the right device, then nothing to do.
|
||||
if (remote_cpu_device != handle_device) {
|
||||
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
|
||||
@ -854,7 +867,12 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
|
||||
// ineligible for CPU pinning.
|
||||
break;
|
||||
} else if (all_inputs_eligible_for_cpu_pinning) {
|
||||
Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
|
||||
auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
|
||||
if (VariantDeviceIsCustom(input_device_variant)) {
|
||||
all_inputs_eligible_for_cpu_pinning = false;
|
||||
continue;
|
||||
}
|
||||
Device* input_device = absl::get<Device*>(input_device_variant);
|
||||
DVLOG(2) << "for op " << op->Name() << " input " << i << " "
|
||||
<< DataTypeString(tensor_handle->dtype)
|
||||
<< " input device = " << input_device->name()
|
||||
@ -902,6 +920,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
[&] { return absl::StrCat("EagerExecute: ", op->Name()); },
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
|
||||
CustomDevice* custom_device;
|
||||
if (op->EagerContext()
|
||||
.FindCustomDeviceFromName(op->GetDeviceName(), &custom_device)
|
||||
.ok()) {
|
||||
return custom_device->Execute(op, retvals, num_retvals);
|
||||
}
|
||||
|
||||
if (!op->Executor().Async()) {
|
||||
// In sync mode, always clear error to maintain the same behavior as before.
|
||||
@ -996,7 +1020,7 @@ Status EagerKernelExecute(
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
|
||||
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
retvals[i]->device());
|
||||
absl::get<Device*>(retvals[i]->device()));
|
||||
|
||||
TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
|
||||
}
|
||||
@ -1031,9 +1055,12 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
EagerExecutor* executor, Device* device, bool mirror,
|
||||
TensorHandle** result) {
|
||||
Device* send_device = h->DeviceOrHostCPU(*ctx);
|
||||
|
||||
bool sender_is_local = send_device->IsLocal();
|
||||
auto send_device = h->DeviceOrHostCPU(*ctx);
|
||||
if (VariantDeviceIsCustom(send_device)) {
|
||||
return errors::Unimplemented(
|
||||
"Copying a TensorHandle from a custom device is not supported.");
|
||||
}
|
||||
bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
|
||||
|
||||
bool recver_is_local = device->IsLocal();
|
||||
|
||||
|
@ -71,9 +71,16 @@ Status ExecuteNodeArgs::Init(
|
||||
serialize_remote_handle_ =
|
||||
[ctx, &op_inputs](const int i,
|
||||
eager::RemoteTensorHandle* handle) -> Status {
|
||||
absl::variant<Device*, CustomDevice*> variant_device =
|
||||
op_inputs[i]->device();
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return errors::Internal(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
Device* device = absl::get<Device*>(variant_device);
|
||||
return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
op_inputs[i], handle, op_inputs[i]->device(),
|
||||
op_inputs[i]->device()->name());
|
||||
op_inputs[i], handle, device, device->name());
|
||||
};
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
|
@ -95,16 +95,25 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
|
||||
TensorHandle** h) {
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
return CreateLocalHandle(t, /*d=*/nullptr, /*op_device=*/nullptr,
|
||||
return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr),
|
||||
/*op_device=*/nullptr,
|
||||
/*ctx=*/nullptr, h);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
|
||||
EagerContext* ctx, TensorHandle** h) {
|
||||
return CreateLocalHandle(t, d, d, ctx, h);
|
||||
Device* op_device;
|
||||
if (VariantDeviceIsCustom(d)) {
|
||||
// TODO(allenl): Figure out a better op_device story for custom devices,
|
||||
// since always setting it to CPU=nullptr doesn't make much sense.
|
||||
op_device = nullptr;
|
||||
} else {
|
||||
op_device = absl::get<Device*>(d);
|
||||
}
|
||||
return CreateLocalHandle(t, d, op_device, ctx, h);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
|
||||
Device* op_device, EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
if (t.dtype() != DT_RESOURCE) {
|
||||
@ -120,7 +129,7 @@ Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
DataType dtype, Device* d, Device* op_device,
|
||||
DataType dtype, VariantDevice d, Device* op_device,
|
||||
EagerContext* ctx)
|
||||
: dtype(dtype),
|
||||
device_(d),
|
||||
@ -135,12 +144,14 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
is_async_(false),
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
const ResourceHandle& resource_handle, Device* d,
|
||||
Device* op_device, EagerContext* ctx)
|
||||
const ResourceHandle& resource_handle,
|
||||
VariantDevice d, Device* op_device,
|
||||
EagerContext* ctx)
|
||||
: dtype(DT_RESOURCE),
|
||||
device_(d),
|
||||
op_device_(op_device),
|
||||
@ -155,7 +166,8 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
is_ready_(true),
|
||||
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
|
||||
@ -170,7 +182,7 @@ Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
|
||||
bool async, Device* d, Device* op_device,
|
||||
bool async, VariantDevice d, Device* op_device,
|
||||
Device* resource_device, DataType dtype,
|
||||
EagerContext* ctx)
|
||||
: dtype(dtype),
|
||||
@ -187,7 +199,7 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
|
||||
is_ready_(!async),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Async Local TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -227,7 +239,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
|
||||
is_ready_(true),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Remote TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateUnshapedRemoteHandle(
|
||||
@ -263,7 +275,7 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
|
||||
is_ready_(false),
|
||||
tensor_handle_data_(std::move(t)) {
|
||||
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||
<< " device: " << device_;
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -297,8 +309,14 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
|
||||
return tensor_handle_data_->TensorValue(t);
|
||||
}
|
||||
|
||||
Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
||||
return (device_ == nullptr) ? ctx.HostCPU() : device_;
|
||||
TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
|
||||
const EagerContext& ctx) const {
|
||||
if (VariantDeviceIsCustom(device_)) {
|
||||
return device_;
|
||||
} else {
|
||||
Device* d = absl::get<Device*>(device_);
|
||||
return (d == nullptr) ? ctx.HostCPU() : d;
|
||||
}
|
||||
}
|
||||
|
||||
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
|
||||
@ -413,7 +431,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
|
||||
int32* output_num) const {
|
||||
if (d != device_) {
|
||||
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
|
||||
tf_shared_lock l(mu_);
|
||||
auto mirror = remote_mirrors_.find(d);
|
||||
if (mirror != remote_mirrors_.end()) {
|
||||
@ -517,7 +535,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
|
||||
tensorflow::Device* d) {
|
||||
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d != device_) {
|
||||
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
|
||||
mutex_lock l(mu_);
|
||||
if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
|
||||
return errors::Internal(
|
||||
@ -593,7 +611,7 @@ void TensorHandle::Poison(Status status) {
|
||||
Status TensorHandle::CopyToDevice(const EagerContext& ctx,
|
||||
tensorflow::Device* dstd,
|
||||
tensorflow::Tensor* output) {
|
||||
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
||||
tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
|
||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
||||
bool is_same_device =
|
||||
@ -655,6 +673,20 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
|
||||
return status;
|
||||
}
|
||||
|
||||
bool VariantDeviceIsCustom(
|
||||
absl::variant<Device*, CustomDevice*> variant_device) {
|
||||
return variant_device.index() != 0;
|
||||
}
|
||||
|
||||
string VariantDeviceDebugString(
|
||||
absl::variant<Device*, CustomDevice*> variant_device) {
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return absl::get<CustomDevice*>(variant_device)->name();
|
||||
} else {
|
||||
return absl::get<Device*>(variant_device)->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return nullptr;
|
||||
@ -671,10 +703,14 @@ string TensorHandle::DebugString() const {
|
||||
DVLOG(1) << "Calling TensorHandle::DebugString() on " << this;
|
||||
|
||||
string out;
|
||||
strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
|
||||
// Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
|
||||
string device_debug = VariantDeviceDebugString(device_);
|
||||
strings::StrAppend(&out, "Device: ", device_debug);
|
||||
bool is_cpu =
|
||||
!VariantDeviceIsCustom(device_) && absl::get<Device*>(device_) != nullptr;
|
||||
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
|
||||
// non-NULL if needed.
|
||||
strings::StrAppend(&out, ", Tensor: ",
|
||||
device_ ? "?" : tensor_handle_data_->DebugString(), "\n");
|
||||
is_cpu ? tensor_handle_data_->DebugString() : "?", "\n");
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/platform.h"
|
||||
// clang-format on
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -43,6 +44,7 @@ limitations under the License.
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -60,15 +62,20 @@ namespace tensorflow {
|
||||
// of the TFE_TensorHandle struct and the python EagerTensor class
|
||||
// (unrelated to python TensorHandle).
|
||||
class TensorHandle : public core::RefCounted {
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
// much more restricted interface. We pass around ambiguous pointers since
|
||||
// TensorHandles may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
// TensorHandle for dtype != DT_RESOURCE
|
||||
TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
|
||||
Device* d, Device* op_device, EagerContext* ctx);
|
||||
VariantDevice d, Device* op_device, EagerContext* ctx);
|
||||
// TensorHandle for dtype == DT_RESOURCE
|
||||
TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
|
||||
const ResourceHandle& resource_handle, Device* d,
|
||||
const ResourceHandle& resource_handle, VariantDevice d,
|
||||
Device* op_device, EagerContext* ctx);
|
||||
TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
|
||||
Device* d, Device* op_device, Device* resource_device,
|
||||
VariantDevice d, Device* op_device, Device* resource_device,
|
||||
DataType dtype, EagerContext* ctx);
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -82,9 +89,9 @@ class TensorHandle : public core::RefCounted {
|
||||
// TensorHandle with no assigned device
|
||||
static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
|
||||
// TensorHandle with device == op_device
|
||||
static Status CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static Status CreateLocalHandle(const class Tensor& t, Device* d,
|
||||
static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
|
||||
Device* op_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
|
||||
@ -117,11 +124,11 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
Status TensorValue(tensorflow::TensorValue* t);
|
||||
|
||||
Device* device() const { return device_; }
|
||||
VariantDevice device() const { return device_; }
|
||||
Device* op_device() const { return op_device_; }
|
||||
Device* resource_device() const { return resource_device_; }
|
||||
|
||||
Device* DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||
|
||||
Status Shape(tensorflow::TensorShape* shape);
|
||||
Status NumDims(int* num_dims) const;
|
||||
@ -188,8 +195,10 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
bool OnHostCPU() const {
|
||||
return device_ == nullptr ||
|
||||
(ctx_ != nullptr && ctx_->HostCPU() == device_);
|
||||
return (
|
||||
device_.index() == 0 &&
|
||||
(absl::get<Device*>(device_) == nullptr ||
|
||||
(ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
|
||||
}
|
||||
|
||||
bool IsRemote() const { return is_remote_; }
|
||||
@ -216,7 +225,7 @@ class TensorHandle : public core::RefCounted {
|
||||
// done and the handle is "ready".
|
||||
Status WaitReady(const char* caller) const;
|
||||
|
||||
// TODO(b/136608821): device_ == nullptr iff Host CPU:0
|
||||
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
|
||||
// This was expedient, but perhaps worth revisiting ('device_' should always
|
||||
// be a valid pointer?)
|
||||
// This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
|
||||
@ -224,7 +233,7 @@ class TensorHandle : public core::RefCounted {
|
||||
//
|
||||
// TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
|
||||
// TFE_TensorHandle does not outlive the TFE_Context from which it came?
|
||||
tensorflow::Device* const device_;
|
||||
VariantDevice const device_;
|
||||
|
||||
// Device in which the op producing this tensor was executed. Equals to
|
||||
// device_ for constant tensors.
|
||||
@ -286,6 +295,12 @@ class TensorHandle : public core::RefCounted {
|
||||
PartialTensorShape inference_shape_;
|
||||
};
|
||||
|
||||
// Checks whether a VariantDevice contains a custom device.
|
||||
bool VariantDeviceIsCustom(absl::variant<Device*, CustomDevice*> device);
|
||||
|
||||
// Wraps device->DebugString() or CustomDevice->name().
|
||||
string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device);
|
||||
|
||||
// Returns the device backing the resource. Else, returns nullptr.
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
|
||||
|
||||
|
@ -685,7 +685,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
|
||||
Device* device = tensor_handle->device();
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
EXPECT_EQ(device, nullptr);
|
||||
|
||||
auto actual = t->flat<float>();
|
||||
|
@ -77,7 +77,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
|
||||
src_(src),
|
||||
ctx_(ctx),
|
||||
executor_(executor),
|
||||
send_device_(src->DeviceOrHostCPU(*ctx)),
|
||||
send_device_(absl::get<Device*>(src->DeviceOrHostCPU(*ctx))),
|
||||
recv_device_(recv_device),
|
||||
wire_id_(GetUniqueWireID()),
|
||||
recv_op_id_(recv_op_id),
|
||||
@ -145,8 +145,8 @@ void RemoteCopyNode::StartSend() {
|
||||
request.set_context_id(ctx_->GetContextId());
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
src_, remote_op->add_inputs(), src_->device(),
|
||||
src_->DeviceOrHostCPU(*ctx_)->name());
|
||||
src_, remote_op->add_inputs(), absl::get<Device*>(src_->device()),
|
||||
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
|
||||
if (!status.ok()) {
|
||||
captured_state_->SetSendStatus(status);
|
||||
return;
|
||||
|
@ -75,8 +75,15 @@ Status RemoteMgr::GetMirroredResourceShape(
|
||||
|
||||
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
int64* op_id, int32* output_num) {
|
||||
// TODO(allenl): Consider supporting remote handles on custom devices.
|
||||
absl::variant<Device*, CustomDevice*> device = handle->device();
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
handle->RemoteAddress(handle->device(), op_id, output_num));
|
||||
handle->RemoteAddress(absl::get<Device*>(device), op_id, output_num));
|
||||
tensorflow::TensorHandle* h;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
|
||||
|
@ -1244,6 +1244,11 @@ StringPiece Tensor::tensor_data() const {
|
||||
return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
|
||||
}
|
||||
|
||||
void* Tensor::data() const {
|
||||
if (buf_ == nullptr) return nullptr; // Don't die for empty tensors
|
||||
return static_cast<void*>(buf_->data());
|
||||
}
|
||||
|
||||
bool Tensor::SharesBufferWith(const Tensor& b) const {
|
||||
return buf_ != nullptr && b.buf_ != nullptr &&
|
||||
buf_->root_buffer() == b.buf_->root_buffer();
|
||||
|
@ -601,6 +601,7 @@ class Tensor {
|
||||
///
|
||||
/// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
|
||||
StringPiece tensor_data() const;
|
||||
void* data() const;
|
||||
|
||||
/// Copy the other tensor into this tensor, reshape it and reinterpret the
|
||||
/// buffer's datatype. If Status::OK() is returned, the two tensors now share
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_VALID_DEVICE_TYPES = {"CPU", "GPU", "TPU"}
|
||||
_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM"})
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
@ -95,7 +95,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
|
||||
if (call->eager) {
|
||||
TensorHandle* handle;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
t, ctx->CanonicalDevice(device), ctx, &handle));
|
||||
t, ctx->CanonicalDevice(device), nullptr, ctx, &handle));
|
||||
arg = EagerTensorFromHandle(new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)});
|
||||
if (arg == nullptr) {
|
||||
@ -149,7 +149,11 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
||||
auto handle = down_cast<tensorflow::TensorHandleInterface*>(
|
||||
EagerTensor_Handle(eager_tensor)->handle.get())
|
||||
->Handle();
|
||||
Device* actual_device = handle->device();
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices are currently not supported with PyFuncs.");
|
||||
}
|
||||
Device* actual_device = absl::get<Device*>(handle->device());
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
|
||||
// actual_device may be nullptr, which implies local CPU.
|
||||
if (expected_device == actual_device) return Status::OK();
|
||||
|
@ -294,7 +294,8 @@ struct Converter {
|
||||
}
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
auto status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
||||
ctx->context, &handle);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
@ -609,7 +610,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
|
||||
auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
|
||||
if (cppstatus.ok()) {
|
||||
cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
|
||||
t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context,
|
||||
&handle);
|
||||
}
|
||||
if (!cppstatus.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
@ -806,7 +808,8 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
|
||||
Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
|
||||
TensorShape(state.inferred_shape));
|
||||
status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h);
|
||||
tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
|
||||
ctx->context, &h);
|
||||
if (!status.ok()) {
|
||||
PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
|
||||
return nullptr;
|
||||
|
Loading…
Reference in New Issue
Block a user