Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT.
PiperOrigin-RevId: 356389689 Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa
This commit is contained in:
parent
063eb2465f
commit
6df72f44ff
@ -73,9 +73,11 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
"//tensorflow/core/common_runtime/eager:custom_device",
|
||||
"//tensorflow/core/common_runtime/eager:custom_device_op_handler",
|
||||
"//tensorflow/core/common_runtime/eager:eager_executor",
|
||||
"//tensorflow/core/common_runtime/eager:execute",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/common_runtime/eager:placement_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -41,7 +41,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||
#include "tensorflow/core/common_runtime/eager/execute.h"
|
||||
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::CustomDevice* device = nullptr;
|
||||
if (!context->FindCustomDeviceFromName(device_name, &device)) {
|
||||
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
|
||||
&device)) {
|
||||
deallocator(data, arg);
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
||||
@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
status->status = context->FindDeviceFromName(device_name, &device);
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
if (!status->status.ok()) {
|
||||
if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
|
||||
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
|
||||
device_name, &custom_device)) {
|
||||
deallocator(data, len, deallocator_arg);
|
||||
status->status =
|
||||
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
|
||||
@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
|
||||
}
|
||||
|
||||
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
|
||||
return tensorflow::wrap(
|
||||
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
|
||||
return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
status->status = tensorflow::unwrap(op)->Execute(
|
||||
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
|
||||
tensorflow::unwrap(retvals)),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
tensorflow::ImmediateExecutionOperation* unwrapped_op =
|
||||
tensorflow::unwrap(op);
|
||||
|
||||
status->status =
|
||||
unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
|
||||
unwrapped_op,
|
||||
reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
|
||||
retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
|
||||
}
|
||||
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
|
||||
ctx, device, device_info, device_name);
|
||||
tensorflow::EagerContext* context =
|
||||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
status->status =
|
||||
context->RegisterCustomDevice(device_name, std::move(custom_device));
|
||||
status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
|
||||
device_name, std::move(custom_device));
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
@ -38,6 +38,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
class EagerExecutor;
|
||||
class EagerContext;
|
||||
class CustomDevice;
|
||||
class CustomDeviceOpHandler;
|
||||
|
||||
// LINT.IfChange
|
||||
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
|
||||
@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
|
||||
// Return the ParsedName of Host CPU device.
|
||||
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
|
||||
virtual const string& HostCPUName() const = 0;
|
||||
|
||||
// Configure soft device placement policy.
|
||||
virtual void SetAllowSoftPlacement(bool enable) = 0;
|
||||
@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext {
|
||||
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Experimental Custom Device.
|
||||
//===--------------------------------------------------------------------===//
|
||||
virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
|
||||
|
||||
// Register a custom device. It will return error is the device name is
|
||||
// already registered.
|
||||
// TODO(tfrt-devs): Remove this method. Let caller register it directly into
|
||||
// CustomDeviceOpHandler.
|
||||
virtual Status RegisterCustomDevice(const string& name,
|
||||
std::unique_ptr<CustomDevice> device) = 0;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Following are features in current TF Eager Runtime.
|
||||
// TODO(tfrt-devs): Figure out a way to deprecate following features after
|
||||
|
@ -33,6 +33,8 @@ struct TFE_Op;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ImmediateExecutionContext;
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class ImmediateExecutionOperation : public AbstractOperation {
|
||||
public:
|
||||
@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation {
|
||||
// Returns the inputs of this op.
|
||||
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
|
||||
const = 0;
|
||||
virtual Status SetInput(size_t index,
|
||||
ImmediateExecutionTensorHandle* input) = 0;
|
||||
|
||||
virtual ImmediateExecutionContext* GetContext() const = 0;
|
||||
|
||||
// Following two methods are used to support custom device.
|
||||
// Return true if the inputs contain custom device tensor handle. It means
|
||||
// that the argument need to be handled by a custom device.
|
||||
virtual bool HasCustomDeviceInput() const = 0;
|
||||
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
|
@ -87,6 +87,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":eager_executor",
|
||||
":kernel_and_device",
|
||||
":custom_device_op_handler",
|
||||
":custom_device",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
@ -140,6 +141,28 @@ tf_cuda_library(
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "custom_device_op_handler",
|
||||
srcs = ["custom_device_op_handler.cc"],
|
||||
hdrs = ["custom_device_op_handler.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":custom_device",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/core:portable_tensorflow_lib_lite",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/core/lib/core:status",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
srcs = ["custom_device_test.cc"],
|
||||
@ -647,6 +670,7 @@ tf_cuda_library(
|
||||
":custom_device",
|
||||
":attr_builder",
|
||||
":eager_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
] + select({
|
||||
"//tensorflow:android": [
|
||||
@ -714,6 +738,7 @@ filegroup(
|
||||
"attr_builder.h",
|
||||
"context.h",
|
||||
"custom_device.h",
|
||||
"custom_device_op_handler.h",
|
||||
"eager_executor.h",
|
||||
"eager_operation.h",
|
||||
"kernel_and_device.h",
|
||||
|
@ -522,7 +522,7 @@ EagerContext::~EagerContext() {
|
||||
|
||||
// Custom devices may have obtained references to various context components
|
||||
// (executors, thread pool). It's safer to run their destructors early.
|
||||
custom_devices_.clear();
|
||||
custom_device_op_handler_.Clear();
|
||||
|
||||
ClearCachesAndThreadExecutors();
|
||||
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
|
||||
@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName(
|
||||
return errors::NotFound("Unknown composite device: ", device_name);
|
||||
}
|
||||
|
||||
bool EagerContext::FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const {
|
||||
auto dev_it = custom_devices_.find(device_name);
|
||||
if (dev_it == custom_devices_.end()) {
|
||||
return false;
|
||||
}
|
||||
*dev = dev_it->second.get();
|
||||
return true;
|
||||
}
|
||||
|
||||
Status EagerContext::RegisterCustomDevice(
|
||||
const string& device_name, std::unique_ptr<CustomDevice> device) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
|
||||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
|
||||
!parsed.has_type || !parsed.has_id) {
|
||||
return errors::InvalidArgument(
|
||||
device_name,
|
||||
" could not be parsed as a device name. Use the full "
|
||||
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
|
||||
"format.");
|
||||
}
|
||||
Device* existing_physical_device = nullptr;
|
||||
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
|
||||
return errors::AlreadyExists(device_name,
|
||||
" already registered as a physical device.");
|
||||
}
|
||||
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
|
||||
return errors::AlreadyExists(device_name,
|
||||
" already registered as a custom device.");
|
||||
}
|
||||
return Status::OK();
|
||||
return custom_device_op_handler_.RegisterCustomDevice(device_name,
|
||||
std::move(device));
|
||||
}
|
||||
|
||||
Status EagerContext::FindOrCreateCompositeDevice(
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
|
||||
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
return HostCPU()->parsed_name();
|
||||
}
|
||||
|
||||
const string& HostCPUName() const override { return HostCPU()->name(); }
|
||||
|
||||
GraphCollector* GetGraphCollector() { return &graph_collector_; }
|
||||
|
||||
EagerExecutor& Executor() override;
|
||||
@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
Status FindCompositeDeviceFromName(StringPiece device_name,
|
||||
CompositeDevice** device) const;
|
||||
|
||||
bool FindCustomDeviceFromName(const string& device_name,
|
||||
CustomDevice** dev) const;
|
||||
|
||||
Status RegisterCustomDevice(const string& name,
|
||||
std::unique_ptr<CustomDevice> device);
|
||||
std::unique_ptr<CustomDevice> device) override;
|
||||
|
||||
CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
|
||||
return custom_device_op_handler_;
|
||||
};
|
||||
|
||||
// Find or create a composite device with the given `underlying_devices` and
|
||||
// `device_name` (if not empty).
|
||||
@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
|
||||
TF_GUARDED_BY(device_type_list_mu_);
|
||||
Rendezvous* rendezvous_;
|
||||
std::function<Rendezvous*(const int64)> rendezvous_creator_;
|
||||
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
||||
CustomDeviceOpHandler custom_device_op_handler_;
|
||||
|
||||
mutable mutex composite_devices_mu_;
|
||||
// Maps from the fingerprint of a set of device names to a virtual
|
||||
|
@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
*status = this->FindDeviceFromName(device_name, &device);
|
||||
if (!status->ok()) {
|
||||
tensorflow::CustomDevice* dev;
|
||||
if (this->FindCustomDeviceFromName(device_name, &dev)) {
|
||||
if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) {
|
||||
*status = dev->CopyTensorToDevice(handle, &result);
|
||||
if (status->ok()) {
|
||||
return result;
|
||||
@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::CustomDevice* dev;
|
||||
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
|
||||
if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name,
|
||||
&dev)) {
|
||||
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
|
||||
if (status->ok()) {
|
||||
return result;
|
||||
@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
}
|
||||
}
|
||||
|
||||
// Decide to either run the operation on a custom device or copy off all of
|
||||
// the custom device inputs.
|
||||
VariantDevice maybe_custom_device = Device();
|
||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
|
||||
!inputs_are_tensor_handles_) {
|
||||
// If the op wasn't placed on a custom device explicitly and there are no
|
||||
// non-TensorHandle inputs, the op will definitely be placed on a physical
|
||||
// device. Otherwise we need to check the inputs one by one.
|
||||
TF_RETURN_IF_ERROR(
|
||||
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
|
||||
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
|
||||
ImmediateExecutionTensorHandle** retval_array =
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
|
||||
return absl::get<CustomDevice*>(maybe_custom_device)
|
||||
->Execute(this, retval_array, num_retvals);
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
|
||||
}
|
||||
}
|
||||
|
||||
// Run eager placement logic.
|
||||
class Device* device = absl::get<class Device*>(maybe_custom_device);
|
||||
class Device* device = absl::get<class Device*>(Device());
|
||||
if (device == nullptr) {
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
|
||||
}
|
||||
|
167
tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
Normal file
167
tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
|
||||
|
||||
Status CustomDeviceOpHandler::RegisterCustomDevice(
|
||||
const string& device_name, std::unique_ptr<CustomDevice> device) {
|
||||
DeviceNameUtils::ParsedName parsed;
|
||||
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
|
||||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
|
||||
!parsed.has_type || !parsed.has_id) {
|
||||
return errors::InvalidArgument(
|
||||
device_name,
|
||||
" could not be parsed as a device name. Use the full "
|
||||
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
|
||||
"format.");
|
||||
}
|
||||
|
||||
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
|
||||
return errors::AlreadyExists(device_name,
|
||||
" already registered as a custom device.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool CustomDeviceOpHandler::FindCustomDeviceFromName(
|
||||
const string& name, CustomDevice** device) const {
|
||||
auto dev_it = custom_devices_.find(name);
|
||||
if (dev_it == custom_devices_.end()) {
|
||||
return false;
|
||||
}
|
||||
*device = dev_it->second.get();
|
||||
return true;
|
||||
}
|
||||
|
||||
Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
|
||||
ImmediateExecutionTensorHandle** retvals,
|
||||
int* num_retvals) {
|
||||
tensorflow::CustomDevice* custom_device = nullptr;
|
||||
|
||||
TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
|
||||
|
||||
if (custom_device != nullptr) {
|
||||
return custom_device->Execute(op, retvals, num_retvals);
|
||||
}
|
||||
|
||||
// The op will be placed on physical device. However, it contains custom
|
||||
// device tensor handles. The tensor handles will be copy to physical device
|
||||
// first.
|
||||
if (op->HasCustomDeviceInput()) {
|
||||
auto inputs = op->GetInputs();
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto target_device = op->DeviceName();
|
||||
if (target_device.empty()) {
|
||||
target_device = op->GetContext()->HostCPUName();
|
||||
}
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
|
||||
tensorflow::CustomDeviceTensorHandle* previous =
|
||||
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
|
||||
inputs[i]);
|
||||
tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
|
||||
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
|
||||
previous, target_device, &new_tesnor));
|
||||
Status s = op->SetInput(i, new_tesnor);
|
||||
new_tesnor->Unref();
|
||||
TF_RETURN_IF_ERROR(s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return op->Execute(
|
||||
absl::MakeSpan(
|
||||
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
|
||||
*num_retvals),
|
||||
num_retvals);
|
||||
}
|
||||
|
||||
Status CustomDeviceOpHandler::MaybePinToCustomDevice(
|
||||
CustomDevice** device, const ImmediateExecutionOperation& op) const {
|
||||
*device = nullptr;
|
||||
if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
|
||||
!op.HasCustomDeviceInput()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Ops are placed on a custom device if there's no other explicit requested
|
||||
// placement and there is only one custom device in the op
|
||||
// inputs.
|
||||
//
|
||||
// Resource-dtype inputs take precedence over non-resource inputs and explicit
|
||||
// placements; this function pins ops with a resource-dtype custom device
|
||||
// input to that custom device.
|
||||
CustomDevice* first = nullptr;
|
||||
if (!op.GetInputs().empty()) {
|
||||
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||
const CustomDeviceTensorHandle* input =
|
||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||
CustomDevice* current = input->device();
|
||||
if (first == nullptr) {
|
||||
first = current;
|
||||
} else if (first != current) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same custom device or another "
|
||||
"physical device. Operation ",
|
||||
op.Name(),
|
||||
" has one input in custom "
|
||||
"device ",
|
||||
first->name(),
|
||||
" and at least one input in a different custom device ",
|
||||
current->name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
|
||||
if (generic_input->DataType() == DT_RESOURCE) {
|
||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||
const CustomDeviceTensorHandle* input =
|
||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||
// There's only one custom device input, and it's a resource input, so
|
||||
// we'll force-place the op on to that custom device. As with physical
|
||||
// devices, this overrides any explicit placement for the op.
|
||||
*device = input->device();
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Don't set a custom device if there's a physical-device resource
|
||||
// input.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Since there are no resource-dtype inputs, we'll respect explicit placements
|
||||
// before considering input-based placement.
|
||||
if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
|
||||
// If there are non-resource inputs on a custom device we will default the
|
||||
// op to that custom device, but not override an explicit op placement.
|
||||
*device = first;
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,51 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/eager/custom_device.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT.
|
||||
class CustomDeviceOpHandler {
|
||||
public:
|
||||
~CustomDeviceOpHandler() {}
|
||||
// Register a new custom device.
|
||||
Status RegisterCustomDevice(const string& device_name,
|
||||
std::unique_ptr<CustomDevice> device);
|
||||
|
||||
// Find the custom device from given name. Return true if it finds one.
|
||||
bool FindCustomDeviceFromName(const string& name,
|
||||
CustomDevice** device) const;
|
||||
|
||||
Status Execute(ImmediateExecutionOperation* op,
|
||||
ImmediateExecutionTensorHandle** retvals, int* num_retvals);
|
||||
|
||||
// Determine whether to place an op on a custom device. This method is
|
||||
// exposed as public for test only.
|
||||
Status MaybePinToCustomDevice(CustomDevice** device,
|
||||
const ImmediateExecutionOperation& op) const;
|
||||
|
||||
void Clear();
|
||||
|
||||
private:
|
||||
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
|
||||
};
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
|
@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) {
|
||||
TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
|
||||
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
||||
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
||||
VariantDevice placed_device(kVariantDeviceNull);
|
||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
||||
CustomDevice* placed_device = nullptr;
|
||||
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||
&placed_device, op));
|
||||
// MaybePinToCustomDevice has no opinion about ops which have physical
|
||||
// resource-dtype inputs. They'll get placed on physical devices.
|
||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
||||
EXPECT_EQ(nullptr, placed_device);
|
||||
|
||||
op.Clear();
|
||||
TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
|
||||
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
|
||||
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
|
||||
placed_device = kVariantDeviceNull;
|
||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
||||
placed_device = nullptr;
|
||||
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||
&placed_device, op));
|
||||
// Explicit placement onto a custom device also doesn't trigger custom device
|
||||
// placement if there's a physical device resource input.
|
||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
||||
EXPECT_EQ(nullptr, placed_device);
|
||||
|
||||
op.Clear();
|
||||
TF_ASSERT_OK(
|
||||
op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
||||
placed_device = kVariantDeviceNull;
|
||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
||||
placed_device = nullptr;
|
||||
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||
&placed_device, op));
|
||||
// Explicit placements typically override input-based placement onto a custom
|
||||
// device.
|
||||
EXPECT_EQ(kVariantDeviceNull, placed_device);
|
||||
EXPECT_EQ(nullptr, placed_device);
|
||||
|
||||
op.Clear();
|
||||
TF_ASSERT_OK(op.Reset("AssignVariableOp",
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0"));
|
||||
TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
|
||||
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
|
||||
placed_device = kVariantDeviceNull;
|
||||
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
|
||||
placed_device = nullptr;
|
||||
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
|
||||
&placed_device, op));
|
||||
// Even with an explicit physical device placement, custom device resource
|
||||
// inputs place the op on the custom device.
|
||||
ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device));
|
||||
EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device));
|
||||
ASSERT_NE(placed_device, nullptr);
|
||||
EXPECT_EQ(&custom_device, placed_device);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -36,7 +36,7 @@ void EagerOperation::Clear() {
|
||||
h->Unref();
|
||||
}
|
||||
inputs_.clear();
|
||||
inputs_are_tensor_handles_ = true;
|
||||
custom_device_tensor_handles_count_ = 0;
|
||||
ClearInferenceState();
|
||||
}
|
||||
|
||||
@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
|
||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
|
||||
if (CustomDeviceTensorHandle::classof(h)) {
|
||||
inputs_are_tensor_handles_ = false;
|
||||
custom_device_tensor_handles_count_++;
|
||||
}
|
||||
AddTensorHandle(h);
|
||||
return MaybeInferSingleInputAttrs(h);
|
||||
@ -281,7 +281,7 @@ Status EagerOperation::AddInputList(
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(input)) {
|
||||
inputs_are_tensor_handles_ = false;
|
||||
custom_device_tensor_handles_count_++;
|
||||
}
|
||||
ImmediateExecutionTensorHandle* h =
|
||||
down_cast<ImmediateExecutionTensorHandle*>(input);
|
||||
@ -290,6 +290,25 @@ Status EagerOperation::AddInputList(
|
||||
return InferInputListAttrs(inputs.size());
|
||||
}
|
||||
|
||||
Status EagerOperation::SetInput(size_t index,
|
||||
ImmediateExecutionTensorHandle* input) {
|
||||
if (index >= inputs_.size()) {
|
||||
return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
|
||||
inputs_.size());
|
||||
}
|
||||
auto* previous = inputs_[index];
|
||||
if (CustomDeviceTensorHandle::classof(previous)) {
|
||||
custom_device_tensor_handles_count_--;
|
||||
}
|
||||
if (CustomDeviceTensorHandle::classof(input)) {
|
||||
custom_device_tensor_handles_count_++;
|
||||
}
|
||||
input->Ref();
|
||||
inputs_[index] = input;
|
||||
previous->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EagerOperation::Reset(
|
||||
const char* op, const char* device_name, bool remote,
|
||||
EagerExecutor* executor,
|
||||
@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
|
||||
|
||||
Status EagerOperation::TensorHandleInputs(
|
||||
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
|
||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
||||
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
|
||||
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
|
||||
&inputs_);
|
||||
return Status::OK();
|
||||
@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs(
|
||||
|
||||
Status EagerOperation::MutableTensorHandleInputs(
|
||||
absl::InlinedVector<TensorHandle*, 4>** inputs) {
|
||||
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
|
||||
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
|
||||
*inputs =
|
||||
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
|
||||
return Status::OK();
|
||||
@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) {
|
||||
}
|
||||
last_set_device_name_ = name;
|
||||
device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
|
||||
CustomDevice* custom_device;
|
||||
if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) {
|
||||
device_ = custom_device;
|
||||
} else {
|
||||
// Device placement for physical devices happens lazily in
|
||||
// EagerExecute/EagerRemoteExecute, and can depend on the inputs.
|
||||
device_ = kVariantDeviceNull;
|
||||
}
|
||||
device_ = kVariantDeviceNull;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
|
||||
attrs_.NumInputs(static_cast<int>(inputs_.size()));
|
||||
}
|
||||
|
||||
Status EagerOperation::CopyOffCustomDeviceInputs() {
|
||||
if (absl::holds_alternative<CustomDevice*>(device_)) {
|
||||
return errors::Internal(
|
||||
"Trying to copy inputs to a custom device op off a custom device.");
|
||||
}
|
||||
for (int i = 0; i < inputs_.size(); ++i) {
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
|
||||
CustomDeviceTensorHandle* previous =
|
||||
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
|
||||
class Device* target_device;
|
||||
if (device_ == kVariantDeviceNull) {
|
||||
target_device = ctx_.HostCPU();
|
||||
} else {
|
||||
target_device = absl::get<class Device*>(device_);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
|
||||
previous, target_device->name(), &inputs_[i]));
|
||||
previous->Unref();
|
||||
}
|
||||
}
|
||||
inputs_are_tensor_handles_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
|
||||
const string& DeviceName() const override { return device_name_; }
|
||||
|
||||
ImmediateExecutionContext* GetContext() const override { return &ctx_; }
|
||||
|
||||
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
|
||||
return device_parsed_name_;
|
||||
}
|
||||
@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
|
||||
Status AddInput(AbstractTensorHandle* input) override;
|
||||
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
|
||||
Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
|
||||
absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
|
||||
bool HasCustomDeviceInput() const override {
|
||||
return custom_device_tensor_handles_count_ > 0;
|
||||
}
|
||||
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
int* num_retvals) override;
|
||||
const tensorflow::OpDef* OpDef() const override { return op_def_; };
|
||||
@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation {
|
||||
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
|
||||
const std::vector<DataType>& dtypes);
|
||||
|
||||
// Replaces input tensors placed on custom devices with physical device
|
||||
// equivalents. Used if an op is placed on a physical device but may have
|
||||
// custom device inputs.
|
||||
Status CopyOffCustomDeviceInputs();
|
||||
|
||||
tensorflow::EagerContext& ctx_;
|
||||
const char* op_name_ = nullptr;
|
||||
AttrBuilder attrs_;
|
||||
const AttrTypeMap* attr_types_;
|
||||
|
||||
// Toggled to indicate whether all inputs are known to be TensorHandles and
|
||||
// not another type (e.g. custom device tensor handles). Explicitly set to
|
||||
// false when custom device TensorHandles are added.
|
||||
bool inputs_are_tensor_handles_ = true;
|
||||
// The number of custom device TensorHandle inputs. These inputs need to be
|
||||
// processed by CustomDeviceOpHandler first.
|
||||
int custom_device_tensor_handles_count_ = 0;
|
||||
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
|
||||
|
||||
// The last device name given to SetDeviceName.
|
||||
|
@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
|
||||
CustomDevice* custom_device;
|
||||
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device);
|
||||
}
|
||||
|
||||
Status MaybePinSmallOpsToCpu(
|
||||
bool* result, StringPiece op_name,
|
||||
absl::Span<ImmediateExecutionTensorHandle* const> args,
|
||||
@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
|
||||
// Ops are placed on a custom device if there's no other explicit requested
|
||||
// placement and there is only one custom device in the op
|
||||
// inputs.
|
||||
//
|
||||
// Resource-dtype inputs take precedence over non-resource inputs and explicit
|
||||
// placements; this function pins ops with a resource-dtype custom device
|
||||
// input to that custom device.
|
||||
CustomDevice* first = nullptr;
|
||||
if (!op.Inputs().empty()) {
|
||||
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
|
||||
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
|
||||
// here.
|
||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||
const CustomDeviceTensorHandle* input =
|
||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||
CustomDevice* current = input->device();
|
||||
if (first == nullptr) {
|
||||
first = current;
|
||||
} else if (first != current) {
|
||||
return errors::InvalidArgument(absl::StrCat(
|
||||
"If an operation has one of its inputs in a custom device, then "
|
||||
"all inputs should be on that same custom device or another "
|
||||
"physical device. Operation ",
|
||||
op.Name(),
|
||||
" has one input in custom "
|
||||
"device ",
|
||||
first->name(),
|
||||
" and at least one input in a different custom device ",
|
||||
current->name()));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
|
||||
if (generic_input->DataType() == DT_RESOURCE) {
|
||||
if (CustomDeviceTensorHandle::classof(generic_input)) {
|
||||
const CustomDeviceTensorHandle* input =
|
||||
down_cast<const CustomDeviceTensorHandle*>(generic_input);
|
||||
// There's only one custom device input, and it's a resource input, so
|
||||
// we'll force-place the op on to that custom device. As with physical
|
||||
// devices, this overrides any explicit placement for the op.
|
||||
*device = input->device();
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Don't set a custom device if there's a physical-device resource
|
||||
// input.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Since there are no resource-dtype inputs, we'll respect explicit placements
|
||||
// before considering input-based placement.
|
||||
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
|
||||
*device = op.Device();
|
||||
} else if (op.DeviceName().empty() && first != nullptr) {
|
||||
// If there are non-resource inputs on a custom device we will default the
|
||||
// op to that custom device, but not override an explicit op placement.
|
||||
*device = first;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name);
|
||||
|
||||
bool IsFunction(StringPiece op_name);
|
||||
|
||||
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
|
||||
|
||||
// TODO(b/154234908): Unify placement logic.
|
||||
// TODO(b/159647422): Add C++ unit tests for placement logic.
|
||||
|
||||
@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu(
|
||||
// the device the resource is, regardless of anything else that has been
|
||||
// specified. This is identical to the graph mode behavior.
|
||||
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
|
||||
|
||||
// If all the inputs are on the same custom device, use that custom
|
||||
// device. Otherwise, it is an error to have a custom device as an input.
|
||||
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user