Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT.

PiperOrigin-RevId: 356389689
Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa
This commit is contained in:
Xiao Yu 2021-02-08 17:21:58 -08:00 committed by TensorFlower Gardener
parent 063eb2465f
commit 6df72f44ff
15 changed files with 358 additions and 205 deletions

View File

@ -73,9 +73,11 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:context_distributed_manager",
"//tensorflow/core/common_runtime/eager:core",
"//tensorflow/core/common_runtime/eager:custom_device",
"//tensorflow/core/common_runtime/eager:custom_device_op_handler",
"//tensorflow/core/common_runtime/eager:eager_executor",
"//tensorflow/core/common_runtime/eager:execute",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/common_runtime/eager:placement_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",

View File

@ -41,7 +41,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/placement_utils.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
tensorflow::CustomDevice* device = nullptr;
if (!context->FindCustomDeviceFromName(device_name, &device)) {
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
&device)) {
deallocator(data, arg);
status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
status->status = context->FindDeviceFromName(device_name, &device);
tensorflow::CustomDevice* custom_device = nullptr;
if (!status->status.ok()) {
if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
device_name, &custom_device)) {
deallocator(data, len, deallocator_arg);
status->status =
tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
}
TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
return tensorflow::wrap(
&(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
status->status = tensorflow::unwrap(op)->Execute(
absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
tensorflow::unwrap(retvals)),
*num_retvals),
num_retvals);
tensorflow::ImmediateExecutionOperation* unwrapped_op =
tensorflow::unwrap(op);
status->status =
unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
unwrapped_op,
reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
retvals),
num_retvals);
}
TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
}
auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
ctx, device, device_info, device_name);
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
status->status =
context->RegisterCustomDevice(device_name, std::move(custom_device));
status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
device_name, std::move(custom_device));
}
} // extern "C"

View File

@ -38,6 +38,9 @@ limitations under the License.
namespace tensorflow {
class EagerExecutor;
class EagerContext;
class CustomDevice;
class CustomDeviceOpHandler;
// LINT.IfChange
// Note: Keep in sync with exported copy of enum in eager/c_api.h.
@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext {
// Return the ParsedName of Host CPU device.
virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
virtual const string& HostCPUName() const = 0;
// Configure soft device placement policy.
virtual void SetAllowSoftPlacement(bool enable) = 0;
@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
}
//===--------------------------------------------------------------------===//
// Experimental Custom Device.
//===--------------------------------------------------------------------===//
virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
// Register a custom device. It will return error is the device name is
// already registered.
// TODO(tfrt-devs): Remove this method. Let caller register it directly into
// CustomDeviceOpHandler.
virtual Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device) = 0;
//===--------------------------------------------------------------------===//
// Following are features in current TF Eager Runtime.
// TODO(tfrt-devs): Figure out a way to deprecate following features after

View File

