From 253111e23b612585e61f28acaf1b6dd6a965249b Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 12 Jan 2021 11:03:38 -0800 Subject: [PATCH] Stop holding custom devices in TensorHandles My hope is to not change the custom device API much, just clean up the implementation. Previously TensorHandles held Tensors which held the void* custom device tensor handle data. This required a bunch of special cases, mostly because the TensorHandle's device wasn't a physical device. Now EagerOperations still accept custom device TensorHandles, but deal with them before execution (either by copying them off the custom device or by executing the operation on a custom device). This means the rest of the runtime can assume TensorHandles are on physical devices, and gives custom device tensor handles some freedom to evolve. Rolling cl/350684489 forward with a fix for making packed tensors from custom device tensors. This requires one new custom device method to maintain the previous (accidental) functionality. PiperOrigin-RevId: 351406348 Change-Id: I9c2ffd40a687b06434fab40e2db9e90129b9f2b7 --- tensorflow/c/eager/BUILD | 20 +- tensorflow/c/eager/c_api.cc | 241 ++++++++++++------ tensorflow/c/eager/c_api_experimental.cc | 17 +- tensorflow/c/eager/c_api_experimental.h | 16 +- tensorflow/c/eager/c_api_test.cc | 2 +- .../core/common_runtime/composite_device.cc | 8 - .../common_runtime/composite_device_test.cc | 15 -- tensorflow/core/common_runtime/eager/BUILD | 3 + tensorflow/core/common_runtime/eager/core.cc | 80 +++--- .../core/common_runtime/eager/custom_device.h | 29 ++- .../eager/custom_device_test.cc | 20 +- .../common_runtime/eager/eager_operation.cc | 89 ++++++- .../common_runtime/eager/eager_operation.h | 39 ++- .../core/common_runtime/eager/execute.cc | 103 ++++---- .../core/common_runtime/eager/execute_node.cc | 11 +- .../common_runtime/eager/execute_node_test.cc | 2 +- .../common_runtime/eager/placement_utils.cc | 26 +- .../common_runtime/eager/placement_utils.h | 3 +- .../common_runtime/eager/tensor_handle.cc | 139 +++------- .../core/common_runtime/eager/tensor_handle.h | 22 +- .../eager/tensor_handle_test.cc | 16 +- .../eager/eager_service_impl.cc | 2 +- .../eager/eager_service_impl_test.cc | 8 +- .../eager/remote_copy_node.cc | 15 +- .../distributed_runtime/eager/remote_mgr.cc | 11 +- tensorflow/lite/delegates/flex/kernel.cc | 13 +- tensorflow/python/lib/core/py_func.cc | 6 +- 27 files changed, 524 insertions(+), 432 deletions(-) diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 09d5e654107..0600253da3c 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -52,7 +52,6 @@ tf_cuda_library( ":immediate_execution_operation", ":immediate_execution_tensor_handle", ":immediate_execution_distributed_manager", - ":abstract_tensor_handle", ":tfe_context_internal", ":tfe_cancellation_manager_internal", ":tfe_executor_internal", @@ -73,6 +72,7 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:custom_device", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", @@ -86,6 +86,7 @@ tf_cuda_library( ], }) + [ "@com_google_absl//absl/memory", + ":abstract_tensor_handle", "//tensorflow/core/common_runtime/eager:eager_operation", "//tensorflow/core/distributed_runtime/eager:remote_mgr", "//tensorflow/core/distributed_runtime/eager:cluster_function_library_runtime", @@ -480,12 +481,17 @@ cc_library( visibility = [ "//tensorflow:internal", ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:refcount", - "//tensorflow/core/platform:status", - ], + deps = select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:status", + ], + }), ) cc_library( diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 056d17e9a4b..7896f8d2414 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" @@ -363,13 +364,21 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { status->status = tensorflow::errors::InvalidArgument("Invalid handle"); return nullptr; } - tensorflow::TensorHandle* handle = - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h)); - if (VariantDeviceIsCustom(handle->device())) { - const tensorflow::Tensor* t; - status->status = handle->Tensor(&t); - return t->data(); + tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = + tensorflow::unwrap(h); + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. + if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { + return tensorflow::down_cast( + unwrapped_handle) + ->DevicePointer(); } + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. + if (!tensorflow::TensorHandle::classof(unwrapped_handle)) { + status->status = tensorflow::errors::InvalidArgument("Invalid handle"); + return nullptr; + } + tensorflow::TensorHandle* handle = + tensorflow::TensorHandleFromInterface(unwrapped_handle); if (handle->Type() != tensorflow::TensorHandle::LOCAL) { status->status = tensorflow::errors::InvalidArgument( @@ -377,7 +386,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { handle->TypeString(), " tensor handle."); return nullptr; } - tensorflow::Device* device(absl::get(handle->device())); + tensorflow::Device* device(handle->device()); if (device != nullptr) { status->status = device->Sync(); if (!status->status.ok()) { @@ -393,6 +402,125 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) { static_cast(tensor->tensor_data().data())); } +namespace tensorflow { +namespace { +class CustomDeviceAPI : public tensorflow::CustomDevice { + public: + CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, + string name) + : context_(context), device_(device), info_(info), name_(name) {} + + ~CustomDeviceAPI() override { device_.delete_device(info_); } + + const string& name() override { return name_; } + + tensorflow::Status CopyTensorToDevice( + ImmediateExecutionTensorHandle* handle, + ImmediateExecutionTensorHandle** result) override { + handle->Ref(); + TF_Status status; + TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( + context_, tensorflow::wrap(handle), &status, info_); + handle->Release(); + if (!status.status.ok()) return status.status; + *result = tensorflow::unwrap(result_handle); + (*result)->Ref(); + TFE_DeleteTensorHandle(result_handle); + return status.status; + } + + tensorflow::Status CopyTensorFromDevice( + ImmediateExecutionTensorHandle* handle, + const tensorflow::string& target_device_name, + ImmediateExecutionTensorHandle** result) override { + TF_Status status; + handle->Ref(); + TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( + context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, + info_); + handle->Release(); + if (!status.status.ok()) return status.status; + *result = tensorflow::unwrap(result_handle); + (*result)->Ref(); + TFE_DeleteTensorHandle(result_handle); + return status.status; + } + + tensorflow::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) override { + std::vector outputs(*num_retvals); + TF_Status status; + device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, + info_); + if (status.status.ok()) { + for (int i = 0; i < *num_retvals; ++i) { + retvals[i] = tensorflow::unwrap(outputs[i]); + retvals[i]->Ref(); + TFE_DeleteTensorHandle(outputs[i]); + } + } + return status.status; + } + + tensorflow::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) override { + TF_Status status; + *result = tensorflow::unwrap(device_.pack(context_, + tensorflow::wrap(handles.data()), + handles.size(), &status, info_)); + return status.status; + } + + private: + TFE_Context* context_; + TFE_CustomDevice device_; + void* info_; + string name_; +}; + +// An adapter which wraps the shape/data produced by C custom devices and uses +// it to implement custom device methods. +class CAPICustomDeviceTensorHandle + : public tensorflow::CustomDeviceTensorHandle { + public: + CAPICustomDeviceTensorHandle(tensorflow::ImmediateExecutionContext* context, + tensorflow::CustomDevice* device, + tensorflow::DataType dtype, void* data, + size_t len, std::vector shape, + void (*deallocator)(void* data, size_t len, + void* arg), + void* deallocator_arg) + : tensorflow::CustomDeviceTensorHandle(context, device, dtype), + data_(data), + len_(len), + shape_(shape), + deallocator_(deallocator), + deallocator_arg_(deallocator_arg) {} + ~CAPICustomDeviceTensorHandle() override { + deallocator_(data_, len_, deallocator_arg_); + } + void* DevicePointer() const override { return data_; } + Status NumDims(int* num_dims) const override { + *num_dims = shape_.size(); + return Status::OK(); + } + Status Dim(int dim_index, int64* dim) const override { + *dim = shape_[dim_index]; + return Status::OK(); + } + + private: + void* const data_; + size_t len_; + std::vector shape_; + void (*const deallocator_)(void* data, size_t len, void* arg); + void* const deallocator_arg_; +}; + +} // namespace +} // namespace tensorflow + TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( TFE_Context* ctx, const char* device_name, TF_DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, @@ -417,6 +545,12 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( for (int i = 0; i < num_dims; ++i) { dimvec[i] = static_cast(dims[i]); } + if (custom_device != nullptr) { + return tensorflow::wrap(new tensorflow::CAPICustomDeviceTensorHandle( + context, custom_device, + *reinterpret_cast(&dtype), data, len, dimvec, + deallocator, deallocator_arg)); + } // TODO(apassos) do we need to wrap the deallocator here to make sure to sync // the device? @@ -427,13 +561,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( tensorflow::Tensor t(static_cast(dtype), tensorflow::TensorShape(dimvec), buf); buf->Unref(); - if (custom_device == nullptr) { - return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), device, device, context)); - } else { - return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( - std::move(t), custom_device, context)); - } + return tensorflow::wrap(tensorflow::TensorHandle::CreateLocalHandle( + std::move(t), device, device, context)); } // This function will block till the operation that produces `h` has @@ -961,74 +1090,14 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op, } // namespace tensorflow namespace { -class CustomDeviceAPI : public tensorflow::CustomDevice { - public: - CustomDeviceAPI(TFE_Context* context, TFE_CustomDevice device, void* info, - string name) - : context_(context), device_(device), info_(info), name_(name) {} - - ~CustomDeviceAPI() override { device_.delete_device(info_); } - - const string& name() override { return name_; } - - tensorflow::Status CopyTensorToDevice( - tensorflow::TensorHandle* handle, - tensorflow::TensorHandle** result) override { - handle->Ref(); - TF_Status status; - TFE_TensorHandle* result_handle = device_.copy_tensor_to_device( - context_, tensorflow::wrap(handle), &status, info_); - handle->Release(); - if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(result_handle)); - (*result)->Ref(); - TFE_DeleteTensorHandle(result_handle); - return status.status; - } - - tensorflow::Status CopyTensorFromDevice( - tensorflow::TensorHandle* handle, - const tensorflow::string& target_device_name, - tensorflow::TensorHandle** result) override { - TF_Status status; - handle->Ref(); - TFE_TensorHandle* result_handle = device_.copy_tensor_from_device( - context_, tensorflow::wrap(handle), target_device_name.c_str(), &status, - info_); - handle->Release(); - if (!status.status.ok()) return status.status; - *result = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(result_handle)); - (*result)->Ref(); - TFE_DeleteTensorHandle(result_handle); - return status.status; - } - - tensorflow::Status Execute(const tensorflow::EagerOperation* op, - tensorflow::TensorHandle** retvals, - int* num_retvals) override { - std::vector outputs(*num_retvals); - TF_Status status; - device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, - info_); - if (status.status.ok()) { - for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = tensorflow::TensorHandleFromInterface( - tensorflow::unwrap(outputs[i])); - retvals[i]->Ref(); - TFE_DeleteTensorHandle(outputs[i]); - } - } - return status.status; - } - - private: - TFE_Context* context_; - TFE_CustomDevice device_; - void* info_; - string name_; -}; +TFE_TensorHandle* DefaultCustomDevicePack(TFE_Context* context, + TFE_TensorHandle** handles, + int num_handles, TF_Status* status, + void* device_info) { + TF_SetStatus(status, TF_UNIMPLEMENTED, + "This custom device does not support packing tensors."); + return nullptr; +} } // namespace extern "C" { @@ -1036,8 +1105,12 @@ extern "C" { void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, const char* device_name, void* device_info, TF_Status* status) { - auto custom_device = - std::make_unique(ctx, device, device_info, device_name); + // Fill in default values for optional functionality. + if (device.pack == nullptr) { + device.pack = &DefaultCustomDevicePack; + } + auto custom_device = std::make_unique( + ctx, device, device_info, device_name); tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); status->status = diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 90e9cdc162d..e8cdc61ba93 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -613,8 +613,23 @@ TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, std::vector tensor_handles; tensor_handles.reserve(*num_handles); for (int i = 0; i < *num_handles; ++i) { + tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = + tensorflow::unwrap(handles[i]); + if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { + // One of the inputs we're trying to pack is on a custom device. We'll let + // the first custom device we see handle all of the packing. + auto* custom_device_handle = + tensorflow::down_cast( + unwrapped_handle); + tensorflow::ImmediateExecutionTensorHandle* result; + status->status = custom_device_handle->device()->Pack( + absl::Span( + tensorflow::unwrap(handles), *num_handles), + &result); + return tensorflow::wrap(result); + } tensor_handles.push_back( - tensorflow::TensorHandleFromInterface(tensorflow::unwrap(handles[i]))); + tensorflow::TensorHandleFromInterface(unwrapped_handle)); } tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 30044244acf..b2850b0a076 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -435,16 +435,16 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op, // to have a non-string representation of devices (TF_Device) extracted from // tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc. -#define TFE_CUSTOM_DEVICE_VERSION 3 +#define TFE_CUSTOM_DEVICE_VERSION 4 -// Struct to be filled in +// Struct to be filled in. Functions are required except where indicated. typedef struct TFE_CustomDevice { int version = TFE_CUSTOM_DEVICE_VERSION; // Method to copy a tensor to the custom device. TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status, - void* device_info) = nullptr; + void* device_info); // Method to copy a tensor from the custom device to a target device. TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context, @@ -468,6 +468,16 @@ typedef struct TFE_CustomDevice { // Method to delete a device. void (*delete_device)(void* device_info); + + // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this + // custom device. + // + // Many devices will want to simply return an "unimplemented" status + // here. This is the default behavior if `pack` is null when passed to + // TFE_RegisterCustomDevice. + TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles, + int num_handles, TF_Status* s, + void* device_info) = nullptr; } TFE_CustomDevice; // Registers a custom device for use with eager execution. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3037669ac9d..c08ed5b6cc5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -424,7 +424,7 @@ void TensorHandleSilentCopy(bool async, tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hcpu)); auto gpu_arg = tensorflow::TensorHandleFromInterface(tensorflow::unwrap(hgpu)); - auto gpu_device = absl::get(gpu_arg->device()); + auto gpu_device = gpu_arg->device(); ASSERT_FALSE(cpu_arg->HasLocalMirror(gpu_device)); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); diff --git a/tensorflow/core/common_runtime/composite_device.cc b/tensorflow/core/common_runtime/composite_device.cc index d4548946cbf..afde64faf17 100644 --- a/tensorflow/core/common_runtime/composite_device.cc +++ b/tensorflow/core/common_runtime/composite_device.cc @@ -40,14 +40,6 @@ std::unique_ptr CompositeDevice::MakeDevice( errors::InvalidArgument("underlying_devices should not be empty.")); return nullptr; } - std::set unique_devices; - for (const string& device : underlying_devices) { - if (!unique_devices.insert(device).second) { - status->Update(errors::InvalidArgument( - "Got a duplicated device in underlying_devices: ", device)); - return nullptr; - } - } DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseFullName(underlying_devices.at(0), &parsed_name)) { status->Update(tensorflow::errors::InvalidArgument( diff --git a/tensorflow/core/common_runtime/composite_device_test.cc b/tensorflow/core/common_runtime/composite_device_test.cc index 7d195a7a08e..219fcba954b 100644 --- a/tensorflow/core/common_runtime/composite_device_test.cc +++ b/tensorflow/core/common_runtime/composite_device_test.cc @@ -50,21 +50,6 @@ TEST(CompositeDeviceTest, Basic) { EXPECT_EQ(underlying_devices, *composite_device->underlying_devices()); } - { - Status status; - underlying_devices.push_back( - "/job:localhost/replica:0/task:0/device:CPU:0"); - std::unique_ptr composite_device = - CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/1, - parsed_host_name, &status); - EXPECT_EQ(composite_device, nullptr); - EXPECT_EQ(error::INVALID_ARGUMENT, status.code()); - EXPECT_TRUE( - absl::StrContains(status.error_message(), "Got a duplicated device")) - << status.ToString(); - underlying_devices.pop_back(); - } - { Status status; underlying_devices.push_back( diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index b9c2d3caaf2..164c9d87d57 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -124,6 +124,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/core/lib/core:status", ], }), @@ -228,6 +229,7 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":attr_builder", + ":custom_device", ":context", ":eager_executor", ":kernel_and_device", @@ -631,6 +633,7 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":context", + ":custom_device", ":attr_builder", ":eager_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index d80952a05a1..81b1e3594f2 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -24,11 +24,7 @@ limitations under the License. namespace { -bool IsCPU(tensorflow::VariantDevice variant) { - if (VariantDeviceIsCustom(variant)) { - return false; - } - tensorflow::Device* d = absl::get(variant); +bool IsCPU(tensorflow::Device* d) { return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; } @@ -43,20 +39,6 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { if (!status->ok()) { return nullptr; } - if (VariantDeviceIsCustom(device())) { - auto* custom_device = absl::get(device()); - TensorHandle* copy; - *status = custom_device->CopyTensorFromDevice(this, ctx_->HostCPU()->name(), - ©); - if (status->ok()) { - auto result = copy->Resolve(status); - copy->Unref(); - return result; - } else { - return nullptr; - } - } - if (Type() == REMOTE) { const tensorflow::Tensor* t = nullptr; TensorHandle* h_cpu = nullptr; @@ -124,14 +106,13 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* handle, const char* device_name, Status* status) { - TensorHandle* input = TensorHandleFromInterface(handle); - TensorHandle* result = nullptr; + ImmediateExecutionTensorHandle* result = nullptr; Device* device; *status = this->FindDeviceFromName(device_name, &device); if (!status->ok()) { tensorflow::CustomDevice* dev; if (this->FindCustomDeviceFromName(device_name, &dev)) { - *status = dev->CopyTensorToDevice(input, &result); + *status = dev->CopyTensorToDevice(handle, &result); if (status->ok()) { return result; } @@ -142,13 +123,13 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( return nullptr; } // Handle tensor handles currently in custom devices - const char* handle_device_name = input->DeviceName(status); + const char* handle_device_name = handle->DeviceName(status); if (!status->ok()) { return nullptr; } tensorflow::CustomDevice* dev; if (this->FindCustomDeviceFromName(handle_device_name, &dev)) { - *status = dev->CopyTensorFromDevice(input, device_name, &result); + *status = dev->CopyTensorFromDevice(handle, device_name, &result); if (status->ok()) { return result; } @@ -156,8 +137,10 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( } // Handle regular case. + TensorHandle* input = TensorHandleFromInterface(handle); *status = - EagerCopyToDevice(input, this, &this->Executor(), device, false, &result); + EagerCopyToDevice(input, this, &this->Executor(), device, false, + reinterpret_cast(&result)); if (status->ok()) { return result; } @@ -213,16 +196,38 @@ Status EagerContext::RegisterFunction(AbstractFunction* f) { // eager_operation.cc we can avoid a circular dependency between them. Status EagerOperation::Execute(absl::Span retvals, int* num_retvals) { - for (int i = 0; i < Inputs().size(); ++i) { - TF_RETURN_IF_ERROR(Inputs()[i]->WaitUnknownDevice()); + for (ImmediateExecutionTensorHandle* handle : inputs_) { + if (TensorHandle::classof(handle)) { + TF_RETURN_IF_ERROR(down_cast(handle)->WaitUnknownDevice()); + } } + + // Decide to either run the operation on a custom device or copy off all of + // the custom device inputs. + VariantDevice maybe_custom_device = Device(); + if (absl::holds_alternative(maybe_custom_device) || + !inputs_are_tensor_handles_) { + // If the op wasn't placed on a custom device explicitly and there are no + // non-TensorHandle inputs, the op will definitely be placed on a physical + // device. Otherwise we need to check the inputs one by one. + TF_RETURN_IF_ERROR( + eager::MaybePinToCustomDevice(&maybe_custom_device, *this)); + if (absl::holds_alternative(maybe_custom_device)) { + ImmediateExecutionTensorHandle** retval_array = + reinterpret_cast(retvals.data()); + return absl::get(maybe_custom_device) + ->Execute(this, retval_array, num_retvals); + } else { + TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs()); + } + } + // Run eager placement logic. - VariantDevice device; - TF_RETURN_IF_ERROR(eager::MaybePinToCustomDevice(&device, *this)); - if (device == kVariantDeviceNull) { + class Device* device = absl::get(maybe_custom_device); + if (device == nullptr) { TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this)); } - if (device == kVariantDeviceNull && ctx_.PinSmallOpsToCPU()) { + if (device == nullptr && ctx_.PinSmallOpsToCPU()) { bool pin_to_cpu; TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu( &pin_to_cpu, Name(), GetInputs(), ctx_.HostCPU()->name())); @@ -231,16 +236,13 @@ Status EagerOperation::Execute(absl::Span retvals, } } - tensorflow::TensorHandle** retval_array = - reinterpret_cast(retvals.data()); - if (VariantDeviceIsCustom(device)) { - return absl::get(device)->Execute(this, retval_array, - num_retvals); - } - - if (device != kVariantDeviceNull) { + if (device != nullptr) { SetDevice(device); } + // At this point all inputs and outputs are TensorHandles associated with + // physical devices. + tensorflow::TensorHandle** retval_array = + reinterpret_cast(retvals.data()); return EagerExecute(this, retval_array, num_retvals); } diff --git a/tensorflow/core/common_runtime/eager/custom_device.h b/tensorflow/core/common_runtime/eager/custom_device.h index e3168b6265b..337a78b28b7 100644 --- a/tensorflow/core/common_runtime/eager/custom_device.h +++ b/tensorflow/core/common_runtime/eager/custom_device.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/device_name_utils.h" @@ -26,6 +27,7 @@ namespace tensorflow { class TensorHandle; class EagerOperation; +class CustomDeviceTensorHandle; // Custom devices intercept the execution of operations (the `Execute` method), // typically implemented with one or more of the custom device's own executions. @@ -33,15 +35,22 @@ class CustomDevice { public: virtual ~CustomDevice() {} virtual const string& name() = 0; - virtual Status CopyTensorToDevice(TensorHandle* tensor, - TensorHandle** result) = 0; + virtual Status CopyTensorToDevice( + ImmediateExecutionTensorHandle* tensor, + ImmediateExecutionTensorHandle** result) = 0; - virtual Status CopyTensorFromDevice(TensorHandle* tensor, - const string& target_device_name, - TensorHandle** result) = 0; + virtual Status CopyTensorFromDevice( + ImmediateExecutionTensorHandle* tensor, const string& target_device_name, + ImmediateExecutionTensorHandle** result) = 0; - virtual Status Execute(const EagerOperation* op, TensorHandle** retvals, + virtual Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, int* num_retvals) = 0; + + // Creates a packed TensorHandle from a group of custom device TensorHandles, + // one of which is on this custom device. + virtual Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) = 0; }; // Custom devices do many of the same things as physical Devices, but have a @@ -49,6 +58,10 @@ class CustomDevice { // operations may be placed either on custom or physical devices. using VariantDevice = absl::variant; +// Indicates either HostCPU or an unset physical device. We never set a null +// CustomDevice*. +const VariantDevice kVariantDeviceNull = static_cast(nullptr); + // A tensor handle produced by a custom device. Generally they can only be // consumed by executing an operation on the same custom device that produced it // originally, or by attempting to copy the handle off the custom device. @@ -65,6 +78,10 @@ class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { device_(device), dtype_(dtype) {} + // TODO(allenl): Should this be a generic method of + // ImmediateExecutionTensorHandle to support TFE_TensorHandleDevicePointer? + virtual void* DevicePointer() const = 0; + tensorflow::DataType DataType() const override { return dtype_; } Status Shape(PartialTensorShape* shape) const override; Status NumElements(int64* num_elements) const override; diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc index 9f512ea0828..772a21f0276 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_test.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -28,24 +28,31 @@ class TestCustomDevice : public CustomDevice { public: explicit TestCustomDevice(std::string name) : name_(name) {} const std::string& name() override { return name_; } - Status CopyTensorToDevice(TensorHandle* tensor, - TensorHandle** result) override { + Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor, + ImmediateExecutionTensorHandle** result) override { tensor->Ref(); *result = tensor; return Status::OK(); } - Status CopyTensorFromDevice(TensorHandle* tensor, - const std::string& target_device_name, - TensorHandle** result) override { + Status CopyTensorFromDevice( + ImmediateExecutionTensorHandle* tensor, + const std::string& target_device_name, + ImmediateExecutionTensorHandle** result) override { tensor->Ref(); *result = tensor; return Status::OK(); } - Status Execute(const EagerOperation* op, TensorHandle** retvals, + Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, int* num_retvals) override { return errors::Unimplemented("Not implemented"); } + Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) override { + return errors::Unimplemented("Packing is not implemented"); + } + private: std::string name_; }; @@ -57,6 +64,7 @@ class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle { tensorflow::DataType dtype) : CustomDeviceTensorHandle(context, device, dtype) {} + void* DevicePointer() const override { return nullptr; } Status NumDims(int* num_dims) const override { *num_dims = 1; return Status::OK(); diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 41ab54a91e9..883e9a8a8b0 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" @@ -31,10 +32,11 @@ namespace tensorflow { // Clear(), and then Reset(...) with the same arguments that would have // been provided to the constructor. void EagerOperation::Clear() { - for (TensorHandle* h : inputs_) { + for (ImmediateExecutionTensorHandle* h : inputs_) { h->Unref(); } inputs_.clear(); + inputs_are_tensor_handles_ = true; ClearInferenceState(); } @@ -263,7 +265,12 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) { } Status EagerOperation::AddInput(AbstractTensorHandle* input) { - TensorHandle* h = TensorHandleFromInterface(input); + ImmediateExecutionTensorHandle* h = + down_cast(input); + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. + if (CustomDeviceTensorHandle::classof(h)) { + inputs_are_tensor_handles_ = false; + } AddTensorHandle(h); return MaybeInferSingleInputAttrs(h); } @@ -271,7 +278,13 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) { Status EagerOperation::AddInputList( absl::Span inputs) { for (auto& input : inputs) { - TensorHandle* h = TensorHandleFromInterface(input); + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(input)) { + inputs_are_tensor_handles_ = false; + } + ImmediateExecutionTensorHandle* h = + down_cast(input); AddTensorHandle(h); } return InferInputListAttrs(inputs.size()); @@ -317,7 +330,8 @@ Status EagerOperation::Reset( return SetDeviceName(device_name); } -Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) { +Status EagerOperation::MaybeInferSingleInputAttrs( + ImmediateExecutionTensorHandle* handle) { if (!op_def_) return Status::OK(); const auto& input_def = op_def_->input_arg(inference_arg_idx_++); @@ -334,7 +348,7 @@ Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) { const std::string& type_attr = input_def.type_attr(); if (!type_attr.empty() && inference_attrs_.find(type_attr) == inference_attrs_.end()) { - MutableAttrs()->Set(type_attr, handle->dtype); + MutableAttrs()->Set(type_attr, handle->DataType()); inference_attrs_.insert(type_attr); } return Status::OK(); @@ -372,12 +386,13 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { if (!input_def.type_list_attr().empty()) { std::vector dtypes(num_inputs); for (int i = 0; i < num_inputs; ++i) { - dtypes[i] = inputs_[start + i]->dtype; + dtypes[i] = inputs_[start + i]->DataType(); } InferMixedTypeInputListAttrs(input_def, dtypes); } else if (!input_def.type_attr().empty() && !input_def.number_attr().empty()) { - InferSingleTypeInputListAttrs(input_def, inputs_[start]->dtype, num_inputs); + InferSingleTypeInputListAttrs(input_def, inputs_[start]->DataType(), + num_inputs); } else if (!input_def.number_attr().empty()) { if (inference_attrs_.find(input_def.number_attr()) == inference_attrs_.end()) { @@ -390,6 +405,28 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { return Status::OK(); } +Status EagerOperation::TensorHandleInputs( + const absl::InlinedVector** inputs) const { + if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + *inputs = reinterpret_cast*>( + &inputs_); + return Status::OK(); + } else { + return errors::Internal("The operation unexpectedly had custom devices."); + } +} + +Status EagerOperation::MutableTensorHandleInputs( + absl::InlinedVector** inputs) { + if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + *inputs = + reinterpret_cast*>(&inputs_); + return Status::OK(); + } else { + return errors::Internal("The operation unexpectedly had custom devices."); + } +} + Status EagerOperation::SetDeviceName(const char* c_name) { string name(c_name != nullptr ? c_name : ""); if (name != last_set_device_name_) { @@ -423,6 +460,16 @@ bool EagerOperation::IsLocal() const { device_parsed_name_.task == host_cpu_name.task; } +string VariantDeviceDebugString(VariantDevice device) { + if (device == kVariantDeviceNull) { + return "[]"; + } else if (absl::holds_alternative(device)) { + return absl::get(device)->name(); + } else { + return absl::get(device)->DebugString(); + } +} + string EagerOperation::DebugString() const { string out; VLOG(1) << "EagerOperation::DebugString() over " << this; @@ -442,10 +489,36 @@ string EagerOperation::DebugString() const { return out; } -void EagerOperation::AddTensorHandle(TensorHandle* h) { +void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) { h->Ref(); inputs_.push_back(h); attrs_.NumInputs(static_cast(inputs_.size())); } +Status EagerOperation::CopyOffCustomDeviceInputs() { + if (absl::holds_alternative(device_)) { + return errors::Internal( + "Trying to copy inputs to a custom device op off a custom device."); + } + for (int i = 0; i < inputs_.size(); ++i) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(inputs_[i])) { + CustomDeviceTensorHandle* previous = + down_cast(inputs_[i]); + class Device* target_device; + if (device_ == kVariantDeviceNull) { + target_device = ctx_.HostCPU(); + } else { + target_device = absl::get(device_); + } + TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( + previous, target_device->name(), &inputs_[i])); + previous->Unref(); + } + } + inputs_are_tensor_handles_ = true; + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 5463158ae61..e440a4a79dd 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -39,7 +39,7 @@ class EagerOperation : public ImmediateExecutionOperation { explicit EagerOperation(tensorflow::EagerContext* ctx) : ImmediateExecutionOperation(kEager), ctx_(*ctx) {} ~EagerOperation() override { - for (TensorHandle* h : inputs_) { + for (ImmediateExecutionTensorHandle* h : inputs_) { h->Unref(); } } @@ -69,8 +69,9 @@ class EagerOperation : public ImmediateExecutionOperation { void SetDevice(VariantDevice device) { device_ = device; - device_name_ = - device == kVariantDeviceNull ? "" : VariantDeviceName(device); + device_name_ = absl::visit( + [](auto* device) { return device == nullptr ? "" : device->name(); }, + device); DeviceNameUtils::ParseFullName(device_name_, &device_parsed_name_); // TODO(b/154133594): Due to intricacies of external logic, we can not // set this do device_name_ as it would be natural, because we need the @@ -141,10 +142,18 @@ class EagerOperation : public ImmediateExecutionOperation { AttrBuilder* MutableAttrs() { return &attrs_; } const AttrBuilder& Attrs() const { return attrs_; } - const absl::InlinedVector& Inputs() const { + // TensorHandleInputs and MutableTensorHandleInputs first check that all + // inputs are TensorHandles, i.e. that there are no custom device inputs. They + // return a bad status otherwise. + Status TensorHandleInputs( + const absl::InlinedVector** inputs) const; + Status MutableTensorHandleInputs( + absl::InlinedVector** inputs); + + const absl::InlinedVector& Inputs() + const { return inputs_; } - absl::InlinedVector* MutableInputs() { return &inputs_; } void UpdateInput(int i, TensorHandle* h); @@ -180,7 +189,7 @@ class EagerOperation : public ImmediateExecutionOperation { } private: - void AddTensorHandle(TensorHandle* h); + void AddTensorHandle(ImmediateExecutionTensorHandle* h); const tensorflow::OpDef* GetOpDef(Status* status); @@ -190,7 +199,7 @@ class EagerOperation : public ImmediateExecutionOperation { inference_attrs_.clear_no_resize(); } - Status MaybeInferSingleInputAttrs(TensorHandle* handle); + Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle); Status InferInputListAttrs(int num_inputs); void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, @@ -198,11 +207,21 @@ class EagerOperation : public ImmediateExecutionOperation { void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, const std::vector& dtypes); + // Replaces input tensors placed on custom devices with physical device + // equivalents. Used if an op is placed on a physical device but may have + // custom device inputs. + Status CopyOffCustomDeviceInputs(); + tensorflow::EagerContext& ctx_; const char* op_name_ = nullptr; AttrBuilder attrs_; const AttrTypeMap* attr_types_; - absl::InlinedVector inputs_; + + // Toggled to indicate whether all inputs are known to be TensorHandles and + // not another type (e.g. custom device tensor handles). Explicitly set to + // false when custom device TensorHandles are added. + bool inputs_are_tensor_handles_ = true; + absl::InlinedVector inputs_; // The last device name given to SetDeviceName. // This is used to avoid having to re-process the same device in repeated @@ -240,8 +259,8 @@ class EagerOperation : public ImmediateExecutionOperation { }; inline void EagerOperation::UpdateInput(int i, TensorHandle* h) { - TensorHandle** slot = &inputs_[i]; - TensorHandle* existing = *slot; + ImmediateExecutionTensorHandle** slot = &inputs_[i]; + ImmediateExecutionTensorHandle* existing = *slot; if (existing != h) { h->Ref(); existing->Unref(); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 7163944c3e2..248f1aadc13 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -81,14 +81,6 @@ const string& DeviceNameOrUnspecified(Device* device) { return (device == nullptr) ? *unspecified_string : device->name(); } -const string& DeviceNameOrUnspecified(VariantDevice device) { - if (VariantDeviceIsCustom(device)) { - return absl::get(device)->name(); - } else { - return DeviceNameOrUnspecified(absl::get(device)); - } -} - // Returns whether a kernel should be cached. bool KernelCacheEnabled(const OpDef& op_def) { if (data::DatasetOpKernel::IsDatasetOp(&op_def)) { @@ -200,9 +192,10 @@ Status ValidateInputTypeAndPlacement( const bool is_function = kernel->IsFunction(); if (n_inputs > 0) { const DataType* input_types = &kernel->input_dtypes()[0]; - TensorHandle* const* handles = &op->Inputs()[0]; + const absl::InlinedVector* handles; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&handles)); for (int i = 0; i < n_inputs; ++i) { - TensorHandle* handle = handles[i]; + TensorHandle* handle = (*handles)[i]; Device* expected_device = kernel->InputDevice(i); if (!kernel->IsFunction() && handle->Type() == TensorHandle::PACKED) { // Extract a handle on the op device from a packed input. @@ -220,13 +213,7 @@ Status ValidateInputTypeAndPlacement( } } } - auto handle_device_variant = handle->DeviceOrHostCPU(*ctx); - if (VariantDeviceIsCustom(handle_device_variant)) { - return errors::Unimplemented( - "Custom devices and remote execution are not yet supported " - "together."); - } - Device* handle_device = absl::get(handle_device_variant); + Device* handle_device = handle->DeviceOrHostCPU(*ctx); const bool maybe_copy = !is_function || handle->Type() != TensorHandle::REMOTE; // If the input is already on the right device, then nothing to do. @@ -280,14 +267,10 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a, Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle, Device** result) { - if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) { - return errors::Unimplemented( - "The kernel cache does not work with custom devices."); - } Device* cpu_device = ctx.HostCPU(); string device_name; if (tensor_handle->Type() != TensorHandle::LOCAL) { - Device* device = absl::get(tensor_handle->device()); + Device* device = tensor_handle->device(); device_name = device != nullptr ? device->name() : cpu_device->name(); *result = (device == nullptr ? cpu_device : device); } else if (tensor_handle->dtype == DT_RESOURCE) { @@ -304,7 +287,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle, ctx.FindDeviceFromName(device_name.c_str(), &input_device)); *result = input_device; } else { - Device* device = absl::get(tensor_handle->device()); + Device* device = tensor_handle->device(); const bool is_tpu = device != nullptr && device->device_type() == "TPU"; // int32 return values can be placed on TPUs. const bool use_host_memory = @@ -431,8 +414,10 @@ Status GetOrCreateKernelAndDevice( profiler::TraceMe activity("EagerCopyToDeviceAndAddCacheKey", profiler::TraceMeLevel::kInfo); input_dev_ptrs.reserve(op->Inputs().size()); - for (int i = 0, end = op->Inputs().size(); i < end; i++) { - TensorHandle* input = op->Inputs()[i]; + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); + for (int i = 0, end = inputs->size(); i < end; i++) { + TensorHandle* input = (*inputs)[i]; // Get device for this input, and add it to 'cache_key'. Device* input_device; @@ -477,7 +462,7 @@ Status GetOrCreateKernelAndDevice( core::RefCountPtr kernel = ctx.GetCachedKernel(cache_key); if (kernel == nullptr) { DVLOG(2) << "Creating new kernel for " << op->Name() << " on device " - << DeviceNameOrUnspecified(op->Device()); + << DeviceNameOrUnspecified(absl::get(op->Device())); bool run_function_with_flr = false; bool function_outputs_on_op_device = false; if (op->is_function()) { @@ -656,9 +641,11 @@ Status AddOrExecuteNode(core::RefCountPtr kernel, remote_func_params, &ctx, &retvals[i])); } } + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); auto node = absl::make_unique( - &ctx, op->Inputs(), remote_func_params, std::move(kernel), - graph_collector, op->GetCancellationManager(), + &ctx, *inputs, remote_func_params, std::move(kernel), graph_collector, + op->GetCancellationManager(), absl::Span(retvals, num_outputs), op->GetStackTrace()); // Release the inputs from the eager operation since the AsyncExecuteNode // would have taken ownership. This allows the inputs to be forwarded if @@ -673,8 +660,10 @@ Status AddOrExecuteNode(core::RefCountPtr kernel, for (int i = 0, end = num_outputs; i < end; ++i) { retvals[i] = nullptr; } - ExecuteNode node(&ctx, op->Inputs(), remote_func_params, kernel, - graph_collector, op->GetCancellationManager(), + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); + ExecuteNode node(&ctx, *inputs, remote_func_params, kernel, graph_collector, + op->GetCancellationManager(), {retvals, static_cast(num_outputs)}); Status s = executor.SyncExecute(&node); // We release the inputs AFTER executing the operation in sync mode since @@ -764,8 +753,10 @@ Status MaybePackInputTensor(EagerOperation* op) { return Status::OK(); } EagerContext& ctx = op->EagerContext(); - for (int i = 0; i < op->Inputs().size(); ++i) { - TensorHandle* handle = op->Inputs()[i]; + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); + for (int i = 0; i < inputs->size(); ++i) { + TensorHandle* handle = (*inputs)[i]; if (handle->Type() == TensorHandle::PACKED) { EagerOperation pack_op(&ctx); TF_RETURN_IF_ERROR(pack_op.Reset("Pack", /*device_name=*/nullptr, @@ -842,7 +833,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, if (!DeviceNameUtils::GetTaskName(op->GetDeviceParsedName(), &remote_task)) { return errors::InvalidArgument( "Unable to find remote task corresponding to device ", - VariantDeviceName(op->Device())); + op->DeviceName()); } std::unique_ptr request(new eager::EnqueueRequest); @@ -855,11 +846,12 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, profiler::TraceMe activity("CopyInputToExpectedDevice", profiler::TraceMeLevel::kInfo); const bool is_function = op->is_function(); - for (int i = 0, end = op->Inputs().size(); i < end; i++) { - tensorflow::TensorHandle* input = op->Inputs()[i]; - tensorflow::Device* input_device = absl::get(input->device()); - tensorflow::Device* input_device_or_cpu = - absl::get(input->DeviceOrHostCPU(ctx)); + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); + for (int i = 0, end = inputs->size(); i < end; i++) { + tensorflow::TensorHandle* input = (*inputs)[i]; + tensorflow::Device* input_device = input->device(); + tensorflow::Device* input_device_or_cpu = input->DeviceOrHostCPU(ctx); const string* input_device_name = &input_device_or_cpu->name(); bool serialize_resource_dtype_and_shape = false; if (op_device != input_device && @@ -876,9 +868,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, // Always copy to the remote CPU so that the actual device can be // correctly determined after the kernel is selected/instantiated, // since the op might have its inputs on host memory. - TensorHandle* handle = op->Inputs()[i]; - Device* handle_device = - absl::get(handle->DeviceOrHostCPU(ctx)); + TensorHandle* handle = input; + Device* handle_device = handle->DeviceOrHostCPU(ctx); // If the input is already on the right device, then nothing to do. if (remote_cpu_device != handle_device) { TF_RETURN_IF_ERROR(CopyInputToExpectedDevice( @@ -959,11 +950,14 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, DVLOG(4) << "Execute remote eager op: " << op->Name() << " (is async?: " << executor.Async() << ")."; + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op->TensorHandleInputs(&inputs)); + std::unique_ptr node(new eager::RemoteExecuteNode( &op->EagerContext(), std::move(request), op_device, ctx.GetContextViewId(), eager_client.get(), op->GetCancellationManager(), op->MutableAttrs()->BuildNodeDef(), op->EagerContext().FuncLibDef(), - op->Inputs(), {retvals, num_outputs})); + *inputs, {retvals, num_outputs})); if (op->EagerContext().LogDevicePlacement() || VLOG_IS_ON(1)) { string msg = strings::StrCat( @@ -1020,7 +1014,7 @@ Status GetKernelOutputs( "kernel. This should never happen."); } if (TF_PREDICT_FALSE(ctx->CanonicalDevice(kernel->OutputDevice(i)) != - absl::get(retvals[i]->device()))) { + retvals[i]->device())) { return errors::Internal( "Kernel output tensor handle locates on a different device than " "the specified kernel output device. This should never happen."); @@ -1037,8 +1031,8 @@ Status GetKernelOutputs( "Remote outputs are not available on mobile devices."); #else // !IS_MOBILE_PLATFORM TF_RETURN_IF_ERROR(retvals[i]->SetRemoteShape( - absl::get(ret), - absl::get(retvals[i]->device()), ctx->GetContextViewId())); + absl::get(ret), retvals[i]->device(), + ctx->GetContextViewId())); #endif // !IS_MOBILE_PLATFORM } } @@ -1218,11 +1212,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, TensorHandle** result) { TF_RETURN_IF_ERROR(h->WaitUnknownDevice()); auto send_device = h->DeviceOrHostCPU(*ctx); - if (VariantDeviceIsCustom(send_device)) { - return errors::Unimplemented( - "Copying a TensorHandle from a custom device is not supported."); - } - bool sender_is_local = absl::get(send_device)->IsLocal(); + bool sender_is_local = send_device->IsLocal(); bool receiver_is_local = device->IsLocal(); @@ -1363,11 +1353,6 @@ void EagerKernelExecuteAsync( // triggered after execution with its status. void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, int* num_retvals, StatusCallback done) { - if (VariantDeviceIsCustom(op->Device())) { - done(errors::Unimplemented( - "Custom device is not supported in EagerLocalExecuteAsync.")); - return; - } if (!op->IsLocal()) { done(errors::InvalidArgument( "Remote execution is not supported in async EagerLocalExecuteAsync")); @@ -1419,8 +1404,14 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, output_dtypes[i], &ctx); } + const absl::InlinedVector* inputs; + s = op->TensorHandleInputs(&inputs); + if (!s.ok()) { + done(s); + return; + } EagerKernelExecuteAsync( - &ctx, op->Inputs(), op->remote_func_params(), std::move(kernel), + &ctx, *inputs, op->remote_func_params(), std::move(kernel), graph_collector, op->GetCancellationManager(), retvals, num_outputs, [op, num_outputs, retvals, done = std::move(done)](const Status& s) { op->Clear(); diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index 27503cfd99d..39237181fe8 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -44,8 +44,7 @@ Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx, TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h)); // We have validated that h->device() is not a CustomDevice when // constructing a pack TensorHandle. - const Status status = - h->TensorValue(absl::get(h->device()), &packed_arg_flat[i]); + const Status status = h->TensorValue(h->device(), &packed_arg_flat[i]); if (!status.ok()) { #if !defined(IS_MOBILE_PLATFORM) if (IsRemote(ctx, input_device, h)) { @@ -107,13 +106,7 @@ Status ExecuteNodeArgs::Init( TF_RETURN_IF_ERROR( op_inputs[index.index]->ExtractPackedHandle(index.sub_index, &h)); } - VariantDevice variant_device = h->device(); - if (VariantDeviceIsCustom(variant_device)) { - return errors::Internal( - "Custom devices and remote execution are currently not supported " - "together."); - } - Device* device = absl::get(variant_device); + Device* device = h->device(); // For a multi-device function, a remote RunComponentFunction request is // not sent through StreamingEnqueueAsync. It could arrive at a remote // worker before a remote execution request which produces an input of the diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc index 424da531816..2fa3fbd7f0e 100644 --- a/tensorflow/core/common_runtime/eager/execute_node_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc @@ -105,7 +105,7 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) { std::vector input_devices; for (auto* h : inputs) { - input_devices.push_back(absl::get(h->DeviceOrHostCPU(*ctx))); + input_devices.push_back(h->DeviceOrHostCPU(*ctx)); } const core::RefCountPtr kernel( new TestKernelAndDeviceFunc(std::move(input_devices), device0)); diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 619715f1cae..b55fdf82644 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -138,17 +139,18 @@ Status MaybePinSmallOpsToCpu( return Status::OK(); } -Status MaybePinToResourceDevice(VariantDevice* device, - const EagerOperation& op) { +Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) { if (op.colocation_exempt()) { return Status::OK(); } EagerContext& ctx = op.EagerContext(); + const absl::InlinedVector* inputs; + TF_RETURN_IF_ERROR(op.TensorHandleInputs(&inputs)); Device* op_device = op.Device() == kVariantDeviceNull ? ctx.HostCPU() : absl::get(op.Device()); - for (int i = 0; i < op.Inputs().size(); ++i) { - TensorHandle* tensor_handle = op.Inputs()[i]; + for (int i = 0; i < inputs->size(); ++i) { + TensorHandle* tensor_handle = (*inputs)[i]; if (tensor_handle->dtype == DT_RESOURCE) { if (tensor_handle->resource_remote_device_incarnation() != 0) { TF_RETURN_IF_ERROR(ValidateTensorHandleRemoteDevice( @@ -182,7 +184,7 @@ Status MaybePinToResourceDevice(VariantDevice* device, Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { // If operation was already placed on a custom device, use it. - if (VariantDeviceIsCustom(op.Device())) { + if (absl::holds_alternative(op.Device())) { *device = op.Device(); return Status::OK(); } else if (!op.DeviceName().empty()) { @@ -194,9 +196,13 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { // placement and there is only one custom device in the op inputs. if (!op.Inputs().empty()) { CustomDevice* first = nullptr; - for (const TensorHandle* input : op.Inputs()) { - if (VariantDeviceIsCustom(input->device())) { - CustomDevice* current = absl::get(input->device()); + for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast(generic_input); + CustomDevice* current = input->device(); if (first == nullptr) { first = current; } else if (first != current) { @@ -207,9 +213,9 @@ Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { op.Name(), " has one input in custom " "device ", - VariantDeviceName(first), + first->name(), " and at least one input in a different custom device ", - VariantDeviceName(current))); + current->name())); } } } diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index b051e13ea08..7676fe01b43 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -43,8 +43,7 @@ Status MaybePinSmallOpsToCpu( // If a resource touching input is specified, all resource-touching ops run in // the device the resource is, regardless of anything else that has been // specified. This is identical to the graph mode behavior. -Status MaybePinToResourceDevice(VariantDevice* device, - const EagerOperation& op); +Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); // If all the inputs are on the same custom device, use that custom // device. Otherwise, it is an error to have a custom device as an input. diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index ca13190748c..297536baad0 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -54,6 +54,14 @@ int64 GetRemoteDeviceIncarnation(Device* device) { if (device == nullptr || device->IsLocal()) return 0; return device->attributes().incarnation(); } + +string SafeDeviceDebugString(Device* device) { + if (device == nullptr) { + return "[]"; + } else { + return device->DebugString(); + } +} } // namespace TensorHandle::PackedTensorHandleData::PackedTensorHandleData( @@ -231,12 +239,6 @@ TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d, } } -TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, - CustomDevice* d, - EagerContext* ctx) { - return new TensorHandle(std::move(t), d, ctx); -} - TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, Device* resource_device, EagerContext* ctx) : ImmediateExecutionTensorHandle(kEager), @@ -249,7 +251,7 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, ctx_(ctx), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_) + << " device: " << SafeDeviceDebugString(device_) << " tensor: " << t.DeviceSafeDebugString(); } @@ -268,26 +270,10 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, t.flat()(0).dtypes_and_shapes()), data_(absl::in_place_type, std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_) + << " device: " << SafeDeviceDebugString(device_) << " tensor: " << t.DeviceSafeDebugString(); } -TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, - EagerContext* ctx) - : ImmediateExecutionTensorHandle(kEager), - dtype(t.dtype()), - device_(d), - op_device_(nullptr), - resource_device_(nullptr), - resource_remote_device_incarnation_(0), - ctx_(ctx), - data_(absl::in_place_type, std::move(t)) { - // TODO(allenl): Figure out a better op_device story for custom devices, - // since always setting it to CPU=nullptr doesn't make much sense. - DVLOG(3) << "Creating Local TensorHandle: " << this - << " custom device: " << VariantDeviceDebugString(device_) - << " tensor: " << t.DeviceSafeDebugString(); -} TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device, Device* resource_device, @@ -309,7 +295,7 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, ctx_(ctx), data_(absl::in_place_type) { DVLOG(3) << "Creating empty Local TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << SafeDeviceDebugString(device_); } Status TensorHandle::CreatePackedHandle(std::vector&& handles, @@ -328,13 +314,10 @@ Status TensorHandle::CreatePackedHandle(std::vector&& handles, handles.at(0)->GetResourceHandleDtypesAndShapes(&dtypes_and_shapes)); } std::vector devices; + devices.reserve(handles.size()); for (auto* handle : handles) { - if (VariantDeviceIsCustom(handle->device())) { - devices.push_back(absl::get(handle->device())->name()); - } else { - devices.push_back(handle->op_device() ? handle->op_device()->name() - : ctx->HostCPU()->name()); - } + devices.push_back(handle->op_device() ? handle->op_device()->name() + : ctx->HostCPU()->name()); } CompositeDevice* composite_device = nullptr; @@ -378,7 +361,7 @@ TensorHandle::TensorHandle(std::vector&& handles, Device* device, data_(absl::in_place_type, std::move(handles), shape) { DVLOG(3) << "Creating a packed TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << SafeDeviceDebugString(device_); } #if !defined(IS_MOBILE_PLATFORM) @@ -406,7 +389,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, data_(absl::in_place_type, op_id, output_num, remote_task, ctx) { DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << SafeDeviceDebugString(device_); } TensorHandle* TensorHandle::CreateLazyRemoteHandle( @@ -429,7 +412,7 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num, data_(absl::in_place_type, op_id, output_num, ctx->GetContextViewId(), is_ready) { DVLOG(3) << "Creating Lazy Remote TensorHandle: " << this - << " device: " << VariantDeviceDebugString(device_); + << " device: " << SafeDeviceDebugString(device_); } #endif @@ -487,7 +470,7 @@ Status TensorHandle::TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const { DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; - if (d == absl::get(device_)) { + if (d == device_) { if (Type() != LOCAL) { return errors::Internal("Invalid Tensor call on a ", TypeString(), " handle: ", this); @@ -511,13 +494,7 @@ Status TensorHandle::TensorFromDevice(const Device* d, Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d; - if (VariantDeviceIsCustom(device_)) { - return errors::Internal( - "TensorHandle::TensorValue not supported for custom devices yet. " - "Handle device: ", - VariantDeviceDebugString(device_), - ", requested device: ", d != nullptr ? d->name() : "(nil)"); - } else if (d == absl::get(device_)) { + if (d == device_) { if (Type() != LOCAL) { return errors::Internal("Invalid TensorValue call on a ", TypeString(), " handle: ", this); @@ -549,13 +526,8 @@ Status TensorHandle::WaitUnknownDevice() const { return Status::OK(); } -VariantDevice TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { - if (VariantDeviceIsCustom(device_)) { - return device_; - } else { - Device* d = absl::get(device_); - return (d == nullptr) ? ctx.HostCPU() : d; - } +Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { + return (device_ == nullptr) ? ctx.HostCPU() : device_; } Status TensorHandle::Shape(tensorflow::TensorShape* shape) { @@ -691,7 +663,7 @@ Status TensorHandle::NumElements(int64* num_elements) const { Status TensorHandle::Unprotect(const Device* d) { DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; - if (d == absl::get(device_)) { + if (d == device_) { return absl::visit([](auto& data) { return data.Unprotect(); }, data_); } @@ -718,7 +690,7 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) { DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this << " device: " << d; - if (!VariantDeviceIsCustom(device_) && d == absl::get(device_)) { + if (d == device_) { return errors::Internal("Cannot add mirror for primary device."); } @@ -739,7 +711,7 @@ Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready, DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d << " " << d->name(); - if (VariantDeviceIsCustom(device_) || d != absl::get(device_)) { + if (d != device_) { tf_shared_lock l(mu_); auto mirror = remote_mirrors_.find(d->name()); if (mirror != remote_mirrors_.end()) { @@ -854,7 +826,7 @@ Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape, DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d << " " << d->name(); - if (VariantDeviceIsCustom(device_) || d != absl::get(device_)) { + if (d != device_) { tf_shared_lock l(mu_); auto remote_mirror = remote_mirrors_.find(d->name()); if (remote_mirror == remote_mirrors_.end()) { @@ -916,7 +888,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d, DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d << " " << d->name(); - if (!VariantDeviceIsCustom(device_) && d == absl::get(device_)) { + if (d == device_) { DCHECK(Type() == REMOTE) << "Poison can only be on remote handles: " << this; @@ -936,7 +908,7 @@ void TensorHandle::PoisonRemote(Status status, const Device* d, Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d) { - if (d == absl::get(device_)) { + if (d == device_) { return errors::Internal( "Local mirror assign conflicts with primary device."); } @@ -955,7 +927,7 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; - if (d == absl::get(device_)) { + if (d == device_) { DCHECK(Type() == LOCAL) << "SetTensor is not called on local handles."; if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) { @@ -982,7 +954,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { void TensorHandle::Poison(Status status, const Device* d) { DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; - if (!VariantDeviceIsCustom(device_) && d == absl::get(device_)) { + if (d == device_) { DCHECK(Type() != REMOTE) << "Poison can only be on local handles: " << this; absl::visit([status](auto& data) { data.Poison(status); }, data_); } else { @@ -1001,7 +973,7 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, tensorflow::Tensor* output) const { tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d; - tensorflow::Device* srcd = absl::get(DeviceOrHostCPU(ctx)); + tensorflow::Device* srcd = DeviceOrHostCPU(ctx); const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr; const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr; bool is_same_device = @@ -1063,27 +1035,6 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx, return status; } -bool VariantDeviceIsCustom(VariantDevice variant_device) { - return variant_device.index() != 0; -} - -string VariantDeviceName(VariantDevice device) { - if (device == kVariantDeviceNull) { - return "[]"; - } - return absl::visit([](auto* device) { return device->name(); }, device); -} - -string VariantDeviceDebugString(VariantDevice device) { - if (device == kVariantDeviceNull) { - return "[]"; - } else if (VariantDeviceIsCustom(device)) { - return absl::get(device)->name(); - } else { - return absl::get(device)->DebugString(); - } -} - Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) { if (ctx == nullptr) { return nullptr; @@ -1100,10 +1051,9 @@ string TensorHandle::DebugString() const { DVLOG(4) << "Calling TensorHandle::DebugString() on " << this; string out; - string device_debug = VariantDeviceDebugString(device_); + string device_debug = SafeDeviceDebugString(device_); strings::StrAppend(&out, "Device: ", device_debug); - bool is_cpu = - !VariantDeviceIsCustom(device_) && device_ != kVariantDeviceNull; + bool is_cpu = device_ != nullptr; // Consider supporting non-CPU tensors and CPU tensors with a device_ set to // non-NULL if needed. strings::StrAppend( @@ -1115,9 +1065,6 @@ string TensorHandle::DebugString() const { } const char* TensorHandle::DeviceName(Status* status) const { - if (VariantDeviceIsCustom(device())) { - return absl::get(device())->name().c_str(); - } status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" @@ -1125,33 +1072,19 @@ const char* TensorHandle::DeviceName(Status* status) const { } const char* TensorHandle::BackingDeviceName(Status* status) const { - if (VariantDeviceIsCustom(device())) { - return absl::get(device())->name().c_str(); - } else { - status->Update(WaitUnknownDevice()); - tensorflow::Device* d = absl::get(device()); - return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" - : d->name().c_str(); - } + status->Update(WaitUnknownDevice()); + tensorflow::Device* d = device(); + return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : d->name().c_str(); } const char* TensorHandle::DeviceType(Status* status) const { - if (VariantDeviceIsCustom(device())) { - status->Update( - tensorflow::errors::Unimplemented("Custom device unsupported")); - return nullptr; - } status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str(); } int TensorHandle::DeviceId(Status* status) const { - if (VariantDeviceIsCustom(device())) { - status->Update( - tensorflow::errors::Unimplemented("Custom device unsupported")); - return -1; - } status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? 0 : d->parsed_name().id; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 396af4166c7..1072ad5a3ab 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -60,7 +60,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // TensorHandle for dtype == DT_RESOURCE TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device, EagerContext* ctx); - TensorHandle(tensorflow::Tensor&& t, CustomDevice* d, EagerContext* ctx); TensorHandle(Device* d, Device* op_device, Device* resource_device, tensorflow::DataType dtype, EagerContext* ctx); @@ -81,8 +80,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle { Device* op_device, Device* resource_device, EagerContext* ctx); - static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, - CustomDevice* d, EagerContext* ctx); static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device, Device* resource_device, tensorflow::DataType dtype, @@ -150,7 +147,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // requesting the HostCPU. Status TensorValue(const Device* d, tensorflow::TensorValue* t); - VariantDevice device() const { return device_; } + Device* device() const { return device_; } Device* op_device() const { return op_device_; } Device* resource_device() const { return resource_device_; } int64 resource_remote_device_incarnation() const { @@ -161,7 +158,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // are set (data is ready). Status WaitUnknownDevice() const; - VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const; + Device* DeviceOrHostCPU(const EagerContext& ctx) const; Status Shape(tensorflow::TensorShape* shape); @@ -286,7 +283,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { bool IsReady() const; Status WaitReady(const char* caller) const; - VariantDevice device_; + tensorflow::Device* device_; // Device in which the op producing this tensor was executed. Equals to // device_ for constant tensors. @@ -391,19 +388,6 @@ class TensorHandle : public ImmediateExecutionTensorHandle { PartialTensorShape inference_shape_; }; -// Checks whether a VariantDevice contains a custom device. -bool VariantDeviceIsCustom(VariantDevice device); - -// Wraps device->name() or CustomDevice->name(). -string VariantDeviceName(VariantDevice device); - -// Wraps device->DebugString() or CustomDevice->name(). -string VariantDeviceDebugString(VariantDevice device); - -// Indicates either HostCPU or an unset physical device. We never set a null -// CustomDevice*. -const VariantDevice kVariantDeviceNull = static_cast(nullptr); - // Returns the device backing the resource. Else, returns nullptr. Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 936b35bda8b..5c729c6560e 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -189,8 +189,8 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { EXPECT_EQ(dtypes_and_shapes.at(0).dtype, DT_FLOAT); EXPECT_EQ(dtypes_and_shapes.at(0).shape.IsIdenticalTo({2, 2}), true); - CompositeDevice* device = reinterpret_cast( - absl::get(packed_handle->device())); + CompositeDevice* device = + reinterpret_cast(packed_handle->device()); EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0"); EXPECT_EQ(device->underlying_devices()->size(), 4); @@ -200,7 +200,7 @@ TEST_F(PackedTensorHandleTest, PackedHandle) { for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) { TensorHandle* h = nullptr; TF_ASSERT_OK(packed_handle->ExtractPackedHandle(i, &h)); - EXPECT_EQ(absl::get(h->device()), ListDevices().at(i)); + EXPECT_EQ(h->device(), ListDevices().at(i)); EXPECT_EQ(h->Type(), expected_handle_types.at(i)); } EXPECT_FALSE(IsReady(packed_handle)); @@ -236,14 +236,14 @@ TEST_F(PackedTensorHandleTest, PackedSingleHandle) { TF_ASSERT_OK(packed_handle->Shape(&packed_shape)); EXPECT_EQ(packed_shape, shape); - CompositeDevice* device = reinterpret_cast( - absl::get(packed_handle->device())); + CompositeDevice* device = + reinterpret_cast(packed_handle->device()); EXPECT_EQ(device->name(), "/job:worker/replica:0/task:0/device:COMPOSITE:0"); EXPECT_EQ(device->underlying_devices()->size(), 1); EXPECT_EQ(packed_handle->NumPackedHandles(), 1); TensorHandle* h0 = nullptr; TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &h0)); - EXPECT_EQ(absl::get(h0->device()), d); + EXPECT_EQ(h0->device(), d); EXPECT_TRUE(IsReady(packed_handle)); packed_handle->Unref(); } @@ -392,7 +392,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) { TensorHandle* h = TensorHandle::CreateUnshapedRemoteHandle( /*op_id=*/0, /*output_num=*/0, remote_task, dtype, d1, context, /*unknown_device=*/true); - EXPECT_EQ(absl::get(h->device()), d1); + EXPECT_EQ(h->device(), d1); Device* d2 = device_mgr.ListDevices().at(2); TF_ASSERT_OK(h->SetRemoteShapeAndDevice( @@ -400,7 +400,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) { Status s; EXPECT_EQ(h->BackingDeviceName(&s), d2->name()); TF_EXPECT_OK(s); - EXPECT_EQ(absl::get(h->device()), d2); + EXPECT_EQ(h->device(), d2); h->Unref(); context->Unref(); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 385fb39223f..7377fad0a39 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -186,7 +186,7 @@ Status AddOpRetvalsToResponse( for (int i = 0; i < num_retvals; i++) { TF_RETURN_IF_ERROR(TensorHandleShape(retvals[i], add_shape_proto_fn())); if (add_device_fn) { - Device* device = absl::get(retvals[i]->device()); + Device* device = retvals[i]->device(); *add_device_fn() = device ? device->name() : ""; } if (retvals[i]->Type() == TensorHandle::REMOTE) { diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 1b7ca04318b..77a38d09a29 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -1086,8 +1086,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&t)); - Device* device = absl::get(tensor_handle->device()); - EXPECT_EQ(device, nullptr); + EXPECT_EQ(tensor_handle->device(), nullptr); auto actual = t->flat(); EXPECT_EQ(4, actual.size()); @@ -1168,8 +1167,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) { EXPECT_EQ(packed_handle->Type(), TensorHandle::PACKED); EXPECT_EQ(packed_handle->NumPackedHandles(), 3); - EXPECT_EQ(absl::get(packed_handle->device())->name(), - composite_device); + EXPECT_EQ(packed_handle->device()->name(), composite_device); TensorHandle* handle0 = nullptr; TF_ASSERT_OK(packed_handle->ExtractPackedHandle(0, &handle0)); @@ -1198,7 +1196,7 @@ TEST_F(EagerServiceImplTest, SendPackedHandleTest) { EXPECT_EQ(handle2->op_device()->name(), device2); int64 op_id; int32 output_num; - TF_ASSERT_OK(handle2->RemoteAddress(absl::get(handle2->device()), + TF_ASSERT_OK(handle2->RemoteAddress(handle2->device(), /*wait_until_ready=*/true, &op_id, &output_num)); EXPECT_EQ(op_id, 2); diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index f391af5d4be..5ab016130f4 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -37,7 +37,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) { remote_op->set_name(op->Name()); op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); - remote_op->set_device(VariantDeviceName(op->Device())); + remote_op->set_device(op->DeviceName()); } Status CreateUncachedKernelAndDeviceOp( @@ -80,7 +80,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, src_(src), ctx_(ctx), executor_(executor), - send_device_(absl::get(src->DeviceOrHostCPU(*ctx))), + send_device_(src->DeviceOrHostCPU(*ctx)), recv_device_(recv_device), wire_id_(GetUniqueWireID()), recv_op_id_(recv_op_id), @@ -149,9 +149,8 @@ void RemoteCopyNode::StartSend() { auto* remote_op = request.add_queue()->mutable_operation(); status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle( src_, /*wait_until_ready=*/false, - remote_op->add_op_inputs()->mutable_remote_handle(), - absl::get(src_->device()), - absl::get(src_->DeviceOrHostCPU(*ctx_))->name()); + remote_op->add_op_inputs()->mutable_remote_handle(), src_->device(), + src_->DeviceOrHostCPU(*ctx_)->name()); if (!status.ok()) { captured_state_->SetSendStatus(status); return; @@ -310,7 +309,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, const Device* target_device, EagerContext* ctx, SendPackedHandleOp* op) { op->set_op_id(op_id); - op->set_device_name(VariantDeviceName(packed_handle->DeviceOrHostCPU(*ctx))); + op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name()); for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) { TensorHandle* h = nullptr; TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h)); @@ -329,7 +328,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, // If src_device is on the same task of target_device, the handle is a // local handle on the target device, which means the resource dtype and // shape are known on the target device. - Device* src_device = absl::get(h->device()); + Device* src_device = h->device(); const bool serialize_resource_dtype_and_shape = (i == 0) && (h->dtype == DT_RESOURCE) && (!ctx->OnSameTask(src_device, target_device)); @@ -341,7 +340,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, TF_RETURN_IF_ERROR(ctx->RemoteMgr()->SerializeRemoteTensorHandle( h, /*wait_until_ready=*/true, op->add_handles()->mutable_remote_handle(), src_device, - absl::get(h->DeviceOrHostCPU(*ctx))->name(), + h->DeviceOrHostCPU(*ctx)->name(), serialize_resource_dtype_and_shape)); } else { return errors::InvalidArgument("Nested packed handles are not supported"); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 7a3a447042e..79fcb99ba10 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -83,15 +83,8 @@ Status RemoteMgr::GetMirroredResourceShape( Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, const bool wait_until_ready, int64* op_id, int32* output_num) { - // TODO(allenl): Consider supporting remote handles on custom devices. - VariantDevice device = handle->device(); - if (VariantDeviceIsCustom(device)) { - return errors::Unimplemented( - "Custom devices and remote execution are currently not supported " - "together."); - } - TF_RETURN_IF_ERROR(handle->RemoteAddress( - absl::get(device), wait_until_ready, op_id, output_num)); + TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready, + op_id, output_num)); tensorflow::TensorHandle* h; TF_RETURN_IF_ERROR( GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h)); diff --git a/tensorflow/lite/delegates/flex/kernel.cc b/tensorflow/lite/delegates/flex/kernel.cc index f21c984fe3e..b4db4511334 100644 --- a/tensorflow/lite/delegates/flex/kernel.cc +++ b/tensorflow/lite/delegates/flex/kernel.cc @@ -268,14 +268,11 @@ class OpNode { return tensorflow::Status::OK(); } - void ClearEagerInputs() { - for (tensorflow::TensorHandle* h : *op_->MutableInputs()) { - if (h) h->Unref(); - } - op_->MutableInputs()->clear(); - } + void ClearEagerInputs() { op_->Clear(); } tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) { + absl::InlinedVector* op_inputs; + TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs)); for (int i = 0; i < inputs_.Size(); ++i) { int input_index = inputs_.TfLiteIndex(i); TensorSource s = inputs_.GetTensorSource(i); @@ -290,14 +287,14 @@ class OpNode { tensorflow::TensorHandle* handle = tensorflow::TensorHandle::CreateLocalHandle( buffer_map->GetTensor(input_index)); - op_->MutableInputs()->push_back(handle); + op_inputs->push_back(handle); } else { // If this is a forwardable tensor, we will remove it from the previous // op's list, giving TF the opportunity to reuse its buffer. bool unref_handle = inputs_.IsForwardable(i); auto* handle = s.node->outputs_.GetHandle(s.node_output_index, unref_handle); - op_->MutableInputs()->push_back(handle); + op_inputs->push_back(handle); } } return tensorflow::Status::OK(); diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 52bc6ee8233..1d04a584a9a 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -155,11 +155,7 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor, tensorflow::unwrap(ctx)->TFTensorHandleFromInterface( tensorflow::unwrap(EagerTensor_Handle(eager_tensor)))); - if (VariantDeviceIsCustom(handle->device())) { - return errors::Unimplemented( - "Custom devices are currently not supported with PyFuncs."); - } - Device* actual_device = absl::get(handle->device()); + Device* actual_device = handle->device(); TF_RETURN_IF_ERROR(handle->Tensor(output_tensor)); // actual_device may be nullptr, which implies local CPU. if (expected_device == actual_device) return Status::OK();