Stop holding custom devices in TensorHandles

My hope is to not change the custom device API much, just clean up the implementation.

Previously TensorHandles held Tensors which held the void* custom device tensor handle data. This required a bunch of special cases, mostly because the TensorHandle's device wasn't a physical device.

Now EagerOperations still accept custom device TensorHandles, but deal with them before execution (either by copying them off the custom device or by executing the operation on a custom device). This means the rest of the runtime can assume TensorHandles are on physical devices, and gives custom device tensor handles some freedom to evolve.

Rolling cl/350684489 forward with a fix for making packed tensors from custom device tensors. This requires one new custom device method to maintain the previous (accidental) functionality.

PiperOrigin-RevId: 351406348
Change-Id: I9c2ffd40a687b06434fab40e2db9e90129b9f2b7
This commit is contained in:
Allen Lavoie 2021-01-12 11:03:38 -08:00 committed by TensorFlower Gardener
parent 0e43a67584
commit 253111e23b
27 changed files with 524 additions and 432 deletions

View File

@ -52,7 +52,6 @@ tf_cuda_library(
":immediate_execution_operation",
":immediate_execution_tensor_handle",
":immediate_execution_distributed_manager",
":abstract_tensor_handle",
":tfe_context_internal",
":tfe_cancellation_manager_internal",
":tfe_executor_internal",
@ -73,6 +72,7 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:custom_device",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
@ -86,6 +86,7 @@ tf_cuda_library(
],
}) + [
"@com_google_absl//absl/memory",
":abstract_tensor_handle",
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
"//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime",
@ -480,12 +481,17 @@ cc_library(
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:status",
],
deps = select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:status",
],
}),
)
cc_library(

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
@ -363,13 +364,21 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
return t->data();
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
tensorflow::unwrap(h);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
unwrapped_handle)
->DevicePointer();
}
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::TensorHandleFromInterface(unwrapped_handle);
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
status->status = tensorflow::errors::InvalidArgument(
@ -377,7 +386,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
handle->TypeString(), " tensor handle.");
return nullptr;
}
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
tensorflow::Device* device(handle->device());
if (device != nullptr) {
status->status = device->Sync();
if (!status->status.ok()) {
@ -393,6 +402,125 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
static_cast<const void*>(tensor->tensor_data().data()));
}
namespace tensorflow {
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
string name)
: context_(context), device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
ImmediateExecutionTensorHandle* handle,
ImmediateExecutionTensorHandle** result) override {
handle->Ref();
TF_Status status;
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
context_, tensorflow::wrap(handle), &status, info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::unwrap(result_handle);
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
ImmediateExecutionTensorHandle* handle,
const tensorflow::string& target_device_name,
ImmediateExecutionTensorHandle** result) override {
TF_Status status;
handle->Ref();
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::unwrap(result_handle);
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status Execute(const ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::unwrap(outputs[i]);
retvals[i]->Ref();
TFE_DeleteTensorHandle(outputs[i]);
}
}
return status.status;
}
tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
ImmediateExecutionTensorHandle** result) override {
TF_Status status;
*result = tensorflow::unwrap(device_.pack(context_,
tensorflow::wrap(handles.data()),
handles.size(), &status, info_));
return status.status;
}
private:
TFE_Context* context_;
TFE_CustomDevice device_;
void* info_;
string name_;
};
// An adapter which wraps the shape/data produced by C custom devices and uses
// it to implement custom device methods.
class CAPICustomDeviceTensorHandle
: public tensorflow::CustomDeviceTensorHandle {
public:
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
tensorflow::CustomDevice* device,
tensorflow::DataType dtype, void* data,
size_t len, std::vector<tensorflow::int64> shape,
void (*deallocator)(void* data, size_t len,
void* arg),
void* deallocator_arg)
: tensorflow::CustomDeviceTensorHandle(context, device, dtype),
data_(data),
len_(len),
shape_(shape),
deallocator_(deallocator),
deallocator_arg_(deallocator_arg) {}
~CAPICustomDeviceTensorHandle() override {
deallocator_(data_, len_, deallocator_arg_);
}
void* DevicePointer() const override { return data_; }
Status NumDims(int* num_dims) const override {
*num_dims = shape_.size();
return Status::OK();
}
Status Dim(int dim_index, int64* dim) const override {
*dim = shape_[dim_index];
return Status::OK();
}
private:
void* const data_;
size_t len_;
std::vector<tensorflow::int64> shape_;
void (*const deallocator_)(void* data, size_t len, void* arg);
void* const deallocator_arg_;
};
} // namespace
} // namespace tensorflow
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
const int64_t* dims, int num_dims, void* data, size_t len,
@ -417,6 +545,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
for (int i = 0; i < num_dims; ++i) {
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
}
if (custom_device != nullptr) {
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
context, custom_device,
*reinterpret_cast<tensorflow::DataType*>(&dtype), data, len, dimvec,
deallocator, deallocator_arg));
}
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
// the device?
@ -427,13 +561,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
if (custom_device == nullptr) {
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context));
} else {
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context));
}
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context));
}
// This function will block till the operation that produces `h` has
@ -961,74 +1090,14 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
} // namespace tensorflow
namespace {
class CustomDeviceAPI : public tensorflow::CustomDevice {
public:
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
string name)
: context_(context), device_(device), info_(info), name_(name) {}
~CustomDeviceAPI() override { device_.delete_device(info_); }
const string& name() override { return name_; }
tensorflow::Status CopyTensorToDevice(
tensorflow::TensorHandle* handle,
tensorflow::TensorHandle** result) override {
handle->Ref();
TF_Status status;
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
context_, tensorflow::wrap(handle), &status, info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status CopyTensorFromDevice(
tensorflow::TensorHandle* handle,
const tensorflow::string& target_device_name,
tensorflow::TensorHandle** result) override {
TF_Status status;
handle->Ref();
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
info_);
handle->Release();
if (!status.status.ok()) return status.status;
*result = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(result_handle));
(*result)->Ref();
TFE_DeleteTensorHandle(result_handle);
return status.status;
}
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
tensorflow::TensorHandle** retvals,
int* num_retvals) override {
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
TF_Status status;
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::TensorHandleFromInterface(
tensorflow::unwrap(outputs[i]));
retvals[i]->Ref();
TFE_DeleteTensorHandle(outputs[i]);
}
}
return status.status;
}
private:
TFE_Context* context_;
TFE_CustomDevice device_;
void* info_;
string name_;
};
TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
TFE_TensorHandle** handles,
int num_handles, TF_Status* status,
void* device_info) {
TF_SetStatus(status, TF_UNIMPLEMENTED,
"This custom device does not support packing tensors.");
return nullptr;
}
} // namespace
extern "C" {
@ -1036,8 +1105,12 @@ extern "C" {
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
const char* device_name, void* device_info,
TF_Status* status) {
auto custom_device =
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
// Fill in default values for optional functionality.
if (device.pack == nullptr) {
device.pack = &DefaultCustomDevicePack;
}
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =

View File

@ -613,8 +613,23 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
std::vector<tensorflow::TensorHandle*> tensor_handles;
tensor_handles.reserve(*num_handles);
for (int i = 0; i < *num_handles; ++i) {
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
tensorflow::unwrap(handles[i]);
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
// One of the inputs we're trying to pack is on a custom device. We'll let
// the first custom device we see handle all of the packing.
auto* custom_device_handle =
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
unwrapped_handle);
tensorflow::ImmediateExecutionTensorHandle* result;
status->status = custom_device_handle->device()->Pack(
absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
tensorflow::unwrap(handles), *num_handles),
&result);
return tensorflow::wrap(result);
}
tensor_handles.push_back(
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
tensorflow::TensorHandleFromInterface(unwrapped_handle));
}
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));