@ -33,6 +33,8 @@ struct TFE_Op;
namespace tensorflow {
class ImmediateExecutionContext;
// Abstract interface to an operation.
class ImmediateExecutionOperation : public AbstractOperation {
public:
@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Returns the inputs of this op.
virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
const = 0;
virtual Status SetInput(size_t index,
ImmediateExecutionTensorHandle* input) = 0;
virtual ImmediateExecutionContext* GetContext() const = 0;
// Following two methods are used to support custom device.
// Return true if the inputs contain custom device tensor handle. It means
// that the argument need to be handled by a custom device.
virtual bool HasCustomDeviceInput() const = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;

View File

@ -87,6 +87,7 @@ tf_cuda_library(
deps = [
":eager_executor",
":kernel_and_device",
":custom_device_op_handler",
":custom_device",
"@com_google_absl//absl/container:flat_hash_map",
"//tensorflow/c:tf_tensor_internal",
@ -140,6 +141,28 @@ tf_cuda_library(
}),
)
tf_cuda_library(
name = "custom_device_op_handler",
srcs = ["custom_device_op_handler.cc"],
hdrs = ["custom_device_op_handler.h"],
visibility = ["//tensorflow:internal"],
deps = [
":custom_device",
] + select({
"//tensorflow:android": [
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/core/lib/core:status",
],
}),
)
tf_cc_test(
name = "custom_device_test",
srcs = ["custom_device_test.cc"],
@ -647,6 +670,7 @@ tf_cuda_library(
":custom_device",
":attr_builder",
":eager_operation",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
] + select({
"//tensorflow:android": [
@ -714,6 +738,7 @@ filegroup(
"attr_builder.h",
"context.h",
"custom_device.h",
"custom_device_op_handler.h",
"eager_executor.h",
"eager_operation.h",
"kernel_and_device.h",

View File

@ -522,7 +522,7 @@ EagerContext::~EagerContext() {
// Custom devices may have obtained references to various context components
// (executors, thread pool). It's safer to run their destructors early.
custom_devices_.clear();
custom_device_op_handler_.Clear();
ClearCachesAndThreadExecutors();
std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName(
return errors::NotFound("Unknown composite device: ", device_name);
}
bool EagerContext::FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const {
auto dev_it = custom_devices_.find(device_name);
if (dev_it == custom_devices_.end()) {
return false;
}
*dev = dev_it->second.get();
return true;
}
Status EagerContext::RegisterCustomDevice(
const string& device_name, std::unique_ptr<CustomDevice> device) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
!parsed.has_type || !parsed.has_id) {
return errors::InvalidArgument(
device_name,
" could not be parsed as a device name. Use the full "
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
"format.");
}
Device* existing_physical_device = nullptr;
if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
return errors::AlreadyExists(device_name,
" already registered as a physical device.");
}
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
return errors::AlreadyExists(device_name,
" already registered as a custom device.");
}
return Status::OK();
return custom_device_op_handler_.RegisterCustomDevice(device_name,
std::move(device));
}
Status EagerContext::FindOrCreateCompositeDevice(

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/common_runtime/eager/eager_executor.h"
#include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
#include "tensorflow/core/common_runtime/function.h"
@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
return HostCPU()->parsed_name();
}
const string& HostCPUName() const override { return HostCPU()->name(); }
GraphCollector* GetGraphCollector() { return &graph_collector_; }
EagerExecutor& Executor() override;
@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
Status FindCompositeDeviceFromName(StringPiece device_name,
CompositeDevice** device) const;
bool FindCustomDeviceFromName(const string& device_name,
CustomDevice** dev) const;
Status RegisterCustomDevice(const string& name,
std::unique_ptr<CustomDevice> device);
std::unique_ptr<CustomDevice> device) override;
CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
return custom_device_op_handler_;
};
// Find or create a composite device with the given `underlying_devices` and
// `device_name` (if not empty).
@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
TF_GUARDED_BY(device_type_list_mu_);
Rendezvous* rendezvous_;
std::function<Rendezvous*(const int64)> rendezvous_creator_;
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
CustomDeviceOpHandler custom_device_op_handler_;
mutable mutex composite_devices_mu_;
// Maps from the fingerprint of a set of device names to a virtual

View File

@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
*status = this->FindDeviceFromName(device_name, &device);
if (!status->ok()) {
tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(device_name, &dev)) {
if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) {
*status = dev->CopyTensorToDevice(handle, &result);
if (status->ok()) {
return result;
@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
return nullptr;
}
tensorflow::CustomDevice* dev;
if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name,
&dev)) {
*status = dev->CopyTensorFromDevice(handle, device_name, &result);
if (status->ok()) {
return result;
@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
}
}
// Decide to either run the operation on a custom device or copy off all of
// the custom device inputs.
VariantDevice maybe_custom_device = Device();
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
!inputs_are_tensor_handles_) {
// If the op wasn't placed on a custom device explicitly and there are no
// non-TensorHandle inputs, the op will definitely be placed on a physical
// device. Otherwise we need to check the inputs one by one.
TF_RETURN_IF_ERROR(
eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
ImmediateExecutionTensorHandle** retval_array =
reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
return absl::get<CustomDevice*>(maybe_custom_device)
->Execute(this, retval_array, num_retvals);
} else {
TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
}
}
// Run eager placement logic.
class Device* device = absl::get<class Device*>(maybe_custom_device);
class Device* device = absl::get<class Device*>(Device());
if (device == nullptr) {
TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
}

View File

