Stop holding custom devices in TensorHandles
My hope is to not change the custom device API much, just clean up the implementation. Previously TensorHandles held Tensors which held the void* custom device tensor handle data. This required a bunch of special cases, mostly because the TensorHandle's device wasn't a physical device. Now EagerOperations still accept custom device TensorHandles, but deal with them before execution (either by copying them off the custom device or by executing the operation on a custom device). This means the rest of the runtime can assume TensorHandles are on physical devices, and gives custom device tensor handles some freedom to evolve. Rolling cl/350684489 forward with a fix for making packed tensors from custom device tensors. This requires one new custom device method to maintain the previous (accidental) functionality. PiperOrigin-RevId: 351406348 Change-Id: I9c2ffd40a687b06434fab40e2db9e90129b9f2b7
This commit is contained in:
parent
0e43a67584
commit
253111e23b
@ -52,7 +52,6 @@ tf_cuda_library(
|
||||
":immediate_execution_operation",
|
||||
":immediate_execution_tensor_handle",
|
||||
":immediate_execution_distributed_manager",
|
||||
":abstract_tensor_handle",
|
||||
":tfe_context_internal",
|
||||
":tfe_cancellation_manager_internal",
|
||||
":tfe_executor_internal",
|
||||
@ -73,6 +72,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:custom_device",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
@ -86,6 +86,7 @@ tf_cuda_library(
|
||||
],
|
||||
}) + [
|
||||
"@com_google_absl//absl/memory",
|
||||
":abstract_tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_mgr",
|
||||
"//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime",
|
||||
@ -480,12 +481,17 @@ cc_library(
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -363,13 +364,21 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h));
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
return t->data();
|
||||
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
|
||||
tensorflow::unwrap(h);
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
||||
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
|
||||
return tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
|
||||
unwrapped_handle)
|
||||
->DevicePointer();
|
||||
}
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
||||
if (!tensorflow::TensorHandle::classof(unwrapped_handle)) {
|
||||
status->status = tensorflow::errors::InvalidArgument("Invalid handle");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandleFromInterface(unwrapped_handle);
|
||||
|
||||
if (handle->Type() != tensorflow::TensorHandle::LOCAL) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -377,7 +386,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
handle->TypeString(), " tensor handle.");
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
|
||||
tensorflow::Device* device(handle->device());
|
||||
if (device != nullptr) {
|
||||
status->status = device->Sync();
|
||||
if (!status->status.ok()) {
|
||||
@ -393,6 +402,125 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
static_cast<const void*>(tensor->tensor_data().data()));
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
||||
string name)
|
||||
: context_(context), device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
ImmediateExecutionTensorHandle* handle,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
handle->Ref();
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
||||
context_, tensorflow::wrap(handle), &status, info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::unwrap(result_handle);
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
ImmediateExecutionTensorHandle* handle,
|
||||
const tensorflow::string& target_device_name,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
TF_Status status;
|
||||
handle->Ref();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
||||
info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::unwrap(result_handle);
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(const ImmediateExecutionOperation* op,
|
||||
ImmediateExecutionTensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::unwrap(outputs[i]);
|
||||
retvals[i]->Ref();
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
TF_Status status;
|
||||
*result = tensorflow::unwrap(device_.pack(context_,
|
||||
tensorflow::wrap(handles.data()),
|
||||
handles.size(), &status, info_));
|
||||
return status.status;
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_Context* context_;
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
};
|
||||
|
||||
// An adapter which wraps the shape/data produced by C custom devices and uses
|
||||
// it to implement custom device methods.
|
||||
class CAPICustomDeviceTensorHandle
|
||||
: public tensorflow::CustomDeviceTensorHandle {
|
||||
public:
|
||||
CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context,
|
||||
tensorflow::CustomDevice* device,
|
||||
tensorflow::DataType dtype, void* data,
|
||||
size_t len, std::vector<tensorflow::int64> shape,
|
||||
void (*deallocator)(void* data, size_t len,
|
||||
void* arg),
|
||||
void* deallocator_arg)
|
||||
: tensorflow::CustomDeviceTensorHandle(context, device, dtype),
|
||||
data_(data),
|
||||
len_(len),
|
||||
shape_(shape),
|
||||
deallocator_(deallocator),
|
||||
deallocator_arg_(deallocator_arg) {}
|
||||
~CAPICustomDeviceTensorHandle() override {
|
||||
deallocator_(data_, len_, deallocator_arg_);
|
||||
}
|
||||
void* DevicePointer() const override { return data_; }
|
||||
Status NumDims(int* num_dims) const override {
|
||||
*num_dims = shape_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
Status Dim(int dim_index, int64* dim) const override {
|
||||
*dim = shape_[dim_index];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
void* const data_;
|
||||
size_t len_;
|
||||
std::vector<tensorflow::int64> shape_;
|
||||
void (*const deallocator_)(void* data, size_t len, void* arg);
|
||||
void* const deallocator_arg_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
TFE_Context* ctx, const char* device_name, TF_DataType dtype,
|
||||
const int64_t* dims, int num_dims, void* data, size_t len,
|
||||
@ -417,6 +545,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
for (int i = 0; i < num_dims; ++i) {
|
||||
dimvec[i] = static_cast<tensorflow::int64>(dims[i]);
|
||||
}
|
||||
if (custom_device != nullptr) {
|
||||
return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle(
|
||||
context, custom_device,
|
||||
*reinterpret_cast<tensorflow::DataType*>(&dtype), data, len, dimvec,
|
||||
deallocator, deallocator_arg));
|
||||
}
|
||||
|
||||
// TODO(apassos) do we need to wrap the deallocator here to make sure to sync
|
||||
// the device?
|
||||
@ -427,13 +561,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
if (custom_device == nullptr) {
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context));
|
||||
} else {
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context));
|
||||
}
|
||||
return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context));
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -961,74 +1090,14 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
|
||||
} // namespace tensorflow
|
||||
|
||||
namespace {
|
||||
class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
public:
|
||||
CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info,
|
||||
string name)
|
||||
: context_(context), device_(device), info_(info), name_(name) {}
|
||||
|
||||
~CustomDeviceAPI() override { device_.delete_device(info_); }
|
||||
|
||||
const string& name() override { return name_; }
|
||||
|
||||
tensorflow::Status CopyTensorToDevice(
|
||||
tensorflow::TensorHandle* handle,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
handle->Ref();
|
||||
TF_Status status;
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_to_device(
|
||||
context_, tensorflow::wrap(handle), &status, info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status CopyTensorFromDevice(
|
||||
tensorflow::TensorHandle* handle,
|
||||
const tensorflow::string& target_device_name,
|
||||
tensorflow::TensorHandle** result) override {
|
||||
TF_Status status;
|
||||
handle->Ref();
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, tensorflow::wrap(handle), target_device_name.c_str(), &status,
|
||||
info_);
|
||||
handle->Release();
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(result_handle));
|
||||
(*result)->Ref();
|
||||
TFE_DeleteTensorHandle(result_handle);
|
||||
return status.status;
|
||||
}
|
||||
|
||||
tensorflow::Status Execute(const tensorflow::EagerOperation* op,
|
||||
tensorflow::TensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
std::vector<TFE_TensorHandle*> outputs(*num_retvals);
|
||||
TF_Status status;
|
||||
device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status,
|
||||
info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(
|
||||
tensorflow::unwrap(outputs[i]));
|
||||
retvals[i]->Ref();
|
||||
TFE_DeleteTensorHandle(outputs[i]);
|
||||
}
|
||||
}
|
||||
return status.status;
|
||||
}
|
||||
|
||||
private:
|
||||
TFE_Context* context_;
|
||||
TFE_CustomDevice device_;
|
||||
void* info_;
|
||||
string name_;
|
||||
};
|
||||
TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context,
|
||||
TFE_TensorHandle** handles,
|
||||
int num_handles, TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"This custom device does not support packing tensors.");
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
@ -1036,8 +1105,12 @@ extern "C" {
|
||||
void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
const char* device_name, void* device_info,
|
||||
TF_Status* status) {
|
||||
auto custom_device =
|
||||
std::make_unique<CustomDeviceAPI>(ctx, device, device_info, device_name);
|
||||
// Fill in default values for optional functionality.
|
||||
if (device.pack == nullptr) {
|
||||
device.pack = &DefaultCustomDevicePack;
|
||||
}
|
||||
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
|
||||
ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
|
@ -613,8 +613,23 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
|
||||
std::vector<tensorflow::TensorHandle*> tensor_handles;
|
||||
tensor_handles.reserve(*num_handles);
|
||||
for (int i = 0; i < *num_handles; ++i) {
|
||||
tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
|
||||
tensorflow::unwrap(handles[i]);
|
||||
if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
|
||||
// One of the inputs we're trying to pack is on a custom device. We'll let
|
||||
// the first custom device we see handle all of the packing.
|
||||
auto* custom_device_handle =
|
||||
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
|
||||
unwrapped_handle);
|
||||
tensorflow::ImmediateExecutionTensorHandle* result;
|
||||
status->status = custom_device_handle->device()->Pack(
|
||||
absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
|
||||
tensorflow::unwrap(handles), *num_handles),
|
||||
&result);
|
||||
return tensorflow::wrap(result);
|
||||
}
|
||||
tensor_handles.push_back(
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i])));
|
||||
tensorflow::TensorHandleFromInterface(unwrapped_handle));
|
||||
}
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
|
@ -435,16 +435,16 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
|
||||
// to have a non-string representation of devices (TF_Device) extracted from
|
||||
// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
|
||||
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 3
|
||||
#define TFE_CUSTOM_DEVICE_VERSION 4
|
||||
|
||||
// Struct to be filled in
|
||||
// Struct to be filled in. Functions are required except where indicated.
|
||||
typedef struct TFE_CustomDevice {
|
||||
int version = TFE_CUSTOM_DEVICE_VERSION;
|
||||
// Method to copy a tensor to the custom device.
|
||||
TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
|
||||
TFE_TensorHandle* tensor,
|
||||
TF_Status* status,
|
||||
void* device_info) = nullptr;
|
||||
void* device_info);
|
||||
|
||||
// Method to copy a tensor from the custom device to a target device.
|
||||
TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
|
||||
@ -468,6 +468,16 @@ typedef struct TFE_CustomDevice {
|
||||
|
||||
// Method to delete a device.
|
||||
void (*delete_device)(void* device_info);
|
||||
|
||||
// Implements TFE_CreatePackedTensorHandle when one of `handles` is on this
|
||||
// custom device.
|
||||
//
|
||||
// Many devices will want to simply return an "unimplemented" status
|
||||
// here. This is the default behavior if `pack` is null when passed to
|
||||
// TFE_RegisterCustomDevice.
|
||||
TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles,
|
||||
int num_handles, TF_Status* s,
|
||||
void* device_info) = nullptr;
|
||||
} TFE_CustomDevice;
|
||||
|
||||
// Registers a custom device for use with eager execution.
|
||||
|
@ -424,7 +424,7 @@ void TensorHandleSilentCopy(bool async,
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu));
|
||||
auto gpu_arg =
|
||||
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu));
|
||||
auto gpu_device = absl::get<tensorflow::Device*>(gpu_arg->device());
|
||||
auto gpu_device = gpu_arg->device();
|
||||
ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device));
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
|
@ -40,14 +40,6 @@ std::unique_ptr<CompositeDevice> CompositeDevice::MakeDevice(
|
||||
errors::InvalidArgument("underlying_devices should not be empty."));
|
||||
return nullptr;
|
||||
}
|
||||
std::set<string> unique_devices;
|
||||
for (const string& device : underlying_devices) {
|
||||
if (!unique_devices.insert(device).second) {
|
||||
status->Update(errors::InvalidArgument(
|
||||
"Got a duplicated device in underlying_devices: ", device));
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) {
|
||||
status->Update(tensorflow::errors::InvalidArgument(
|
||||
|
@ -50,21 +50,6 @@ TEST(CompositeDeviceTest, Basic) {
|
||||
EXPECT_EQ(underlying_devices, *composite_device->underlying_devices());
|
||||
}
|
||||
|
||||
{
|
||||
Status status;
|
||||
underlying_devices.push_back(
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
std::unique_ptr<CompositeDevice> composite_device =
|
||||
CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1,
|
||||
parsed_host_name, &status);
|
||||
EXPECT_EQ(composite_device, nullptr);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.code());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(status.error_message(), "Got a duplicated device"))
|
||||
<< status.ToString();
|
||||
underlying_devices.pop_back();
|
||||
}
|
||||
|
||||
{
|
||||
Status status;
|
||||
underlying_devices.push_back(
|
||||
|
@ -124,6 +124,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
],
|
||||
}),
|
||||
@ -228,6 +229,7 @@ tf_cuda_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":attr_builder",
|
||||
":custom_device",
|
||||
":context",
|
||||
":eager_executor",
|
||||
":kernel_and_device",
|
||||
@ -631,6 +633,7 @@ tf_cuda_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
":custom_device",
|
||||
":attr_builder",
|
||||
":eager_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
|
@ -24,11 +24,7 @@ limitations under the License.
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsCPU(tensorflow::VariantDevice variant) {
|
||||
if (VariantDeviceIsCustom(variant)) {
|
||||
return false;
|
||||
}
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
|
||||
bool IsCPU(tensorflow::Device* d) {
|
||||
return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
|
||||
}
|
||||
|
||||
@ -43,20 +39,6 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (VariantDeviceIsCustom(device())) {
|
||||
auto* custom_device = absl::get<CustomDevice*>(device());
|
||||
TensorHandle* copy;
|
||||
*status = custom_device->CopyTensorFromDevice(this, ctx_->HostCPU()->name(),
|
||||
©);
|
||||
if (status->ok()) {
|
||||
auto result = copy->Resolve(status);
|
||||
copy->Unref();
|
||||
return result;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (Type() == REMOTE) {
|
||||
const tensorflow::Tensor* t = nullptr;
|
||||
TensorHandle* h_cpu = nullptr;
|
||||
@ -124,14 +106,13 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) {
|
||||
ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
ImmediateExecutionTensorHandle* handle, const char* device_name,
|
||||
Status* status) {
|
||||
TensorHandle* input = TensorHandleFromInterface(handle);
|
||||
TensorHandle* result = nullptr;
|
||||
ImmediateExecutionTensorHandle* result = nullptr;
|
||||
Device* device;
|
||||
*status = this->FindDeviceFromName(device_name, &device);
|
||||
if (!status->ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
if (this->FindCustomDeviceFromName(device_name, &dev)) {
|
||||
*status = dev->CopyTensorToDevice(input, &result);
|
||||
*status = dev->CopyTensorToDevice(handle, &result);
|
||||
if (status->ok()) {
|
||||
return result;
|
||||
}
|
||||
@ -142,13 +123,13 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
return nullptr;
|
||||
}
|
||||
// Handle tensor handles currently in custom devices
|
||||
const char* handle_device_name = input->DeviceName(status);
|
||||
const char* handle_device_name = handle->DeviceName(status);
|
||||
if (!status->ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
|
||||
*status = dev->CopyTensorFromDevice(input, device_name, &result);
|
||||
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
|
||||
if (status->ok()) {
|
||||
return result;
|
||||
}
|
||||
@ -156,8 +137,10 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
}
|
||||
|
||||
// Handle regular case.
|
||||
TensorHandle* input = TensorHandleFromInterface(handle);
|
||||
*status =
|
||||
EagerCopyToDevice(input, this, &this->Executor(), device, false, &result);
|
||||
EagerCopyToDevice(input, this, &this->Executor(), device, false,
|
||||
reinterpret_cast<tensorflow::TensorHandle**>(&result));
|
||||
if (status->ok()) {
|
||||
return result;
|
||||
}
|
||||
@ -213,16 +196,38 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) {
|
||||
// eager_operation.cc we can avoid a circular dependency between them.
|
||||
Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) {
|
||||
for (int i = 0; i < Inputs().size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(Inputs()[i]->WaitUnknownDevice());
|
||||
for (ImmediateExecutionTensorHandle* handle : inputs_) {
|
||||
if (TensorHandle::classof(handle)) {
|
||||
TF_RETURN_IF_ERROR(down_cast<TensorHandle*>(handle)->WaitUnknownDevice());
|
||||
}
|
||||
}
|
||||
|
||||
// Decide to either run the operation on a custom device or copy off all of
|
||||
// the custom device inputs.
|
||||
VariantDevice maybe_custom_device = Device();
|
||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
|
||||
!inputs_are_tensor_handles_) {
|
||||
// If the op wasn't placed on a custom device explicitly and there are no
|
||||
// non-TensorHandle inputs, the op will definitely be placed on a physical
|
||||
// device. Otherwise we need to check the inputs one by one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
|
||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
|
||||
ImmediateExecutionTensorHandle** retval_array =
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
|
||||
return absl::get<CustomDevice*>(maybe_custom_device)
|
||||
->Execute(this, retval_array, num_retvals);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
|
||||
}
|
||||
}
|
||||
|
||||
// Run eager placement logic.
|
||||
VariantDevice device;
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this));
|
||||
if (device == kVariantDeviceNull) {
|
||||
class Device* device = absl::get<class Device*>(maybe_custom_device);
|
||||
if (device == nullptr) {
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
||||
}
|
||||
if (device == kVariantDeviceNull && ctx_.PinSmallOpsToCPU()) {
|
||||
if (device == nullptr && ctx_.PinSmallOpsToCPU()) {
|
||||
bool pin_to_cpu;
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
|
||||
&pin_to_cpu, Name(), GetInputs(), ctx_.HostCPU()->name()));
|
||||
@ -231,16 +236,13 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
}
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle** retval_array =
|
||||
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return absl::get<CustomDevice*>(device)->Execute(this, retval_array,
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
if (device != kVariantDeviceNull) {
|
||||
if (device != nullptr) {
|
||||
SetDevice(device);
|
||||
}
|
||||
// At this point all inputs and outputs are TensorHandles associated with
|
||||
// physical devices.
|
||||
tensorflow::TensorHandle** retval_array =
|
||||
reinterpret_cast<tensorflow::TensorHandle**>(retvals.data());
|
||||
return EagerExecute(this, retval_array, num_retvals);
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -26,6 +27,7 @@ namespace tensorflow {
|
||||
|
||||
class TensorHandle;
|
||||
class EagerOperation;
|
||||
class CustomDeviceTensorHandle;
|
||||
|
||||
// Custom devices intercept the execution of operations (the `Execute` method),
|
||||
// typically implemented with one or more of the custom device's own executions.
|
||||
@ -33,15 +35,22 @@ class CustomDevice {
|
||||
public:
|
||||
virtual ~CustomDevice() {}
|
||||
virtual const string& name() = 0;
|
||||
virtual Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) = 0;
|
||||
virtual Status CopyTensorToDevice(
|
||||
ImmediateExecutionTensorHandle* tensor,
|
||||
ImmediateExecutionTensorHandle** result) = 0;
|
||||
|
||||
virtual Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const string& target_device_name,
|
||||
TensorHandle** result) = 0;
|
||||
virtual Status CopyTensorFromDevice(
|
||||
ImmediateExecutionTensorHandle* tensor, const string& target_device_name,
|
||||
ImmediateExecutionTensorHandle** result) = 0;
|
||||
|
||||
virtual Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
virtual Status Execute(const ImmediateExecutionOperation* op,
|
||||
ImmediateExecutionTensorHandle** retvals,
|
||||
int* num_retvals) = 0;
|
||||
|
||||
// Creates a packed TensorHandle from a group of custom device TensorHandles,
|
||||
// one of which is on this custom device.
|
||||
virtual Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
|
||||
ImmediateExecutionTensorHandle** result) = 0;
|
||||
};
|
||||
|
||||
// Custom devices do many of the same things as physical Devices, but have a
|
||||
@ -49,6 +58,10 @@ class CustomDevice {
|
||||
// operations may be placed either on custom or physical devices.
|
||||
using VariantDevice = absl::variant<Device*, CustomDevice*>;
|
||||
|
||||
// Indicates either HostCPU or an unset physical device. We never set a null
|
||||
// CustomDevice*.
|
||||
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
|
||||
|
||||
// A tensor handle produced by a custom device. Generally they can only be
|
||||
// consumed by executing an operation on the same custom device that produced it
|
||||
// originally, or by attempting to copy the handle off the custom device.
|
||||
@ -65,6 +78,10 @@ class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle {
|
||||
device_(device),
|
||||
dtype_(dtype) {}
|
||||
|
||||
// TODO(allenl): Should this be a generic method of
|
||||
// ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer?
|
||||
virtual void* DevicePointer() const = 0;
|
||||
|
||||
tensorflow::DataType DataType() const override { return dtype_; }
|
||||
Status Shape(PartialTensorShape* shape) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
|
@ -28,24 +28,31 @@ class TestCustomDevice : public CustomDevice {
|
||||
public:
|
||||
explicit TestCustomDevice(std::string name) : name_(name) {}
|
||||
const std::string& name() override { return name_; }
|
||||
Status CopyTensorToDevice(TensorHandle* tensor,
|
||||
TensorHandle** result) override {
|
||||
Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
*result = tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
Status CopyTensorFromDevice(TensorHandle* tensor,
|
||||
const std::string& target_device_name,
|
||||
TensorHandle** result) override {
|
||||
Status CopyTensorFromDevice(
|
||||
ImmediateExecutionTensorHandle* tensor,
|
||||
const std::string& target_device_name,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
tensor->Ref();
|
||||
*result = tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
Status Execute(const EagerOperation* op, TensorHandle** retvals,
|
||||
Status Execute(const ImmediateExecutionOperation* op,
|
||||
ImmediateExecutionTensorHandle** retvals,
|
||||
int* num_retvals) override {
|
||||
return errors::Unimplemented("Not implemented");
|
||||
}
|
||||
|
||||
Status Pack(absl::Span<ImmediateExecutionTensorHandle*> handles,
|
||||
ImmediateExecutionTensorHandle** result) override {
|
||||
return errors::Unimplemented("Packing is not implemented");
|
||||
}
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
};
|
||||
@ -57,6 +64,7 @@ class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle {
|
||||
tensorflow::DataType dtype)
|
||||
: CustomDeviceTensorHandle(context, device, dtype) {}
|
||||
|
||||
void* DevicePointer() const override { return nullptr; }
|
||||
Status NumDims(int* num_dims) const override {
|
||||
*num_dims = 1;
|
||||
return Status::OK();
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -31,10 +32,11 @@ namespace tensorflow {
|
||||
// Clear(), and then Reset(...) with the same arguments that would have
|
||||
// been provided to the constructor.
|
||||
void EagerOperation::Clear() {
|
||||
for (TensorHandle* h : inputs_) {
|
||||
for (ImmediateExecutionTensorHandle* h : inputs_) {
|
||||
h->Unref();
|
||||
}
|
||||
inputs_.clear();
|
||||
inputs_are_tensor_handles_ = true;
|
||||
ClearInferenceState();
|
||||
}
|
||||
|
||||
@ -263,7 +265,12 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) {
|
||||
}
|
||||
|
||||
Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
ImmediateExecutionTensorHandle* h =
|
||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
||||
if (CustomDeviceTensorHandle::classof(h)) {
|
||||
inputs_are_tensor_handles_ = false;
|
||||
}
|
||||
AddTensorHandle(h);
|
||||
return MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
@ -271,7 +278,13 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
||||
Status EagerOperation::AddInputList(
|
||||
absl::Span<AbstractTensorHandle* const> inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(input)) {
|
||||
inputs_are_tensor_handles_ = false;
|
||||
}
|
||||
ImmediateExecutionTensorHandle* h =
|
||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||
AddTensorHandle(h);
|
||||
}
|
||||
return InferInputListAttrs(inputs.size());
|
||||
@ -317,7 +330,8 @@ Status EagerOperation::Reset(
|
||||
return SetDeviceName(device_name);
|
||||
}
|
||||
|
||||
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
||||
Status EagerOperation::MaybeInferSingleInputAttrs(
|
||||
ImmediateExecutionTensorHandle* handle) {
|
||||
if (!op_def_) return Status::OK();
|
||||
|
||||
const auto& input_def = op_def_->input_arg(inference_arg_idx_++);
|
||||
@ -334,7 +348,7 @@ Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
|
||||
const std::string& type_attr = input_def.type_attr();
|
||||
if (!type_attr.empty() &&
|
||||
inference_attrs_.find(type_attr) == inference_attrs_.end()) {
|
||||
MutableAttrs()->Set(type_attr, handle->dtype);
|
||||
MutableAttrs()->Set(type_attr, handle->DataType());
|
||||
inference_attrs_.insert(type_attr);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -372,12 +386,13 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
if (!input_def.type_list_attr().empty()) {
|
||||
std::vector<DataType> dtypes(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
dtypes[i] = inputs_[start + i]->dtype;
|
||||
dtypes[i] = inputs_[start + i]->DataType();
|
||||
}
|
||||
InferMixedTypeInputListAttrs(input_def, dtypes);
|
||||
} else if (!input_def.type_attr().empty() &&
|
||||
!input_def.number_attr().empty()) {
|
||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs);
|
||||
InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(),
|
||||
num_inputs);
|
||||
} else if (!input_def.number_attr().empty()) {
|
||||
if (inference_attrs_.find(input_def.number_attr()) ==
|
||||
inference_attrs_.end()) {
|
||||
@ -390,6 +405,28 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerOperation::TensorHandleInputs(
|
||||
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
|
||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
||||
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
|
||||
&inputs_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("The operation unexpectedly had custom devices.");
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerOperation::MutableTensorHandleInputs(
|
||||
absl::InlinedVector<TensorHandle*, 4>** inputs) {
|
||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
||||
*inputs =
|
||||
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Internal("The operation unexpectedly had custom devices.");
|
||||
}
|
||||
}
|
||||
|
||||
Status EagerOperation::SetDeviceName(const char* c_name) {
|
||||
string name(c_name != nullptr ? c_name : "");
|
||||
if (name != last_set_device_name_) {
|
||||
@ -423,6 +460,16 @@ bool EagerOperation::IsLocal() const {
|
||||
device_parsed_name_.task == host_cpu_name.task;
|
||||
}
|
||||
|
||||
string VariantDeviceDebugString(VariantDevice device) {
|
||||
if (device == kVariantDeviceNull) {
|
||||
return "[]";
|
||||
} else if (absl::holds_alternative<CustomDevice*>(device)) {
|
||||
return absl::get<CustomDevice*>(device)->name();
|
||||
} else {
|
||||
return absl::get<Device*>(device)->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
string EagerOperation::DebugString() const {
|
||||
string out;
|
||||
VLOG(1) << "EagerOperation::DebugString() over " << this;
|
||||
@ -442,10 +489,36 @@ string EagerOperation::DebugString() const {
|
||||
return out;
|
||||
}
|
||||
|
||||
void EagerOperation::AddTensorHandle(TensorHandle* h) {
|
||||
void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
|
||||
h->Ref();
|
||||
inputs_.push_back(h);
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
Status EagerOperation::CopyOffCustomDeviceInputs() {
|
||||
if (absl::holds_alternative<CustomDevice*>(device_)) {
|
||||
return errors::Internal(
|
||||
"Trying to copy inputs to a custom device op off a custom device.");
|
||||
}
|
||||
for (int i = 0; i < inputs_.size(); ++i) {
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
|
||||
CustomDeviceTensorHandle* previous =
|
||||
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
|
||||
class Device* target_device;
|
||||
if (device_ == kVariantDeviceNull) {
|
||||
target_device = ctx_.HostCPU();
|
||||
} else {
|
||||
target_device = absl::get<class Device*>(device_);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
|
||||
previous, target_device->name(), &inputs_[i]));
|
||||
previous->Unref();
|
||||
}
|
||||
}
|
||||
inputs_are_tensor_handles_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -39,7 +39,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
explicit EagerOperation(tensorflow::EagerContext* ctx)
|
||||
: ImmediateExecutionOperation(kEager), ctx_(*ctx) {}
|
||||
~EagerOperation() override {
|
||||
for (TensorHandle* h : inputs_) {
|
||||
for (ImmediateExecutionTensorHandle* h : inputs_) {
|
||||
h->Unref();
|
||||
}
|
||||
}
|
||||
@ -69,8 +69,9 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
|
||||
void SetDevice(VariantDevice device) {
|
||||
device_ = device;
|
||||
device_name_ =
|
||||
device == kVariantDeviceNull ? "" : VariantDeviceName(device);
|
||||
device_name_ = absl::visit(
|
||||
[](auto* device) { return device == nullptr ? "" : device->name(); },
|
||||
device);
|
||||
DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_);
|
||||
// TODO(b/154133594): Due to intricacies of external logic, we can not
|
||||
// set this do device_name_ as it would be natural, because we need the
|
||||
@ -141,10 +142,18 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
AttrBuilder* MutableAttrs() { return &attrs_; }
|
||||
const AttrBuilder& Attrs() const { return attrs_; }
|
||||
|
||||
const absl::InlinedVector<TensorHandle*, 4>& Inputs() const {
|
||||
// TensorHandleInputs and MutableTensorHandleInputs first check that all
|
||||
// inputs are TensorHandles, i.e. that there are no custom device inputs. They
|
||||
// return a bad status otherwise.
|
||||
Status TensorHandleInputs(
|
||||
const absl::InlinedVector<TensorHandle*, 4>** inputs) const;
|
||||
Status MutableTensorHandleInputs(
|
||||
absl::InlinedVector<TensorHandle*, 4>** inputs);
|
||||
|
||||
const absl::InlinedVector<ImmediateExecutionTensorHandle*, 4>& Inputs()
|
||||
const {
|
||||
return inputs_;
|
||||
}
|
||||
absl::InlinedVector<TensorHandle*, 4>* MutableInputs() { return &inputs_; }
|
||||
|
||||
void UpdateInput(int i, TensorHandle* h);
|
||||
|
||||
@ -180,7 +189,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
}
|
||||
|
||||
private:
|
||||
void AddTensorHandle(TensorHandle* h);
|
||||
void AddTensorHandle(ImmediateExecutionTensorHandle* h);
|
||||
|
||||
const tensorflow::OpDef* GetOpDef(Status* status);
|
||||
|
||||
@ -190,7 +199,7 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
inference_attrs_.clear_no_resize();
|
||||
}
|
||||
|
||||
Status MaybeInferSingleInputAttrs(TensorHandle* handle);
|
||||
Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle);
|
||||
Status InferInputListAttrs(int num_inputs);
|
||||
|
||||
void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||
@ -198,11 +207,21 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||
const std::vector<DataType>& dtypes);
|
||||
|
||||
// Replaces input tensors placed on custom devices with physical device
|
||||
// equivalents. Used if an op is placed on a physical device but may have
|
||||
// custom device inputs.
|
||||
Status CopyOffCustomDeviceInputs();
|
||||
|
||||
tensorflow::EagerContext& ctx_;
|
||||
const char* op_name_ = nullptr;
|
||||
AttrBuilder attrs_;
|
||||
const AttrTypeMap* attr_types_;
|
||||
absl::InlinedVector<TensorHandle*, 4> inputs_;
|
||||
|
||||
// Toggled to indicate whether all inputs are known to be TensorHandles and
|
||||
// not another type (e.g. custom device tensor handles). Explicitly set to
|
||||
// false when custom device TensorHandles are added.
|
||||
bool inputs_are_tensor_handles_ = true;
|
||||
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
|
||||
|
||||
// The last device name given to SetDeviceName.
|
||||
// This is used to avoid having to re-process the same device in repeated
|
||||
@ -240,8 +259,8 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
};
|
||||
|
||||
inline void EagerOperation::UpdateInput(int i, TensorHandle* h) {
|
||||
TensorHandle** slot = &inputs_[i];
|
||||
TensorHandle* existing = *slot;
|
||||
ImmediateExecutionTensorHandle** slot = &inputs_[i];
|
||||
ImmediateExecutionTensorHandle* existing = *slot;
|
||||
if (existing != h) {
|
||||
h->Ref();
|
||||
existing->Unref();
|
||||
|
@ -81,14 +81,6 @@ const string& DeviceNameOrUnspecified(Device* device) {
|
||||
return (device == nullptr) ? *unspecified_string : device->name();
|
||||
}
|
||||
|
||||
const string& DeviceNameOrUnspecified(VariantDevice device) {
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return absl::get<CustomDevice*>(device)->name();
|
||||
} else {
|
||||
return DeviceNameOrUnspecified(absl::get<Device*>(device));
|
||||
}
|
||||
}
|
||||
|
||||
// Returns whether a kernel should be cached.
|
||||
bool KernelCacheEnabled(const OpDef& op_def) {
|
||||
if (data::DatasetOpKernel::IsDatasetOp(&op_def)) {
|
||||
@ -200,9 +192,10 @@ Status ValidateInputTypeAndPlacement(
|
||||
const bool is_function = kernel->IsFunction();
|
||||
if (n_inputs > 0) {
|
||||
const DataType* input_types = &kernel->input_dtypes()[0];
|
||||
TensorHandle* const* handles = &op->Inputs()[0];
|
||||
const absl::InlinedVector<TensorHandle*, 4>* handles;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles));
|
||||
for (int i = 0; i < n_inputs; ++i) {
|
||||
TensorHandle* handle = handles[i];
|
||||
TensorHandle* handle = (*handles)[i];
|
||||
Device* expected_device = kernel->InputDevice(i);
|
||||
if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) {
|
||||
// Extract a handle on the op device from a packed input.
|
||||
@ -220,13 +213,7 @@ Status ValidateInputTypeAndPlacement(
|
||||
}
|
||||
}
|
||||
}
|
||||
auto handle_device_variant = handle->DeviceOrHostCPU(*ctx);
|
||||
if (VariantDeviceIsCustom(handle_device_variant)) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices and remote execution are not yet supported "
|
||||
"together.");
|
||||
}
|
||||
Device* handle_device = absl::get<Device*>(handle_device_variant);
|
||||
Device* handle_device = handle->DeviceOrHostCPU(*ctx);
|
||||
const bool maybe_copy =
|
||||
!is_function || handle->Type() != TensorHandle::REMOTE;
|
||||
// If the input is already on the right device, then nothing to do.
|
||||
@ -280,14 +267,10 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
|
||||
|
||||
Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
|
||||
Device** result) {
|
||||
if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) {
|
||||
return errors::Unimplemented(
|
||||
"The kernel cache does not work with custom devices.");
|
||||
}
|
||||
Device* cpu_device = ctx.HostCPU();
|
||||
string device_name;
|
||||
if (tensor_handle->Type() != TensorHandle::LOCAL) {
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
Device* device = tensor_handle->device();
|
||||
device_name = device != nullptr ? device->name() : cpu_device->name();
|
||||
*result = (device == nullptr ? cpu_device : device);
|
||||
} else if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
@ -304,7 +287,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
|
||||
ctx.FindDeviceFromName(device_name.c_str(), &input_device));
|
||||
*result = input_device;
|
||||
} else {
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
Device* device = tensor_handle->device();
|
||||
const bool is_tpu = device != nullptr && device->device_type() == "TPU";
|
||||
// int32 return values can be placed on TPUs.
|
||||
const bool use_host_memory =
|
||||
@ -431,8 +414,10 @@ Status GetOrCreateKernelAndDevice(
|
||||
profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
input_dev_ptrs.reserve(op->Inputs().size());
|
||||
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
||||
TensorHandle* input = op->Inputs()[i];
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
for (int i = 0, end = inputs->size(); i < end; i++) {
|
||||
TensorHandle* input = (*inputs)[i];
|
||||
|
||||
// Get device for this input, and add it to 'cache_key'.
|
||||
Device* input_device;
|
||||
@ -477,7 +462,7 @@ Status GetOrCreateKernelAndDevice(
|
||||
core::RefCountPtr<KernelAndDevice> kernel = ctx.GetCachedKernel(cache_key);
|
||||
if (kernel == nullptr) {
|
||||
DVLOG(2) << "Creating new kernel for " << op->Name() << " on device "
|
||||
<< DeviceNameOrUnspecified(op->Device());
|
||||
<< DeviceNameOrUnspecified(absl::get<Device*>(op->Device()));
|
||||
bool run_function_with_flr = false;
|
||||
bool function_outputs_on_op_device = false;
|
||||
if (op->is_function()) {
|
||||
@ -656,9 +641,11 @@ Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
|
||||
remote_func_params, &ctx, &retvals[i]));
|
||||
}
|
||||
}
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
auto node = absl::make_unique<AsyncExecuteNode>(
|
||||
&ctx, op->Inputs(), remote_func_params, std::move(kernel),
|
||||
graph_collector, op->GetCancellationManager(),
|
||||
&ctx, *inputs, remote_func_params, std::move(kernel), graph_collector,
|
||||
op->GetCancellationManager(),
|
||||
absl::Span<TensorHandle*>(retvals, num_outputs), op->GetStackTrace());
|
||||
// Release the inputs from the eager operation since the AsyncExecuteNode
|
||||
// would have taken ownership. This allows the inputs to be forwarded if
|
||||
@ -673,8 +660,10 @@ Status AddOrExecuteNode(core::RefCountPtr<KernelAndDevice> kernel,
|
||||
for (int i = 0, end = num_outputs; i < end; ++i) {
|
||||
retvals[i] = nullptr;
|
||||
}
|
||||
ExecuteNode node(&ctx, op->Inputs(), remote_func_params, kernel,
|
||||
graph_collector, op->GetCancellationManager(),
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector,
|
||||
op->GetCancellationManager(),
|
||||
{retvals, static_cast<size_t>(num_outputs)});
|
||||
Status s = executor.SyncExecute(&node);
|
||||
// We release the inputs AFTER executing the operation in sync mode since
|
||||
@ -764,8 +753,10 @@ Status MaybePackInputTensor(EagerOperation* op) {
|
||||
return Status::OK();
|
||||
}
|
||||
EagerContext& ctx = op->EagerContext();
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
for (int i = 0; i < inputs->size(); ++i) {
|
||||
TensorHandle* handle = (*inputs)[i];
|
||||
if (handle->Type() == TensorHandle::PACKED) {
|
||||
EagerOperation pack_op(&ctx);
|
||||
TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr,
|
||||
@ -842,7 +833,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to find remote task corresponding to device ",
|
||||
VariantDeviceName(op->Device()));
|
||||
op->DeviceName());
|
||||
}
|
||||
|
||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||
@ -855,11 +846,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
profiler::TraceMe activity("CopyInputToExpectedDevice",
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
const bool is_function = op->is_function();
|
||||
for (int i = 0, end = op->Inputs().size(); i < end; i++) {
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
tensorflow::Device* input_device = absl::get<Device*>(input->device());
|
||||
tensorflow::Device* input_device_or_cpu =
|
||||
absl::get<Device*>(input->DeviceOrHostCPU(ctx));
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
for (int i = 0, end = inputs->size(); i < end; i++) {
|
||||
tensorflow::TensorHandle* input = (*inputs)[i];
|
||||
tensorflow::Device* input_device = input->device();
|
||||
tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx);
|
||||
const string* input_device_name = &input_device_or_cpu->name();
|
||||
bool serialize_resource_dtype_and_shape = false;
|
||||
if (op_device != input_device &&
|
||||
@ -876,9 +868,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// Always copy to the remote CPU so that the actual device can be
|
||||
// correctly determined after the kernel is selected/instantiated,
|
||||
// since the op might have its inputs on host memory.
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
Device* handle_device =
|
||||
absl::get<Device*>(handle->DeviceOrHostCPU(ctx));
|
||||
TensorHandle* handle = input;
|
||||
Device* handle_device = handle->DeviceOrHostCPU(ctx);
|
||||
// If the input is already on the right device, then nothing to do.
|
||||
if (remote_cpu_device != handle_device) {
|
||||
TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
|
||||
@ -959,11 +950,14 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
DVLOG(4) << "Execute remote eager op: " << op->Name()
|
||||
<< " (is async?: " << executor.Async() << ").";
|
||||
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs));
|
||||
|
||||
std::unique_ptr<EagerNode> node(new eager::RemoteExecuteNode(
|
||||
&op->EagerContext(), std::move(request), op_device,
|
||||
ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(),
|
||||
op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(),
|
||||
op->Inputs(), {retvals, num_outputs}));
|
||||
*inputs, {retvals, num_outputs}));
|
||||
|
||||
if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) {
|
||||
string msg = strings::StrCat(
|
||||
@ -1020,7 +1014,7 @@ Status GetKernelOutputs(
|
||||
"kernel. This should never happen.");
|
||||
}
|
||||
if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) !=
|
||||
absl::get<Device*>(retvals[i]->device()))) {
|
||||
retvals[i]->device())) {
|
||||
return errors::Internal(
|
||||
"Kernel output tensor handle locates on a different device than "
|
||||
"the specified kernel output device. This should never happen.");
|
||||
@ -1037,8 +1031,8 @@ Status GetKernelOutputs(
|
||||
"Remote outputs are not available on mobile devices.");
|
||||
#else // !IS_MOBILE_PLATFORM
|
||||
TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape(
|
||||
absl::get<TensorShape>(ret),
|
||||
absl::get<Device*>(retvals[i]->device()), ctx->GetContextViewId()));
|
||||
absl::get<TensorShape>(ret), retvals[i]->device(),
|
||||
ctx->GetContextViewId()));
|
||||
#endif // !IS_MOBILE_PLATFORM
|
||||
}
|
||||
}
|
||||
@ -1218,11 +1212,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
TensorHandle** result) {
|
||||
TF_RETURN_IF_ERROR(h->WaitUnknownDevice());
|
||||
auto send_device = h->DeviceOrHostCPU(*ctx);
|
||||
if (VariantDeviceIsCustom(send_device)) {
|
||||
return errors::Unimplemented(
|
||||
"Copying a TensorHandle from a custom device is not supported.");
|
||||
}
|
||||
bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
|
||||
bool sender_is_local = send_device->IsLocal();
|
||||
|
||||
bool receiver_is_local = device->IsLocal();
|
||||
|
||||
@ -1363,11 +1353,6 @@ void EagerKernelExecuteAsync(
|
||||
// triggered after execution with its status.
|
||||
void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
||||
int* num_retvals, StatusCallback done) {
|
||||
if (VariantDeviceIsCustom(op->Device())) {
|
||||
done(errors::Unimplemented(
|
||||
"Custom device is not supported in EagerLocalExecuteAsync."));
|
||||
return;
|
||||
}
|
||||
if (!op->IsLocal()) {
|
||||
done(errors::InvalidArgument(
|
||||
"Remote execution is not supported in async EagerLocalExecuteAsync"));
|
||||
@ -1419,8 +1404,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals,
|
||||
output_dtypes[i], &ctx);
|
||||
}
|
||||
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
s = op->TensorHandleInputs(&inputs);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
EagerKernelExecuteAsync(
|
||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
&ctx, *inputs, op->remote_func_params(), std::move(kernel),
|
||||
graph_collector, op->GetCancellationManager(), retvals, num_outputs,
|
||||
[op, num_outputs, retvals, done = std::move(done)](const Status& s) {
|
||||
op->Clear();
|
||||
|
@ -44,8 +44,7 @@ Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx,
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||
// We have validated that h->device() is not a CustomDevice when
|
||||
// constructing a pack TensorHandle.
|
||||
const Status status =
|
||||
h->TensorValue(absl::get<Device*>(h->device()), &packed_arg_flat[i]);
|
||||
const Status status = h->TensorValue(h->device(), &packed_arg_flat[i]);
|
||||
if (!status.ok()) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (IsRemote(ctx, input_device, h)) {
|
||||
@ -107,13 +106,7 @@ Status ExecuteNodeArgs::Init(
|
||||
TF_RETURN_IF_ERROR(
|
||||
op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h));
|
||||
}
|
||||
VariantDevice variant_device = h->device();
|
||||
if (VariantDeviceIsCustom(variant_device)) {
|
||||
return errors::Internal(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
Device* device = absl::get<Device*>(variant_device);
|
||||
Device* device = h->device();
|
||||
// For a multi-device function, a remote RunComponentFunction request is
|
||||
// not sent through StreamingEnqueueAsync. It could arrive at a remote
|
||||
// worker before a remote execution request which produces an input of the
|
||||
|
@ -105,7 +105,7 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) {
|
||||
|
||||
std::vector<Device*> input_devices;
|
||||
for (auto* h : inputs) {
|
||||
input_devices.push_back(absl::get<Device*>(h->DeviceOrHostCPU(*ctx)));
|
||||
input_devices.push_back(h->DeviceOrHostCPU(*ctx));
|
||||
}
|
||||
const core::RefCountPtr<KernelAndDevice> kernel(
|
||||
new TestKernelAndDeviceFunc(std::move(input_devices), device0));
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
@ -138,17 +139,18 @@ Status MaybePinSmallOpsToCpu(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||
const EagerOperation& op) {
|
||||
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
|
||||
if (op.colocation_exempt()) {
|
||||
return Status::OK();
|
||||
}
|
||||
EagerContext& ctx = op.EagerContext();
|
||||
const absl::InlinedVector<TensorHandle*, 4>* inputs;
|
||||
TF_RETURN_IF_ERROR(op.TensorHandleInputs(&inputs));
|
||||
Device* op_device = op.Device() == kVariantDeviceNull
|
||||
? ctx.HostCPU()
|
||||
: absl::get<Device*>(op.Device());
|
||||
for (int i = 0; i < op.Inputs().size(); ++i) {
|
||||
TensorHandle* tensor_handle = op.Inputs()[i];
|
||||
for (int i = 0; i < inputs->size(); ++i) {
|
||||
TensorHandle* tensor_handle = (*inputs)[i];
|
||||
if (tensor_handle->dtype == DT_RESOURCE) {
|
||||
if (tensor_handle->resource_remote_device_incarnation() != 0) {
|
||||
TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice(
|
||||
@ -182,7 +184,7 @@ Status MaybePinToResourceDevice(VariantDevice* device,
|
||||
|
||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
// If operation was already placed on a custom device, use it.
|
||||
if (VariantDeviceIsCustom(op.Device())) {
|
||||
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
|
||||
*device = op.Device();
|
||||
return Status::OK();
|
||||
} else if (!op.DeviceName().empty()) {
|
||||
@ -194,9 +196,13 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
// placement and there is only one custom device in the op inputs.
|
||||
if (!op.Inputs().empty()) {
|
||||
CustomDevice* first = nullptr;
|
||||
for (const TensorHandle* input : op.Inputs()) {
|
||||
if (VariantDeviceIsCustom(input->device())) {
|
||||
CustomDevice* current = absl::get<CustomDevice*>(input->device());
|
||||
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||
const CustomDeviceTensorHandle* input =
|
||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||
CustomDevice* current = input->device();
|
||||
if (first == nullptr) {
|
||||
first = current;
|
||||
} else if (first != current) {
|
||||
@ -207,9 +213,9 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
op.Name(),
|
||||
" has one input in custom "
|
||||
"device ",
|
||||
VariantDeviceName(first),
|
||||
first->name(),
|
||||
" and at least one input in a different custom device ",
|
||||
VariantDeviceName(current)));
|
||||
current->name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,8 +43,7 @@ Status MaybePinSmallOpsToCpu(
|
||||
// If a resource touching input is specified, all resource-touching ops run in
|
||||
// the device the resource is, regardless of anything else that has been
|
||||
// specified. This is identical to the graph mode behavior.
|
||||
Status MaybePinToResourceDevice(VariantDevice* device,
|
||||
const EagerOperation& op);
|
||||
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
|
||||
|
||||
// If all the inputs are on the same custom device, use that custom
|
||||
// device. Otherwise, it is an error to have a custom device as an input.
|
||||
|
@ -54,6 +54,14 @@ int64 GetRemoteDeviceIncarnation(Device* device) {
|
||||
if (device == nullptr || device->IsLocal()) return 0;
|
||||
return device->attributes().incarnation();
|
||||
}
|
||||
|
||||
string SafeDeviceDebugString(Device* device) {
|
||||
if (device == nullptr) {
|
||||
return "[]";
|
||||
} else {
|
||||
return device->DebugString();
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TensorHandle::PackedTensorHandleData::PackedTensorHandleData(
|
||||
@ -231,12 +239,6 @@ TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
}
|
||||
}
|
||||
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t,
|
||||
CustomDevice* d,
|
||||
EagerContext* ctx) {
|
||||
return new TensorHandle(std::move(t), d, ctx);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
Device* resource_device, EagerContext* ctx)
|
||||
: ImmediateExecutionTensorHandle(kEager),
|
||||
@ -249,7 +251,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_)
|
||||
<< " device: " << SafeDeviceDebugString(device_)
|
||||
<< " tensor: " << t.DeviceSafeDebugString();
|
||||
}
|
||||
|
||||
@ -268,26 +270,10 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
t.flat<class ResourceHandle>()(0).dtypes_and_shapes()),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_)
|
||||
<< " device: " << SafeDeviceDebugString(device_)
|
||||
<< " tensor: " << t.DeviceSafeDebugString();
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
||||
EagerContext* ctx)
|
||||
: ImmediateExecutionTensorHandle(kEager),
|
||||
dtype(t.dtype()),
|
||||
device_(d),
|
||||
op_device_(nullptr),
|
||||
resource_device_(nullptr),
|
||||
resource_remote_device_incarnation_(0),
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>, std::move(t)) {
|
||||
// TODO(allenl): Figure out a better op_device story for custom devices,
|
||||
// since always setting it to CPU=nullptr doesn't make much sense.
|
||||
DVLOG(3) << "Creating Local TensorHandle: " << this
|
||||
<< " custom device: " << VariantDeviceDebugString(device_)
|
||||
<< " tensor: " << t.DeviceSafeDebugString();
|
||||
}
|
||||
|
||||
TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
@ -309,7 +295,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
ctx_(ctx),
|
||||
data_(absl::in_place_type<LocalTensorHandleData>) {
|
||||
DVLOG(3) << "Creating empty Local TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
<< " device: " << SafeDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
@ -328,13 +314,10 @@ Status TensorHandle::CreatePackedHandle(std::vector<TensorHandle*>&& handles,
|
||||
handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes));
|
||||
}
|
||||
std::vector<string> devices;
|
||||
devices.reserve(handles.size());
|
||||
for (auto* handle : handles) {
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
devices.push_back(absl::get<CustomDevice*>(handle->device())->name());
|
||||
} else {
|
||||
devices.push_back(handle->op_device() ? handle->op_device()->name()
|
||||
: ctx->HostCPU()->name());
|
||||
}
|
||||
devices.push_back(handle->op_device() ? handle->op_device()->name()
|
||||
: ctx->HostCPU()->name());
|
||||
}
|
||||
|
||||
CompositeDevice* composite_device = nullptr;
|
||||
@ -378,7 +361,7 @@ TensorHandle::TensorHandle(std::vector<TensorHandle*>&& handles, Device* device,
|
||||
data_(absl::in_place_type<PackedTensorHandleData>, std::move(handles),
|
||||
shape) {
|
||||
DVLOG(3) << "Creating a packed TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
<< " device: " << SafeDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
@ -406,7 +389,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||
remote_task, ctx) {
|
||||
DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
<< " device: " << SafeDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
TensorHandle* TensorHandle::CreateLazyRemoteHandle(
|
||||
@ -429,7 +412,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
data_(absl::in_place_type<RemoteTensorHandleData>, op_id, output_num,
|
||||
ctx->GetContextViewId(), is_ready) {
|
||||
DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
<< " device: " << SafeDeviceDebugString(device_);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -487,7 +470,7 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
const tensorflow::Tensor** t) const {
|
||||
DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
if (Type() != LOCAL) {
|
||||
return errors::Internal("Invalid Tensor call on a ", TypeString(),
|
||||
" handle: ", this);
|
||||
@ -511,13 +494,7 @@ Status TensorHandle::TensorFromDevice(const Device* d,
|
||||
Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) {
|
||||
DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (VariantDeviceIsCustom(device_)) {
|
||||
return errors::Internal(
|
||||
"TensorHandle::TensorValue not supported for custom devices yet. "
|
||||
"Handle device: ",
|
||||
VariantDeviceDebugString(device_),
|
||||
", requested device: ", d != nullptr ? d->name() : "(nil)");
|
||||
} else if (d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
if (Type() != LOCAL) {
|
||||
return errors::Internal("Invalid TensorValue call on a ", TypeString(),
|
||||
" handle: ", this);
|
||||
@ -549,13 +526,8 @@ Status TensorHandle::WaitUnknownDevice() const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
||||
if (VariantDeviceIsCustom(device_)) {
|
||||
return device_;
|
||||
} else {
|
||||
Device* d = absl::get<Device*>(device_);
|
||||
return (d == nullptr) ? ctx.HostCPU() : d;
|
||||
}
|
||||
Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
|
||||
return (device_ == nullptr) ? ctx.HostCPU() : device_;
|
||||
}
|
||||
|
||||
Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
|
||||
@ -691,7 +663,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
|
||||
Status TensorHandle::Unprotect(const Device* d) {
|
||||
DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
return absl::visit([](auto& data) { return data.Unprotect(); }, data_);
|
||||
}
|
||||
|
||||
@ -718,7 +690,7 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) {
|
||||
DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this
|
||||
<< " device: " << d;
|
||||
|
||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
return errors::Internal("Cannot add mirror for primary device.");
|
||||
}
|
||||
|
||||
@ -739,7 +711,7 @@ Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready,
|
||||
DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d
|
||||
<< " " << d->name();
|
||||
|
||||
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
|
||||
if (d != device_) {
|
||||
tf_shared_lock l(mu_);
|
||||
auto mirror = remote_mirrors_.find(d->name());
|
||||
if (mirror != remote_mirrors_.end()) {
|
||||
@ -854,7 +826,7 @@ Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape,
|
||||
DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d
|
||||
<< " " << d->name();
|
||||
|
||||
if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
|
||||
if (d != device_) {
|
||||
tf_shared_lock l(mu_);
|
||||
auto remote_mirror = remote_mirrors_.find(d->name());
|
||||
if (remote_mirror == remote_mirrors_.end()) {
|
||||
@ -916,7 +888,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
|
||||
DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d
|
||||
<< " " << d->name();
|
||||
|
||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
DCHECK(Type() == REMOTE)
|
||||
<< "Poison can only be on remote handles: " << this;
|
||||
|
||||
@ -936,7 +908,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d,
|
||||
|
||||
Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
|
||||
const Device* d) {
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
return errors::Internal(
|
||||
"Local mirror assign conflicts with primary device.");
|
||||
}
|
||||
@ -955,7 +927,7 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor,
|
||||
Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
||||
DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles.";
|
||||
|
||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
||||
@ -982,7 +954,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) {
|
||||
void TensorHandle::Poison(Status status, const Device* d) {
|
||||
DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d;
|
||||
|
||||
if (!VariantDeviceIsCustom(device_) && d == absl::get<Device*>(device_)) {
|
||||
if (d == device_) {
|
||||
DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this;
|
||||
absl::visit([status](auto& data) { data.Poison(status); }, data_);
|
||||
} else {
|
||||
@ -1001,7 +973,7 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
|
||||
tensorflow::Device* d,
|
||||
tensorflow::Tensor* output) const {
|
||||
tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d;
|
||||
tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
|
||||
tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
|
||||
const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
|
||||
const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
|
||||
bool is_same_device =
|
||||
@ -1063,27 +1035,6 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
|
||||
return status;
|
||||
}
|
||||
|
||||
bool VariantDeviceIsCustom(VariantDevice variant_device) {
|
||||
return variant_device.index() != 0;
|
||||
}
|
||||
|
||||
string VariantDeviceName(VariantDevice device) {
|
||||
if (device == kVariantDeviceNull) {
|
||||
return "[]";
|
||||
}
|
||||
return absl::visit([](auto* device) { return device->name(); }, device);
|
||||
}
|
||||
|
||||
string VariantDeviceDebugString(VariantDevice device) {
|
||||
if (device == kVariantDeviceNull) {
|
||||
return "[]";
|
||||
} else if (VariantDeviceIsCustom(device)) {
|
||||
return absl::get<CustomDevice*>(device)->name();
|
||||
} else {
|
||||
return absl::get<Device*>(device)->DebugString();
|
||||
}
|
||||
}
|
||||
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
|
||||
if (ctx == nullptr) {
|
||||
return nullptr;
|
||||
@ -1100,10 +1051,9 @@ string TensorHandle::DebugString() const {
|
||||
DVLOG(4) << "Calling TensorHandle::DebugString() on " << this;
|
||||
|
||||
string out;
|
||||
string device_debug = VariantDeviceDebugString(device_);
|
||||
string device_debug = SafeDeviceDebugString(device_);
|
||||
strings::StrAppend(&out, "Device: ", device_debug);
|
||||
bool is_cpu =
|
||||
!VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull;
|
||||
bool is_cpu = device_ != nullptr;
|
||||
// Consider supporting non-CPU tensors and CPU tensors with a device_ set to
|
||||
// non-NULL if needed.
|
||||
strings::StrAppend(
|
||||
@ -1115,9 +1065,6 @@ string TensorHandle::DebugString() const {
|
||||
}
|
||||
|
||||
const char* TensorHandle::DeviceName(Status* status) const {
|
||||
if (VariantDeviceIsCustom(device())) {
|
||||
return absl::get<CustomDevice*>(device())->name().c_str();
|
||||
}
|
||||
status->Update(WaitUnknownDevice());
|
||||
tensorflow::Device* d = op_device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
@ -1125,33 +1072,19 @@ const char* TensorHandle::DeviceName(Status* status) const {
|
||||
}
|
||||
|
||||
const char* TensorHandle::BackingDeviceName(Status* status) const {
|
||||
if (VariantDeviceIsCustom(device())) {
|
||||
return absl::get<tensorflow::CustomDevice*>(device())->name().c_str();
|
||||
} else {
|
||||
status->Update(WaitUnknownDevice());
|
||||
tensorflow::Device* d = absl::get<tensorflow::Device*>(device());
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
status->Update(WaitUnknownDevice());
|
||||
tensorflow::Device* d = device();
|
||||
return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
: d->name().c_str();
|
||||
}
|
||||
|
||||
const char* TensorHandle::DeviceType(Status* status) const {
|
||||
if (VariantDeviceIsCustom(device())) {
|
||||
status->Update(
|
||||
tensorflow::errors::Unimplemented("Custom device unsupported"));
|
||||
return nullptr;
|
||||
}
|
||||
status->Update(WaitUnknownDevice());
|
||||
tensorflow::Device* d = op_device();
|
||||
return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str();
|
||||
}
|
||||
|
||||
int TensorHandle::DeviceId(Status* status) const {
|
||||
if (VariantDeviceIsCustom(device())) {
|
||||
status->Update(
|
||||
tensorflow::errors::Unimplemented("Custom device unsupported"));
|
||||
return -1;
|
||||
}
|
||||
status->Update(WaitUnknownDevice());
|
||||
tensorflow::Device* d = op_device();
|
||||
return (d == nullptr) ? 0 : d->parsed_name().id;
|
||||
|
@ -60,7 +60,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
// TensorHandle for dtype == DT_RESOURCE
|
||||
TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
EagerContext* ctx);
|
||||
TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx);
|
||||
TensorHandle(Device* d, Device* op_device, Device* resource_device,
|
||||
tensorflow::DataType dtype, EagerContext* ctx);
|
||||
|
||||
@ -81,8 +80,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
Device* op_device,
|
||||
Device* resource_device,
|
||||
EagerContext* ctx);
|
||||
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t,
|
||||
CustomDevice* d, EagerContext* ctx);
|
||||
static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
tensorflow::DataType dtype,
|
||||
@ -150,7 +147,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
// requesting the HostCPU.
|
||||
Status TensorValue(const Device* d, tensorflow::TensorValue* t);
|
||||
|
||||
VariantDevice device() const { return device_; }
|
||||
Device* device() const { return device_; }
|
||||
Device* op_device() const { return op_device_; }
|
||||
Device* resource_device() const { return resource_device_; }
|
||||
int64 resource_remote_device_incarnation() const {
|
||||
@ -161,7 +158,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
// are set (data is ready).
|
||||
Status WaitUnknownDevice() const;
|
||||
|
||||
VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||
Device* DeviceOrHostCPU(const EagerContext& ctx) const;
|
||||
|
||||
Status Shape(tensorflow::TensorShape* shape);
|
||||
|
||||
@ -286,7 +283,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
bool IsReady() const;
|
||||
Status WaitReady(const char* caller) const;
|
||||
|
||||
VariantDevice device_;
|
||||
tensorflow::Device* device_;
|
||||
|
||||
// Device in which the op producing this tensor was executed. Equals to
|
||||
// device_ for constant tensors.
|
||||
@ -391,19 +388,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
PartialTensorShape inference_shape_;
|
||||
};
|
||||
|
||||
// Checks whether a VariantDevice contains a custom device.
|
||||
bool VariantDeviceIsCustom(VariantDevice device);
|
||||
|
||||
// Wraps device->name() or CustomDevice->name().
|
||||
string VariantDeviceName(VariantDevice device);
|
||||
|
||||
// Wraps device->DebugString() or CustomDevice->name().
|
||||
string VariantDeviceDebugString(VariantDevice device);
|
||||
|
||||
// Indicates either HostCPU or an unset physical device. We never set a null
|
||||
// CustomDevice*.
|
||||
const VariantDevice kVariantDeviceNull = static_cast<Device*>(nullptr);
|
||||
|
||||
// Returns the device backing the resource. Else, returns nullptr.
|
||||
Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
|
||||
|
||||
|
@ -189,8 +189,8 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT);
|
||||
EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true);
|
||||
|
||||
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
|
||||
absl::get<Device*>(packed_handle->device()));
|
||||
CompositeDevice* device =
|
||||
reinterpret_cast<CompositeDevice*>(packed_handle->device());
|
||||
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||
EXPECT_EQ(device->underlying_devices()->size(), 4);
|
||||
|
||||
@ -200,7 +200,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) {
|
||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h));
|
||||
EXPECT_EQ(absl::get<Device*>(h->device()), ListDevices().at(i));
|
||||
EXPECT_EQ(h->device(), ListDevices().at(i));
|
||||
EXPECT_EQ(h->Type(), expected_handle_types.at(i));
|
||||
}
|
||||
EXPECT_FALSE(IsReady(packed_handle));
|
||||
@ -236,14 +236,14 @@ TEST_F(PackedTensorHandleTest, PackedSingleHandle) {
|
||||
TF_ASSERT_OK(packed_handle->Shape(&packed_shape));
|
||||
EXPECT_EQ(packed_shape, shape);
|
||||
|
||||
CompositeDevice* device = reinterpret_cast<CompositeDevice*>(
|
||||
absl::get<Device*>(packed_handle->device()));
|
||||
CompositeDevice* device =
|
||||
reinterpret_cast<CompositeDevice*>(packed_handle->device());
|
||||
EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0");
|
||||
EXPECT_EQ(device->underlying_devices()->size(), 1);
|
||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 1);
|
||||
TensorHandle* h0 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &h0));
|
||||
EXPECT_EQ(absl::get<Device*>(h0->device()), d);
|
||||
EXPECT_EQ(h0->device(), d);
|
||||
EXPECT_TRUE(IsReady(packed_handle));
|
||||
packed_handle->Unref();
|
||||
}
|
||||
@ -392,7 +392,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
|
||||
TensorHandle* h = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
/*op_id=*/0, /*output_num=*/0, remote_task, dtype, d1, context,
|
||||
/*unknown_device=*/true);
|
||||
EXPECT_EQ(absl::get<Device*>(h->device()), d1);
|
||||
EXPECT_EQ(h->device(), d1);
|
||||
|
||||
Device* d2 = device_mgr.ListDevices().at(2);
|
||||
TF_ASSERT_OK(h->SetRemoteShapeAndDevice(
|
||||
@ -400,7 +400,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) {
|
||||
Status s;
|
||||
EXPECT_EQ(h->BackingDeviceName(&s), d2->name());
|
||||
TF_EXPECT_OK(s);
|
||||
EXPECT_EQ(absl::get<Device*>(h->device()), d2);
|
||||
EXPECT_EQ(h->device(), d2);
|
||||
h->Unref();
|
||||
context->Unref();
|
||||
}
|
||||
|
@ -186,7 +186,7 @@ Status AddOpRetvalsToResponse(
|
||||
for (int i = 0; i < num_retvals; i++) {
|
||||
TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn()));
|
||||
if (add_device_fn) {
|
||||
Device* device = absl::get<Device*>(retvals[i]->device());
|
||||
Device* device = retvals[i]->device();
|
||||
*add_device_fn() = device ? device->name() : "";
|
||||
}
|
||||
if (retvals[i]->Type() == TensorHandle::REMOTE) {
|
||||
|
@ -1086,8 +1086,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||
|
||||
Device* device = absl::get<Device*>(tensor_handle->device());
|
||||
EXPECT_EQ(device, nullptr);
|
||||
EXPECT_EQ(tensor_handle->device(), nullptr);
|
||||
|
||||
auto actual = t->flat<float>();
|
||||
EXPECT_EQ(4, actual.size());
|
||||
@ -1168,8 +1167,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
|
||||
EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED);
|
||||
EXPECT_EQ(packed_handle->NumPackedHandles(), 3);
|
||||
EXPECT_EQ(absl::get<Device*>(packed_handle->device())->name(),
|
||||
composite_device);
|
||||
EXPECT_EQ(packed_handle->device()->name(), composite_device);
|
||||
|
||||
TensorHandle* handle0 = nullptr;
|
||||
TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0));
|
||||
@ -1198,7 +1196,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) {
|
||||
EXPECT_EQ(handle2->op_device()->name(), device2);
|
||||
int64 op_id;
|
||||
int32 output_num;
|
||||
TF_ASSERT_OK(handle2->RemoteAddress(absl::get<Device*>(handle2->device()),
|
||||
TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(),
|
||||
/*wait_until_ready=*/true, &op_id,
|
||||
&output_num));
|
||||
EXPECT_EQ(op_id, 2);
|
||||
|
@ -37,7 +37,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) {
|
||||
remote_op->set_name(op->Name());
|
||||
|
||||
op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
|
||||
remote_op->set_device(VariantDeviceName(op->Device()));
|
||||
remote_op->set_device(op->DeviceName());
|
||||
}
|
||||
|
||||
Status CreateUncachedKernelAndDeviceOp(
|
||||
@ -80,7 +80,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
|
||||
src_(src),
|
||||
ctx_(ctx),
|
||||
executor_(executor),
|
||||
send_device_(absl::get<Device*>(src->DeviceOrHostCPU(*ctx))),
|
||||
send_device_(src->DeviceOrHostCPU(*ctx)),
|
||||
recv_device_(recv_device),
|
||||
wire_id_(GetUniqueWireID()),
|
||||
recv_op_id_(recv_op_id),
|
||||
@ -149,9 +149,8 @@ void RemoteCopyNode::StartSend() {
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
src_, /*wait_until_ready=*/false,
|
||||
remote_op->add_op_inputs()->mutable_remote_handle(),
|
||||
absl::get<Device*>(src_->device()),
|
||||
absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
|
||||
remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(),
|
||||
src_->DeviceOrHostCPU(*ctx_)->name());
|
||||
if (!status.ok()) {
|
||||
captured_state_->SetSendStatus(status);
|
||||
return;
|
||||
@ -310,7 +309,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
const Device* target_device, EagerContext* ctx,
|
||||
SendPackedHandleOp* op) {
|
||||
op->set_op_id(op_id);
|
||||
op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx)));
|
||||
op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name());
|
||||
for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) {
|
||||
TensorHandle* h = nullptr;
|
||||
TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h));
|
||||
@ -329,7 +328,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
// If src_device is on the same task of target_device, the handle is a
|
||||
// local handle on the target device, which means the resource dtype and
|
||||
// shape are known on the target device.
|
||||
Device* src_device = absl::get<Device*>(h->device());
|
||||
Device* src_device = h->device();
|
||||
const bool serialize_resource_dtype_and_shape =
|
||||
(i == 0) && (h->dtype == DT_RESOURCE) &&
|
||||
(!ctx->OnSameTask(src_device, target_device));
|
||||
@ -341,7 +340,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle,
|
||||
TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle(
|
||||
h, /*wait_until_ready=*/true,
|
||||
op->add_handles()->mutable_remote_handle(), src_device,
|
||||
absl::get<Device*>(h->DeviceOrHostCPU(*ctx))->name(),
|
||||
h->DeviceOrHostCPU(*ctx)->name(),
|
||||
serialize_resource_dtype_and_shape));
|
||||
} else {
|
||||
return errors::InvalidArgument("Nested packed handles are not supported");
|
||||
|
@ -83,15 +83,8 @@ Status RemoteMgr::GetMirroredResourceShape(
|
||||
Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
|
||||
const bool wait_until_ready,
|
||||
int64* op_id, int32* output_num) {
|
||||
// TODO(allenl): Consider supporting remote handles on custom devices.
|
||||
VariantDevice device = handle->device();
|
||||
if (VariantDeviceIsCustom(device)) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices and remote execution are currently not supported "
|
||||
"together.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(handle->RemoteAddress(
|
||||
absl::get<Device*>(device), wait_until_ready, op_id, output_num));
|
||||
TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready,
|
||||
op_id, output_num));
|
||||
tensorflow::TensorHandle* h;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
|
||||
|
@ -268,14 +268,11 @@ class OpNode {
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
void ClearEagerInputs() {
|
||||
for (tensorflow::TensorHandle* h : *op_->MutableInputs()) {
|
||||
if (h) h->Unref();
|
||||
}
|
||||
op_->MutableInputs()->clear();
|
||||
}
|
||||
void ClearEagerInputs() { op_->Clear(); }
|
||||
|
||||
tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
|
||||
absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
|
||||
TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
|
||||
for (int i = 0; i < inputs_.Size(); ++i) {
|
||||
int input_index = inputs_.TfLiteIndex(i);
|
||||
TensorSource s = inputs_.GetTensorSource(i);
|
||||
@ -290,14 +287,14 @@ class OpNode {
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandle::CreateLocalHandle(
|
||||
buffer_map->GetTensor(input_index));
|
||||
op_->MutableInputs()->push_back(handle);
|
||||
op_inputs->push_back(handle);
|
||||
} else {
|
||||
// If this is a forwardable tensor, we will remove it from the previous
|
||||
// op's list, giving TF the opportunity to reuse its buffer.
|
||||
bool unref_handle = inputs_.IsForwardable(i);
|
||||
auto* handle =
|
||||
s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
|
||||
op_->MutableInputs()->push_back(handle);
|
||||
op_inputs->push_back(handle);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
|
@ -155,11 +155,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
|
||||
tensorflow::unwrap(ctx)->TFTensorHandleFromInterface(
|
||||
tensorflow::unwrap(EagerTensor_Handle(eager_tensor))));
|
||||
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
return errors::Unimplemented(
|
||||
"Custom devices are currently not supported with PyFuncs.");
|
||||
}
|
||||
Device* actual_device = absl::get<Device*>(handle->device());
|
||||
Device* actual_device = handle->device();
|
||||
TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
|
||||
// actual_device may be nullptr, which implies local CPU.
|
||||
if (expected_device == actual_device) return Status::OK();
|
||||
|
Loading…
Reference in New Issue
Block a user