From 6df72f44ff4edefb5ff00530665d202d53a40f98 Mon Sep 17 00:00:00 2001 From: Xiao Yu <fishx@google.com> Date: Mon, 8 Feb 2021 17:21:58 -0800 Subject: [PATCH] Move custom device placement from eager/execute.cc to c_api.cc. Then it can be reused by TFRT. PiperOrigin-RevId: 356389689 Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa --- tensorflow/c/eager/BUILD | 2 + tensorflow/c/eager/c_api.cc | 31 ++-- .../c/eager/immediate_execution_context.h | 16 ++ .../c/eager/immediate_execution_operation.h | 11 ++ tensorflow/core/common_runtime/eager/BUILD | 25 +++ .../core/common_runtime/eager/context.cc | 29 +-- .../core/common_runtime/eager/context.h | 14 +- tensorflow/core/common_runtime/eager/core.cc | 27 +-- .../eager/custom_device_op_handler.cc | 167 ++++++++++++++++++ .../eager/custom_device_op_handler.h | 51 ++++++ .../eager/custom_device_test.cc | 30 ++-- .../common_runtime/eager/eager_operation.cc | 64 +++---- .../common_runtime/eager/eager_operation.h | 18 +- .../common_runtime/eager/placement_utils.cc | 70 -------- .../common_runtime/eager/placement_utils.h | 8 +- 15 files changed, 358 insertions(+), 205 deletions(-) create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.cc create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.h diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index cf135b03273..b7957e4fab2 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -73,9 +73,11 @@ tf_cuda_library( "//tensorflow/core/common_runtime/eager:context_distributed_manager", "//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:custom_device", + "//tensorflow/core/common_runtime/eager:custom_device_op_handler", "//tensorflow/core/common_runtime/eager:eager_executor", "//tensorflow/core/common_runtime/eager:execute", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/common_runtime/eager:placement_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 204db3078f4..5a31c434eaa 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -41,7 +41,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/attr_builder.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/placement_utils.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle( tensorflow::EagerContext* context = tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::CustomDevice* device = nullptr; - if (!context->FindCustomDeviceFromName(device_name, &device)) { + if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name, + &device)) { deallocator(data, arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { - if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { + if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName( + device_name, &custom_device)) { deallocator(data, len, deallocator_arg); status->status = tensorflow::errors::InvalidArgument(device_name, " unknown device."); @@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) { } TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) { - return tensorflow::wrap( - &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext())); + return tensorflow::wrap(tensorflow::unwrap(op)->GetContext()); } void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { @@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { - status->status = tensorflow::unwrap(op)->Execute( - absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>( - tensorflow::unwrap(retvals)), - *num_retvals), - num_retvals); + tensorflow::ImmediateExecutionOperation* unwrapped_op = + tensorflow::unwrap(op); + + status->status = + unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute( + unwrapped_op, + reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>( + retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, @@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device, } auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>( ctx, device, device_info, device_name); - tensorflow::EagerContext* context = - tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); - status->status = - context->RegisterCustomDevice(device_name, std::move(custom_device)); + status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice( + device_name, std::move(custom_device)); } } // extern "C" diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 4faf7f24f78..6ed01411d72 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -38,6 +38,9 @@ limitations under the License. namespace tensorflow { class EagerExecutor; +class EagerContext; +class CustomDevice; +class CustomDeviceOpHandler; // LINT.IfChange // Note: Keep in sync with exported copy of enum in eager/c_api.h. @@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext { // Return the ParsedName of Host CPU device. virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0; + virtual const string& HostCPUName() const = 0; // Configure soft device placement policy. virtual void SetAllowSoftPlacement(bool enable) = 0; @@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; } + //===--------------------------------------------------------------------===// + // Experimental Custom Device. + //===--------------------------------------------------------------------===// + virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0; + + // Register a custom device. It will return error is the device name is + // already registered. + // TODO(tfrt-devs): Remove this method. Let caller register it directly into + // CustomDeviceOpHandler. + virtual Status RegisterCustomDevice(const string& name, + std::unique_ptr<CustomDevice> device) = 0; + //===--------------------------------------------------------------------===// // Following are features in current TF Eager Runtime. // TODO(tfrt-devs): Figure out a way to deprecate following features after diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 85af5a706e1..a23177b56d5 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -33,6 +33,8 @@ struct TFE_Op; namespace tensorflow { +class ImmediateExecutionContext; + // Abstract interface to an operation. class ImmediateExecutionOperation : public AbstractOperation { public: @@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation { // Returns the inputs of this op. virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const = 0; + virtual Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) = 0; + + virtual ImmediateExecutionContext* GetContext() const = 0; + + // Following two methods are used to support custom device. + // Return true if the inputs contain custom device tensor handle. It means + // that the argument need to be handled by a custom device. + virtual bool HasCustomDeviceInput() const = 0; virtual const tensorflow::OpDef* OpDef() const = 0; diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index 8549c32417a..dddfe47de6b 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -87,6 +87,7 @@ tf_cuda_library( deps = [ ":eager_executor", ":kernel_and_device", + ":custom_device_op_handler", ":custom_device", "@com_google_absl//absl/container:flat_hash_map", "//tensorflow/c:tf_tensor_internal", @@ -140,6 +141,28 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "custom_device_op_handler", + srcs = ["custom_device_op_handler.cc"], + hdrs = ["custom_device_op_handler.h"], + visibility = ["//tensorflow:internal"], + deps = [ + ":custom_device", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:portable_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/core/lib/core:status", + ], + }), +) + tf_cc_test( name = "custom_device_test", srcs = ["custom_device_test.cc"], @@ -647,6 +670,7 @@ tf_cuda_library( ":custom_device", ":attr_builder", ":eager_operation", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", ] + select({ "//tensorflow:android": [ @@ -714,6 +738,7 @@ filegroup( "attr_builder.h", "context.h", "custom_device.h", + "custom_device_op_handler.h", "eager_executor.h", "eager_operation.h", "kernel_and_device.h", diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 7c20766a1ce..7fe6e00928c 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -522,7 +522,7 @@ EagerContext::~EagerContext() { // Custom devices may have obtained references to various context components // (executors, thread pool). It's safer to run their destructors early. - custom_devices_.clear(); + custom_device_op_handler_.Clear(); ClearCachesAndThreadExecutors(); std::unordered_map<std::thread::id, EagerExecutor*> executors_copy; @@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName( return errors::NotFound("Unknown composite device: ", device_name); } -bool EagerContext::FindCustomDeviceFromName(const string& device_name, - CustomDevice** dev) const { - auto dev_it = custom_devices_.find(device_name); - if (dev_it == custom_devices_.end()) { - return false; - } - *dev = dev_it->second.get(); - return true; -} - Status EagerContext::RegisterCustomDevice( const string& device_name, std::unique_ptr<CustomDevice> device) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || - !parsed.has_job || !parsed.has_replica || !parsed.has_task || - !parsed.has_type || !parsed.has_id) { - return errors::InvalidArgument( - device_name, - " could not be parsed as a device name. Use the full " - "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> " - "format."); - } Device* existing_physical_device = nullptr; if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) { return errors::AlreadyExists(device_name, " already registered as a physical device."); } - if (!custom_devices_.emplace(device_name, std::move(device)).second) { - return errors::AlreadyExists(device_name, - " already registered as a custom device."); - } - return Status::OK(); + return custom_device_op_handler_.RegisterCustomDevice(device_name, + std::move(device)); } Status EagerContext::FindOrCreateCompositeDevice( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 12a00c243d6..99cfb89e49d 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" #include "tensorflow/core/common_runtime/eager/eager_executor.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" #include "tensorflow/core/common_runtime/function.h" @@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { return HostCPU()->parsed_name(); } + const string& HostCPUName() const override { return HostCPU()->name(); } + GraphCollector* GetGraphCollector() { return &graph_collector_; } EagerExecutor& Executor() override; @@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { Status FindCompositeDeviceFromName(StringPiece device_name, CompositeDevice** device) const; - bool FindCustomDeviceFromName(const string& device_name, - CustomDevice** dev) const; - Status RegisterCustomDevice(const string& name, - std::unique_ptr<CustomDevice> device); + std::unique_ptr<CustomDevice> device) override; + + CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { + return custom_device_op_handler_; + }; // Find or create a composite device with the given `underlying_devices` and // `device_name` (if not empty). @@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { TF_GUARDED_BY(device_type_list_mu_); Rendezvous* rendezvous_; std::function<Rendezvous*(const int64)> rendezvous_creator_; - std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_; + CustomDeviceOpHandler custom_device_op_handler_; mutable mutex composite_devices_mu_; // Maps from the fingerprint of a set of device names to a virtual diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index 81b1e3594f2..905b1d94dad 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( *status = this->FindDeviceFromName(device_name, &device); if (!status->ok()) { tensorflow::CustomDevice* dev; - if (this->FindCustomDeviceFromName(device_name, &dev)) { + if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) { *status = dev->CopyTensorToDevice(handle, &result); if (status->ok()) { return result; @@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( return nullptr; } tensorflow::CustomDevice* dev; - if (this->FindCustomDeviceFromName(handle_device_name, &dev)) { + if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name, + &dev)) { *status = dev->CopyTensorFromDevice(handle, device_name, &result); if (status->ok()) { return result; @@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals, } } - // Decide to either run the operation on a custom device or copy off all of - // the custom device inputs. - VariantDevice maybe_custom_device = Device(); - if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) || - !inputs_are_tensor_handles_) { - // If the op wasn't placed on a custom device explicitly and there are no - // non-TensorHandle inputs, the op will definitely be placed on a physical - // device. Otherwise we need to check the inputs one by one. - TF_RETURN_IF_ERROR( - eager::MaybePinToCustomDevice(&maybe_custom_device, *this)); - if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) { - ImmediateExecutionTensorHandle** retval_array = - reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data()); - return absl::get<CustomDevice*>(maybe_custom_device) - ->Execute(this, retval_array, num_retvals); - } else { - TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs()); - } - } - // Run eager placement logic. - class Device* device = absl::get<class Device*>(maybe_custom_device); + class Device* device = absl::get<class Device*>(Device()); if (device == nullptr) { TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this)); } diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc new file mode 100644 index 00000000000..28c7a1161f6 --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc @@ -0,0 +1,167 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h" + +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { + +void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); } + +Status CustomDeviceOpHandler::RegisterCustomDevice( + const string& device_name, std::unique_ptr<CustomDevice> device) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || + !parsed.has_job || !parsed.has_replica || !parsed.has_task || + !parsed.has_type || !parsed.has_id) { + return errors::InvalidArgument( + device_name, + " could not be parsed as a device name. Use the full " + "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> " + "format."); + } + + if (!custom_devices_.emplace(device_name, std::move(device)).second) { + return errors::AlreadyExists(device_name, + " already registered as a custom device."); + } + return Status::OK(); +} + +bool CustomDeviceOpHandler::FindCustomDeviceFromName( + const string& name, CustomDevice** device) const { + auto dev_it = custom_devices_.find(name); + if (dev_it == custom_devices_.end()) { + return false; + } + *device = dev_it->second.get(); + return true; +} + +Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) { + tensorflow::CustomDevice* custom_device = nullptr; + + TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op)); + + if (custom_device != nullptr) { + return custom_device->Execute(op, retvals, num_retvals); + } + + // The op will be placed on physical device. However, it contains custom + // device tensor handles. The tensor handles will be copy to physical device + // first. + if (op->HasCustomDeviceInput()) { + auto inputs = op->GetInputs(); + for (int i = 0; i < inputs.size(); ++i) { + auto target_device = op->DeviceName(); + if (target_device.empty()) { + target_device = op->GetContext()->HostCPUName(); + } + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) { + tensorflow::CustomDeviceTensorHandle* previous = + tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>( + inputs[i]); + tensorflow::ImmediateExecutionTensorHandle* new_tesnor; + TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( + previous, target_device, &new_tesnor)); + Status s = op->SetInput(i, new_tesnor); + new_tesnor->Unref(); + TF_RETURN_IF_ERROR(s); + } + } + } + + return op->Execute( + absl::MakeSpan( + reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals), + *num_retvals), + num_retvals); +} + +Status CustomDeviceOpHandler::MaybePinToCustomDevice( + CustomDevice** device, const ImmediateExecutionOperation& op) const { + *device = nullptr; + if (!FindCustomDeviceFromName(op.DeviceName(), device) && + !op.HasCustomDeviceInput()) { + return Status::OK(); + } + + // Ops are placed on a custom device if there's no other explicit requested + // placement and there is only one custom device in the op + // inputs. + // + // Resource-dtype inputs take precedence over non-resource inputs and explicit + // placements; this function pins ops with a resource-dtype custom device + // input to that custom device. + CustomDevice* first = nullptr; + if (!op.GetInputs().empty()) { + for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { + // TODO(b/175427838): It would be nice to be able to use tensorflow::isa + // here. + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast<const CustomDeviceTensorHandle*>(generic_input); + CustomDevice* current = input->device(); + if (first == nullptr) { + first = current; + } else if (first != current) { + return errors::InvalidArgument(absl::StrCat( + "If an operation has one of its inputs in a custom device, then " + "all inputs should be on that same custom device or another " + "physical device. Operation ", + op.Name(), + " has one input in custom " + "device ", + first->name(), + " and at least one input in a different custom device ", + current->name())); + } + } + } + for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) { + if (generic_input->DataType() == DT_RESOURCE) { + if (CustomDeviceTensorHandle::classof(generic_input)) { + const CustomDeviceTensorHandle* input = + down_cast<const CustomDeviceTensorHandle*>(generic_input); + // There's only one custom device input, and it's a resource input, so + // we'll force-place the op on to that custom device. As with physical + // devices, this overrides any explicit placement for the op. + *device = input->device(); + return Status::OK(); + } else { + // Don't set a custom device if there's a physical-device resource + // input. + return Status::OK(); + } + } + } + } + // Since there are no resource-dtype inputs, we'll respect explicit placements + // before considering input-based placement. + if (*device == nullptr && op.DeviceName().empty() && first != nullptr) { + // If there are non-resource inputs on a custom device we will default the + // op to that custom device, but not override an explicit op placement. + *device = first; + return Status::OK(); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h new file mode 100644 index 00000000000..00ac5f324ba --- /dev/null +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/core/common_runtime/eager/custom_device.h" +#include "tensorflow/core/lib/core/status.h" +namespace tensorflow { + +// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT. +class CustomDeviceOpHandler { + public: + ~CustomDeviceOpHandler() {} + // Register a new custom device. + Status RegisterCustomDevice(const string& device_name, + std::unique_ptr<CustomDevice> device); + + // Find the custom device from given name. Return true if it finds one. + bool FindCustomDeviceFromName(const string& name, + CustomDevice** device) const; + + Status Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, int* num_retvals); + + // Determine whether to place an op on a custom device. This method is + // exposed as public for test only. + Status MaybePinToCustomDevice(CustomDevice** device, + const ImmediateExecutionOperation& op) const; + + void Clear(); + + private: + std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_ diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc index a642a816c76..cd7340e8463 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_test.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) { TF_ASSERT_OK(op.Reset("AssignVariableOp", "")); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - VariantDevice placed_device(kVariantDeviceNull); - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + CustomDevice* placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // MaybePinToCustomDevice has no opinion about ops which have physical // resource-dtype inputs. They'll get placed on physical devices. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str())); TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(custom_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Explicit placement onto a custom device also doesn't trigger custom device // placement if there's a physical device resource input. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK( op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Explicit placements typically override input-based placement onto a custom // device. - EXPECT_EQ(kVariantDeviceNull, placed_device); + EXPECT_EQ(nullptr, placed_device); op.Clear(); TF_ASSERT_OK(op.Reset("AssignVariableOp", "/job:localhost/replica:0/task:0/device:CPU:0")); TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get())); TF_ASSERT_OK(op.AddInput(physical_float_tensor.get())); - placed_device = kVariantDeviceNull; - TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op)); + placed_device = nullptr; + TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice( + &placed_device, op)); // Even with an explicit physical device placement, custom device resource // inputs place the op on the custom device. - ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device)); - EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device)); + ASSERT_NE(placed_device, nullptr); + EXPECT_EQ(&custom_device, placed_device); } } // namespace diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 883e9a8a8b0..de4a4495e87 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -36,7 +36,7 @@ void EagerOperation::Clear() { h->Unref(); } inputs_.clear(); - inputs_are_tensor_handles_ = true; + custom_device_tensor_handles_count_ = 0; ClearInferenceState(); } @@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) { down_cast<ImmediateExecutionTensorHandle*>(input); // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. if (CustomDeviceTensorHandle::classof(h)) { - inputs_are_tensor_handles_ = false; + custom_device_tensor_handles_count_++; } AddTensorHandle(h); return MaybeInferSingleInputAttrs(h); @@ -281,7 +281,7 @@ Status EagerOperation::AddInputList( // TODO(b/175427838): It would be nice to be able to use tensorflow::isa // here. if (CustomDeviceTensorHandle::classof(input)) { - inputs_are_tensor_handles_ = false; + custom_device_tensor_handles_count_++; } ImmediateExecutionTensorHandle* h = down_cast<ImmediateExecutionTensorHandle*>(input); @@ -290,6 +290,25 @@ Status EagerOperation::AddInputList( return InferInputListAttrs(inputs.size()); } +Status EagerOperation::SetInput(size_t index, + ImmediateExecutionTensorHandle* input) { + if (index >= inputs_.size()) { + return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index, + inputs_.size()); + } + auto* previous = inputs_[index]; + if (CustomDeviceTensorHandle::classof(previous)) { + custom_device_tensor_handles_count_--; + } + if (CustomDeviceTensorHandle::classof(input)) { + custom_device_tensor_handles_count_++; + } + input->Ref(); + inputs_[index] = input; + previous->Unref(); + return Status::OK(); +} + Status EagerOperation::Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, @@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { Status EagerOperation::TensorHandleInputs( const absl::InlinedVector<TensorHandle*, 4>** inputs) const { - if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>( &inputs_); return Status::OK(); @@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs( Status EagerOperation::MutableTensorHandleInputs( absl::InlinedVector<TensorHandle*, 4>** inputs) { - if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) { + if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_); return Status::OK(); @@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) { } last_set_device_name_ = name; device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_); - CustomDevice* custom_device; - if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) { - device_ = custom_device; - } else { - // Device placement for physical devices happens lazily in - // EagerExecute/EagerRemoteExecute, and can depend on the inputs. - device_ = kVariantDeviceNull; - } + device_ = kVariantDeviceNull; } return Status::OK(); } @@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) { attrs_.NumInputs(static_cast<int>(inputs_.size())); } -Status EagerOperation::CopyOffCustomDeviceInputs() { - if (absl::holds_alternative<CustomDevice*>(device_)) { - return errors::Internal( - "Trying to copy inputs to a custom device op off a custom device."); - } - for (int i = 0; i < inputs_.size(); ++i) { - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (CustomDeviceTensorHandle::classof(inputs_[i])) { - CustomDeviceTensorHandle* previous = - down_cast<CustomDeviceTensorHandle*>(inputs_[i]); - class Device* target_device; - if (device_ == kVariantDeviceNull) { - target_device = ctx_.HostCPU(); - } else { - target_device = absl::get<class Device*>(device_); - } - TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( - previous, target_device->name(), &inputs_[i])); - previous->Unref(); - } - } - inputs_are_tensor_handles_ = true; - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index e440a4a79dd..e1cb20b7575 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation { const string& DeviceName() const override { return device_name_; } + ImmediateExecutionContext* GetContext() const override { return &ctx_; } + const DeviceNameUtils::ParsedName& GetDeviceParsedName() const { return device_parsed_name_; } @@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation { Status AddInput(AbstractTensorHandle* input) override; Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override; + Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override; absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override; + bool HasCustomDeviceInput() const override { + return custom_device_tensor_handles_count_ > 0; + } Status Execute(absl::Span<AbstractTensorHandle*> retvals, int* num_retvals) override; const tensorflow::OpDef* OpDef() const override { return op_def_; }; @@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation { void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def, const std::vector<DataType>& dtypes); - // Replaces input tensors placed on custom devices with physical device - // equivalents. Used if an op is placed on a physical device but may have - // custom device inputs. - Status CopyOffCustomDeviceInputs(); - tensorflow::EagerContext& ctx_; const char* op_name_ = nullptr; AttrBuilder attrs_; const AttrTypeMap* attr_types_; - // Toggled to indicate whether all inputs are known to be TensorHandles and - // not another type (e.g. custom device tensor handles). Explicitly set to - // false when custom device TensorHandles are added. - bool inputs_are_tensor_handles_ = true; + // The number of custom device TensorHandle inputs. These inputs need to be + // processed by CustomDeviceOpHandler first. + int custom_device_tensor_handles_count_ = 0; absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_; // The last device name given to SetDeviceName. diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 77514d67e3a..3b9fa7bb2d1 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) { return false; } -bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) { - CustomDevice* custom_device; - return ctx.FindCustomDeviceFromName(string(device_name), &custom_device); -} - Status MaybePinSmallOpsToCpu( bool* result, StringPiece op_name, absl::Span<ImmediateExecutionTensorHandle* const> args, @@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) { return Status::OK(); } -Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) { - // Ops are placed on a custom device if there's no other explicit requested - // placement and there is only one custom device in the op - // inputs. - // - // Resource-dtype inputs take precedence over non-resource inputs and explicit - // placements; this function pins ops with a resource-dtype custom device - // input to that custom device. - CustomDevice* first = nullptr; - if (!op.Inputs().empty()) { - for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { - // TODO(b/175427838): It would be nice to be able to use tensorflow::isa - // here. - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast<const CustomDeviceTensorHandle*>(generic_input); - CustomDevice* current = input->device(); - if (first == nullptr) { - first = current; - } else if (first != current) { - return errors::InvalidArgument(absl::StrCat( - "If an operation has one of its inputs in a custom device, then " - "all inputs should be on that same custom device or another " - "physical device. Operation ", - op.Name(), - " has one input in custom " - "device ", - first->name(), - " and at least one input in a different custom device ", - current->name())); - } - } - } - for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) { - if (generic_input->DataType() == DT_RESOURCE) { - if (CustomDeviceTensorHandle::classof(generic_input)) { - const CustomDeviceTensorHandle* input = - down_cast<const CustomDeviceTensorHandle*>(generic_input); - // There's only one custom device input, and it's a resource input, so - // we'll force-place the op on to that custom device. As with physical - // devices, this overrides any explicit placement for the op. - *device = input->device(); - return Status::OK(); - } else { - // Don't set a custom device if there's a physical-device resource - // input. - return Status::OK(); - } - } - } - } - // Since there are no resource-dtype inputs, we'll respect explicit placements - // before considering input-based placement. - if (absl::holds_alternative<CustomDevice*>(op.Device())) { - *device = op.Device(); - } else if (op.DeviceName().empty() && first != nullptr) { - // If there are non-resource inputs on a custom device we will default the - // op to that custom device, but not override an explicit op placement. - *device = first; - return Status::OK(); - } - - return Status::OK(); -} - } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index 7676fe01b43..9435f9848d3 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_ +#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" @@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name); bool IsFunction(StringPiece op_name); -bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx); - // TODO(b/154234908): Unify placement logic. // TODO(b/159647422): Add C++ unit tests for placement logic. @@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu( // the device the resource is, regardless of anything else that has been // specified. This is identical to the graph mode behavior. Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); - -// If all the inputs are on the same custom device, use that custom -// device. Otherwise, it is an error to have a custom device as an input. -Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op); - } // namespace eager } // namespace tensorflow