Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT.

PiperOrigin-RevId: 356389689
Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa
This commit is contained in:
Xiao Yu 2021-02-08 17:21:58 -08:00 committed by TensorFlower Gardener
parent 063eb2465f
commit 6df72f44ff
15 changed files with 358 additions and 205 deletions

View File

@ -73,9 +73,11 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:context_distributed_manager",
"//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:custom_device", "//tensorflow/core/common_runtime/eager:custom_device",
"//tensorflow/core/common_runtime/eager:custom_device_op_handler",
"//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle", "//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:placement_utils",
"//tensorflow/core:core_cpu_internal", "//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",

View File

@ -41,7 +41,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/common_runtime/eager/execute.h" #include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
tensorflow::EagerContext* context = tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::CustomDevice* device = nullptr; tensorflow::CustomDevice* device = nullptr;
if (!context->FindCustomDeviceFromName(device_name, &device)) { if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
&device)) {
deallocator(data, arg); deallocator(data, arg);
status->status = status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device."); tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
status->status = context->FindDeviceFromName(device_name, &device); status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr; tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) { if (!status->status.ok()) {
if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
device_name, &custom_device)) {
deallocator(data, len, deallocator_arg); deallocator(data, len, deallocator_arg);
status->status = status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device."); tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
} }
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
return tensorflow::wrap( return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
} }
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) { TF_Status* status) {
status->status = tensorflow::unwrap(op)->Execute( tensorflow::ImmediateExecutionOperation* unwrapped_op =
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>( tensorflow::unwrap(op);
tensorflow::unwrap(retvals)),
*num_retvals), status->status =
num_retvals); unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
unwrapped_op,
reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
retvals),
num_retvals);
} }
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
} }
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>( auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
ctx, device, device_info, device_name); ctx, device, device_info, device_name);
tensorflow::EagerContext* context = status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); device_name, std::move(custom_device));
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
} }
} // extern "C" } // extern "C"

View File

@ -38,6 +38,9 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
class EagerExecutor; class EagerExecutor;
class EagerContext;
class CustomDevice;
class CustomDeviceOpHandler;
// LINT.IfChange // LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h. // Note: Keep in sync with exported copy of enum in eager/c_api.h.
@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext {
// Return the ParsedName of Host CPU device. // Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
virtual const string& HostCPUName() const = 0;
// Configure soft device placement policy. // Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0; virtual void SetAllowSoftPlacement(bool enable) = 0;
@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt; return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
} }
//===--------------------------------------------------------------------===//
// Experimental Custom Device.
//===--------------------------------------------------------------------===//
virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
// Register a custom device. It will return error is the device name is
// already registered.
// TODO(tfrt-devs): Remove this method. Let caller register it directly into
// CustomDeviceOpHandler.
virtual Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device) = 0;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Following are features in current TF Eager Runtime. // Following are features in current TF Eager Runtime.
// TODO(tfrt-devs): Figure out a way to deprecate following features after // TODO(tfrt-devs): Figure out a way to deprecate following features after

View File