View File

@ -435,16 +435,16 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
// to have a non-string representation of devices (TF_Device) extracted from
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
#define TFE_CUSTOM_DEVICE_VERSION 3
#define TFE_CUSTOM_DEVICE_VERSION 4
// Struct to be filled in
// Struct to be filled in. Functions are required except where indicated.
typedef struct TFE_CustomDevice {
int version = TFE_CUSTOM_DEVICE_VERSION;
// Method to copy a tensor to the custom device.
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
TFE_TensorHandle* tensor,
TF_Status* status,
void* device_info) = nullptr;
void* device_info);
// Method to copy a tensor from the custom device to a target device.
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
@ -468,6 +468,16 @@ typedef struct TFE_CustomDevice {
// Method to delete a device.
void (*delete_device)(void* device_info);
// Implements TFE_CreatePackedTensorHandle when one of `handles` is on this
// custom device.
//
// Many devices will want to simply return an "unimplemented" status
// here. This is the default behavior if `pack` is null when passed to
// TFE_RegisterCustomDevice.
TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles,
int num_handles, TF_Status* s,
void* device_info) = nullptr;
} TFE_CustomDevice;
// Registers a custom device for use with eager execution.

View File

@ -424,7 +424,7 @@ void TensorHandleSilentCopy(bool async,
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
auto gpu_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
auto gpu_device = gpu_arg->device();
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);

View File

