From 6df72f44ff4edefb5ff00530665d202d53a40f98 Mon Sep 17 00:00:00 2001
From: Xiao Yu <fishx@google.com>
Date: Mon, 8 Feb 2021 17:21:58 -0800
Subject: [PATCH] Move custom device placement from eager/execute.cc to
 c_api.cc. Then it can be reused by TFRT.

PiperOrigin-RevId: 356389689
Change-Id: Ibd3df16e2a4bd0607389edbd42d01cd04d24d0aa
---
 tensorflow/c/eager/BUILD                      |   2 +
 tensorflow/c/eager/c_api.cc                   |  31 ++--
 .../c/eager/immediate_execution_context.h     |  16 ++
 .../c/eager/immediate_execution_operation.h   |  11 ++
 tensorflow/core/common_runtime/eager/BUILD    |  25 +++
 .../core/common_runtime/eager/context.cc      |  29 +--
 .../core/common_runtime/eager/context.h       |  14 +-
 tensorflow/core/common_runtime/eager/core.cc  |  27 +--
 .../eager/custom_device_op_handler.cc         | 167 ++++++++++++++++++
 .../eager/custom_device_op_handler.h          |  51 ++++++
 .../eager/custom_device_test.cc               |  30 ++--
 .../common_runtime/eager/eager_operation.cc   |  64 +++----
 .../common_runtime/eager/eager_operation.h    |  18 +-
 .../common_runtime/eager/placement_utils.cc   |  70 --------
 .../common_runtime/eager/placement_utils.h    |   8 +-
 15 files changed, 358 insertions(+), 205 deletions(-)
 create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
 create mode 100644 tensorflow/core/common_runtime/eager/custom_device_op_handler.h

diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index cf135b03273..b7957e4fab2 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -73,9 +73,11 @@ tf_cuda_library(
             "//tensorflow/core/common_runtime/eager:context_distributed_manager",
             "//tensorflow/core/common_runtime/eager:core",
             "//tensorflow/core/common_runtime/eager:custom_device",
+            "//tensorflow/core/common_runtime/eager:custom_device_op_handler",
             "//tensorflow/core/common_runtime/eager:eager_executor",
             "//tensorflow/core/common_runtime/eager:execute",
             "//tensorflow/core/common_runtime/eager:tensor_handle",
+            "//tensorflow/core/common_runtime/eager:placement_utils",
             "//tensorflow/core:core_cpu_internal",
             "//tensorflow/core:framework",
             "//tensorflow/core:framework_internal",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 204db3078f4..5a31c434eaa 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -41,7 +41,9 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/common_runtime/eager/custom_device.h"
+#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
 #include "tensorflow/core/common_runtime/eager/execute.h"
+#include "tensorflow/core/common_runtime/eager/placement_utils.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
@@ -532,7 +534,8 @@ TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
   tensorflow::EagerContext* context =
       tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
   tensorflow::CustomDevice* device = nullptr;
-  if (!context->FindCustomDeviceFromName(device_name, &device)) {
+  if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(device_name,
+                                                                    &device)) {
     deallocator(data, arg);
     status->status =
         tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@@ -562,7 +565,8 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
   status->status = context->FindDeviceFromName(device_name, &device);
   tensorflow::CustomDevice* custom_device = nullptr;
   if (!status->status.ok()) {
-    if (!context->FindCustomDeviceFromName(device_name, &custom_device)) {
+    if (!context->GetCustomDeviceOpHandler().FindCustomDeviceFromName(
+            device_name, &custom_device)) {
       deallocator(data, len, deallocator_arg);
       status->status =
           tensorflow::errors::InvalidArgument(device_name, " unknown device.");
@@ -654,8 +658,7 @@ const char* TFE_OpGetName(const TFE_Op* op, TF_Status* status) {
 }
 
 TFE_Context* TFE_OpGetContext(const TFE_Op* op, TF_Status* status) {
-  return tensorflow::wrap(
-      &(OperationFromInterface(tensorflow::unwrap(op))->EagerContext()));
+  return tensorflow::wrap(tensorflow::unwrap(op)->GetContext());
 }
 
 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
@@ -889,11 +892,15 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
 
 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
                  TF_Status* status) {
-  status->status = tensorflow::unwrap(op)->Execute(
-      absl::MakeSpan(reinterpret_cast<tensorflow::AbstractTensorHandle**>(
-                         tensorflow::unwrap(retvals)),
-                     *num_retvals),
-      num_retvals);
+  tensorflow::ImmediateExecutionOperation* unwrapped_op =
+      tensorflow::unwrap(op);
+
+  status->status =
+      unwrapped_op->GetContext()->GetCustomDeviceOpHandler().Execute(
+          unwrapped_op,
+          reinterpret_cast<tensorflow::ImmediateExecutionTensorHandle**>(
+              retvals),
+          num_retvals);
 }
 
 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
@@ -1150,10 +1157,8 @@ void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
   }
   auto custom_device = std::make_unique<tensorflow::CustomDeviceAPI>(
       ctx, device, device_info, device_name);
-  tensorflow::EagerContext* context =
-      tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
-  status->status =
-      context->RegisterCustomDevice(device_name, std::move(custom_device));
+  status->status = tensorflow::unwrap(ctx)->RegisterCustomDevice(
+      device_name, std::move(custom_device));
 }
 
 }  // extern "C"
diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h
index 4faf7f24f78..6ed01411d72 100644
--- a/tensorflow/c/eager/immediate_execution_context.h
+++ b/tensorflow/c/eager/immediate_execution_context.h
@@ -38,6 +38,9 @@ limitations under the License.
 
 namespace tensorflow {
 class EagerExecutor;
+class EagerContext;
+class CustomDevice;
+class CustomDeviceOpHandler;
 
 // LINT.IfChange
 // Note: Keep in sync with exported copy of enum in eager/c_api.h.
@@ -122,6 +125,7 @@ class ImmediateExecutionContext : public AbstractContext {
 
   // Return the ParsedName of Host CPU device.
   virtual const DeviceNameUtils::ParsedName& HostCPUParsedName() const = 0;
+  virtual const string& HostCPUName() const = 0;
 
   // Configure soft device placement policy.
   virtual void SetAllowSoftPlacement(bool enable) = 0;
@@ -147,6 +151,18 @@ class ImmediateExecutionContext : public AbstractContext {
     return ptr->getKind() == kEager || ptr->getKind() == kTfrt;
   }
 
+  //===--------------------------------------------------------------------===//
+  // Experimental Custom Device.
+  //===--------------------------------------------------------------------===//
+  virtual CustomDeviceOpHandler& GetCustomDeviceOpHandler() = 0;
+
+  // Register a custom device. It will return error is the device name is
+  // already registered.
+  // TODO(tfrt-devs): Remove this method. Let caller register it directly into
+  // CustomDeviceOpHandler.
+  virtual Status RegisterCustomDevice(const string& name,
+                                      std::unique_ptr<CustomDevice> device) = 0;
+
   //===--------------------------------------------------------------------===//
   // Following are features in current TF Eager Runtime.
   // TODO(tfrt-devs): Figure out a way to deprecate following features after
diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h
index 85af5a706e1..a23177b56d5 100644
--- a/tensorflow/c/eager/immediate_execution_operation.h
+++ b/tensorflow/c/eager/immediate_execution_operation.h
@@ -33,6 +33,8 @@ struct TFE_Op;
 
 namespace tensorflow {
 
+class ImmediateExecutionContext;
+
 // Abstract interface to an operation.
 class ImmediateExecutionOperation : public AbstractOperation {
  public:
@@ -41,6 +43,15 @@ class ImmediateExecutionOperation : public AbstractOperation {
   // Returns the inputs of this op.
   virtual absl::Span<ImmediateExecutionTensorHandle* const> GetInputs()
       const = 0;
+  virtual Status SetInput(size_t index,
+                          ImmediateExecutionTensorHandle* input) = 0;
+
+  virtual ImmediateExecutionContext* GetContext() const = 0;
+
+  // Following two methods are used to support custom device.
+  // Return true if the inputs contain custom device tensor handle. It means
+  // that the argument need to be handled by a custom device.
+  virtual bool HasCustomDeviceInput() const = 0;
 
   virtual const tensorflow::OpDef* OpDef() const = 0;
 
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 8549c32417a..dddfe47de6b 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -87,6 +87,7 @@ tf_cuda_library(
     deps = [
         ":eager_executor",
         ":kernel_and_device",
+        ":custom_device_op_handler",
         ":custom_device",
         "@com_google_absl//absl/container:flat_hash_map",
         "//tensorflow/c:tf_tensor_internal",
@@ -140,6 +141,28 @@ tf_cuda_library(
     }),
 )
 
+tf_cuda_library(
+    name = "custom_device_op_handler",
+    srcs = ["custom_device_op_handler.cc"],
+    hdrs = ["custom_device_op_handler.h"],
+    visibility = ["//tensorflow:internal"],
+    deps = [
+        ":custom_device",
+    ] + select({
+        "//tensorflow:android": [
+            "//tensorflow/core:portable_tensorflow_lib_lite",
+        ],
+        "//conditions:default": [
+            "//tensorflow/core:framework",
+            "//tensorflow/core:lib",
+            "//tensorflow/c/eager:immediate_execution_context",
+            "//tensorflow/c/eager:immediate_execution_tensor_handle",
+            "//tensorflow/c/eager:immediate_execution_operation",
+            "//tensorflow/core/lib/core:status",
+        ],
+    }),
+)
+
 tf_cc_test(
     name = "custom_device_test",
     srcs = ["custom_device_test.cc"],
@@ -647,6 +670,7 @@ tf_cuda_library(
         ":custom_device",
         ":attr_builder",
         ":eager_operation",
+        "//tensorflow/c/eager:immediate_execution_operation",
         "//tensorflow/c/eager:immediate_execution_tensor_handle",
     ] + select({
         "//tensorflow:android": [
@@ -714,6 +738,7 @@ filegroup(
         "attr_builder.h",
         "context.h",
         "custom_device.h",
+        "custom_device_op_handler.h",
         "eager_executor.h",
         "eager_operation.h",
         "kernel_and_device.h",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 7c20766a1ce..7fe6e00928c 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -522,7 +522,7 @@ EagerContext::~EagerContext() {
 
   // Custom devices may have obtained references to various context components
   // (executors, thread pool). It's safer to run their destructors early.
-  custom_devices_.clear();
+  custom_device_op_handler_.Clear();
 
   ClearCachesAndThreadExecutors();
   std::unordered_map<std::thread::id, EagerExecutor*> executors_copy;
@@ -904,38 +904,15 @@ Status EagerContext::FindCompositeDeviceFromName(
   return errors::NotFound("Unknown composite device: ", device_name);
 }
 
-bool EagerContext::FindCustomDeviceFromName(const string& device_name,
-                                            CustomDevice** dev) const {
-  auto dev_it = custom_devices_.find(device_name);
-  if (dev_it == custom_devices_.end()) {
-    return false;
-  }
-  *dev = dev_it->second.get();
-  return true;
-}
-
 Status EagerContext::RegisterCustomDevice(
     const string& device_name, std::unique_ptr<CustomDevice> device) {
-  DeviceNameUtils::ParsedName parsed;
-  if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
-      !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
-      !parsed.has_type || !parsed.has_id) {
-    return errors::InvalidArgument(
-        device_name,
-        " could not be parsed as a device name. Use the full "
-        "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
-        "format.");
-  }
   Device* existing_physical_device = nullptr;
   if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) {
     return errors::AlreadyExists(device_name,
                                  " already registered as a physical device.");
   }
-  if (!custom_devices_.emplace(device_name, std::move(device)).second) {
-    return errors::AlreadyExists(device_name,
-                                 " already registered as a custom device.");
-  }
-  return Status::OK();
+  return custom_device_op_handler_.RegisterCustomDevice(device_name,
+                                                        std::move(device));
 }
 
 Status EagerContext::FindOrCreateCompositeDevice(
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 12a00c243d6..99cfb89e49d 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/eager/custom_device.h"
+#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
 #include "tensorflow/core/common_runtime/eager/eager_executor.h"
 #include "tensorflow/core/common_runtime/eager/kernel_and_device.h"
 #include "tensorflow/core/common_runtime/function.h"
@@ -204,6 +205,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
     return HostCPU()->parsed_name();
   }
 
+  const string& HostCPUName() const override { return HostCPU()->name(); }
+
   GraphCollector* GetGraphCollector() { return &graph_collector_; }
 
   EagerExecutor& Executor() override;
@@ -473,11 +476,12 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
   Status FindCompositeDeviceFromName(StringPiece device_name,
                                      CompositeDevice** device) const;
 
-  bool FindCustomDeviceFromName(const string& device_name,
-                                CustomDevice** dev) const;
-
   Status RegisterCustomDevice(const string& name,
-                              std::unique_ptr<CustomDevice> device);
+                              std::unique_ptr<CustomDevice> device) override;
+
+  CustomDeviceOpHandler& GetCustomDeviceOpHandler() override {
+    return custom_device_op_handler_;
+  };
 
   // Find or create a composite device with the given `underlying_devices` and
   // `device_name` (if not empty).
@@ -587,7 +591,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted {
       TF_GUARDED_BY(device_type_list_mu_);
   Rendezvous* rendezvous_;
   std::function<Rendezvous*(const int64)> rendezvous_creator_;
-  std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
+  CustomDeviceOpHandler custom_device_op_handler_;
 
   mutable mutex composite_devices_mu_;
   // Maps from the fingerprint of a set of device names to a virtual
diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc
index 81b1e3594f2..905b1d94dad 100644
--- a/tensorflow/core/common_runtime/eager/core.cc
+++ b/tensorflow/core/common_runtime/eager/core.cc
@@ -111,7 +111,7 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
   *status = this->FindDeviceFromName(device_name, &device);
   if (!status->ok()) {
     tensorflow::CustomDevice* dev;
-    if (this->FindCustomDeviceFromName(device_name, &dev)) {
+    if (custom_device_op_handler_.FindCustomDeviceFromName(device_name, &dev)) {
       *status = dev->CopyTensorToDevice(handle, &result);
       if (status->ok()) {
         return result;
@@ -128,7 +128,8 @@ ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice(
     return nullptr;
   }
   tensorflow::CustomDevice* dev;
-  if (this->FindCustomDeviceFromName(handle_device_name, &dev)) {
+  if (custom_device_op_handler_.FindCustomDeviceFromName(handle_device_name,
+                                                         &dev)) {
     *status = dev->CopyTensorFromDevice(handle, device_name, &result);
     if (status->ok()) {
       return result;
@@ -202,28 +203,8 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
     }
   }
 
-  // Decide to either run the operation on a custom device or copy off all of
-  // the custom device inputs.
-  VariantDevice maybe_custom_device = Device();
-  if (absl::holds_alternative<CustomDevice*>(maybe_custom_device) ||
-      !inputs_are_tensor_handles_) {
-    // If the op wasn't placed on a custom device explicitly and there are no
-    // non-TensorHandle inputs, the op will definitely be placed on a physical
-    // device. Otherwise we need to check the inputs one by one.
-    TF_RETURN_IF_ERROR(
-        eager::MaybePinToCustomDevice(&maybe_custom_device, *this));
-    if (absl::holds_alternative<CustomDevice*>(maybe_custom_device)) {
-      ImmediateExecutionTensorHandle** retval_array =
-          reinterpret_cast<ImmediateExecutionTensorHandle**>(retvals.data());
-      return absl::get<CustomDevice*>(maybe_custom_device)
-          ->Execute(this, retval_array, num_retvals);
-    } else {
-      TF_RETURN_IF_ERROR(CopyOffCustomDeviceInputs());
-    }
-  }
-
   // Run eager placement logic.
-  class Device* device = absl::get<class Device*>(maybe_custom_device);
+  class Device* device = absl::get<class Device*>(Device());
   if (device == nullptr) {
     TF_RETURN_IF_ERROR(eager::MaybePinToResourceDevice(&device, *this));
   }
diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
new file mode 100644
index 00000000000..28c7a1161f6
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc
@@ -0,0 +1,167 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/eager/custom_device_op_handler.h"
+
+#include "tensorflow/core/platform/errors.h"
+
+namespace tensorflow {
+
+void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); }
+
+Status CustomDeviceOpHandler::RegisterCustomDevice(
+    const string& device_name, std::unique_ptr<CustomDevice> device) {
+  DeviceNameUtils::ParsedName parsed;
+  if (!DeviceNameUtils::ParseFullName(device_name, &parsed) ||
+      !parsed.has_job || !parsed.has_replica || !parsed.has_task ||
+      !parsed.has_type || !parsed.has_id) {
+    return errors::InvalidArgument(
+        device_name,
+        " could not be parsed as a device name. Use the full "
+        "/job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num> "
+        "format.");
+  }
+
+  if (!custom_devices_.emplace(device_name, std::move(device)).second) {
+    return errors::AlreadyExists(device_name,
+                                 " already registered as a custom device.");
+  }
+  return Status::OK();
+}
+
+bool CustomDeviceOpHandler::FindCustomDeviceFromName(
+    const string& name, CustomDevice** device) const {
+  auto dev_it = custom_devices_.find(name);
+  if (dev_it == custom_devices_.end()) {
+    return false;
+  }
+  *device = dev_it->second.get();
+  return true;
+}
+
+Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op,
+                                      ImmediateExecutionTensorHandle** retvals,
+                                      int* num_retvals) {
+  tensorflow::CustomDevice* custom_device = nullptr;
+
+  TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op));
+
+  if (custom_device != nullptr) {
+    return custom_device->Execute(op, retvals, num_retvals);
+  }
+
+  // The op will be placed on physical device. However, it contains custom
+  // device tensor handles. The tensor handles will be copy to physical device
+  // first.
+  if (op->HasCustomDeviceInput()) {
+    auto inputs = op->GetInputs();
+    for (int i = 0; i < inputs.size(); ++i) {
+      auto target_device = op->DeviceName();
+      if (target_device.empty()) {
+        target_device = op->GetContext()->HostCPUName();
+      }
+      // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
+      // here.
+      if (tensorflow::CustomDeviceTensorHandle::classof(inputs[i])) {
+        tensorflow::CustomDeviceTensorHandle* previous =
+            tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
+                inputs[i]);
+        tensorflow::ImmediateExecutionTensorHandle* new_tesnor;
+        TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
+            previous, target_device, &new_tesnor));
+        Status s = op->SetInput(i, new_tesnor);
+        new_tesnor->Unref();
+        TF_RETURN_IF_ERROR(s);
+      }
+    }
+  }
+
+  return op->Execute(
+      absl::MakeSpan(
+          reinterpret_cast<tensorflow::AbstractTensorHandle**>(retvals),
+          *num_retvals),
+      num_retvals);
+}
+
+Status CustomDeviceOpHandler::MaybePinToCustomDevice(
+    CustomDevice** device, const ImmediateExecutionOperation& op) const {
+  *device = nullptr;
+  if (!FindCustomDeviceFromName(op.DeviceName(), device) &&
+      !op.HasCustomDeviceInput()) {
+    return Status::OK();
+  }
+
+  // Ops are placed on a custom device if there's no other explicit requested
+  // placement and there is only one custom device in the op
+  // inputs.
+  //
+  // Resource-dtype inputs take precedence over non-resource inputs and explicit
+  // placements; this function pins ops with a resource-dtype custom device
+  // input to that custom device.
+  CustomDevice* first = nullptr;
+  if (!op.GetInputs().empty()) {
+    for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
+      // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
+      // here.
+      if (CustomDeviceTensorHandle::classof(generic_input)) {
+        const CustomDeviceTensorHandle* input =
+            down_cast<const CustomDeviceTensorHandle*>(generic_input);
+        CustomDevice* current = input->device();
+        if (first == nullptr) {
+          first = current;
+        } else if (first != current) {
+          return errors::InvalidArgument(absl::StrCat(
+              "If an operation has one of its inputs in a custom device, then "
+              "all inputs should be on that same custom device or another "
+              "physical device. Operation ",
+              op.Name(),
+              " has one input in custom "
+              "device ",
+              first->name(),
+              " and at least one input in a different custom device ",
+              current->name()));
+        }
+      }
+    }
+    for (const ImmediateExecutionTensorHandle* generic_input : op.GetInputs()) {
+      if (generic_input->DataType() == DT_RESOURCE) {
+        if (CustomDeviceTensorHandle::classof(generic_input)) {
+          const CustomDeviceTensorHandle* input =
+              down_cast<const CustomDeviceTensorHandle*>(generic_input);
+          // There's only one custom device input, and it's a resource input, so
+          // we'll force-place the op on to that custom device. As with physical
+          // devices, this overrides any explicit placement for the op.
+          *device = input->device();
+          return Status::OK();
+        } else {
+          // Don't set a custom device if there's a physical-device resource
+          // input.
+          return Status::OK();
+        }
+      }
+    }
+  }
+  // Since there are no resource-dtype inputs, we'll respect explicit placements
+  // before considering input-based placement.
+  if (*device == nullptr && op.DeviceName().empty() && first != nullptr) {
+    // If there are non-resource inputs on a custom device we will default the
+    // op to that custom device, but not override an explicit op placement.
+    *device = first;
+    return Status::OK();
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h
new file mode 100644
index 00000000000..00ac5f324ba
--- /dev/null
+++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h
@@ -0,0 +1,51 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
+
+#include "tensorflow/c/eager/immediate_execution_operation.h"
+#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
+#include "tensorflow/core/common_runtime/eager/custom_device.h"
+#include "tensorflow/core/lib/core/status.h"
+namespace tensorflow {
+
+// TODO(tfrt-devs): Figure out a way to unify it with OpHandler in TFRT.
+class CustomDeviceOpHandler {
+ public:
+  ~CustomDeviceOpHandler() {}
+  // Register a new custom device.
+  Status RegisterCustomDevice(const string& device_name,
+                              std::unique_ptr<CustomDevice> device);
+
+  // Find the custom device from given name. Return true if it finds one.
+  bool FindCustomDeviceFromName(const string& name,
+                                CustomDevice** device) const;
+
+  Status Execute(ImmediateExecutionOperation* op,
+                 ImmediateExecutionTensorHandle** retvals, int* num_retvals);
+
+  // Determine whether to place an op on a custom device. This method is
+  // exposed as public for test only.
+  Status MaybePinToCustomDevice(CustomDevice** device,
+                                const ImmediateExecutionOperation& op) const;
+
+  void Clear();
+
+ private:
+  std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
+};
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_CUSTOM_DEVICE_OP_HANDLER_H_
diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc
index a642a816c76..cd7340e8463 100644
--- a/tensorflow/core/common_runtime/eager/custom_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc
@@ -138,43 +138,47 @@ TEST(CustomDevice, TestResourcePlacement) {
   TF_ASSERT_OK(op.Reset("AssignVariableOp", ""));
   TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
   TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
-  VariantDevice placed_device(kVariantDeviceNull);
-  TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
+  CustomDevice* placed_device = nullptr;
+  TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
+      &placed_device, op));
   // MaybePinToCustomDevice has no opinion about ops which have physical
   // resource-dtype inputs. They'll get placed on physical devices.
-  EXPECT_EQ(kVariantDeviceNull, placed_device);
+  EXPECT_EQ(nullptr, placed_device);
 
   op.Clear();
   TF_ASSERT_OK(op.Reset("AssignVariableOp", custom_device_name.c_str()));
   TF_ASSERT_OK(op.AddInput(physical_resource_tensor.get()));
   TF_ASSERT_OK(op.AddInput(custom_float_tensor.get()));
-  placed_device = kVariantDeviceNull;
-  TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
+  placed_device = nullptr;
+  TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
+      &placed_device, op));
   // Explicit placement onto a custom device also doesn't trigger custom device
   // placement if there's a physical device resource input.
-  EXPECT_EQ(kVariantDeviceNull, placed_device);
+  EXPECT_EQ(nullptr, placed_device);
 
   op.Clear();
   TF_ASSERT_OK(
       op.Reset("Identity", "/job:localhost/replica:0/task:0/device:CPU:0"));
   TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
-  placed_device = kVariantDeviceNull;
-  TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
+  placed_device = nullptr;
+  TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
+      &placed_device, op));
   // Explicit placements typically override input-based placement onto a custom
   // device.
-  EXPECT_EQ(kVariantDeviceNull, placed_device);
+  EXPECT_EQ(nullptr, placed_device);
 
   op.Clear();
   TF_ASSERT_OK(op.Reset("AssignVariableOp",
                         "/job:localhost/replica:0/task:0/device:CPU:0"));
   TF_ASSERT_OK(op.AddInput(custom_resource_tensor.get()));
   TF_ASSERT_OK(op.AddInput(physical_float_tensor.get()));
-  placed_device = kVariantDeviceNull;
-  TF_ASSERT_OK(MaybePinToCustomDevice(&placed_device, op));
+  placed_device = nullptr;
+  TF_ASSERT_OK(ctx->GetCustomDeviceOpHandler().MaybePinToCustomDevice(
+      &placed_device, op));
   // Even with an explicit physical device placement, custom device resource
   // inputs place the op on the custom device.
-  ASSERT_TRUE(absl::holds_alternative<CustomDevice*>(placed_device));
-  EXPECT_EQ(&custom_device, absl::get<CustomDevice*>(placed_device));
+  ASSERT_NE(placed_device, nullptr);
+  EXPECT_EQ(&custom_device, placed_device);
 }
 
 }  // namespace
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc
index 883e9a8a8b0..de4a4495e87 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation.cc
@@ -36,7 +36,7 @@ void EagerOperation::Clear() {
     h->Unref();
   }
   inputs_.clear();
-  inputs_are_tensor_handles_ = true;
+  custom_device_tensor_handles_count_ = 0;
   ClearInferenceState();
 }
 
@@ -269,7 +269,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) {
       down_cast<ImmediateExecutionTensorHandle*>(input);
   // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
   if (CustomDeviceTensorHandle::classof(h)) {
-    inputs_are_tensor_handles_ = false;
+    custom_device_tensor_handles_count_++;
   }
   AddTensorHandle(h);
   return MaybeInferSingleInputAttrs(h);
@@ -281,7 +281,7 @@ Status EagerOperation::AddInputList(
     // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
     // here.
     if (CustomDeviceTensorHandle::classof(input)) {
-      inputs_are_tensor_handles_ = false;
+      custom_device_tensor_handles_count_++;
     }
     ImmediateExecutionTensorHandle* h =
         down_cast<ImmediateExecutionTensorHandle*>(input);
@@ -290,6 +290,25 @@ Status EagerOperation::AddInputList(
   return InferInputListAttrs(inputs.size());
 }
 
+Status EagerOperation::SetInput(size_t index,
+                                ImmediateExecutionTensorHandle* input) {
+  if (index >= inputs_.size()) {
+    return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index,
+                                   inputs_.size());
+  }
+  auto* previous = inputs_[index];
+  if (CustomDeviceTensorHandle::classof(previous)) {
+    custom_device_tensor_handles_count_--;
+  }
+  if (CustomDeviceTensorHandle::classof(input)) {
+    custom_device_tensor_handles_count_++;
+  }
+  input->Ref();
+  inputs_[index] = input;
+  previous->Unref();
+  return Status::OK();
+}
+
 Status EagerOperation::Reset(
     const char* op, const char* device_name, bool remote,
     EagerExecutor* executor,
@@ -407,7 +426,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) {
 
 Status EagerOperation::TensorHandleInputs(
     const absl::InlinedVector<TensorHandle*, 4>** inputs) const {
-  if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
+  if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
     *inputs = reinterpret_cast<const absl::InlinedVector<TensorHandle*, 4>*>(
         &inputs_);
     return Status::OK();
@@ -418,7 +437,7 @@ Status EagerOperation::TensorHandleInputs(
 
 Status EagerOperation::MutableTensorHandleInputs(
     absl::InlinedVector<TensorHandle*, 4>** inputs) {
-  if (TF_PREDICT_TRUE(inputs_are_tensor_handles_)) {
+  if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) {
     *inputs =
         reinterpret_cast<absl::InlinedVector<TensorHandle*, 4>*>(&inputs_);
     return Status::OK();
@@ -436,14 +455,7 @@ Status EagerOperation::SetDeviceName(const char* c_name) {
     }
     last_set_device_name_ = name;
     device_name_ = DeviceNameUtils::ParsedNameToString(device_parsed_name_);
-    CustomDevice* custom_device;
-    if (ctx_.FindCustomDeviceFromName(device_name_, &custom_device)) {
-      device_ = custom_device;
-    } else {
-      // Device placement for physical devices happens lazily in
-      // EagerExecute/EagerRemoteExecute, and can depend on the inputs.
-      device_ = kVariantDeviceNull;
-    }
+    device_ = kVariantDeviceNull;
   }
   return Status::OK();
 }
@@ -495,30 +507,4 @@ void EagerOperation::AddTensorHandle(ImmediateExecutionTensorHandle* h) {
   attrs_.NumInputs(static_cast<int>(inputs_.size()));
 }
 
-Status EagerOperation::CopyOffCustomDeviceInputs() {
-  if (absl::holds_alternative<CustomDevice*>(device_)) {
-    return errors::Internal(
-        "Trying to copy inputs to a custom device op off a custom device.");
-  }
-  for (int i = 0; i < inputs_.size(); ++i) {
-    // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
-    // here.
-    if (CustomDeviceTensorHandle::classof(inputs_[i])) {
-      CustomDeviceTensorHandle* previous =
-          down_cast<CustomDeviceTensorHandle*>(inputs_[i]);
-      class Device* target_device;
-      if (device_ == kVariantDeviceNull) {
-        target_device = ctx_.HostCPU();
-      } else {
-        target_device = absl::get<class Device*>(device_);
-      }
-      TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice(
-          previous, target_device->name(), &inputs_[i]));
-      previous->Unref();
-    }
-  }
-  inputs_are_tensor_handles_ = true;
-  return Status::OK();
-}
-
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index e440a4a79dd..e1cb20b7575 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -55,6 +55,8 @@ class EagerOperation : public ImmediateExecutionOperation {
 
   const string& DeviceName() const override { return device_name_; }
 
+  ImmediateExecutionContext* GetContext() const override { return &ctx_; }
+
   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
     return device_parsed_name_;
   }
@@ -83,7 +85,11 @@ class EagerOperation : public ImmediateExecutionOperation {
 
   Status AddInput(AbstractTensorHandle* input) override;
   Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
+  Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override;
   absl::Span<ImmediateExecutionTensorHandle* const> GetInputs() const override;
+  bool HasCustomDeviceInput() const override {
+    return custom_device_tensor_handles_count_ > 0;
+  }
   Status Execute(absl::Span<AbstractTensorHandle*> retvals,
                  int* num_retvals) override;
   const tensorflow::OpDef* OpDef() const override { return op_def_; };
@@ -207,20 +213,14 @@ class EagerOperation : public ImmediateExecutionOperation {
   void InferMixedTypeInputListAttrs(const OpDef::ArgDef& input_def,
                                     const std::vector<DataType>& dtypes);
 
-  // Replaces input tensors placed on custom devices with physical device
-  // equivalents. Used if an op is placed on a physical device but may have
-  // custom device inputs.
-  Status CopyOffCustomDeviceInputs();
-
   tensorflow::EagerContext& ctx_;
   const char* op_name_ = nullptr;
   AttrBuilder attrs_;
   const AttrTypeMap* attr_types_;
 
-  // Toggled to indicate whether all inputs are known to be TensorHandles and
-  // not another type (e.g. custom device tensor handles). Explicitly set to
-  // false when custom device TensorHandles are added.
-  bool inputs_are_tensor_handles_ = true;
+  // The number of custom device TensorHandle inputs. These inputs need to be
+  // processed by CustomDeviceOpHandler first.
+  int custom_device_tensor_handles_count_ = 0;
   absl::InlinedVector<ImmediateExecutionTensorHandle*, 4> inputs_;
 
   // The last device name given to SetDeviceName.
diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc
index 77514d67e3a..3b9fa7bb2d1 100644
--- a/tensorflow/core/common_runtime/eager/placement_utils.cc
+++ b/tensorflow/core/common_runtime/eager/placement_utils.cc
@@ -77,11 +77,6 @@ bool IsFunction(StringPiece op_name) {
   return false;
 }
 
-bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx) {
-  CustomDevice* custom_device;
-  return ctx.FindCustomDeviceFromName(string(device_name), &custom_device);
-}
-
 Status MaybePinSmallOpsToCpu(
     bool* result, StringPiece op_name,
     absl::Span<ImmediateExecutionTensorHandle* const> args,
@@ -182,70 +177,5 @@ Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) {
   return Status::OK();
 }
 
-Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op) {
-  // Ops are placed on a custom device if there's no other explicit requested
-  // placement and there is only one custom device in the op
-  // inputs.
-  //
-  // Resource-dtype inputs take precedence over non-resource inputs and explicit
-  // placements; this function pins ops with a resource-dtype custom device
-  // input to that custom device.
-  CustomDevice* first = nullptr;
-  if (!op.Inputs().empty()) {
-    for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
-      // TODO(b/175427838): It would be nice to be able to use tensorflow::isa
-      // here.
-      if (CustomDeviceTensorHandle::classof(generic_input)) {
-        const CustomDeviceTensorHandle* input =
-            down_cast<const CustomDeviceTensorHandle*>(generic_input);
-        CustomDevice* current = input->device();
-        if (first == nullptr) {
-          first = current;
-        } else if (first != current) {
-          return errors::InvalidArgument(absl::StrCat(
-              "If an operation has one of its inputs in a custom device, then "
-              "all inputs should be on that same custom device or another "
-              "physical device. Operation ",
-              op.Name(),
-              " has one input in custom "
-              "device ",
-              first->name(),
-              " and at least one input in a different custom device ",
-              current->name()));
-        }
-      }
-    }
-    for (const ImmediateExecutionTensorHandle* generic_input : op.Inputs()) {
-      if (generic_input->DataType() == DT_RESOURCE) {
-        if (CustomDeviceTensorHandle::classof(generic_input)) {
-          const CustomDeviceTensorHandle* input =
-              down_cast<const CustomDeviceTensorHandle*>(generic_input);
-          // There's only one custom device input, and it's a resource input, so
-          // we'll force-place the op on to that custom device. As with physical
-          // devices, this overrides any explicit placement for the op.
-          *device = input->device();
-          return Status::OK();
-        } else {
-          // Don't set a custom device if there's a physical-device resource
-          // input.
-          return Status::OK();
-        }
-      }
-    }
-  }
-  // Since there are no resource-dtype inputs, we'll respect explicit placements
-  // before considering input-based placement.
-  if (absl::holds_alternative<CustomDevice*>(op.Device())) {
-    *device = op.Device();
-  } else if (op.DeviceName().empty() && first != nullptr) {
-    // If there are non-resource inputs on a custom device we will default the
-    // op to that custom device, but not override an explicit op placement.
-    *device = first;
-    return Status::OK();
-  }
-
-  return Status::OK();
-}
-
 }  // namespace eager
 }  // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h
index 7676fe01b43..9435f9848d3 100644
--- a/tensorflow/core/common_runtime/eager/placement_utils.h
+++ b/tensorflow/core/common_runtime/eager/placement_utils.h
@@ -16,6 +16,7 @@ limitations under the License.
 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PLACEMENT_UTILS_H_
 
+#include "tensorflow/c/eager/immediate_execution_operation.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
 #include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/stringpiece.h"
@@ -27,8 +28,6 @@ bool IsColocationExempt(StringPiece op_name);
 
 bool IsFunction(StringPiece op_name);
 
-bool IsCustomDevice(StringPiece device_name, const EagerContext& ctx);
-
 // TODO(b/154234908): Unify placement logic.
 // TODO(b/159647422): Add C++ unit tests for placement logic.
 
@@ -44,11 +43,6 @@ Status MaybePinSmallOpsToCpu(
 // the device the resource is, regardless of anything else that has been
 // specified. This is identical to the graph mode behavior.
 Status MaybePinToResourceDevice(Device** device, const EagerOperation& op);
-
-// If all the inputs are on the same custom device, use that custom
-// device. Otherwise, it is an error to have a custom device as an input.
-Status MaybePinToCustomDevice(VariantDevice* device, const EagerOperation& op);
-
 }  // namespace eager
 }  // namespace tensorflow