@ -33,6 +33,8 @@ struct TFE_Op;
namespace tensorflow { namespace tensorflow {
class ImmediateExecutionContext;
// Abstract interface to an operation. // Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation { class ImmediateExecutionOperation : public AbstractOperation {
public: public:
@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Returns the inputs of this op. // Returns the inputs of this op.
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
const = 0; const = 0;
virtual Status SetInput(size_t index,
ImmediateExecutionTensorHandle* input) = 0;
virtual ImmediateExecutionContext* GetContext() const = 0;
// Following two methods are used to support custom device.
// Return true if the inputs contain custom device tensor handle. It means
// that the argument need to be handled by a custom device.
virtual bool HasCustomDeviceInput() const = 0;
virtual const tensorflow::OpDef* OpDef() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0;

View File

@ -87,6 +87,7 @@ tf_cuda_library(
deps = [ deps = [
":eager_executor", ":eager_executor",
":kernel_and_device", ":kernel_and_device",
":custom_device_op_handler",
":custom_device", ":custom_device",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
@ -140,6 +141,28 @@ tf_cuda_library(
}), }),
) )
tf_cuda_library(
name = "custom_device_op_handler",
srcs = ["custom_device_op_handler.cc"],
hdrs = ["custom_device_op_handler.h"],
visibility = ["//tensorflow:internal"],
deps = [
":custom_device",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/core/lib/core:status",
],
}),
)
tf_cc_test( tf_cc_test(
name = "custom_device_test", name = "custom_device_test",
srcs = ["custom_device_test.cc"], srcs = ["custom_device_test.cc"],
@ -647,6 +670,7 @@ tf_cuda_library(
":custom_device", ":custom_device",
":attr_builder", ":attr_builder",
":eager_operation", ":eager_operation",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
] + select({ ] + select({
"//tensorflow:android": [ "//tensorflow:android": [
@ -714,6 +738,7 @@ filegroup(
"attr_builder.h", "attr_builder.h",
"context.h", "context.h",
"custom_device.h", "custom_device.h",
"custom_device_op_handler.h",
"eager_executor.h", "eager_executor.h",
"eager_operation.h", "eager_operation.h",
"kernel_and_device.h", "kernel_and_device.h",

View File

@ -522,7 +522,7 @@ EagerContext::~EagerContext() {
// Custom devices may have obtained references to various context components // Custom devices may have obtained references to various context components
// (executors, thread pool). It's safer to run their destructors early. // (executors, thread pool). It's safer to run their destructors early.
custom_devices_.clear(); custom_device_op_handler_.Clear();
ClearCachesAndThreadExecutors(); ClearCachesAndThreadExecutors();
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy; std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName(
return errors::NotFound("Unknown composite device: ", device_name); return errors::NotFound("Unknown composite device: ", device_name);
} }
bool EagerContext::FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const {
auto dev_it = custom_devices_.find(device_name);
if (dev_it == custom_devices_.end()) {
return false;
}
*dev = dev_it->second.get();
return true;
}
Status EagerContext::RegisterCustomDevice( Status EagerContext::RegisterCustomDevice(
const string& device_name, std::unique_ptr<CustomDevice> device) { const string& device_name, std::unique_ptr<CustomDevice> device) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
!parsed.has_type || !parsed.has_id) {
return errors::InvalidArgument(
device_name,
" could not be parsed as a device name. Use the full "
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
"format.");
}
Device* existing_physical_device = nullptr; Device* existing_physical_device = nullptr;
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) { if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
return errors::AlreadyExists(device_name, return errors::AlreadyExists(device_name,
" already registered as a physical device."); " already registered as a physical device.");
} }
if (!custom_devices_.emplace(device_name, std::move(device)).second) { return custom_device_op_handler_.RegisterCustomDevice(device_name,
return errors::AlreadyExists(device_name, std::move(device));
" already registered as a custom device.");
}
return Status::OK();
} }
Status EagerContext::FindOrCreateCompositeDevice( Status EagerContext::FindOrCreateCompositeDevice(

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h" #include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
return HostCPU()->parsed_name(); return HostCPU()->parsed_name();
} }
const string& HostCPUName() const override { return HostCPU()->name(); }
GraphCollector* GetGraphCollector() { return &graph_collector_; } GraphCollector* GetGraphCollector() { return &graph_collector_; }
EagerExecutor& Executor() override; EagerExecutor& Executor() override;
@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status FindCompositeDeviceFromName(StringPiece device_name, Status FindCompositeDeviceFromName(StringPiece device_name,
CompositeDevice** device) const; CompositeDevice** device) const;
bool FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const;
Status RegisterCustomDevice(const string& name, Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device); std::unique_ptr<CustomDevice> device) override;
CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
return custom_device_op_handler_;
};
// Find or create a composite device with the given `underlying_devices` and // Find or create a composite device with the given `underlying_devices` and
// `device_name` (if not empty). // `device_name` (if not empty).
@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
TF_GUARDED_BY(device_type_list_mu_); TF_GUARDED_BY(device_type_list_mu_);
Rendezvous* rendezvous_; Rendezvous* rendezvous_;
std::function<Rendezvous*(const int64)> rendezvous_creator_; std::function<Rendezvous*(const int64)> rendezvous_creator_;
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_; CustomDeviceOpHandler custom_device_op_handler_;
mutable mutex composite_devices_mu_; mutable mutex composite_devices_mu_;
// Maps from the fingerprint of a set of device names to a virtual // Maps from the fingerprint of a set of device names to a virtual

View File

@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
*status = this->FindDeviceFromName(device_name, &device); *status = this->FindDeviceFromName(device_name, &device);
if (!status->ok()) { if (!status->ok()) {
tensorflow::CustomDevice* dev; tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(device_name, &dev)) { if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) {
*status = dev->CopyTensorToDevice(handle, &result); *status = dev->CopyTensorToDevice(handle, &result);
if (status->ok()) { if (status->ok()) {
return result; return result;
@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
return nullptr; return nullptr;
} }
tensorflow::CustomDevice* dev; tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) { if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name,
&dev)) {
*status = dev->CopyTensorFromDevice(handle, device_name, &result); *status = dev->CopyTensorFromDevice(handle, device_name, &result);
if (status->ok()) { if (status->ok()) {
return result; return result;
@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
} }
} }
// Decide to either run the operation on a custom device or copy off all of
// the custom device inputs.
VariantDevice maybe_custom_device = Device();
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
!inputs_are_tensor_handles_) {
// If the op wasn't placed on a custom device explicitly and there are no
// non-TensorHandle inputs, the op will definitely be placed on a physical
// device. Otherwise we need to check the inputs one by one.
TF_RETURN_IF_ERROR(
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
ImmediateExecutionTensorHandle** retval_array =
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
return absl::get<CustomDevice*>(maybe_custom_device)
->Execute(this, retval_array, num_retvals);
} else {
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
}
}
// Run eager placement logic. // Run eager placement logic.
class Device* device = absl::get<class Device*>(maybe_custom_device); class Device* device = absl::get<class Device*>(Device());
if (device == nullptr) { if (device == nullptr) {
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this)); TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
} }