@ -0,0 +1,167 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
Status CustomDeviceOpHandler::RegisterCustomDevice(
const string& device_name, std::unique_ptr<CustomDevice> device) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
!parsed.has_job || !parsed.has_replica || !parsed.has_task ||
!parsed.has_type || !parsed.has_id) {
return errors::InvalidArgument(
device_name,
" could not be parsed as a device name. Use the full "
"/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
"format.");
}
if (!custom_devices_.emplace(device_name, std::move(device)).second) {
return errors::AlreadyExists(device_name,
" already registered as a custom device.");
}
return Status::OK();
}
bool CustomDeviceOpHandler::FindCustomDeviceFromName(
const string& name, CustomDevice** device) const {
auto dev_it = custom_devices_.find(name);
if (dev_it == custom_devices_.end()) {
return false;
}
*device = dev_it->second.get();
return true;
}
Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals,
int* num_retvals) {
tensorflow::CustomDevice* custom_device = nullptr;
TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
if (custom_device != nullptr) {
return custom_device->Execute(op, retvals, num_retvals);
}
// The op will be placed on physical device. However, it contains custom
// device tensor handles. The tensor handles will be copy to physical device
// first.
if (op->HasCustomDeviceInput()) {
auto inputs = op->GetInputs();
for (int i = 0; i < inputs.size(); ++i) {
auto target_device = op->DeviceName();
if (target_device.empty()) {
target_device = op->GetContext()->HostCPUName();
}
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
tensorflow::CustomDeviceTensorHandle* previous =
tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
inputs[i]);
tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
previous, target_device, &new_tesnor));
Status s = op->SetInput(i, new_tesnor);
new_tesnor->Unref();
TF_RETURN_IF_ERROR(s);
}
}
}
return op->Execute(
absl::MakeSpan(
reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
*num_retvals),
num_retvals);
}
Status CustomDeviceOpHandler::MaybePinToCustomDevice(
CustomDevice** device, const ImmediateExecutionOperation& op) const {
*device = nullptr;
if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
!op.HasCustomDeviceInput()) {
return Status::OK();
}
// Ops are placed on a custom device if there's no other explicit requested
// placement and there is only one custom device in the op
// inputs.
//
// Resource-dtype inputs take precedence over non-resource inputs and explicit
// placements; this function pins ops with a resource-dtype custom device
// input to that custom device.
CustomDevice* first = nullptr;
if (!op.GetInputs().empty()) {
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
CustomDevice* current = input->device();
if (first == nullptr) {
first = current;
} else if (first != current) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same custom device or another "
"physical device. Operation ",
op.Name(),
" has one input in custom "
"device ",
first->name(),
" and at least one input in a different custom device ",
current->name()));
}
}
}
for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
if (generic_input->DataType() == DT_RESOURCE) {
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
// There's only one custom device input, and it's a resource input, so
// we'll force-place the op on to that custom device. As with physical
// devices, this overrides any explicit placement for the op.
*device = input->device();
return Status::OK();
} else {
// Don't set a custom device if there's a physical-device resource
// input.
return Status::OK();
}
}
}
}
// Since there are no resource-dtype inputs, we'll respect explicit placements
// before considering input-based placement.
if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
// If there are non-resource inputs on a custom device we will default the
// op to that custom device, but not override an explicit op placement.
*device = first;
return Status::OK();
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,51 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/common_runtime/eager/custom_device.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT.
class CustomDeviceOpHandler {
public:
~CustomDeviceOpHandler() {}
// Register a new custom device.
Status RegisterCustomDevice(const string& device_name,
std::unique_ptr<CustomDevice> device);
// Find the custom device from given name. Return true if it finds one.
bool FindCustomDeviceFromName(const string& name,
CustomDevice** device) const;
Status Execute(ImmediateExecutionOperation* op,
ImmediateExecutionTensorHandle** retvals, int* num_retvals);
// Determine whether to place an op on a custom device. This method is
// exposed as public for test only.
Status MaybePinToCustomDevice(CustomDevice** device,
const ImmediateExecutionOperation& op) const;
void Clear();
private:
std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_

View File

@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) {
TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
VariantDevice placed_device(kVariantDeviceNull);
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
CustomDevice* placed_device = nullptr;
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// MaybePinToCustomDevice has no opinion about ops which have physical
// resource-dtype inputs. They'll get placed on physical devices.
EXPECT_EQ(kVariantDeviceNull, placed_device);
EXPECT_EQ(nullptr, placed_device);
op.Clear();
TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
placed_device = kVariantDeviceNull;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
placed_device = nullptr;
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Explicit placement onto a custom device also doesn't trigger custom device
// placement if there's a physical device resource input.
EXPECT_EQ(kVariantDeviceNull, placed_device);
EXPECT_EQ(nullptr, placed_device);
op.Clear();
TF_ASSERT_OK(
op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
placed_device = kVariantDeviceNull;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
placed_device = nullptr;
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Explicit placements typically override input-based placement onto a custom
// device.
EXPECT_EQ(kVariantDeviceNull, placed_device);
EXPECT_EQ(nullptr, placed_device);
op.Clear();
TF_ASSERT_OK(op.Reset("AssignVariableOp",
"/job:localhost/replica:0/task:0/device:CPU:0"));
TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
placed_device = kVariantDeviceNull;
TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
placed_device = nullptr;
TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
&placed_device, op));
// Even with an explicit physical device placement, custom device resource
// inputs place the op on the custom device.
ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device));
EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device));
ASSERT_NE(placed_device, nullptr);
EXPECT_EQ(&custom_device, placed_device);
}
} // namespace