@ -40,14 +40,6 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
errors::InvalidArgument("underlying_devices should not be empty."));
return nullptr;
}
std::set<string> unique_devices;
for (const string& device : underlying_devices) {
if (!unique_devices.insert(device).second) {
status->Update(errors::InvalidArgument(
"Got a duplicated device in underlying_devices: ", device));
return nullptr;
}
}
DeviceNameUtils::ParsedName parsed_name;
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) {
status->Update(tensorflow::errors::InvalidArgument(

View File

@ -50,21 +50,6 @@ TEST(CompositeDeviceTest, Basic) {
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
}
{
Status status;
underlying_devices.push_back(
"/job:localhost/replica:0/task:0/device:CPU:0");
std::unique_ptr<CompositeDevice> composite_device =
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
parsed_host_name, &status);
EXPECT_EQ(composite_device, nullptr);
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
EXPECT_TRUE(
absl::StrContains(status.error_message(), "Got a duplicated device"))
<< status.ToString();
underlying_devices.pop_back();
}
{
Status status;
underlying_devices.push_back(

View File

@ -124,6 +124,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/core/lib/core:status",
],
}),
@ -228,6 +229,7 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":attr_builder",
":custom_device",
":context",
":eager_executor",
":kernel_and_device",
@ -631,6 +633,7 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":context",
":custom_device",
":attr_builder",
":eager_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",

View File

@ -24,11 +24,7 @@ limitations under the License.
namespace {
bool IsCPU(tensorflow::VariantDevice variant) {
if (VariantDeviceIsCustom(variant)) {
return false;
}
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
bool IsCPU(tensorflow::Device* d) {
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
}
@ -43,20 +39,6 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
if (!status->ok()) {
return nullptr;
}
if (VariantDeviceIsCustom(device())) {
auto* custom_device = absl::get<CustomDevice*>(device());
TensorHandle* copy;
*status = custom_device->CopyTensorFromDevice(this, ctx_->HostCPU()->name(),
&copy);
if (status->ok()) {
auto result = copy->Resolve(status);
copy->Unref();
return result;
} else {
return nullptr;
}
}
if (Type() == REMOTE) {
const tensorflow::Tensor* t = nullptr;
TensorHandle* h_cpu = nullptr;
@ -124,14 +106,13 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
ImmediateExecutionTensorHandle* handle, const char* device_name,
Status* status) {
TensorHandle* input = TensorHandleFromInterface(handle);
TensorHandle* result = nullptr;
ImmediateExecutionTensorHandle* result = nullptr;
Device* device;
*status = this->FindDeviceFromName(device_name, &device);
if (!status->ok()) {
tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(device_name, &dev)) {
*status = dev->CopyTensorToDevice(input, &result);
*status = dev->CopyTensorToDevice(handle, &result);
if (status->ok()) {
return result;
}
@ -142,13 +123,13 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
return nullptr;
}
// Handle tensor handles currently in custom devices
const char* handle_device_name = input->DeviceName(status);
const char* handle_device_name = handle->DeviceName(status);
if (!status->ok()) {
return nullptr;
}
tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
*status = dev->CopyTensorFromDevice(input, device_name, &result);
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
if (status->ok()) {
return result;
}
@ -156,8 +137,10 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
}
// Handle regular case.
TensorHandle* input = TensorHandleFromInterface(handle);
*status =
EagerCopyToDevice(input, this, &this->Executor(), device, false, &result);
EagerCopyToDevice(input, this, &this->Executor(), device, false,
reinterpret_cast<tensorflow::TensorHandle**>(&result));
if (status->ok()) {
return result;
}
@ -213,16 +196,38 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
// eager_operation.cc we can avoid a circular dependency between them.
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
for (int i = 0; i < Inputs().size(); ++i) {
TF_RETURN_IF_ERROR(Inputs()[i]->WaitUnknownDevice());
for (ImmediateExecutionTensorHandle* handle : inputs_) {
if (TensorHandle::classof(handle)) {
TF_RETURN_IF_ERROR(down_cast<TensorHandle*>(handle)->WaitUnknownDevice());
}
}
// Decide to either run the operation on a custom device or copy off all of
// the custom device inputs.
VariantDevice maybe_custom_device = Device();
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
!inputs_are_tensor_handles_) {
// If the op wasn't placed on a custom device explicitly and there are no
// non-TensorHandle inputs, the op will definitely be placed on a physical
// device. Otherwise we need to check the inputs one by one.
TF_RETURN_IF_ERROR(
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
ImmediateExecutionTensorHandle** retval_array =
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
return absl::get<CustomDevice*>(maybe_custom_device)
->Execute(this, retval_array, num_retvals);
} else {
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
}
}
// Run eager placement logic.
VariantDevice device;
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
if (device == kVariantDeviceNull) {
class Device* device = absl::get<class Device*>(maybe_custom_device);
if (device == nullptr) {
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
}
if (device == kVariantDeviceNull && ctx_.PinSmallOpsToCPU()) {
if (device == nullptr && ctx_.PinSmallOpsToCPU()) {
bool pin_to_cpu;
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
&pin_to_cpu, Name(), GetInputs(), ctx_.HostCPU()->name()));
@ -231,16 +236,13 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
}
}
tensorflow::TensorHandle** retval_array =
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
if (VariantDeviceIsCustom(device)) {
return absl::get<CustomDevice*>(device)->Execute(this, retval_array,
num_retvals);
}
if (device != kVariantDeviceNull) {
if (device != nullptr) {
SetDevice(device);
}
// At this point all inputs and outputs are TensorHandles associated with
// physical devices.
tensorflow::TensorHandle** retval_array =
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
return EagerExecute(this, retval_array, num_retvals);
}

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -26,6 +27,7 @@ namespace tensorflow {
class TensorHandle;
class EagerOperation;
class CustomDeviceTensorHandle;
// Custom devices intercept the execution of operations (the `Execute` method),
// typically implemented with one or more of the custom device's own executions.
@ -33,15 +35,22 @@ class CustomDevice {
public:
virtual ~CustomDevice() {}
virtual const string& name() = 0;
virtual Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) = 0;
virtual Status CopyTensorToDevice(
ImmediateExecutionTensorHandle* tensor,
ImmediateExecutionTensorHandle** result) = 0;
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
const string& target_device_name,
TensorHandle** result) = 0;
virtual Status CopyTensorFromDevice(
ImmediateExecutionTensorHandle* tensor, const string& target_device_name,
ImmediateExecutionTensorHandle** result) = 0;
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
virtual Status Execute(const ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals,
int* num_retvals) = 0;
// Creates a packed TensorHandle from a group of custom device TensorHandles,
// one of which is on this custom device.
virtual Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
ImmediateExecutionTensorHandle** result) = 0;
};
// Custom devices do many of the same things as physical Devices, but have a
@ -49,6 +58,10 @@ class CustomDevice {
// operations may be placed either on custom or physical devices.
using VariantDevice = absl::variant<Device*, CustomDevice*>;
// Indicates either HostCPU or an unset physical device. We never set a null
// CustomDevice*.
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
// A tensor handle produced by a custom device. Generally they can only be
// consumed by executing an operation on the same custom device that produced it
// originally, or by attempting to copy the handle off the custom device.
@ -65,6 +78,10 @@ class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
device_(device),
dtype_(dtype) {}
// TODO(allenl): Should this be a generic method of
// ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer?
virtual void* DevicePointer() const = 0;
tensorflow::DataType DataType() const override { return dtype_; }
Status Shape(PartialTensorShape* shape) const override;
Status NumElements(int64* num_elements) const override;

View File

@ -28,24 +28,31 @@ class TestCustomDevice : public CustomDevice {
public:
explicit TestCustomDevice(std::string name) : name_(name) {}
const std::string& name() override { return name_; }
Status CopyTensorToDevice(TensorHandle* tensor,
TensorHandle** result) override {
Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor,
ImmediateExecutionTensorHandle** result) override {
tensor->Ref();
*result = tensor;
return Status::OK();
}
Status CopyTensorFromDevice(TensorHandle* tensor,
const std::string& target_device_name,
TensorHandle** result) override {
Status CopyTensorFromDevice(
ImmediateExecutionTensorHandle* tensor,
const std::string& target_device_name,
ImmediateExecutionTensorHandle** result) override {
tensor->Ref();
*result = tensor;
return Status::OK();
}
Status Execute(const EagerOperation* op, TensorHandle** retvals,
Status Execute(const ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals,
int* num_retvals) override {
return errors::Unimplemented("Not implemented");
}
Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
ImmediateExecutionTensorHandle** result) override {
return errors::Unimplemented("Packing is not implemented");
}
private:
std::string name_;
};
@ -57,6 +64,7 @@ class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
tensorflow::DataType dtype)
: CustomDeviceTensorHandle(context, device, dtype) {}
void* DevicePointer() const override { return nullptr; }
Status NumDims(int* num_dims) const override {
*num_dims = 1;
return Status::OK();

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
@ -31,10 +32,11 @@ namespace tensorflow {
// Clear(), and then Reset(...) with the same arguments that would have
// been provided to the constructor.
void EagerOperation::Clear() {
for (TensorHandle* h : inputs_) {
for (ImmediateExecutionTensorHandle* h : inputs_) {
h->Unref();
}
inputs_.clear();
inputs_are_tensor_handles_ = true;
ClearInferenceState();
}
@ -263,7 +265,12 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
}
Status EagerOperation::AddInput(AbstractTensorHandle* input) {
TensorHandle* h = TensorHandleFromInterface(input);
ImmediateExecutionTensorHandle* h =
down_cast<ImmediateExecutionTensorHandle*>(input);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (CustomDeviceTensorHandle::classof(h)) {
inputs_are_tensor_handles_ = false;
}
AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h);
}
@ -271,7 +278,13 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
Status EagerOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
for (auto& input : inputs) {
TensorHandle* h = TensorHandleFromInterface(input);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(input)) {
inputs_are_tensor_handles_ = false;
}
ImmediateExecutionTensorHandle* h =
down_cast<ImmediateExecutionTensorHandle*>(input);
AddTensorHandle(h);
}
return InferInputListAttrs(inputs.size());
@ -317,7 +330,8 @@ Status EagerOperation::Reset(
return SetDeviceName(device_name);
}
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
Status EagerOperation::MaybeInferSingleInputAttrs(
ImmediateExecutionTensorHandle* handle) {
if (!op_def_) return Status::OK();
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
@ -334,7 +348,7 @@ Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
const std::string& type_attr = input_def.type_attr();
if (!type_attr.empty() &&
inference_attrs_.find(type_attr) == inference_attrs_.end()) {
MutableAttrs()->Set(type_attr, handle->dtype);
MutableAttrs()->Set(type_attr, handle->DataType());
inference_attrs_.insert(type_attr);
}
return Status::OK();
@ -372,12 +386,13 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
if (!input_def.type_list_attr().empty()) {
std::vector<DataType> dtypes(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs_[start + i]->dtype;
dtypes[i] = inputs_[start + i]->DataType();
}
InferMixedTypeInputListAttrs(input_def, dtypes);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
num_inputs);
} else if (!input_def.number_attr().empty()) {
if (inference_attrs_.find(input_def.number_attr()) ==
inference_attrs_.end()) {
@ -390,6 +405,28 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
return Status::OK();
}
Status EagerOperation::TensorHandleInputs(
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
&inputs_);
return Status::OK();
} else {
return errors::Internal("The operation unexpectedly had custom devices.");
}
}
Status EagerOperation::MutableTensorHandleInputs(
absl::InlinedVector<TensorHandle*, 4>** inputs) {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
*inputs =
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
return Status::OK();
} else {
return errors::Internal("The operation unexpectedly had custom devices.");
}
}
Status EagerOperation::SetDeviceName(const char* c_name) {
string name(c_name != nullptr ? c_name : "");
if (name != last_set_device_name_) {
@ -423,6 +460,16 @@ bool EagerOperation::IsLocal() const {
device_parsed_name_.task == host_cpu_name.task;
}
string VariantDeviceDebugString(VariantDevice device) {
if (device == kVariantDeviceNull) {
return "[]";
} else if (absl::holds_alternative<CustomDevice*>(device)) {
return absl::get<CustomDevice*>(device)->name();
} else {
return absl::get<Device*>(device)->DebugString();
}
}
string EagerOperation::DebugString() const {
string out;
VLOG(1) << "EagerOperation::DebugString() over " << this;
@ -442,10 +489,36 @@ string EagerOperation::DebugString() const {
return out;
}
void EagerOperation::AddTensorHandle(TensorHandle* h) {
void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
h->Ref();
inputs_.push_back(h);
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
Status EagerOperation::CopyOffCustomDeviceInputs() {
if (absl::holds_alternative<CustomDevice*>(device_)) {
return errors::Internal(
"Trying to copy inputs to a custom device op off a custom device.");
}
for (int i = 0; i < inputs_.size(); ++i) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
CustomDeviceTensorHandle* previous =
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
class Device* target_device;
if (device_ == kVariantDeviceNull) {
target_device = ctx_.HostCPU();
} else {
target_device = absl::get<class Device*>(device_);
}
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
previous, target_device->name(), &inputs_[i]));
previous->Unref();
}
}
inputs_are_tensor_handles_ = true;
return Status::OK();
}
} // namespace tensorflow