View File

@ -0,0 +1,167 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
Status CustomDeviceOpHandler::RegisterCustomDevice(
const string& device_name, std::unique_ptr<CustomDevice> device) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
!parsed.has_type || !parsed.has_id) {
return errors::InvalidArgument(
device_name,
" could not be parsed as a device name. Use the full "
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
"format.");
}
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
return errors::AlreadyExists(device_name,
" already registered as a custom device.");
}
return Status::OK();
}
bool CustomDeviceOpHandler::FindCustomDeviceFromName(
const string& name, CustomDevice** device) const {
auto dev_it = custom_devices_.find(name);
if (dev_it == custom_devices_.end()) {
return false;
}
*device = dev_it->second.get();
return true;
}
Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals,
int* num_retvals) {
tensorflow::CustomDevice* custom_device = nullptr;
TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
if (custom_device != nullptr) {
return custom_device->Execute(op, retvals, num_retvals);
}
// The op will be placed on physical device. However, it contains custom
// device tensor handles. The tensor handles will be copy to physical device
// first.
if (op->HasCustomDeviceInput()) {
auto inputs = op->GetInputs();
for (int i = 0; i < inputs.size(); ++i) {
auto target_device = op->DeviceName();
if (target_device.empty()) {
target_device = op->GetContext()->HostCPUName();
}
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
tensorflow::CustomDeviceTensorHandle* previous =
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
inputs[i]);
tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
previous, target_device, &new_tesnor));
Status s = op->SetInput(i, new_tesnor);
new_tesnor->Unref();
TF_RETURN_IF_ERROR(s);
}
}
}
return op->Execute(
absl::MakeSpan(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
*num_retvals),
num_retvals);
}
Status CustomDeviceOpHandler::MaybePinToCustomDevice(
CustomDevice** device, const ImmediateExecutionOperation& op) const {
*device = nullptr;
if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
!op.HasCustomDeviceInput()) {
return Status::OK();
}
// Ops are placed on a custom device if there's no other explicit requested
// placement and there is only one custom device in the op
// inputs.
//
// Resource-dtype inputs take precedence over non-resource inputs and explicit
// placements; this function pins ops with a resource-dtype custom device
// input to that custom device.
CustomDevice* first = nullptr;
if (!op.GetInputs().empty()) {
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
CustomDevice* current = input->device();
if (first == nullptr) {
first = current;
} else if (first != current) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same custom device or another "
"physical device. Operation ",
op.Name(),
" has one input in custom "
"device ",
first->name(),
" and at least one input in a different custom device ",
current->name()));
}
}
}
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
if (generic_input->DataType() == DT_RESOURCE) {
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
// There's only one custom device input, and it's a resource input, so
// we'll force-place the op on to that custom device. As with physical
// devices, this overrides any explicit placement for the op.
*device = input->device();
return Status::OK();
} else {
// Don't set a custom device if there's a physical-device resource
// input.
return Status::OK();
}
}
}
}
// Since there are no resource-dtype inputs, we'll respect explicit placements
// before considering input-based placement.
if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
// If there are non-resource inputs on a custom device we will default the
// op to that custom device, but not override an explicit op placement.
*device = first;
return Status::OK();
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT.
class CustomDeviceOpHandler {
public:
~CustomDeviceOpHandler() {}
// Register a new custom device.
Status RegisterCustomDevice(const string& device_name,
std::unique_ptr<CustomDevice> device);
// Find the custom device from given name. Return true if it finds one.
bool FindCustomDeviceFromName(const string& name,
CustomDevice** device) const;
Status Execute(ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals, int* num_retvals);
// Determine whether to place an op on a custom device. This method is
// exposed as public for test only.
Status MaybePinToCustomDevice(CustomDevice** device,
const ImmediateExecutionOperation& op) const;
void Clear();
private:
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_

View File

@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) {
TF_ASSERT_OK(op.Reset("AssignVariableOp", "")); TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
VariantDevice placed_device(kVariantDeviceNull); CustomDevice* placed_device = nullptr;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// MaybePinToCustomDevice has no opinion about ops which have physical // MaybePinToCustomDevice has no opinion about ops which have physical
// resource-dtype inputs. They'll get placed on physical devices. // resource-dtype inputs. They'll get placed on physical devices.
EXPECT_EQ(kVariantDeviceNull, placed_device); EXPECT_EQ(nullptr, placed_device);
op.Clear(); op.Clear();
TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str())); TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
placed_device = kVariantDeviceNull; placed_device = nullptr;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Explicit placement onto a custom device also doesn't trigger custom device // Explicit placement onto a custom device also doesn't trigger custom device
// placement if there's a physical device resource input. // placement if there's a physical device resource input.
EXPECT_EQ(kVariantDeviceNull, placed_device); EXPECT_EQ(nullptr, placed_device);
op.Clear(); op.Clear();
TF_ASSERT_OK( TF_ASSERT_OK(
op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0")); op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
placed_device = kVariantDeviceNull; placed_device = nullptr;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Explicit placements typically override input-based placement onto a custom // Explicit placements typically override input-based placement onto a custom
// device. // device.
EXPECT_EQ(kVariantDeviceNull, placed_device); EXPECT_EQ(nullptr, placed_device);
op.Clear(); op.Clear();
TF_ASSERT_OK(op.Reset("AssignVariableOp", TF_ASSERT_OK(op.Reset("AssignVariableOp",
"/job:localhost/replica:0/task:0/device:CPU:0")); "/job:localhost/replica:0/task:0/device:CPU:0"));
TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
placed_device = kVariantDeviceNull; placed_device = nullptr;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Even with an explicit physical device placement, custom device resource // Even with an explicit physical device placement, custom device resource
// inputs place the op on the custom device. // inputs place the op on the custom device.
ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device)); ASSERT_NE(placed_device, nullptr);
EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device)); EXPECT_EQ(&custom_device, placed_device);
} }
} // namespace } // namespace