View File

@ -36,7 +36,7 @@ void EagerOperation::Clear() {
h->Unref();
}
inputs_.clear();
inputs_are_tensor_handles_ = true;
custom_device_tensor_handles_count_ = 0;
ClearInferenceState();
}
@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
down_cast<ImmediateExecutionTensorHandle*>(input);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if (CustomDeviceTensorHandle::classof(h)) {
inputs_are_tensor_handles_ = false;
custom_device_tensor_handles_count_++;
}
AddTensorHandle(h);
return MaybeInferSingleInputAttrs(h);
@ -281,7 +281,7 @@ Status EagerOperation::AddInputList(
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(input)) {
inputs_are_tensor_handles_ = false;
custom_device_tensor_handles_count_++;
}
ImmediateExecutionTensorHandle* h =
down_cast<ImmediateExecutionTensorHandle*>(input);
@ -290,6 +290,25 @@ Status EagerOperation::AddInputList(
return InferInputListAttrs(inputs.size());
}
Status EagerOperation::SetInput(size_t index,
ImmediateExecutionTensorHandle* input) {
if (index >= inputs_.size()) {
return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
inputs_.size());
}
auto* previous = inputs_[index];
if (CustomDeviceTensorHandle::classof(previous)) {
custom_device_tensor_handles_count_--;
}
if (CustomDeviceTensorHandle::classof(input)) {
custom_device_tensor_handles_count_++;
}
input->Ref();
inputs_[index] = input;
previous->Unref();
return Status::OK();
}
Status EagerOperation::Reset(
const char* op, const char* device_name, bool remote,
EagerExecutor* executor,
@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
Status EagerOperation::TensorHandleInputs(
const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
*inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
&inputs_);
return Status::OK();
@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs(
Status EagerOperation::MutableTensorHandleInputs(
absl::InlinedVector<TensorHandle*, 4>** inputs) {
if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
*inputs =
reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
return Status::OK();
@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) {
}
last_set_device_name_ = name;
device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
CustomDevice* custom_device;
if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) {
device_ = custom_device;
} else {
// Device placement for physical devices happens lazily in
// EagerExecute/EagerRemoteExecute, and can depend on the inputs.
device_ = kVariantDeviceNull;
}
device_ = kVariantDeviceNull;
}
return Status::OK();
}
@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
attrs_.NumInputs(static_cast<int>(inputs_.size()));
}
Status EagerOperation::CopyOffCustomDeviceInputs() {
if (absl::holds_alternative<CustomDevice*>(device_)) {
return errors::Internal(
"Trying to copy inputs to a custom device op off a custom device.");
}
for (int i = 0; i < inputs_.size(); ++i) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(inputs_[i])) {
CustomDeviceTensorHandle* previous =
down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
class Device* target_device;
if (device_ == kVariantDeviceNull) {
target_device = ctx_.HostCPU();
} else {
target_device = absl::get<class Device*>(device_);
}
TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
previous, target_device->name(), &inputs_[i]));
previous->Unref();
}
}
inputs_are_tensor_handles_ = true;
return Status::OK();
}
} // namespace tensorflow

