Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT.
PiperOrigin-RevId: 356389689 Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa
This commit is contained in:
parent
063eb2465f
commit
6df72f44ff
@ -73,9 +73,11 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||||
"//tensorflow/core/common_runtime/eager:core",
|
"//tensorflow/core/common_runtime/eager:core",
|
||||||
"//tensorflow/core/common_runtime/eager:custom_device",
|
"//tensorflow/core/common_runtime/eager:custom_device",
|
||||||
|
"//tensorflow/core/common_runtime/eager:custom_device_op_handler",
|
||||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||||
"//tensorflow/core/common_runtime/eager:execute",
|
"//tensorflow/core/common_runtime/eager:execute",
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
|
"//tensorflow/core/common_runtime/eager:placement_utils",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
@ -41,7 +41,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
|
|||||||
tensorflow::EagerContext* context =
|
tensorflow::EagerContext* context =
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||||
tensorflow::CustomDevice* device = nullptr;
|
tensorflow::CustomDevice* device = nullptr;
|
||||||
if (!context->FindCustomDeviceFromName(device_name, &device)) {
|
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
|
||||||
|
&device)) {
|
||||||
deallocator(data, arg);
|
deallocator(data, arg);
|
||||||
status->status =
|
status->status =
|
||||||
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
||||||
@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
|||||||
status->status = context->FindDeviceFromName(device_name, &device);
|
status->status = context->FindDeviceFromName(device_name, &device);
|
||||||
tensorflow::CustomDevice* custom_device = nullptr;
|
tensorflow::CustomDevice* custom_device = nullptr;
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
|
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
|
||||||
|
device_name, &custom_device)) {
|
||||||
deallocator(data, len, deallocator_arg);
|
deallocator(data, len, deallocator_arg);
|
||||||
status->status =
|
status->status =
|
||||||
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
||||||
@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
||||||
return tensorflow::wrap(
|
return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
|
||||||
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||||
@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
|||||||
|
|
||||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||||
TF_Status* status) {
|
TF_Status* status) {
|
||||||
status->status = tensorflow::unwrap(op)->Execute(
|
tensorflow::ImmediateExecutionOperation* unwrapped_op =
|
||||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
tensorflow::unwrap(op);
|
||||||
tensorflow::unwrap(retvals)),
|
|
||||||
*num_retvals),
|
status->status =
|
||||||
num_retvals);
|
unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
|
||||||
|
unwrapped_op,
|
||||||
|
reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
|
||||||
|
retvals),
|
||||||
|
num_retvals);
|
||||||
}
|
}
|
||||||
|
|
||||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||||
@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
|||||||
}
|
}
|
||||||
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
|
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
|
||||||
ctx, device, device_info, device_name);
|
ctx, device, device_info, device_name);
|
||||||
tensorflow::EagerContext* context =
|
status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
|
||||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
device_name, std::move(custom_device));
|
||||||
status->status =
|
|
||||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
@ -38,6 +38,9 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
class EagerExecutor;
|
class EagerExecutor;
|
||||||
|
class EagerContext;
|
||||||
|
class CustomDevice;
|
||||||
|
class CustomDeviceOpHandler;
|
||||||
|
|
||||||
// LINT.IfChange
|
// LINT.IfChange
|
||||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||||
@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
|
|
||||||
// Return the ParsedName of Host CPU device.
|
// Return the ParsedName of Host CPU device.
|
||||||
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||||
|
virtual const string& HostCPUName() const = 0;
|
||||||
|
|
||||||
// Configure soft device placement policy.
|
// Configure soft device placement policy.
|
||||||
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||||
@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext {
|
|||||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
// Experimental Custom Device.
|
||||||
|
//===--------------------------------------------------------------------===//
|
||||||
|
virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
|
||||||
|
|
||||||
|
// Register a custom device. It will return error is the device name is
|
||||||
|
// already registered.
|
||||||
|
// TODO(tfrt-devs): Remove this method. Let caller register it directly into
|
||||||
|
// CustomDeviceOpHandler.
|
||||||
|
virtual Status RegisterCustomDevice(const string& name,
|
||||||
|
std::unique_ptr<CustomDevice> device) = 0;
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Following are features in current TF Eager Runtime.
|
// Following are features in current TF Eager Runtime.
|
||||||
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
||||||
|
@ -33,6 +33,8 @@ struct TFE_Op;
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class ImmediateExecutionContext;
|
||||||
|
|
||||||
// Abstract interface to an operation.
|
// Abstract interface to an operation.
|
||||||
class ImmediateExecutionOperation : public AbstractOperation {
|
class ImmediateExecutionOperation : public AbstractOperation {
|
||||||
public:
|
public:
|
||||||
@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
|||||||
// Returns the inputs of this op.
|
// Returns the inputs of this op.
|
||||||
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
|
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
|
||||||
const = 0;
|
const = 0;
|
||||||
|
virtual Status SetInput(size_t index,
|
||||||
|
ImmediateExecutionTensorHandle* input) = 0;
|
||||||
|
|
||||||
|
virtual ImmediateExecutionContext* GetContext() const = 0;
|
||||||
|
|
||||||
|
// Following two methods are used to support custom device.
|
||||||
|
// Return true if the inputs contain custom device tensor handle. It means
|
||||||
|
// that the argument need to be handled by a custom device.
|
||||||
|
virtual bool HasCustomDeviceInput() const = 0;
|
||||||
|
|
||||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||||
|
|
||||||
|
@ -87,6 +87,7 @@ tf_cuda_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":eager_executor",
|
":eager_executor",
|
||||||
":kernel_and_device",
|
":kernel_and_device",
|
||||||
|
":custom_device_op_handler",
|
||||||
":custom_device",
|
":custom_device",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//tensorflow/c:tf_tensor_internal",
|
"//tensorflow/c:tf_tensor_internal",
|
||||||
@ -140,6 +141,28 @@ tf_cuda_library(
|
|||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_library(
|
||||||
|
name = "custom_device_op_handler",
|
||||||
|
srcs = ["custom_device_op_handler.cc"],
|
||||||
|
hdrs = ["custom_device_op_handler.h"],
|
||||||
|
visibility = ["//tensorflow:internal"],
|
||||||
|
deps = [
|
||||||
|
":custom_device",
|
||||||
|
] + select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_context",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
|
"//tensorflow/core/lib/core:status",
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "custom_device_test",
|
name = "custom_device_test",
|
||||||
srcs = ["custom_device_test.cc"],
|
srcs = ["custom_device_test.cc"],
|
||||||
@ -647,6 +670,7 @@ tf_cuda_library(
|
|||||||
":custom_device",
|
":custom_device",
|
||||||
":attr_builder",
|
":attr_builder",
|
||||||
":eager_operation",
|
":eager_operation",
|
||||||
|
"//tensorflow/c/eager:immediate_execution_operation",
|
||||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
@ -714,6 +738,7 @@ filegroup(
|
|||||||
"attr_builder.h",
|
"attr_builder.h",
|
||||||
"context.h",
|
"context.h",
|
||||||
"custom_device.h",
|
"custom_device.h",
|
||||||
|
"custom_device_op_handler.h",
|
||||||
"eager_executor.h",
|
"eager_executor.h",
|
||||||
"eager_operation.h",
|
"eager_operation.h",
|
||||||
"kernel_and_device.h",
|
"kernel_and_device.h",
|
||||||
|
@ -522,7 +522,7 @@ EagerContext::~EagerContext() {
|
|||||||
|
|
||||||
// Custom devices may have obtained references to various context components
|
// Custom devices may have obtained references to various context components
|
||||||
// (executors, thread pool). It's safer to run their destructors early.
|
// (executors, thread pool). It's safer to run their destructors early.
|
||||||
custom_devices_.clear();
|
custom_device_op_handler_.Clear();
|
||||||
|
|
||||||
ClearCachesAndThreadExecutors();
|
ClearCachesAndThreadExecutors();
|
||||||
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||||
@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName(
|
|||||||
return errors::NotFound("Unknown composite device: ", device_name);
|
return errors::NotFound("Unknown composite device: ", device_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EagerContext::FindCustomDeviceFromName(const string& device_name,
|
|
||||||
CustomDevice** dev) const {
|
|
||||||
auto dev_it = custom_devices_.find(device_name);
|
|
||||||
if (dev_it == custom_devices_.end()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*dev = dev_it->second.get();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status EagerContext::RegisterCustomDevice(
|
Status EagerContext::RegisterCustomDevice(
|
||||||
const string& device_name, std::unique_ptr<CustomDevice> device) {
|
const string& device_name, std::unique_ptr<CustomDevice> device) {
|
||||||
DeviceNameUtils::ParsedName parsed;
|
|
||||||
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
|
|
||||||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
|
|
||||||
!parsed.has_type || !parsed.has_id) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
device_name,
|
|
||||||
" could not be parsed as a device name. Use the full "
|
|
||||||
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
|
|
||||||
"format.");
|
|
||||||
}
|
|
||||||
Device* existing_physical_device = nullptr;
|
Device* existing_physical_device = nullptr;
|
||||||
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
|
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
|
||||||
return errors::AlreadyExists(device_name,
|
return errors::AlreadyExists(device_name,
|
||||||
" already registered as a physical device.");
|
" already registered as a physical device.");
|
||||||
}
|
}
|
||||||
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
|
return custom_device_op_handler_.RegisterCustomDevice(device_name,
|
||||||
return errors::AlreadyExists(device_name,
|
std::move(device));
|
||||||
" already registered as a custom device.");
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::FindOrCreateCompositeDevice(
|
Status EagerContext::FindOrCreateCompositeDevice(
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
return HostCPU()->parsed_name();
|
return HostCPU()->parsed_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const string& HostCPUName() const override { return HostCPU()->name(); }
|
||||||
|
|
||||||
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||||
|
|
||||||
EagerExecutor& Executor() override;
|
EagerExecutor& Executor() override;
|
||||||
@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
Status FindCompositeDeviceFromName(StringPiece device_name,
|
Status FindCompositeDeviceFromName(StringPiece device_name,
|
||||||
CompositeDevice** device) const;
|
CompositeDevice** device) const;
|
||||||
|
|
||||||
bool FindCustomDeviceFromName(const string& device_name,
|
|
||||||
CustomDevice** dev) const;
|
|
||||||
|
|
||||||
Status RegisterCustomDevice(const string& name,
|
Status RegisterCustomDevice(const string& name,
|
||||||
std::unique_ptr<CustomDevice> device);
|
std::unique_ptr<CustomDevice> device) override;
|
||||||
|
|
||||||
|
CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
|
||||||
|
return custom_device_op_handler_;
|
||||||
|
};
|
||||||
|
|
||||||
// Find or create a composite device with the given `underlying_devices` and
|
// Find or create a composite device with the given `underlying_devices` and
|
||||||
// `device_name` (if not empty).
|
// `device_name` (if not empty).
|
||||||
@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
|||||||
TF_GUARDED_BY(device_type_list_mu_);
|
TF_GUARDED_BY(device_type_list_mu_);
|
||||||
Rendezvous* rendezvous_;
|
Rendezvous* rendezvous_;
|
||||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||||
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
CustomDeviceOpHandler custom_device_op_handler_;
|
||||||
|
|
||||||
mutable mutex composite_devices_mu_;
|
mutable mutex composite_devices_mu_;
|
||||||
// Maps from the fingerprint of a set of device names to a virtual
|
// Maps from the fingerprint of a set of device names to a virtual
|
||||||
|
@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
|||||||
*status = this->FindDeviceFromName(device_name, &device);
|
*status = this->FindDeviceFromName(device_name, &device);
|
||||||
if (!status->ok()) {
|
if (!status->ok()) {
|
||||||
tensorflow::CustomDevice* dev;
|
tensorflow::CustomDevice* dev;
|
||||||
if (this->FindCustomDeviceFromName(device_name, &dev)) {
|
if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) {
|
||||||
*status = dev->CopyTensorToDevice(handle, &result);
|
*status = dev->CopyTensorToDevice(handle, &result);
|
||||||
if (status->ok()) {
|
if (status->ok()) {
|
||||||
return result;
|
return result;
|
||||||
@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
tensorflow::CustomDevice* dev;
|
tensorflow::CustomDevice* dev;
|
||||||
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
|
if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name,
|
||||||
|
&dev)) {
|
||||||
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
|
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
|
||||||
if (status->ok()) {
|
if (status->ok()) {
|
||||||
return result;
|
return result;
|
||||||
@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decide to either run the operation on a custom device or copy off all of
|
|
||||||
// the custom device inputs.
|
|
||||||
VariantDevice maybe_custom_device = Device();
|
|
||||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
|
|
||||||
!inputs_are_tensor_handles_) {
|
|
||||||
// If the op wasn't placed on a custom device explicitly and there are no
|
|
||||||
// non-TensorHandle inputs, the op will definitely be placed on a physical
|
|
||||||
// device. Otherwise we need to check the inputs one by one.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
|
|
||||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
|
|
||||||
ImmediateExecutionTensorHandle** retval_array =
|
|
||||||
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
|
|
||||||
return absl::get<CustomDevice*>(maybe_custom_device)
|
|
||||||
->Execute(this, retval_array, num_retvals);
|
|
||||||
} else {
|
|
||||||
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run eager placement logic.
|
// Run eager placement logic.
|
||||||
class Device* device = absl::get<class Device*>(maybe_custom_device);
|
class Device* device = absl::get<class Device*>(Device());
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
||||||
}
|
}
|
||||||
|
167
tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
Normal file
167
tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
|
||||||
|
|
||||||
|
Status CustomDeviceOpHandler::RegisterCustomDevice(
|
||||||
|
const string& device_name, std::unique_ptr<CustomDevice> device) {
|
||||||
|
DeviceNameUtils::ParsedName parsed;
|
||||||
|
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
|
||||||
|
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
|
||||||
|
!parsed.has_type || !parsed.has_id) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
device_name,
|
||||||
|
" could not be parsed as a device name. Use the full "
|
||||||
|
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
|
||||||
|
"format.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
|
||||||
|
return errors::AlreadyExists(device_name,
|
||||||
|
" already registered as a custom device.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CustomDeviceOpHandler::FindCustomDeviceFromName(
|
||||||
|
const string& name, CustomDevice** device) const {
|
||||||
|
auto dev_it = custom_devices_.find(name);
|
||||||
|
if (dev_it == custom_devices_.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*device = dev_it->second.get();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
|
||||||
|
ImmediateExecutionTensorHandle** retvals,
|
||||||
|
int* num_retvals) {
|
||||||
|
tensorflow::CustomDevice* custom_device = nullptr;
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
|
||||||
|
|
||||||
|
if (custom_device != nullptr) {
|
||||||
|
return custom_device->Execute(op, retvals, num_retvals);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The op will be placed on physical device. However, it contains custom
|
||||||
|
// device tensor handles. The tensor handles will be copy to physical device
|
||||||
|
// first.
|
||||||
|
if (op->HasCustomDeviceInput()) {
|
||||||
|
auto inputs = op->GetInputs();
|
||||||
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
|
auto target_device = op->DeviceName();
|
||||||
|
if (target_device.empty()) {
|
||||||
|
target_device = op->GetContext()->HostCPUName();
|
||||||
|
}
|
||||||
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||||
|
// here.
|
||||||
|
if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
|
||||||
|
tensorflow::CustomDeviceTensorHandle* previous =
|
||||||
|
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
|
||||||
|
inputs[i]);
|
||||||
|
tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
|
||||||
|
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
|
||||||
|
previous, target_device, &new_tesnor));
|
||||||
|
Status s = op->SetInput(i, new_tesnor);
|
||||||
|
new_tesnor->Unref();
|
||||||
|
TF_RETURN_IF_ERROR(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return op->Execute(
|
||||||
|
absl::MakeSpan(
|
||||||
|
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
|
||||||
|
*num_retvals),
|
||||||
|
num_retvals);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status CustomDeviceOpHandler::MaybePinToCustomDevice(
|
||||||
|
CustomDevice** device, const ImmediateExecutionOperation& op) const {
|
||||||
|
*device = nullptr;
|
||||||
|
if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
|
||||||
|
!op.HasCustomDeviceInput()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ops are placed on a custom device if there's no other explicit requested
|
||||||
|
// placement and there is only one custom device in the op
|
||||||
|
// inputs.
|
||||||
|
//
|
||||||
|
// Resource-dtype inputs take precedence over non-resource inputs and explicit
|
||||||
|
// placements; this function pins ops with a resource-dtype custom device
|
||||||
|
// input to that custom device.
|
||||||
|
CustomDevice* first = nullptr;
|
||||||
|
if (!op.GetInputs().empty()) {
|
||||||
|
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
|
||||||
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||||
|
// here.
|
||||||
|
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||||
|
const CustomDeviceTensorHandle* input =
|
||||||
|
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||||
|
CustomDevice* current = input->device();
|
||||||
|
if (first == nullptr) {
|
||||||
|
first = current;
|
||||||
|
} else if (first != current) {
|
||||||
|
return errors::InvalidArgument(absl::StrCat(
|
||||||
|
"If an operation has one of its inputs in a custom device, then "
|
||||||
|
"all inputs should be on that same custom device or another "
|
||||||
|
"physical device. Operation ",
|
||||||
|
op.Name(),
|
||||||
|
" has one input in custom "
|
||||||
|
"device ",
|
||||||
|
first->name(),
|
||||||
|
" and at least one input in a different custom device ",
|
||||||
|
current->name()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
|
||||||
|
if (generic_input->DataType() == DT_RESOURCE) {
|
||||||
|
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||||
|
const CustomDeviceTensorHandle* input =
|
||||||
|
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||||
|
// There's only one custom device input, and it's a resource input, so
|
||||||
|
// we'll force-place the op on to that custom device. As with physical
|
||||||
|
// devices, this overrides any explicit placement for the op.
|
||||||
|
*device = input->device();
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
// Don't set a custom device if there's a physical-device resource
|
||||||
|
// input.
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Since there are no resource-dtype inputs, we'll respect explicit placements
|
||||||
|
// before considering input-based placement.
|
||||||
|
if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
|
||||||
|
// If there are non-resource inputs on a custom device we will default the
|
||||||
|
// op to that custom device, but not override an explicit op placement.
|
||||||
|
*device = first;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -0,0 +1,51 @@
|
|||||||
|
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
||||||
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
|
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT.
|
||||||
|
class CustomDeviceOpHandler {
|
||||||
|
public:
|
||||||
|
~CustomDeviceOpHandler() {}
|
||||||
|
// Register a new custom device.
|
||||||
|
Status RegisterCustomDevice(const string& device_name,
|
||||||
|
std::unique_ptr<CustomDevice> device);
|
||||||
|
|
||||||
|
// Find the custom device from given name. Return true if it finds one.
|
||||||
|
bool FindCustomDeviceFromName(const string& name,
|
||||||
|
CustomDevice** device) const;
|
||||||
|
|
||||||
|
Status Execute(ImmediateExecutionOperation* op,
|
||||||
|
ImmediateExecutionTensorHandle** retvals, int* num_retvals);
|
||||||
|
|
||||||
|
// Determine whether to place an op on a custom device. This method is
|
||||||
|
// exposed as public for test only.
|
||||||
|
Status MaybePinToCustomDevice(CustomDevice** device,
|
||||||
|
const ImmediateExecutionOperation& op) const;
|
||||||
|
|
||||||
|
void Clear();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
||||||
|
};
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) {
|
|||||||
TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
|
TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
|
||||||
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
||||||
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
||||||
VariantDevice placed_device(kVariantDeviceNull);
|
CustomDevice* placed_device = nullptr;
|
||||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||||
|
&placed_device, op));
|
||||||
// MaybePinToCustomDevice has no opinion about ops which have physical
|
// MaybePinToCustomDevice has no opinion about ops which have physical
|
||||||
// resource-dtype inputs. They'll get placed on physical devices.
|
// resource-dtype inputs. They'll get placed on physical devices.
|
||||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
EXPECT_EQ(nullptr, placed_device);
|
||||||
|
|
||||||
op.Clear();
|
op.Clear();
|
||||||
TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
|
TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
|
||||||
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
||||||
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
||||||
placed_device = kVariantDeviceNull;
|
placed_device = nullptr;
|
||||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||||
|
&placed_device, op));
|
||||||
// Explicit placement onto a custom device also doesn't trigger custom device
|
// Explicit placement onto a custom device also doesn't trigger custom device
|
||||||
// placement if there's a physical device resource input.
|
// placement if there's a physical device resource input.
|
||||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
EXPECT_EQ(nullptr, placed_device);
|
||||||
|
|
||||||
op.Clear();
|
op.Clear();
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
||||||
placed_device = kVariantDeviceNull;
|
placed_device = nullptr;
|
||||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||||
|
&placed_device, op));
|
||||||
// Explicit placements typically override input-based placement onto a custom
|
// Explicit placements typically override input-based placement onto a custom
|
||||||
// device.
|
// device.
|
||||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
EXPECT_EQ(nullptr, placed_device);
|
||||||
|
|
||||||
op.Clear();
|
op.Clear();
|
||||||
TF_ASSERT_OK(op.Reset("AssignVariableOp",
|
TF_ASSERT_OK(op.Reset("AssignVariableOp",
|
||||||
"/job:localhost/replica:0/task:0/device:CPU:0"));
|
"/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||||
TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
|
||||||
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
||||||
placed_device = kVariantDeviceNull;
|
placed_device = nullptr;
|
||||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||||
|
&placed_device, op));
|
||||||
// Even with an explicit physical device placement, custom device resource
|
// Even with an explicit physical device placement, custom device resource
|
||||||
// inputs place the op on the custom device.
|
// inputs place the op on the custom device.
|
||||||
ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device));
|
ASSERT_NE(placed_device, nullptr);
|
||||||
EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device));
|
EXPECT_EQ(&custom_device, placed_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -36,7 +36,7 @@ void EagerOperation::Clear() {
|
|||||||
h->Unref();
|
h->Unref();
|
||||||
}
|
}
|
||||||
inputs_.clear();
|
inputs_.clear();
|
||||||
inputs_are_tensor_handles_ = true;
|
custom_device_tensor_handles_count_ = 0;
|
||||||
ClearInferenceState();
|
ClearInferenceState();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
|||||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
||||||
if (CustomDeviceTensorHandle::classof(h)) {
|
if (CustomDeviceTensorHandle::classof(h)) {
|
||||||
inputs_are_tensor_handles_ = false;
|
custom_device_tensor_handles_count_++;
|
||||||
}
|
}
|
||||||
AddTensorHandle(h);
|
AddTensorHandle(h);
|
||||||
return MaybeInferSingleInputAttrs(h);
|
return MaybeInferSingleInputAttrs(h);
|
||||||
@ -281,7 +281,7 @@ Status EagerOperation::AddInputList(
|
|||||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||||
// here.
|
// here.
|
||||||
if (CustomDeviceTensorHandle::classof(input)) {
|
if (CustomDeviceTensorHandle::classof(input)) {
|
||||||
inputs_are_tensor_handles_ = false;
|
custom_device_tensor_handles_count_++;
|
||||||
}
|
}
|
||||||
ImmediateExecutionTensorHandle* h =
|
ImmediateExecutionTensorHandle* h =
|
||||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||||
@ -290,6 +290,25 @@ Status EagerOperation::AddInputList(
|
|||||||
return InferInputListAttrs(inputs.size());
|
return InferInputListAttrs(inputs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerOperation::SetInput(size_t index,
|
||||||
|
ImmediateExecutionTensorHandle* input) {
|
||||||
|
if (index >= inputs_.size()) {
|
||||||
|
return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
|
||||||
|
inputs_.size());
|
||||||
|
}
|
||||||
|
auto* previous = inputs_[index];
|
||||||
|
if (CustomDeviceTensorHandle::classof(previous)) {
|
||||||
|
custom_device_tensor_handles_count_--;
|
||||||
|
}
|
||||||
|
if (CustomDeviceTensorHandle::classof(input)) {
|
||||||
|
custom_device_tensor_handles_count_++;
|
||||||
|
}
|
||||||
|
input->Ref();
|
||||||
|
inputs_[index] = input;
|
||||||
|
previous->Unref();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status EagerOperation::Reset(
|
Status EagerOperation::Reset(
|
||||||
const char* op, const char* device_name, bool remote,
|
const char* op, const char* device_name, bool remote,
|
||||||
EagerExecutor* executor,
|
EagerExecutor* executor,
|
||||||
@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
|||||||
|
|
||||||
Status EagerOperation::TensorHandleInputs(
|
Status EagerOperation::TensorHandleInputs(
|
||||||
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
|
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
|
||||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
|
||||||
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
|
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
|
||||||
&inputs_);
|
&inputs_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs(
|
|||||||
|
|
||||||
Status EagerOperation::MutableTensorHandleInputs(
|
Status EagerOperation::MutableTensorHandleInputs(
|
||||||
absl::InlinedVector<TensorHandle*, 4>** inputs) {
|
absl::InlinedVector<TensorHandle*, 4>** inputs) {
|
||||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
|
||||||
*inputs =
|
*inputs =
|
||||||
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
|
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) {
|
|||||||
}
|
}
|
||||||
last_set_device_name_ = name;
|
last_set_device_name_ = name;
|
||||||
device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
|
device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
|
||||||
CustomDevice* custom_device;
|
device_ = kVariantDeviceNull;
|
||||||
if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) {
|
|
||||||
device_ = custom_device;
|
|
||||||
} else {
|
|
||||||
// Device placement for physical devices happens lazily in
|
|
||||||
// EagerExecute/EagerRemoteExecute, and can depend on the inputs.
|
|
||||||
device_ = kVariantDeviceNull;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
|
|||||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerOperation::CopyOffCustomDeviceInputs() {
|
|
||||||
if (absl::holds_alternative<CustomDevice*>(device_)) {
|
|
||||||
return errors::Internal(
|
|
||||||
"Trying to copy inputs to a custom device op off a custom device.");
|
|
||||||
}
|
|
||||||
for (int i = 0; i < inputs_.size(); ++i) {
|
|
||||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
|
||||||
// here.
|
|
||||||
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
|
|
||||||
CustomDeviceTensorHandle* previous =
|
|
||||||
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
|
|
||||||
class Device* target_device;
|
|
||||||
if (device_ == kVariantDeviceNull) {
|
|
||||||
target_device = ctx_.HostCPU();
|
|
||||||
} else {
|
|
||||||
target_device = absl::get<class Device*>(device_);
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
|
|
||||||
previous, target_device->name(), &inputs_[i]));
|
|
||||||
previous->Unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
inputs_are_tensor_handles_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation {
|
|||||||
|
|
||||||
const string& DeviceName() const override { return device_name_; }
|
const string& DeviceName() const override { return device_name_; }
|
||||||
|
|
||||||
|
ImmediateExecutionContext* GetContext() const override { return &ctx_; }
|
||||||
|
|
||||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||||
return device_parsed_name_;
|
return device_parsed_name_;
|
||||||
}
|
}
|
||||||
@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation {
|
|||||||
|
|
||||||
Status AddInput(AbstractTensorHandle* input) override;
|
Status AddInput(AbstractTensorHandle* input) override;
|
||||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||||
|
Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
|
||||||
absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
|
absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
|
||||||
|
bool HasCustomDeviceInput() const override {
|
||||||
|
return custom_device_tensor_handles_count_ > 0;
|
||||||
|
}
|
||||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||||
int* num_retvals) override;
|
int* num_retvals) override;
|
||||||
const tensorflow::OpDef* OpDef() const override { return op_def_; };
|
const tensorflow::OpDef* OpDef() const override { return op_def_; };
|
||||||
@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation {
|
|||||||
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||||
const std::vector<DataType>& dtypes);
|
const std::vector<DataType>& dtypes);
|
||||||
|
|
||||||
// Replaces input tensors placed on custom devices with physical device
|
|
||||||
// equivalents. Used if an op is placed on a physical device but may have
|
|
||||||
// custom device inputs.
|
|
||||||
Status CopyOffCustomDeviceInputs();
|
|
||||||
|
|
||||||
tensorflow::EagerContext& ctx_;
|
tensorflow::EagerContext& ctx_;
|
||||||
const char* op_name_ = nullptr;
|
const char* op_name_ = nullptr;
|
||||||
AttrBuilder attrs_;
|
AttrBuilder attrs_;
|
||||||
const AttrTypeMap* attr_types_;
|
const AttrTypeMap* attr_types_;
|
||||||
|
|
||||||
// Toggled to indicate whether all inputs are known to be TensorHandles and
|
// The number of custom device TensorHandle inputs. These inputs need to be
|
||||||
// not another type (e.g. custom device tensor handles). Explicitly set to
|
// processed by CustomDeviceOpHandler first.
|
||||||
// false when custom device TensorHandles are added.
|
int custom_device_tensor_handles_count_ = 0;
|
||||||
bool inputs_are_tensor_handles_ = true;
|
|
||||||
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
|
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
|
||||||
|
|
||||||
// The last device name given to SetDeviceName.
|
// The last device name given to SetDeviceName.
|
||||||
|
@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
|
|
||||||
CustomDevice* custom_device;
|
|
||||||
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device);
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MaybePinSmallOpsToCpu(
|
Status MaybePinSmallOpsToCpu(
|
||||||
bool* result, StringPiece op_name,
|
bool* result, StringPiece op_name,
|
||||||
absl::Span<ImmediateExecutionTensorHandle* const> args,
|
absl::Span<ImmediateExecutionTensorHandle* const> args,
|
||||||
@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
|
||||||
// Ops are placed on a custom device if there's no other explicit requested
|
|
||||||
// placement and there is only one custom device in the op
|
|
||||||
// inputs.
|
|
||||||
//
|
|
||||||
// Resource-dtype inputs take precedence over non-resource inputs and explicit
|
|
||||||
// placements; this function pins ops with a resource-dtype custom device
|
|
||||||
// input to that custom device.
|
|
||||||
CustomDevice* first = nullptr;
|
|
||||||
if (!op.Inputs().empty()) {
|
|
||||||
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
|
|
||||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
|
||||||
// here.
|
|
||||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
|
||||||
const CustomDeviceTensorHandle* input =
|
|
||||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
|
||||||
CustomDevice* current = input->device();
|
|
||||||
if (first == nullptr) {
|
|
||||||
first = current;
|
|
||||||
} else if (first != current) {
|
|
||||||
return errors::InvalidArgument(absl::StrCat(
|
|
||||||
"If an operation has one of its inputs in a custom device, then "
|
|
||||||
"all inputs should be on that same custom device or another "
|
|
||||||
"physical device. Operation ",
|
|
||||||
op.Name(),
|
|
||||||
" has one input in custom "
|
|
||||||
"device ",
|
|
||||||
first->name(),
|
|
||||||
" and at least one input in a different custom device ",
|
|
||||||
current->name()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
|
|
||||||
if (generic_input->DataType() == DT_RESOURCE) {
|
|
||||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
|
||||||
const CustomDeviceTensorHandle* input =
|
|
||||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
|
||||||
// There's only one custom device input, and it's a resource input, so
|
|
||||||
// we'll force-place the op on to that custom device. As with physical
|
|
||||||
// devices, this overrides any explicit placement for the op.
|
|
||||||
*device = input->device();
|
|
||||||
return Status::OK();
|
|
||||||
} else {
|
|
||||||
// Don't set a custom device if there's a physical-device resource
|
|
||||||
// input.
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Since there are no resource-dtype inputs, we'll respect explicit placements
|
|
||||||
// before considering input-based placement.
|
|
||||||
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
|
|
||||||
*device = op.Device();
|
|
||||||
} else if (op.DeviceName().empty() && first != nullptr) {
|
|
||||||
// If there are non-resource inputs on a custom device we will default the
|
|
||||||
// op to that custom device, but not override an explicit op placement.
|
|
||||||
*device = first;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||||
|
|
||||||
|
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/stringpiece.h"
|
#include "tensorflow/core/platform/stringpiece.h"
|
||||||
@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name);
|
|||||||
|
|
||||||
bool IsFunction(StringPiece op_name);
|
bool IsFunction(StringPiece op_name);
|
||||||
|
|
||||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
|
|
||||||
|
|
||||||
// TODO(b/154234908): Unify placement logic.
|
// TODO(b/154234908): Unify placement logic.
|
||||||
// TODO(b/159647422): Add C++ unit tests for placement logic.
|
// TODO(b/159647422): Add C++ unit tests for placement logic.
|
||||||
|
|
||||||
@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu(
|
|||||||
// the device the resource is, regardless of anything else that has been
|
// the device the resource is, regardless of anything else that has been
|
||||||
// specified. This is identical to the graph mode behavior.
|
// specified. This is identical to the graph mode behavior.
|
||||||
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
|
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
|
||||||
|
|
||||||
// If all the inputs are on the same custom device, use that custom
|
|
||||||
// device. Otherwise, it is an error to have a custom device as an input.
|
|
||||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
|
|
||||||
|
|
||||||
} // namespace eager
|
} // namespace eager
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user