View File

@ -36,7 +36,7 @@ void EagerOperation::Clear() {
h->Unref(); h->Unref();
} }
inputs_.clear(); inputs_.clear();
inputs_are_tensor_handles_ = true; custom_device_tensor_handles_count_ = 0;
ClearInferenceState(); ClearInferenceState();
} }
@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
down_cast<ImmediateExecutionTensorHandle*>(input); down_cast<ImmediateExecutionTensorHandle*>(input);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (CustomDeviceTensorHandle::classof(h)) { if (CustomDeviceTensorHandle::classof(h)) {
inputs_are_tensor_handles_ = false; custom_device_tensor_handles_count_++;
} }
AddTensorHandle(h); AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h); return MaybeInferSingleInputAttrs(h);
@ -281,7 +281,7 @@ Status EagerOperation::AddInputList(
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here. // here.
if (CustomDeviceTensorHandle::classof(input)) { if (CustomDeviceTensorHandle::classof(input)) {
inputs_are_tensor_handles_ = false; custom_device_tensor_handles_count_++;
} }
ImmediateExecutionTensorHandle* h = ImmediateExecutionTensorHandle* h =
down_cast<ImmediateExecutionTensorHandle*>(input); down_cast<ImmediateExecutionTensorHandle*>(input);
@ -290,6 +290,25 @@ Status EagerOperation::AddInputList(
return InferInputListAttrs(inputs.size()); return InferInputListAttrs(inputs.size());
} }
Status EagerOperation::SetInput(size_t index,
ImmediateExecutionTensorHandle* input) {
if (index >= inputs_.size()) {
return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
inputs_.size());
}
auto* previous = inputs_[index];
if (CustomDeviceTensorHandle::classof(previous)) {
custom_device_tensor_handles_count_--;
}
if (CustomDeviceTensorHandle::classof(input)) {
custom_device_tensor_handles_count_++;
}
input->Ref();
inputs_[index] = input;
previous->Unref();
return Status::OK();
}
Status EagerOperation::Reset( Status EagerOperation::Reset(
const char* op, const char* device_name, bool remote, const char* op, const char* device_name, bool remote,
EagerExecutor* executor, EagerExecutor* executor,
@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
Status EagerOperation::TensorHandleInputs( Status EagerOperation::TensorHandleInputs(
const absl::InlinedVector<TensorHandle*, 4>** inputs) const { const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>( *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
&inputs_); &inputs_);
return Status::OK(); return Status::OK();
@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs(
Status EagerOperation::MutableTensorHandleInputs( Status EagerOperation::MutableTensorHandleInputs(
absl::InlinedVector<TensorHandle*, 4>** inputs) { absl::InlinedVector<TensorHandle*, 4>** inputs) {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
*inputs = *inputs =
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_); reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
return Status::OK(); return Status::OK();
@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) {
} }
last_set_device_name_ = name; last_set_device_name_ = name;
device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
CustomDevice* custom_device; device_ = kVariantDeviceNull;
if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) {
device_ = custom_device;
} else {
// Device placement for physical devices happens lazily in
// EagerExecute/EagerRemoteExecute, and can depend on the inputs.
device_ = kVariantDeviceNull;
}
} }
return Status::OK(); return Status::OK();
} }
@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
attrs_.NumInputs(static_cast<int>(inputs_.size())); attrs_.NumInputs(static_cast<int>(inputs_.size()));
} }
Status EagerOperation::CopyOffCustomDeviceInputs() {
if (absl::holds_alternative<CustomDevice*>(device_)) {
return errors::Internal(
"Trying to copy inputs to a custom device op off a custom device.");
}
for (int i = 0; i < inputs_.size(); ++i) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
CustomDeviceTensorHandle* previous =
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
class Device* target_device;
if (device_ == kVariantDeviceNull) {
target_device = ctx_.HostCPU();
} else {
target_device = absl::get<class Device*>(device_);
}
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
previous, target_device->name(), &inputs_[i]));
previous->Unref();
}
}
inputs_are_tensor_handles_ = true;
return Status::OK();
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation {
const string& DeviceName() const override { return device_name_; } const string& DeviceName() const override { return device_name_; }
ImmediateExecutionContext* GetContext() const override { return &ctx_; }
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
return device_parsed_name_; return device_parsed_name_;
} }
@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation {
Status AddInput(AbstractTensorHandle* input) override; Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override; Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override; absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
bool HasCustomDeviceInput() const override {
return custom_device_tensor_handles_count_ > 0;
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals, Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override; int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override { return op_def_; }; const tensorflow::OpDef* OpDef() const override { return op_def_; };
@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation {
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
const std::vector<DataType>& dtypes); const std::vector<DataType>& dtypes);
// Replaces input tensors placed on custom devices with physical device
// equivalents. Used if an op is placed on a physical device but may have
// custom device inputs.
Status CopyOffCustomDeviceInputs();
tensorflow::EagerContext& ctx_; tensorflow::EagerContext& ctx_;
const char* op_name_ = nullptr; const char* op_name_ = nullptr;
AttrBuilder attrs_; AttrBuilder attrs_;
const AttrTypeMap* attr_types_; const AttrTypeMap* attr_types_;
// Toggled to indicate whether all inputs are known to be TensorHandles and // The number of custom device TensorHandle inputs. These inputs need to be
// not another type (e.g. custom device tensor handles). Explicitly set to // processed by CustomDeviceOpHandler first.
// false when custom device TensorHandles are added. int custom_device_tensor_handles_count_ = 0;
bool inputs_are_tensor_handles_ = true;
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_; absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
// The last device name given to SetDeviceName. // The last device name given to SetDeviceName.

View File

@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) {
return false; return false;
} }
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
CustomDevice* custom_device;
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device);
}
Status MaybePinSmallOpsToCpu( Status MaybePinSmallOpsToCpu(
bool* result, StringPiece op_name, bool* result, StringPiece op_name,
absl::Span<ImmediateExecutionTensorHandle* const> args, absl::Span<ImmediateExecutionTensorHandle* const> args,
@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
return Status::OK(); return Status::OK();
} }
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
// Ops are placed on a custom device if there's no other explicit requested
// placement and there is only one custom device in the op
// inputs.
//
// Resource-dtype inputs take precedence over non-resource inputs and explicit
// placements; this function pins ops with a resource-dtype custom device
// input to that custom device.
CustomDevice* first = nullptr;
if (!op.Inputs().empty()) {
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
CustomDevice* current = input->device();
if (first == nullptr) {
first = current;
} else if (first != current) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same custom device or another "
"physical device. Operation ",
op.Name(),
" has one input in custom "
"device ",
first->name(),
" and at least one input in a different custom device ",
current->name()));
}
}
}
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
if (generic_input->DataType() == DT_RESOURCE) {
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
// There's only one custom device input, and it's a resource input, so
// we'll force-place the op on to that custom device. As with physical
// devices, this overrides any explicit placement for the op.
*device = input->device();
return Status::OK();
} else {
// Don't set a custom device if there's a physical-device resource
// input.
return Status::OK();
}
}
}
}
// Since there are no resource-dtype inputs, we'll respect explicit placements
// before considering input-based placement.
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
*device = op.Device();
} else if (op.DeviceName().empty() && first != nullptr) {
// If there are non-resource inputs on a custom device we will default the
// op to that custom device, but not override an explicit op placement.
*device = first;
return Status::OK();
}
return Status::OK();
}
} // namespace eager } // namespace eager
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/stringpiece.h"
@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name);
bool IsFunction(StringPiece op_name); bool IsFunction(StringPiece op_name);
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
// TODO(b/154234908): Unify placement logic. // TODO(b/154234908): Unify placement logic.
// TODO(b/159647422): Add C++ unit tests for placement logic. // TODO(b/159647422): Add C++ unit tests for placement logic.
@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu(
// the device the resource is, regardless of anything else that has been // the device the resource is, regardless of anything else that has been
// specified. This is identical to the graph mode behavior. // specified. This is identical to the graph mode behavior.
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
} // namespace eager } // namespace eager
} // namespace tensorflow } // namespace tensorflow