View File

@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation {
const string& DeviceName() const override { return device_name_; }
ImmediateExecutionContext* GetContext() const override { return &ctx_; }
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
return device_parsed_name_;
}
@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation {
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
bool HasCustomDeviceInput() const override {
return custom_device_tensor_handles_count_ > 0;
}
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
const tensorflow::OpDef* OpDef() const override { return op_def_; };
@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation {
void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
const std::vector<DataType>& dtypes);
// Replaces input tensors placed on custom devices with physical device
// equivalents. Used if an op is placed on a physical device but may have
// custom device inputs.
Status CopyOffCustomDeviceInputs();
tensorflow::EagerContext& ctx_;
const char* op_name_ = nullptr;
AttrBuilder attrs_;
const AttrTypeMap* attr_types_;
// Toggled to indicate whether all inputs are known to be TensorHandles and
// not another type (e.g. custom device tensor handles). Explicitly set to
// false when custom device TensorHandles are added.
bool inputs_are_tensor_handles_ = true;
// The number of custom device TensorHandle inputs. These inputs need to be
// processed by CustomDeviceOpHandler first.
int custom_device_tensor_handles_count_ = 0;
absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
// The last device name given to SetDeviceName.

View File

@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) {
return false;
}
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
CustomDevice* custom_device;
return ctx.FindCustomDeviceFromName(string(device_name), &custom_device);
}
Status MaybePinSmallOpsToCpu(
bool* result, StringPiece op_name,
absl::Span<ImmediateExecutionTensorHandle* const> args,
@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
return Status::OK();
}
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
// Ops are placed on a custom device if there's no other explicit requested
// placement and there is only one custom device in the op
// inputs.
//
// Resource-dtype inputs take precedence over non-resource inputs and explicit
// placements; this function pins ops with a resource-dtype custom device
// input to that custom device.
CustomDevice* first = nullptr;
if (!op.Inputs().empty()) {
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa
// here.
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
CustomDevice* current = input->device();
if (first == nullptr) {
first = current;
} else if (first != current) {
return errors::InvalidArgument(absl::StrCat(
"If an operation has one of its inputs in a custom device, then "
"all inputs should be on that same custom device or another "
"physical device. Operation ",
op.Name(),
" has one input in custom "
"device ",
first->name(),
" and at least one input in a different custom device ",
current->name()));
}
}
}
for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
if (generic_input->DataType() == DT_RESOURCE) {
if (CustomDeviceTensorHandle::classof(generic_input)) {
const CustomDeviceTensorHandle* input =
down_cast<const CustomDeviceTensorHandle*>(generic_input);
// There's only one custom device input, and it's a resource input, so
// we'll force-place the op on to that custom device. As with physical
// devices, this overrides any explicit placement for the op.
*device = input->device();
return Status::OK();
} else {
// Don't set a custom device if there's a physical-device resource
// input.
return Status::OK();
}
}
}
}
// Since there are no resource-dtype inputs, we'll respect explicit placements
// before considering input-based placement.
if (absl::holds_alternative<CustomDevice*>(op.Device())) {
*device = op.Device();
} else if (op.DeviceName().empty() && first != nullptr) {
// If there are non-resource inputs on a custom device we will default the
// op to that custom device, but not override an explicit op placement.
*device = first;
return Status::OK();
}
return Status::OK();
}
} // namespace eager
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name);
bool IsFunction(StringPiece op_name);
bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
// TODO(b/154234908): Unify placement logic.
// TODO(b/159647422): Add C++ unit tests for placement logic.
@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu(
// the device the resource is, regardless of anything else that has been
// specified. This is identical to the graph mode behavior.
Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
// If all the inputs are on the same custom device, use that custom
// device. Otherwise, it is an error to have a custom device as an input.
Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
} // namespace eager
} // namespace tensorflow