View File

@ -39,7 +39,7 @@ class EagerOperation : public ImmediateExecutionOperation {
explicit EagerOperation(tensorflow::EagerContext* ctx)
: ImmediateExecutionOperation(kEager), ctx_(*ctx) {}
~EagerOperation() override {
for (TensorHandle* h : inputs_) {
for (ImmediateExecutionTensorHandle* h : inputs_) {
h->Unref();
}
}
@ -69,8 +69,9 @@ class EagerOperation : public ImmediateExecutionOperation {
void SetDevice(VariantDevice device) {
device_ = device;
device_name_ =
device == kVariantDeviceNull ? "" : VariantDeviceName(device);
device_name_ = absl::visit(
[](auto* device) { return device == nullptr ? "" : device->name(); },
device);
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
// TODO(b/154133594): Due to intricacies of external logic, we can not
// set this do device_name_ as it would be natural, because we need the
@ -141,10 +142,18 @@ class EagerOperation : public ImmediateExecutionOperation {
AttrBuilder* MutableAttrs() { return &attrs_; }
const AttrBuilder& Attrs() const { return attrs_; }
const absl::InlinedVector<TensorHandle*, 4>& Inputs() const {
// TensorHandleInputs and MutableTensorHandleInputs first check that all
// inputs are TensorHandles, i.e. that there are no custom device inputs. They
// return a bad status otherwise.
Status TensorHandleInputs(
const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
Status MutableTensorHandleInputs(
absl::InlinedVector<TensorHandle*, 4>** inputs);
const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
const {
return inputs_;
}
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
void UpdateInput(int i, TensorHandle* h);
@ -180,7 +189,7 @@ class EagerOperation : public ImmediateExecutionOperation {
}
private:
void AddTensorHandle(TensorHandle* h);
void AddTensorHandle(ImmediateExecutionTensorHandle* h);
const tensorflow::OpDef* GetOpDef(Status* status);
@ -190,7 +199,7 @@ class EagerOperation : public ImmediateExecutionOperation {
inference_attrs_.clear_no_resize();
}
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
Status InferInputListAttrs(int num_inputs);
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
@ -198,11 +207,21 @@ class EagerOperation : public ImmediateExecutionOperation {
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
const std::vector<DataType>& dtypes);
// Replaces input tensors placed on custom devices with physical device
// equivalents. Used if an op is placed on a physical device but may have
// custom device inputs.
Status CopyOffCustomDeviceInputs();
tensorflow::EagerContext& ctx_;
const char* op_name_ = nullptr;
AttrBuilder attrs_;
const AttrTypeMap* attr_types_;
absl::InlinedVector<TensorHandle*, 4> inputs_;
// Toggled to indicate whether all inputs are known to be TensorHandles and
// not another type (e.g. custom device tensor handles). Explicitly set to
// false when custom device TensorHandles are added.
bool inputs_are_tensor_handles_ = true;
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
// The last device name given to SetDeviceName.
// This is used to avoid having to re-process the same device in repeated
@ -240,8 +259,8 @@ class EagerOperation : public ImmediateExecutionOperation {
};
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
TensorHandle** slot = &inputs_[i];
TensorHandle* existing = *slot;
ImmediateExecutionTensorHandle** slot = &inputs_[i];
ImmediateExecutionTensorHandle* existing = *slot;
if (existing != h) {
h->Ref();
existing->Unref();

View File

@ -81,14 +81,6 @@ const string& DeviceNameOrUnspecified(Device* device) {
return (device == nullptr) ? *unspecified_string : device->name();
}
const string& DeviceNameOrUnspecified(VariantDevice device) {
if (VariantDeviceIsCustom(device)) {
return absl::get<CustomDevice*>(device)->name();
} else {
return DeviceNameOrUnspecified(absl::get<Device*>(device));
}
}
// Returns whether a kernel should be cached.
bool KernelCacheEnabled(const OpDef& op_def) {
if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
@ -200,9 +192,10 @@ Status ValidateInputTypeAndPlacement(
const bool is_function = kernel->IsFunction();
if (n_inputs > 0) {
const DataType* input_types = &kernel->input_dtypes()[0];
TensorHandle* const* handles = &op->Inputs()[0];
const absl::InlinedVector<TensorHandle*, 4>* handles;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
for (int i = 0; i < n_inputs; ++i) {
TensorHandle* handle = handles[i];
TensorHandle* handle = (*handles)[i];
Device* expected_device = kernel->InputDevice(i);
if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
// Extract a handle on the op device from a packed input.
@ -220,13 +213,7 @@ Status ValidateInputTypeAndPlacement(
}
}
}
auto handle_device_variant = handle->DeviceOrHostCPU(*ctx);
if (VariantDeviceIsCustom(handle_device_variant)) {
return errors::Unimplemented(
"Custom devices and remote execution are not yet supported "
"together.");
}
Device* handle_device = absl::get<Device*>(handle_device_variant);
Device* handle_device = handle->DeviceOrHostCPU(*ctx);
const bool maybe_copy =
!is_function || handle->Type() != TensorHandle::REMOTE;
// If the input is already on the right device, then nothing to do.
@ -280,14 +267,10 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
Device** result) {
if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) {
return errors::Unimplemented(
"The kernel cache does not work with custom devices.");
}
Device* cpu_device = ctx.HostCPU();
string device_name;
if (tensor_handle->Type() != TensorHandle::LOCAL) {
Device* device = absl::get<Device*>(tensor_handle->device());
Device* device = tensor_handle->device();
device_name = device != nullptr ? device->name() : cpu_device->name();
*result = (device == nullptr ? cpu_device : device);
} else if (tensor_handle->dtype == DT_RESOURCE) {
@ -304,7 +287,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
ctx.FindDeviceFromName(device_name.c_str(), &input_device));
*result = input_device;
} else {
Device* device = absl::get<Device*>(tensor_handle->device());
Device* device = tensor_handle->device();
const bool is_tpu = device != nullptr && device->device_type() == "TPU";
// int32 return values can be placed on TPUs.
const bool use_host_memory =
@ -431,8 +414,10 @@ Status GetOrCreateKernelAndDevice(
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
profiler::TraceMeLevel::kInfo);
input_dev_ptrs.reserve(op->Inputs().size());
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
TensorHandle* input = op->Inputs()[i];
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
for (int i = 0, end = inputs->size(); i < end; i++) {
TensorHandle* input = (*inputs)[i];
// Get device for this input, and add it to 'cache_key'.
Device* input_device;
@ -477,7 +462,7 @@ Status GetOrCreateKernelAndDevice(
core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
if (kernel == nullptr) {
DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
<< DeviceNameOrUnspecified(op->Device());
<< DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
bool run_function_with_flr = false;
bool function_outputs_on_op_device = false;
if (op->is_function()) {
@ -656,9 +641,11 @@ Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
remote_func_params, &ctx, &retvals[i]));
}
}
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), remote_func_params, std::move(kernel),
graph_collector, op->GetCancellationManager(),
&ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
op->GetCancellationManager(),
absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
// Release the inputs from the eager operation since the AsyncExecuteNode
// would have taken ownership. This allows the inputs to be forwarded if
@ -673,8 +660,10 @@ Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
for (int i = 0, end = num_outputs; i < end; ++i) {
retvals[i] = nullptr;
}
ExecuteNode node(&ctx, op->Inputs(), remote_func_params, kernel,
graph_collector, op->GetCancellationManager(),
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
op->GetCancellationManager(),
{retvals, static_cast<size_t>(num_outputs)});
Status s = executor.SyncExecute(&node);
// We release the inputs AFTER executing the operation in sync mode since
@ -764,8 +753,10 @@ Status MaybePackInputTensor(EagerOperation* op) {
return Status::OK();
}
EagerContext& ctx = op->EagerContext();
for (int i = 0; i < op->Inputs().size(); ++i) {
TensorHandle* handle = op->Inputs()[i];
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
for (int i = 0; i < inputs->size(); ++i) {
TensorHandle* handle = (*inputs)[i];
if (handle->Type() == TensorHandle::PACKED) {
EagerOperation pack_op(&ctx);
TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
@ -842,7 +833,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
return errors::InvalidArgument(
"Unable to find remote task corresponding to device ",
VariantDeviceName(op->Device()));
op->DeviceName());
}
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
@ -855,11 +846,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
profiler::TraceMe activity("CopyInputToExpectedDevice",
profiler::TraceMeLevel::kInfo);
const bool is_function = op->is_function();
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
tensorflow::TensorHandle* input = op->Inputs()[i];
tensorflow::Device* input_device = absl::get<Device*>(input->device());
tensorflow::Device* input_device_or_cpu =
absl::get<Device*>(input->DeviceOrHostCPU(ctx));
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
for (int i = 0, end = inputs->size(); i < end; i++) {
tensorflow::TensorHandle* input = (*inputs)[i];
tensorflow::Device* input_device = input->device();
tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
const string* input_device_name = &input_device_or_cpu->name();
bool serialize_resource_dtype_and_shape = false;
if (op_device != input_device &&
@ -876,9 +868,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// Always copy to the remote CPU so that the actual device can be
// correctly determined after the kernel is selected/instantiated,
// since the op might have its inputs on host memory.
TensorHandle* handle = op->Inputs()[i];
Device* handle_device =
absl::get<Device*>(handle->DeviceOrHostCPU(ctx));
TensorHandle* handle = input;
Device* handle_device = handle->DeviceOrHostCPU(ctx);
// If the input is already on the right device, then nothing to do.
if (remote_cpu_device != handle_device) {
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
@ -959,11 +950,14 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
DVLOG(4) << "Execute remote eager op: " << op->Name()
<< " (is async?: " << executor.Async() << ").";
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
&op->EagerContext(), std::move(request), op_device,
ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
op->Inputs(), {retvals, num_outputs}));
*inputs, {retvals, num_outputs}));
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
string msg = strings::StrCat(
@ -1020,7 +1014,7 @@ Status GetKernelOutputs(
"kernel. This should never happen.");
}
if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
absl::get<Device*>(retvals[i]->device()))) {
retvals[i]->device())) {
return errors::Internal(
"Kernel output tensor handle locates on a different device than "
"the specified kernel output device. This should never happen.");
@ -1037,8 +1031,8 @@ Status GetKernelOutputs(
"Remote outputs are not available on mobile devices.");
#else // !IS_MOBILE_PLATFORM
TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
absl::get<TensorShape>(ret),
absl::get<Device*>(retvals[i]->device()), ctx->GetContextViewId()));
absl::get<TensorShape>(ret), retvals[i]->device(),
ctx->GetContextViewId()));
#endif // !IS_MOBILE_PLATFORM
}
}
@ -1218,11 +1212,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
TensorHandle** result) {
TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
auto send_device = h->DeviceOrHostCPU(*ctx);
if (VariantDeviceIsCustom(send_device)) {
return errors::Unimplemented(
"Copying a TensorHandle from a custom device is not supported.");
}
bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
bool sender_is_local = send_device->IsLocal();
bool receiver_is_local = device->IsLocal();
@ -1363,11 +1353,6 @@ void EagerKernelExecuteAsync(
// triggered after execution with its status.
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
int* num_retvals, StatusCallback done) {
if (VariantDeviceIsCustom(op->Device())) {
done(errors::Unimplemented(
"Custom device is not supported in EagerLocalExecuteAsync."));
return;
}
if (!op->IsLocal()) {
done(errors::InvalidArgument(
"Remote execution is not supported in async EagerLocalExecuteAsync"));
@ -1419,8 +1404,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
output_dtypes[i], &ctx);
}
const absl::InlinedVector<TensorHandle*, 4>* inputs;
s = op->TensorHandleInputs(&inputs);
if (!s.ok()) {
done(s);
return;
}
EagerKernelExecuteAsync(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
&ctx, *inputs, op->remote_func_params(), std::move(kernel),
graph_collector, op->GetCancellationManager(), retvals, num_outputs,
[op, num_outputs, retvals, done = std::move(done)](const Status& s) {
op->Clear();

View File

@ -44,8 +44,7 @@ Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
// We have validated that h->device() is not a CustomDevice when
// constructing a pack TensorHandle.
const Status status =
h->TensorValue(absl::get<Device*>(h->device()), &packed_arg_flat[i]);
const Status status = h->TensorValue(h->device(), &packed_arg_flat[i]);
if (!status.ok()) {
#if !defined(IS_MOBILE_PLATFORM)
if (IsRemote(ctx, input_device, h)) {
@ -107,13 +106,7 @@ Status ExecuteNodeArgs::Init(
TF_RETURN_IF_ERROR(
op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
}
VariantDevice variant_device = h->device();
if (VariantDeviceIsCustom(variant_device)) {
return errors::Internal(
"Custom devices and remote execution are currently not supported "
"together.");
}
Device* device = absl::get<Device*>(variant_device);
Device* device = h->device();
// For a multi-device function, a remote RunComponentFunction request is
// not sent through StreamingEnqueueAsync. It could arrive at a remote
// worker before a remote execution request which produces an input of the

View File

@ -105,7 +105,7 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
std::vector<Device*> input_devices;
for (auto* h : inputs) {
input_devices.push_back(absl::get<Device*>(h->DeviceOrHostCPU(*ctx)));
input_devices.push_back(h->DeviceOrHostCPU(*ctx));
}
const core::RefCountPtr<KernelAndDevice> kernel(
new TestKernelAndDeviceFunc(std::move(input_devices), device0));

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/op_def.pb.h"
@ -138,17 +139,18 @@ Status MaybePinSmallOpsToCpu(
return Status::OK();
}
Status MaybePinToResourceDevice(VariantDevice* device,
const EagerOperation& op) {
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
if (op.colocation_exempt()) {
return Status::OK();
}
EagerContext& ctx = op.EagerContext();
const absl::InlinedVector<TensorHandle*, 4>* inputs;
TF_RETURN_IF_ERROR(op.TensorHandleInputs(&inputs));
Device* op_device = op.Device() == kVariantDeviceNull
? ctx.HostCPU()
: absl::get<Device*>(op.Device());
for (int i = 0; i < op.Inputs().size(); ++i) {
TensorHandle* tensor_handle = op.Inputs()[i];
for (int i = 0; i < inputs->size(); ++i) {
TensorHandle* tensor_handle = (*inputs)[i];
if (tensor_handle->dtype == DT_RESOURCE) {
if (tensor_handle->resource_remote_device_incarnation() != 0) {
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
@ -182,7 +184,7 @@ Status MaybePinToResourceDevice(VariantDevice* device,
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
// If operation was already placed on a custom device, use it.
if (VariantDeviceIsCustom(op.Device())) {
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
*device = op.Device();
return Status::OK();
} else if (!op.DeviceName().empty()) {
@ -194,9 +196,13 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
// placement and there is only one custom device in the op inputs.
if (!op.Inputs().empty()) {
CustomDevice* first = nullptr;
for (const TensorHandle* input : op.Inputs()) {
if (VariantDeviceIsCustom(input->device())) {
CustomDevice* current = absl::get<CustomDevice*>(input->device());
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
CustomDevice* current = input->device();
if (first == nullptr) {
first = current;
} else if (first != current) {
@ -207,9 +213,9 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
op.Name(),
" has one input in custom "
"device ",
VariantDeviceName(first),
first->name(),
" and at least one input in a different custom device ",
VariantDeviceName(current)));
current->name()));
}
}
}

View File

@ -43,8 +43,7 @@ Status MaybePinSmallOpsToCpu(
// If a resource touching input is specified, all resource-touching ops run in
// the device the resource is, regardless of anything else that has been
// specified. This is identical to the graph mode behavior.
Status MaybePinToResourceDevice(VariantDevice* device,
const EagerOperation& op);
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.

View File

@ -54,6 +54,14 @@ int64 GetRemoteDeviceIncarnation(Device* device) {
if (device == nullptr || device->IsLocal()) return 0;
return device->attributes().incarnation();
}
string SafeDeviceDebugString(Device* device) {
if (device == nullptr) {
return "[]";
} else {
return device->DebugString();
}
}
} // namespace
TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
@ -231,12 +239,6 @@ TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
}
}
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t,
CustomDevice* d,
EagerContext* ctx) {
return new TensorHandle(std::move(t), d, ctx);
}
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
Device* resource_device, EagerContext* ctx)
: ImmediateExecutionTensorHandle(kEager),
@ -249,7 +251,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
ctx_(ctx),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_)
<< " device: " << SafeDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
@ -268,26 +270,10 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_)
<< " device: " << SafeDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx)
: ImmediateExecutionTensorHandle(kEager),
dtype(t.dtype()),
device_(d),
op_device_(nullptr),
resource_device_(nullptr),
resource_remote_device_incarnation_(0),
ctx_(ctx),
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
// TODO(allenl): Figure out a better op_device story for custom devices,
// since always setting it to CPU=nullptr doesn't make much sense.
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " custom device: " << VariantDeviceDebugString(device_)
<< " tensor: " << t.DeviceSafeDebugString();
}
TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
@ -309,7 +295,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
ctx_(ctx),
data_(absl::in_place_type<LocalTensorHandleData>) {
DVLOG(3) << "Creating empty Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << SafeDeviceDebugString(device_);
}
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
@ -328,13 +314,10 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
}
std::vector<string> devices;
devices.reserve(handles.size());
for (auto* handle : handles) {
if (VariantDeviceIsCustom(handle->device())) {
devices.push_back(absl::get<CustomDevice*>(handle->device())->name());
} else {
devices.push_back(handle->op_device() ? handle->op_device()->name()
: ctx->HostCPU()->name());
}
devices.push_back(handle->op_device() ? handle->op_device()->name()
: ctx->HostCPU()->name());
}
CompositeDevice* composite_device = nullptr;
@ -378,7 +361,7 @@ TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
shape) {
DVLOG(3) << "Creating a packed TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << SafeDeviceDebugString(device_);
}
#if !defined(IS_MOBILE_PLATFORM)
@ -406,7 +389,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
remote_task, ctx) {
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << SafeDeviceDebugString(device_);
}
TensorHandle* TensorHandle::CreateLazyRemoteHandle(
@ -429,7 +412,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
ctx->GetContextViewId(), is_ready) {
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
<< " device: " << SafeDeviceDebugString(device_);
}
#endif
@ -487,7 +470,7 @@ Status TensorHandle::TensorFromDevice(const Device* d,
const tensorflow::Tensor** t) const {
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (d == device_) {
if (Type() != LOCAL) {
return errors::Internal("Invalid Tensor call on a ", TypeString(),
" handle: ", this);
@ -511,13 +494,7 @@ Status TensorHandle::TensorFromDevice(const Device* d,
Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
if (VariantDeviceIsCustom(device_)) {
return errors::Internal(
"TensorHandle::TensorValue not supported for custom devices yet. "
"Handle device: ",
VariantDeviceDebugString(device_),
", requested device: ", d != nullptr ? d->name() : "(nil)");
} else if (d == absl::get<Device*>(device_)) {
if (d == device_) {
if (Type() != LOCAL) {
return errors::Internal("Invalid TensorValue call on a ", TypeString(),
" handle: ", this);
@ -549,13 +526,8 @@ Status TensorHandle::WaitUnknownDevice() const {
return Status::OK();
}
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
if (VariantDeviceIsCustom(device_)) {
return device_;
} else {
Device* d = absl::get<Device*>(device_);
return (d == nullptr) ? ctx.HostCPU() : d;
}
Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
return (device_ == nullptr) ? ctx.HostCPU() : device_;
}
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
@ -691,7 +663,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
Status TensorHandle::Unprotect(const Device* d) {
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (d == device_) {
return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
}
@ -718,7 +690,7 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
<< " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
if (d == device_) {
return errors::Internal("Cannot add mirror for primary device.");
}
@ -739,7 +711,7 @@ Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
<< " " << d->name();
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
if (d != device_) {
tf_shared_lock l(mu_);
auto mirror = remote_mirrors_.find(d->name());
if (mirror != remote_mirrors_.end()) {
@ -854,7 +826,7 @@ Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
<< " " << d->name();
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
if (d != device_) {
tf_shared_lock l(mu_);
auto remote_mirror = remote_mirrors_.find(d->name());
if (remote_mirror == remote_mirrors_.end()) {
@ -916,7 +888,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d
<< " " << d->name();
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
if (d == device_) {
DCHECK(Type() == REMOTE)
<< "Poison can only be on remote handles: " << this;
@ -936,7 +908,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
const Device* d) {
if (d == absl::get<Device*>(device_)) {
if (d == device_) {
return errors::Internal(
"Local mirror assign conflicts with primary device.");
}
@ -955,7 +927,7 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
if (d == absl::get<Device*>(device_)) {
if (d == device_) {
DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
@ -982,7 +954,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
void TensorHandle::Poison(Status status, const Device* d) {
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
if (d == device_) {
DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
absl::visit([status](auto& data) { data.Poison(status); }, data_);
} else {
@ -1001,7 +973,7 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
tensorflow::Device* d,
tensorflow::Tensor* output) const {
tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d;
tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
bool is_same_device =
@ -1063,27 +1035,6 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
return status;
}
bool VariantDeviceIsCustom(VariantDevice variant_device) {
return variant_device.index() != 0;
}
string VariantDeviceName(VariantDevice device) {
if (device == kVariantDeviceNull) {
return "[]";
}
return absl::visit([](auto* device) { return device->name(); }, device);
}
string VariantDeviceDebugString(VariantDevice device) {
if (device == kVariantDeviceNull) {
return "[]";
} else if (VariantDeviceIsCustom(device)) {
return absl::get<CustomDevice*>(device)->name();
} else {
return absl::get<Device*>(device)->DebugString();
}
}
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
if (ctx == nullptr) {
return nullptr;
@ -1100,10 +1051,9 @@ string TensorHandle::DebugString() const {
DVLOG(4) << "Calling TensorHandle::DebugString() on " << this;
string out;
string device_debug = VariantDeviceDebugString(device_);
string device_debug = SafeDeviceDebugString(device_);
strings::StrAppend(&out, "Device: ", device_debug);
bool is_cpu =
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
bool is_cpu = device_ != nullptr;
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
// non-NULL if needed.
strings::StrAppend(
@ -1115,9 +1065,6 @@ string TensorHandle::DebugString() const {
}
const char* TensorHandle::DeviceName(Status* status) const {
if (VariantDeviceIsCustom(device())) {
return absl::get<CustomDevice*>(device())->name().c_str();
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
@ -1125,33 +1072,19 @@ const char* TensorHandle::DeviceName(Status* status) const {
}
const char* TensorHandle::BackingDeviceName(Status* status) const {
if (VariantDeviceIsCustom(device())) {
return absl::get<tensorflow::CustomDevice*>(device())->name().c_str();
} else {
status->Update(WaitUnknownDevice());
tensorflow::Device* d = absl::get<tensorflow::Device*>(device());
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = device();
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
: d->name().c_str();
}
const char* TensorHandle::DeviceType(Status* status) const {
if (VariantDeviceIsCustom(device())) {
status->Update(
tensorflow::errors::Unimplemented("Custom device unsupported"));
return nullptr;
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device();
return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
}
int TensorHandle::DeviceId(Status* status) const {
if (VariantDeviceIsCustom(device())) {
status->Update(
tensorflow::errors::Unimplemented("Custom device unsupported"));
return -1;
}
status->Update(WaitUnknownDevice());
tensorflow::Device* d = op_device();
return (d == nullptr) ? 0 : d->parsed_name().id;

View File

@ -60,7 +60,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
// TensorHandle for dtype == DT_RESOURCE
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
EagerContext* ctx);
TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
TensorHandle(Device* d, Device* op_device, Device* resource_device,
tensorflow::DataType dtype, EagerContext* ctx);
@ -81,8 +80,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
Device* op_device,
Device* resource_device,
EagerContext* ctx);
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t,
CustomDevice* d, EagerContext* ctx);
static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
tensorflow::DataType dtype,
@ -150,7 +147,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
// requesting the HostCPU.
Status TensorValue(const Device* d, tensorflow::TensorValue* t);
VariantDevice device() const { return device_; }
Device* device() const { return device_; }
Device* op_device() const { return op_device_; }
Device* resource_device() const { return resource_device_; }
int64 resource_remote_device_incarnation() const {
@ -161,7 +158,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
// are set (data is ready).
Status WaitUnknownDevice() const;
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
Device* DeviceOrHostCPU(const EagerContext& ctx) const;
Status Shape(tensorflow::TensorShape* shape);
@ -286,7 +283,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
bool IsReady() const;
Status WaitReady(const char* caller) const;
VariantDevice device_;
tensorflow::Device* device_;
// Device in which the op producing this tensor was executed. Equals to
// device_ for constant tensors.
@ -391,19 +388,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
PartialTensorShape inference_shape_;
};
// Checks whether a VariantDevice contains a custom device.
bool VariantDeviceIsCustom(VariantDevice device);
// Wraps device->name() or CustomDevice->name().
string VariantDeviceName(VariantDevice device);
// Wraps device->DebugString() or CustomDevice->name().
string VariantDeviceDebugString(VariantDevice device);
// Indicates either HostCPU or an unset physical device. We never set a null
// CustomDevice*.
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
// Returns the device backing the resource. Else, returns nullptr.
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);

View File

@ -189,8 +189,8 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT);
EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true);
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
absl::get<Device*>(packed_handle->device()));
CompositeDevice* device =
reinterpret_cast<CompositeDevice*>(packed_handle->device());
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
EXPECT_EQ(device->underlying_devices()->size(), 4);
@ -200,7 +200,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
TensorHandle* h = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
EXPECT_EQ(h->device(), ListDevices().at(i));
EXPECT_EQ(h->Type(), expected_handle_types.at(i));
}
EXPECT_FALSE(IsReady(packed_handle));
@ -236,14 +236,14 @@ TEST_F(PackedTensorHandleTest, PackedSingleHandle) {
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
EXPECT_EQ(packed_shape, shape);
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
absl::get<Device*>(packed_handle->device()));
CompositeDevice* device =
reinterpret_cast<CompositeDevice*>(packed_handle->device());
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
EXPECT_EQ(device->underlying_devices()->size(), 1);
EXPECT_EQ(packed_handle->NumPackedHandles(), 1);
TensorHandle* h0 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &h0));
EXPECT_EQ(absl::get<Device*>(h0->device()), d);
EXPECT_EQ(h0->device(), d);
EXPECT_TRUE(IsReady(packed_handle));
packed_handle->Unref();
}
@ -392,7 +392,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
TensorHandle* h = TensorHandle::CreateUnshapedRemoteHandle(
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d1, context,
/*unknown_device=*/true);
EXPECT_EQ(absl::get<Device*>(h->device()), d1);
EXPECT_EQ(h->device(), d1);
Device* d2 = device_mgr.ListDevices().at(2);
TF_ASSERT_OK(h->SetRemoteShapeAndDevice(
@ -400,7 +400,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
Status s;
EXPECT_EQ(h->BackingDeviceName(&s), d2->name());
TF_EXPECT_OK(s);
EXPECT_EQ(absl::get<Device*>(h->device()), d2);
EXPECT_EQ(h->device(), d2);
h->Unref();
context->Unref();
}

View File

@ -186,7 +186,7 @@ Status AddOpRetvalsToResponse(
for (int i = 0; i < num_retvals; i++) {
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
if (add_device_fn) {
Device* device = absl::get<Device*>(retvals[i]->device());
Device* device = retvals[i]->device();
*add_device_fn() = device ? device->name() : "";
}
if (retvals[i]->Type() == TensorHandle::REMOTE) {

View File

@ -1086,8 +1086,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t));
Device* device = absl::get<Device*>(tensor_handle->device());
EXPECT_EQ(device, nullptr);
EXPECT_EQ(tensor_handle->device(), nullptr);
auto actual = t->flat<float>();
EXPECT_EQ(4, actual.size());
@ -1168,8 +1167,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
EXPECT_EQ(absl::get<Device*>(packed_handle->device())->name(),
composite_device);
EXPECT_EQ(packed_handle->device()->name(), composite_device);
TensorHandle* handle0 = nullptr;
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
@ -1198,7 +1196,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
EXPECT_EQ(handle2->op_device()->name(), device2);
int64 op_id;
int32 output_num;
TF_ASSERT_OK(handle2->RemoteAddress(absl::get<Device*>(handle2->device()),
TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
/*wait_until_ready=*/true, &op_id,
&output_num));
EXPECT_EQ(op_id, 2);

View File

@ -37,7 +37,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
remote_op->set_name(op->Name());
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
remote_op->set_device(VariantDeviceName(op->Device()));
remote_op->set_device(op->DeviceName());
}
Status CreateUncachedKernelAndDeviceOp(
@ -80,7 +80,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
src_(src),
ctx_(ctx),
executor_(executor),
send_device_(absl::get<Device*>(src->DeviceOrHostCPU(*ctx))),
send_device_(src->DeviceOrHostCPU(*ctx)),
recv_device_(recv_device),
wire_id_(GetUniqueWireID()),
recv_op_id_(recv_op_id),
@ -149,9 +149,8 @@ void RemoteCopyNode::StartSend() {
auto* remote_op = request.add_queue()->mutable_operation();
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
src_, /*wait_until_ready=*/false,
remote_op->add_op_inputs()->mutable_remote_handle(),
absl::get<Device*>(src_->device()),
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
src_->DeviceOrHostCPU(*ctx_)->name());
if (!status.ok()) {
captured_state_->SetSendStatus(status);
return;
@ -310,7 +309,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
const Device* target_device, EagerContext* ctx,
SendPackedHandleOp* op) {
op->set_op_id(op_id);
op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx)));
op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
TensorHandle* h = nullptr;
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
@ -329,7 +328,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
// If src_device is on the same task of target_device, the handle is a
// local handle on the target device, which means the resource dtype and
// shape are known on the target device.
Device* src_device = absl::get<Device*>(h->device());
Device* src_device = h->device();
const bool serialize_resource_dtype_and_shape =
(i == 0) && (h->dtype == DT_RESOURCE) &&
(!ctx->OnSameTask(src_device, target_device));
@ -341,7 +340,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
h, /*wait_until_ready=*/true,
op->add_handles()->mutable_remote_handle(), src_device,
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
h->DeviceOrHostCPU(*ctx)->name(),
serialize_resource_dtype_and_shape));
} else {
return errors::InvalidArgument("Nested packed handles are not supported");

View File

@ -83,15 +83,8 @@ Status RemoteMgr::GetMirroredResourceShape(
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
const bool wait_until_ready,
int64* op_id, int32* output_num) {
// TODO(allenl): Consider supporting remote handles on custom devices.
VariantDevice device = handle->device();
if (VariantDeviceIsCustom(device)) {
return errors::Unimplemented(
"Custom devices and remote execution are currently not supported "
"together.");
}
TF_RETURN_IF_ERROR(handle->RemoteAddress(
absl::get<Device*>(device), wait_until_ready, op_id, output_num));
TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
op_id, output_num));
tensorflow::TensorHandle* h;
TF_RETURN_IF_ERROR(
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));

View File

@ -268,14 +268,11 @@ class OpNode {
return tensorflow::Status::OK();
}
void ClearEagerInputs() {
for (tensorflow::TensorHandle* h : *op_->MutableInputs()) {
if (h) h->Unref();
}
op_->MutableInputs()->clear();
}
void ClearEagerInputs() { op_->Clear(); }
tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
for (int i = 0; i < inputs_.Size(); ++i) {
int input_index = inputs_.TfLiteIndex(i);
TensorSource s = inputs_.GetTensorSource(i);
@ -290,14 +287,14 @@ class OpNode {
tensorflow::TensorHandle* handle =
tensorflow::TensorHandle::CreateLocalHandle(
buffer_map->GetTensor(input_index));
op_->MutableInputs()->push_back(handle);
op_inputs->push_back(handle);
} else {
// If this is a forwardable tensor, we will remove it from the previous
// op's list, giving TF the opportunity to reuse its buffer.
bool unref_handle = inputs_.IsForwardable(i);
auto* handle =
s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
op_->MutableInputs()->push_back(handle);
op_inputs->push_back(handle);
}
}
return tensorflow::Status::OK();

View File

@ -155,11 +155,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
tensorflow::unwrap(ctx)->TFTensorHandleFromInterface(
tensorflow::unwrap(EagerTensor_Handle(eager_tensor))));
if (VariantDeviceIsCustom(handle->device())) {
return errors::Unimplemented(
"Custom devices are currently not supported with PyFuncs.");
}
Device* actual_device = absl::get<Device*>(handle->device());
Device* actual_device = handle->device();
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
// actual_device may be nullptr, which implies local CPU.
if (expected_device == actual_device) return Status::OK();