From a4064a389e92976ce8e26770cb19478b6068ed94 Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Thu, 6 Feb 2020 09:58:02 -0800
Subject: [PATCH] Experimental API for custom devices in TFE.

Custom devices are an experimental hook into eager op execution, allowing experimentation outside the TensorFlow codebase. These devices do not work in traced code at the moment.

PiperOrigin-RevId: 293615055
Change-Id: I031da213e964caa7d4e11e0f491a3985d034b175
---
 tensorflow/c/eager/BUILD                      |  27 ++-
 tensorflow/c/eager/c_api.cc                   | 188 +++++++++++++++++-
 tensorflow/c/eager/c_api_debug.cc             |   2 +-
 tensorflow/c/eager/c_api_experimental.h       |  51 +++++
 tensorflow/c/eager/custom_device_test.cc      | 159 +++++++++++++++
 tensorflow/core/common_runtime/eager/BUILD    |   1 +
 .../core/common_runtime/eager/context.cc      |  15 ++
 .../core/common_runtime/eager/context.h       |  25 +++
 .../core/common_runtime/eager/execute.cc      |  51 +++--
 .../core/common_runtime/eager/execute_node.cc |  11 +-
 .../common_runtime/eager/tensor_handle.cc     |  78 ++++++--
 .../core/common_runtime/eager/tensor_handle.h |  37 +++-
 .../eager/eager_service_impl_test.cc          |   2 +-
 .../eager/remote_copy_node.cc                 |   6 +-
 .../distributed_runtime/eager/remote_mgr.cc   |   9 +-
 tensorflow/core/framework/tensor.cc           |   5 +
 tensorflow/core/framework/tensor.h            |   1 +
 tensorflow/python/framework/device_spec.py    |   2 +-
 tensorflow/python/lib/core/py_func.cc         |   8 +-
 tensorflow/python/lib/core/py_seq_tensor.cc   |   9 +-
 20 files changed, 618 insertions(+), 69 deletions(-)
 create mode 100644 tensorflow/c/eager/custom_device_test.cc

diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 6c952d7c67f..ff2d243dc34 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -2,6 +2,7 @@
 
 load(
     "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
     "tf_copts",
     "tf_cuda_cc_test",
     "tf_cuda_library",
@@ -289,6 +290,27 @@ tf_cuda_cc_test(
     ],
 )
 
+tf_cc_test(
+    name = "custom_device_test",
+    size = "small",
+    srcs = [
+        "custom_device_test.cc",
+    ],
+    deps = [
+        ":c_api",
+        ":c_api_experimental",
+        ":c_api_test_util",
+        "//tensorflow/c:c_api",
+        "//tensorflow/c:c_test_util",
+        "//tensorflow/cc/profiler",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 cc_library(
     name = "tape",
     hdrs = ["tape.h"],
@@ -301,7 +323,10 @@ cc_library(
 
 filegroup(
     name = "headers",
-    srcs = ["c_api.h"],
+    srcs = [
+        "c_api.h",
+        "c_api_experimental.h",
+    ],
     visibility = ["//tensorflow:__subpackages__"],
 )
 
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 67da9c4f0a4..c3cbdb1ade3 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -103,7 +103,12 @@ const tensorflow::OpDef* GetOpDef(TFE_Op* op, TF_Status* status) {
   return op_def;
 }
 
-bool IsCPU(const tensorflow::Device* d) {
+bool IsCPU(
+    absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant) {
+  if (VariantDeviceIsCustom(variant)) {
+    return false;
+  }
+  tensorflow::Device* d = absl::get<tensorflow::Device*>(variant);
   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
 }
 
@@ -1009,6 +1014,9 @@ const char* tensorflow::TensorHandleInterface::DeviceName(
   if (!IsValid(status)) {
     return nullptr;
   }
+  if (VariantDeviceIsCustom(handle_->device())) {
+    return absl::get<CustomDevice*>(handle_->device())->name().c_str();
+  }
   tensorflow::Device* d = handle_->op_device();
   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
                         : d->name().c_str();
@@ -1029,9 +1037,15 @@ const char* tensorflow::TensorHandleInterface::BackingDeviceName(
   if (!IsValid(status)) {
     return nullptr;
   }
-  tensorflow::Device* d = handle_->device();
-  return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
-                        : d->name().c_str();
+  if (VariantDeviceIsCustom(handle_->device())) {
+    return absl::get<tensorflow::CustomDevice*>(handle_->device())
+        ->name()
+        .c_str();
+  } else {
+    tensorflow::Device* d = absl::get<tensorflow::Device*>(handle_->device());
+    return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
+                          : d->name().c_str();
+  }
 }
 
 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
@@ -1065,6 +1079,18 @@ TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
   if (!IsValid(status)) {
     return nullptr;
   }
+  if (VariantDeviceIsCustom(handle_->device())) {
+    tensorflow::CustomDevice* custom_device =
+        absl::get<tensorflow::CustomDevice*>(handle_->device());
+    tensorflow::TensorHandle* copy;
+    *status = custom_device->CopyTensorFromDevice(
+        handle_, "/job:localhost/task:0/replica:0/device:CPU:0", &copy);
+    if (status->ok()) {
+      return TensorHandleInterface(copy).Resolve(status);
+    } else {
+      return nullptr;
+    }
+  }
 
   // TODO(agarwal): move this implementation inside TFE_TensorHandle.
   if (handle_->IsRemote()) {
@@ -1110,6 +1136,11 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
   tensorflow::TensorHandle* handle =
       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
           ->Handle();
+  if (VariantDeviceIsCustom(handle->device())) {
+    const tensorflow::Tensor* t;
+    status->status = handle->Tensor(&t);
+    return t->data();
+  }
 
   if (handle->IsRemote()) {
     status->status = tensorflow::errors::InvalidArgument(
@@ -1117,8 +1148,9 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
         "handle.");
     return nullptr;
   }
-  if (handle->device() != nullptr) {
-    status->status = handle->device()->Sync();
+  tensorflow::Device* device(absl::get<tensorflow::Device*>(handle->device()));
+  if (device != nullptr) {
+    status->status = device->Sync();
     if (!status->status.ok()) {
       return nullptr;
     }
@@ -1137,12 +1169,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
     const int64_t* dims, int num_dims, void* data, size_t len,
     void (*deallocator)(void* data, size_t len, void* arg),
     void* deallocator_arg, TF_Status* status) {
-  tensorflow::Device* device;
+  tensorflow::Device* device = nullptr;
   tensorflow::EagerContext* context = ctx->context;
   status->status = context->FindDeviceFromName(device_name, &device);
+  tensorflow::CustomDevice* custom_device = nullptr;
   if (!status->status.ok()) {
-    deallocator(data, len, deallocator_arg);
-    return nullptr;
+    status->status =
+        context->FindCustomDeviceFromName(device_name, &custom_device);
+    if (!status->status.ok()) {
+      deallocator(data, len, deallocator_arg);
+      return nullptr;
+    }
   }
   std::vector<tensorflow::int64> dimvec(num_dims);
   for (int i = 0; i < num_dims; ++i) {
@@ -1166,8 +1203,14 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
                        tensorflow::TensorShape(dimvec), buf);
   buf->Unref();
   tensorflow::TensorHandle* ret_handle;
+  absl::variant<tensorflow::Device*, tensorflow::CustomDevice*> variant_device;
+  if (custom_device == nullptr) {
+    variant_device = device;
+  } else {
+    variant_device = custom_device;
+  }
   status->status = tensorflow::TensorHandle::CreateLocalHandle(
-      t, device, context, &ret_handle);
+      t, variant_device, context, &ret_handle);
   if (!status->status.ok()) {
     return nullptr;
   }
@@ -1508,8 +1551,42 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
   tensorflow::EagerContext* context = ctx->context;
   status->status = context->FindDeviceFromName(device_name, &device);
   if (!status->status.ok()) {
+    tensorflow::CustomDevice* dev;
+    status->status = context->FindCustomDeviceFromName(device_name, &dev);
+    if (status->status.ok()) {
+      status->status = dev->CopyTensorToDevice(
+          tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
+              h->handle.get())
+              ->Handle(),
+          &handle);
+      if (status->status.ok()) {
+        return new TFE_TensorHandle{
+            std::make_unique<tensorflow::TensorHandleInterface>(handle)};
+      }
+    }
     return nullptr;
   }
+  // Handle tensor handles currently in custom devices
+  const char* handle_device_name = h->handle->DeviceName(&status->status);
+  if (!status->status.ok()) {
+    return nullptr;
+  }
+  tensorflow::CustomDevice* dev;
+  status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
+  if (status->status.ok()) {
+    status->status = dev->CopyTensorFromDevice(
+        tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
+            h->handle.get())
+            ->Handle(),
+        device_name, &handle);
+    if (status->status.ok()) {
+      return new TFE_TensorHandle{
+          std::make_unique<tensorflow::TensorHandleInterface>(handle)};
+    }
+    return nullptr;
+  }
+
+  // Handle regular case.
   status->status = tensorflow::EagerCopyToDevice(
       tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
           ->Handle(),
@@ -1648,3 +1725,94 @@ void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
   }
 }
 }  // namespace tensorflow
+
+namespace {
+class CustomDeviceAPI : public tensorflow::CustomDevice {
+ public:
+  CustomDeviceAPI(TFE_CustomDevice device, void* info, string name)
+      : device_(device), info_(info), name_(name) {}
+
+  ~CustomDeviceAPI() override { device_.delete_device(info_); }
+
+  const string& name() override { return name_; }
+
+  tensorflow::Status CopyTensorToDevice(
+      tensorflow::TensorHandle* tensor,
+      tensorflow::TensorHandle** result) override {
+    tensor->Ref();
+    TFE_TensorHandle tensor_handle{
+        std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
+    TF_Status status;
+    TFE_TensorHandle* result_handle =
+        device_.copy_tensor_to_device(&tensor_handle, &status, info_);
+    if (!status.status.ok()) return status.status;
+    *result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
+                  result_handle->handle.get())
+                  ->Handle();
+    (*result)->Ref();
+    delete result_handle;
+    return status.status;
+  }
+
+  tensorflow::Status CopyTensorFromDevice(
+      tensorflow::TensorHandle* tensor,
+      const tensorflow::string& target_device_name,
+      tensorflow::TensorHandle** result) override {
+    TF_Status status;
+    tensor->Ref();
+    TFE_TensorHandle tensor_handle{
+        std::make_unique<tensorflow::TensorHandleInterface>(tensor)};
+    TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
+        &tensor_handle, target_device_name.c_str(), &status, info_);
+    if (!status.status.ok()) return status.status;
+    *result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
+                  result_handle->handle.get())
+                  ->Handle();
+    (*result)->Ref();
+    delete result_handle;
+    return status.status;
+  }
+
+  tensorflow::Status Execute(tensorflow::EagerOperation* op,
+                             tensorflow::TensorHandle** retvals,
+                             int* num_retvals) override {
+    std::vector<TFE_TensorHandle*> inputs;
+    inputs.reserve(op->Inputs().size());
+    for (int i = 0; i < op->Inputs().size(); ++i) {
+      op->Inputs()[i]->Ref();
+      inputs.push_back(new TFE_TensorHandle{
+          std::make_unique<tensorflow::TensorHandleInterface>(
+              op->Inputs()[i])});
+    }
+    std::vector<TFE_TensorHandle*> outputs(*num_retvals);
+    // TODO(allenl): figure out how to get attrs from EagerOperation
+    TF_Status status;
+    device_.execute(inputs.size(), inputs.data(), op->Name().c_str(),
+                    num_retvals, outputs.data(), &status, info_);
+    if (status.status.ok()) {
+      for (int i = 0; i < *num_retvals; ++i) {
+        retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
+                         outputs[i]->handle.get())
+                         ->Handle();
+        retvals[i]->Ref();
+      }
+    }
+    for (auto inp : inputs) {
+      delete inp;
+    }
+    return status.status;
+  }
+
+ private:
+  TFE_CustomDevice device_;
+  void* info_;
+  string name_;
+};
+}  // namespace
+
+void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
+                              const char* device_name, void* device_info) {
+  auto custom_device =
+      std::make_unique<CustomDeviceAPI>(device, device_info, device_name);
+  ctx->context->RegisterCustomDevice(device_name, std::move(custom_device));
+}
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
index e8069e19cf1..2d6dd21e12b 100644
--- a/tensorflow/c/eager/c_api_debug.cc
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -66,7 +66,7 @@ TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
   }
 
 #ifdef TENSORFLOW_EAGER_USE_XLA
-  tensorflow::Device* device = handle_->device();
+  tensorflow::Device* device = absl::get<Device*>(handle_->device());
 
   // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
   tensorflow::XlaDevice* xla_device =
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 0165eb3b781..fe95c238e52 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -463,6 +463,57 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
 TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
                                                 TF_Buffer* buf);
 
+#define TFE_CUSTOM_DEVICE_VERSION 0
+
+// Struct to be filled in
+typedef struct TFE_CustomDevice {
+  int version = TFE_CUSTOM_DEVICE_VERSION;
+  // Method to copy a tensor to the custom device.
+  TFE_TensorHandle* (*copy_tensor_to_device)(TFE_TensorHandle* tensor,
+                                             TF_Status* status,
+                                             void* device_info) = nullptr;
+
+  // Method to copy a tensor from the custom device to a target device.
+  TFE_TensorHandle* (*copy_tensor_from_device)(TFE_TensorHandle* tensor,
+                                               const char* target_device_name,
+                                               TF_Status* status,
+                                               void* device_info);
+
+  // Method to execute an operation.
+  // TODO(allenl) figure out a generic way of passing attrs here
+  void (*execute)(int num_inputs, TFE_TensorHandle** inputs,
+                  const char* operation_name, int* num_outputs,
+                  TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
+
+  // Method to delete a device.
+  void (*delete_device)(void* device_info);
+} TFE_CustomDevice;
+
+// Registers a custom device for use with eager execution.
+//
+// Eager operations may be placed on this device, e.g.  `with
+// tf.device("CUSTOM"):` from Python if `device_name` for this call is
+// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
+//
+// The custom device defines copy operations for moving TensorHandles on and
+// off, and an an execution operation for named operations. Often execution will
+// simply wrap op execution on one or more physical devices.
+//
+// device_info is an opaque caller-defined type stored with the custom device
+// which is passed to the functions referenced in the TFE_CustomDevice struct
+// `device` (execute, delete_device, etc.). It can for example contain the
+// names of wrapped devices.
+//
+// There are currently no graph semantics implemented for registered custom
+// devices, so executing tf.functions which contain operations placed on custom
+// devices will fail.
+//
+// This API is highly experimental, and in particular is expected to change when
+// it starts supporting operations with attributes and when tf.function support
+// is added.
+void TFE_RegisterCustomDevice(TFE_Context* ctx, TFE_CustomDevice device,
+                              const char* device_name, void* device_info);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/c/eager/custom_device_test.cc b/tensorflow/c/eager/custom_device_test.cc
new file mode 100644
index 00000000000..cf09cf57de6
--- /dev/null
+++ b/tensorflow/c/eager/custom_device_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A simple logging device to test custom device registration.
+#include <memory>
+
+#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/c/eager/c_api_experimental.h"
+#include "tensorflow/c/eager/c_api_test_util.h"
+#include "tensorflow/c/tf_status.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace {
+
+struct LoggingDevice {
+  TFE_Context* ctx;
+  tensorflow::string device_name;
+  tensorflow::string underlying_device;
+  // Set to true whenever a TensorHandle is copied onto the device
+  bool* arrived_flag;
+};
+
+struct LoggedTensor {
+  TFE_TensorHandle* tensor;
+  LoggedTensor() = delete;
+  explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
+  ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
+};
+
+void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
+  delete reinterpret_cast<LoggedTensor*>(data);
+}
+
+TFE_TensorHandle* MakeLoggedTensorHandle(
+    TFE_Context* ctx, const tensorflow::string& logging_device_name,
+    std::unique_ptr<LoggedTensor> t, TF_Status* status) {
+  std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  for (int i = 0; i < shape.size(); ++i) {
+    shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
+    if (TF_GetCode(status) != TF_OK) return nullptr;
+  }
+  auto dtype = TFE_TensorHandleDataType(t->tensor);
+  return TFE_NewTensorHandleFromDeviceMemory(
+      ctx, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
+      t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
+}
+
+TFE_TensorHandle* CopyToLoggingDevice(TFE_TensorHandle* tensor,
+                                      TF_Status* status, void* device_info) {
+  LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
+  TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
+      tensor, dev->ctx, dev->underlying_device.c_str(), status);
+  if (TF_GetCode(status) != TF_OK) return nullptr;
+  auto dst = std::make_unique<LoggedTensor>(t);
+  *(dev->arrived_flag) = true;
+  return MakeLoggedTensorHandle(dev->ctx, dev->device_name, std::move(dst),
+                                status);
+}
+
+TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_TensorHandle* tensor,
+                                              const char* target_device_name,
+                                              TF_Status* status,
+                                              void* device_info) {
+  TF_SetStatus(status, TF_INTERNAL,
+               "Trying to copy a tensor out of a logging device.");
+  return nullptr;
+}
+
+void LoggingDeviceExecute(int num_inputs, TFE_TensorHandle** inputs,
+                          const char* operation_name, int* num_outputs,
+                          TFE_TensorHandle** outputs, TF_Status* s,
+                          void* device_info) {
+  LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
+  TFE_Op* op(TFE_NewOp(dev->ctx, operation_name, s));
+  if (TF_GetCode(s) != TF_OK) return;
+  TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
+  for (int j = 0; j < num_inputs; ++j) {
+    TFE_TensorHandle* input = inputs[j];
+    const char* input_device = TFE_TensorHandleDeviceName(input, s);
+    if (TF_GetCode(s) != TF_OK) return;
+    if (dev->device_name == input_device) {
+      LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
+          TFE_TensorHandleDevicePointer(input, s));
+      if (TF_GetCode(s) != TF_OK) return;
+      TFE_OpAddInput(op, t->tensor, s);
+    } else {
+      TFE_OpAddInput(op, input, s);
+    }
+    if (TF_GetCode(s) != TF_OK) return;
+  }
+  std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
+  TFE_Execute(op, op_outputs.data(), num_outputs, s);
+  TFE_DeleteOp(op);
+  if (TF_GetCode(s) != TF_OK) return;
+  std::vector<TFE_TensorHandle*> unwrapped_outputs;
+  for (auto* handle : op_outputs) {
+    unwrapped_outputs.push_back(handle);
+  }
+  for (int i = 0; i < *num_outputs; ++i) {
+    auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
+    outputs[i] = MakeLoggedTensorHandle(dev->ctx, dev->device_name,
+                                        std::move(logged_tensor), s);
+  }
+}
+
+void DeleteLoggingDevice(void* device_info) {
+  delete reinterpret_cast<LoggingDevice*>(device_info);
+}
+
+void RegisterLoggingDevice(TFE_Context* context, const char* name,
+                           bool* arrived_flag) {
+  TFE_CustomDevice custom_device;
+  custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
+  custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
+  custom_device.delete_device = &DeleteLoggingDevice;
+  custom_device.execute = &LoggingDeviceExecute;
+  LoggingDevice* device = new LoggingDevice;
+  device->ctx = context;
+  device->arrived_flag = arrived_flag;
+  device->device_name = name;
+  device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+  TFE_RegisterCustomDevice(context, custom_device, name, device);
+}
+
+TEST(CUSTOM_DEVICE, RegisterSimpleDevice) {
+  std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
+      TF_NewStatus(), TF_DeleteStatus);
+  TFE_ContextOptions* opts = TFE_NewContextOptions();
+  TFE_Context* context = TFE_NewContext(opts, status.get());
+  TFE_DeleteContextOptions(opts);
+  ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
+  bool arrived = false;
+  const char* name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
+  RegisterLoggingDevice(context, name, &arrived);
+  TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
+  ASSERT_FALSE(arrived);
+  TFE_TensorHandle* hdevice =
+      TFE_TensorHandleCopyToDevice(hcpu, context, name, status.get());
+  ASSERT_TRUE(arrived);
+  TFE_DeleteTensorHandle(hcpu);
+  TFE_DeleteTensorHandle(hdevice);
+  TFE_DeleteContext(context);
+}
+
+}  // namespace
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index c32c1a81ea9..8705f650a19 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -157,6 +157,7 @@ tf_cuda_library(
         ],
         "//conditions:default": [
             "@com_google_absl//absl/strings",
+            "@com_google_absl//absl/types:variant",
             "//tensorflow/core:core_cpu_lib",
             "//tensorflow/core:framework",
             "//tensorflow/core:framework_internal",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 4c9f6be66a2..4e93b0efaab 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -702,6 +702,21 @@ Status EagerContext::FindDeviceFromName(const char* device_name,
   return status;
 }
 
+Status EagerContext::FindCustomDeviceFromName(const string& device_name,
+                                              CustomDevice** dev) const {
+  auto dev_it = custom_devices_.find(device_name);
+  if (dev_it == custom_devices_.end()) {
+    return errors::InvalidArgument(device_name, " unknown device.");
+  }
+  *dev = dev_it->second.get();
+  return Status::OK();
+}
+
+void EagerContext::RegisterCustomDevice(const string& device_name,
+                                        std::unique_ptr<CustomDevice> device) {
+  custom_devices_[device_name] = std::move(device);
+}
+
 bool EagerContext::OnSameTask(const Device* first, const Device* second) const {
   if (first == nullptr) first = HostCPU();
   if (second == nullptr) second = HostCPU();
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 98c54035234..7c964b61e3d 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -106,6 +106,24 @@ class RunMetadataListener {
   virtual void BeforeClearRunMetadata() = 0;
 };
 
+class TensorHandle;
+class EagerOperation;
+
+class CustomDevice {
+ public:
+  virtual ~CustomDevice() {}
+  virtual const string& name() = 0;
+  virtual Status CopyTensorToDevice(TensorHandle* tensor,
+                                    TensorHandle** result) = 0;
+
+  virtual Status CopyTensorFromDevice(TensorHandle* tensor,
+                                      const string& target_device_name,
+                                      TensorHandle** result) = 0;
+
+  virtual Status Execute(EagerOperation* op, TensorHandle** retvals,
+                         int* num_retvals) = 0;
+};
+
 class EagerContext : public core::RefCounted {
  public:
   static const uint64 kInvalidContextId = 0;
@@ -416,6 +434,12 @@ class EagerContext : public core::RefCounted {
 
   Status FindDeviceFromName(const char* device_name, Device** device) const;
 
+  Status FindCustomDeviceFromName(const string& device_name,
+                                  CustomDevice** dev) const;
+
+  void RegisterCustomDevice(const string& name,
+                            std::unique_ptr<CustomDevice> device);
+
   bool OnSameTask(const Device* first, const Device* second) const;
   // Gets the CPU device on the task of device.
   Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
@@ -492,6 +516,7 @@ class EagerContext : public core::RefCounted {
   std::vector<DeviceType> prioritized_device_type_list_;
   Rendezvous* rendezvous_;
   std::function<Rendezvous*(const int64)> rendezvous_creator_;
+  std::unordered_map<string, std::unique_ptr<CustomDevice>> custom_devices_;
 
   FunctionLibraryDefinition func_lib_def_{OpRegistry::Global(), {}};
 
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index e6861a41447..e2586fe5f2d 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -177,7 +177,13 @@ Status ValidateInputTypeAndPlacement(
   for (int i = 0; i < n_inputs; ++i) {
     TensorHandle* handle = op->Inputs()[i];
     Device* expected_device = kernel->InputDevice(i);
-    Device* handle_device = handle->DeviceOrHostCPU(*ctx);
+    auto handle_device_variant = handle->DeviceOrHostCPU(*ctx);
+    if (VariantDeviceIsCustom(handle_device_variant)) {
+      return errors::Unimplemented(
+          "Custom devices and remote execution are not yet supported "
+          "together.");
+    }
+    Device* handle_device = absl::get<Device*>(handle_device_variant);
     const bool maybe_copy = !skip_remote_copy || !handle->IsRemote();
     // If the input is already on the right device, then nothing to do.
     if (expected_device != handle_device && maybe_copy) {
@@ -229,10 +235,14 @@ inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
 
 Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
                          Device** result) {
+  if (TF_PREDICT_FALSE(VariantDeviceIsCustom(tensor_handle->device()))) {
+    return errors::Unimplemented(
+        "The kernel cache does not work with custom devices.");
+  }
   Device* cpu_device = ctx.HostCPU();
   string device_name;
   if (tensor_handle->IsRemote()) {
-    Device* device = tensor_handle->device();
+    Device* device = absl::get<Device*>(tensor_handle->device());
     device_name = device != nullptr ? device->name() : cpu_device->name();
     *result = (device == nullptr ? cpu_device : device);
   } else if (tensor_handle->dtype == DT_RESOURCE) {
@@ -251,7 +261,7 @@ Status GetDeviceForInput(const EagerContext& ctx, TensorHandle* tensor_handle,
   } else if (MTypeFromDType(tensor_handle->dtype) == HOST_MEMORY) {
     *result = cpu_device;
   } else {
-    Device* device = tensor_handle->device();
+    Device* device = absl::get<Device*>(tensor_handle->device());
     device_name = device != nullptr ? device->name() : cpu_device->name();
     *result = (device == nullptr ? cpu_device : device);
   }
@@ -659,8 +669,10 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
         !ctx.LazyCopyFunctionRemoteInputs() || !op->is_function();
     for (int i = 0; i < op->Inputs().size(); i++) {
       tensorflow::TensorHandle* input = op->Inputs()[i];
-      tensorflow::Device* input_device = input->device();
-      const string* input_device_name = &input->DeviceOrHostCPU(ctx)->name();
+      tensorflow::Device* input_device = absl::get<Device*>(input->device());
+      tensorflow::Device* input_device_or_cpu =
+          absl::get<Device*>(input->DeviceOrHostCPU(ctx));
+      const string* input_device_name = &input_device_or_cpu->name();
       bool serialize_resource_dtype_and_shape = false;
       if (op->Device() != input_device &&
           // If the expected and actual devices are on the same task, don't
@@ -668,7 +680,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
           // when the op is executed on the device.
           !ctx.OnSameTask(op->Device(), input_device)) {
         if (eagerly_copy_function_remote_inputs ||
-            input->DeviceOrHostCPU(ctx)->IsLocal()) {
+            input_device_or_cpu->IsLocal()) {
           tensorflow::Device* remote_cpu_device;
           TF_RETURN_IF_ERROR(
               ctx.CPUDeviceOnTask(op->Device(), &remote_cpu_device));
@@ -678,7 +690,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
           // correctly determined after the kernel is selected/instantiated,
           // since the op might have its inputs on host memory.
           TensorHandle* handle = op->Inputs()[i];
-          Device* handle_device = handle->DeviceOrHostCPU(ctx);
+          Device* handle_device =
+              absl::get<Device*>(handle->DeviceOrHostCPU(ctx));
           // If the input is already on the right device, then nothing to do.
           if (remote_cpu_device != handle_device) {
             TF_RETURN_IF_ERROR(CopyInputToExpectedDevice(
@@ -854,7 +867,12 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
       // ineligible for CPU pinning.
       break;
     } else if (all_inputs_eligible_for_cpu_pinning) {
-      Device* input_device = tensor_handle->DeviceOrHostCPU(ctx);
+      auto input_device_variant = tensor_handle->DeviceOrHostCPU(ctx);
+      if (VariantDeviceIsCustom(input_device_variant)) {
+        all_inputs_eligible_for_cpu_pinning = false;
+        continue;
+      }
+      Device* input_device = absl::get<Device*>(input_device_variant);
       DVLOG(2) << "for op " << op->Name() << " input " << i << " "
                << DataTypeString(tensor_handle->dtype)
                << " input device = " << input_device->name()
@@ -902,6 +920,12 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals,
       [&] { return absl::StrCat("EagerExecute: ", op->Name()); },
       profiler::TraceMeLevel::kInfo);
   TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
+  CustomDevice* custom_device;
+  if (op->EagerContext()
+          .FindCustomDeviceFromName(op->GetDeviceName(), &custom_device)
+          .ok()) {
+    return custom_device->Execute(op, retvals, num_retvals);
+  }
 
   if (!op->Executor().Async()) {
     // In sync mode, always clear error to maintain the same behavior as before.
@@ -996,7 +1020,7 @@ Status EagerKernelExecute(
   for (int i = 0; i < retvals.size(); ++i) {
     DCHECK_EQ(kernel->device(), retvals[i]->op_device());
     DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
-              retvals[i]->device());
+              absl::get<Device*>(retvals[i]->device()));
 
     TF_RETURN_IF_ERROR(retvals[i]->SetTensor(std::move(outputs[i])));
   }
@@ -1031,9 +1055,12 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
 Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
                          EagerExecutor* executor, Device* device, bool mirror,
                          TensorHandle** result) {
-  Device* send_device = h->DeviceOrHostCPU(*ctx);
-
-  bool sender_is_local = send_device->IsLocal();
+  auto send_device = h->DeviceOrHostCPU(*ctx);
+  if (VariantDeviceIsCustom(send_device)) {
+    return errors::Unimplemented(
+        "Copying a TensorHandle from a custom device is not supported.");
+  }
+  bool sender_is_local = absl::get<Device*>(send_device)->IsLocal();
 
   bool recver_is_local = device->IsLocal();
 
diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc
index 42899447784..f63c8ea6221 100644
--- a/tensorflow/core/common_runtime/eager/execute_node.cc
+++ b/tensorflow/core/common_runtime/eager/execute_node.cc
@@ -71,9 +71,16 @@ Status ExecuteNodeArgs::Init(
     serialize_remote_handle_ =
         [ctx, &op_inputs](const int i,
                           eager::RemoteTensorHandle* handle) -> Status {
+      absl::variant<Device*, CustomDevice*> variant_device =
+          op_inputs[i]->device();
+      if (VariantDeviceIsCustom(variant_device)) {
+        return errors::Internal(
+            "Custom devices and remote execution are currently not supported "
+            "together.");
+      }
+      Device* device = absl::get<Device*>(variant_device);
       return ctx->RemoteMgr()->SerializeRemoteTensorHandle(
-          op_inputs[i], handle, op_inputs[i]->device(),
-          op_inputs[i]->device()->name());
+          op_inputs[i], handle, device, device->name());
     };
 #endif  // !IS_MOBILE_PLATFORM
   }
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 007bdb061bd..3747844d583 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -95,16 +95,25 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
 Status TensorHandle::CreateLocalHandle(const class Tensor& t,
                                        TensorHandle** h) {
   // TODO(b/136608821): Move away from nullptr
-  return CreateLocalHandle(t, /*d=*/nullptr, /*op_device=*/nullptr,
+  return CreateLocalHandle(t, /*d=*/static_cast<Device*>(nullptr),
+                           /*op_device=*/nullptr,
                            /*ctx=*/nullptr, h);
 }
 
-Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
+Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
                                        EagerContext* ctx, TensorHandle** h) {
-  return CreateLocalHandle(t, d, d, ctx, h);
+  Device* op_device;
+  if (VariantDeviceIsCustom(d)) {
+    // TODO(allenl): Figure out a better op_device story for custom devices,
+    // since always setting it to CPU=nullptr doesn't make much sense.
+    op_device = nullptr;
+  } else {
+    op_device = absl::get<Device*>(d);
+  }
+  return CreateLocalHandle(t, d, op_device, ctx, h);
 }
 
-Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
+Status TensorHandle::CreateLocalHandle(const class Tensor& t, VariantDevice d,
                                        Device* op_device, EagerContext* ctx,
                                        TensorHandle** h) {
   if (t.dtype() != DT_RESOURCE) {
@@ -120,7 +129,7 @@ Status TensorHandle::CreateLocalHandle(const class Tensor& t, Device* d,
 }
 
 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
-                           DataType dtype, Device* d, Device* op_device,
+                           DataType dtype, VariantDevice d, Device* op_device,
                            EagerContext* ctx)
     : dtype(dtype),
       device_(d),
@@ -135,12 +144,14 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
       is_async_(false),
       is_ready_(true),
       tensor_handle_data_(std::move(t)) {
-  DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
+  DVLOG(3) << "Creating Local TensorHandle: " << this
+           << " device: " << VariantDeviceDebugString(device_);
 }
 
 TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
-                           const ResourceHandle& resource_handle, Device* d,
-                           Device* op_device, EagerContext* ctx)
+                           const ResourceHandle& resource_handle,
+                           VariantDevice d, Device* op_device,
+                           EagerContext* ctx)
     : dtype(DT_RESOURCE),
       device_(d),
       op_device_(op_device),
@@ -155,7 +166,8 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
       is_ready_(true),
       handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
       tensor_handle_data_(std::move(t)) {
-  DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << device_;
+  DVLOG(3) << "Creating Local TensorHandle: " << this
+           << " device: " << VariantDeviceDebugString(device_);
 }
 
 Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
@@ -170,7 +182,7 @@ Status TensorHandle::CreateEmptyLocalHandle(bool async, Device* d,
 }
 
 TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
-                           bool async, Device* d, Device* op_device,
+                           bool async, VariantDevice d, Device* op_device,
                            Device* resource_device, DataType dtype,
                            EagerContext* ctx)
     : dtype(dtype),
@@ -187,7 +199,7 @@ TensorHandle::TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t,
       is_ready_(!async),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Async Local TensorHandle: " << this
-           << " device: " << device_;
+           << " device: " << VariantDeviceDebugString(device_);
 }
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -227,7 +239,7 @@ TensorHandle::TensorHandle(std::unique_ptr<RemoteTensorHandleData> t,
       is_ready_(true),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Remote TensorHandle: " << this
-           << " device: " << device_;
+           << " device: " << VariantDeviceDebugString(device_);
 }
 
 Status TensorHandle::CreateUnshapedRemoteHandle(
@@ -263,7 +275,7 @@ TensorHandle::TensorHandle(std::unique_ptr<UnshapedRemoteTensorHandleData> t,
       is_ready_(false),
       tensor_handle_data_(std::move(t)) {
   DVLOG(3) << "Creating Unshaped Remote TensorHandle: " << this
-           << " device: " << device_;
+           << " device: " << VariantDeviceDebugString(device_);
 }
 #endif
 
@@ -297,8 +309,14 @@ Status TensorHandle::TensorValue(tensorflow::TensorValue* t) {
   return tensor_handle_data_->TensorValue(t);
 }
 
-Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const {
-  return (device_ == nullptr) ? ctx.HostCPU() : device_;
+TensorHandle::VariantDevice TensorHandle::DeviceOrHostCPU(
+    const EagerContext& ctx) const {
+  if (VariantDeviceIsCustom(device_)) {
+    return device_;
+  } else {
+    Device* d = absl::get<Device*>(device_);
+    return (d == nullptr) ? ctx.HostCPU() : d;
+  }
 }
 
 Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
@@ -413,7 +431,7 @@ Status TensorHandle::NumElements(int64* num_elements) const {
 #if !defined(IS_MOBILE_PLATFORM)
 Status TensorHandle::RemoteAddress(Device* d, int64* op_id,
                                    int32* output_num) const {
-  if (d != device_) {
+  if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
     tf_shared_lock l(mu_);
     auto mirror = remote_mirrors_.find(d);
     if (mirror != remote_mirrors_.end()) {
@@ -517,7 +535,7 @@ Status TensorHandle::SetRemoteShape(const TensorShape& shape,
                                     tensorflow::Device* d) {
   DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d;
 
-  if (d != device_) {
+  if (VariantDeviceIsCustom(device_) || d != absl::get<Device*>(device_)) {
     mutex_lock l(mu_);
     if (remote_mirrors_.find(d) != remote_mirrors_.end()) {
       return errors::Internal(
@@ -593,7 +611,7 @@ void TensorHandle::Poison(Status status) {
 Status TensorHandle::CopyToDevice(const EagerContext& ctx,
                                   tensorflow::Device* dstd,
                                   tensorflow::Tensor* output) {
-  tensorflow::Device* srcd = DeviceOrHostCPU(ctx);
+  tensorflow::Device* srcd = absl::get<Device*>(DeviceOrHostCPU(ctx));
   const bool dst_cpu = dstd->tensorflow_gpu_device_info() == nullptr;
   const bool src_cpu = srcd->tensorflow_gpu_device_info() == nullptr;
   bool is_same_device =
@@ -655,6 +673,20 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx,
   return status;
 }
 
+bool VariantDeviceIsCustom(
+    absl::variant<Device*, CustomDevice*> variant_device) {
+  return variant_device.index() != 0;
+}
+
+string VariantDeviceDebugString(
+    absl::variant<Device*, CustomDevice*> variant_device) {
+  if (VariantDeviceIsCustom(variant_device)) {
+    return absl::get<CustomDevice*>(variant_device)->name();
+  } else {
+    return absl::get<Device*>(variant_device)->DebugString();
+  }
+}
+
 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) {
   if (ctx == nullptr) {
     return nullptr;
@@ -671,10 +703,14 @@ string TensorHandle::DebugString() const {
   DVLOG(1) << "Calling TensorHandle::DebugString() on " << this;
 
   string out;
-  strings::StrAppend(&out, "Device: ", device_ ? device_->DebugString() : "[]");
-  // Consider supporting non-CPU tensors (when device_ is non-NULL) if needed.
+  string device_debug = VariantDeviceDebugString(device_);
+  strings::StrAppend(&out, "Device: ", device_debug);
+  bool is_cpu =
+      !VariantDeviceIsCustom(device_) && absl::get<Device*>(device_) != nullptr;
+  // Consider supporting non-CPU tensors and CPU tensors with a device_ set to
+  // non-NULL if needed.
   strings::StrAppend(&out, ", Tensor: ",
-                     device_ ? "?" : tensor_handle_data_->DebugString(), "\n");
+                     is_cpu ? tensor_handle_data_->DebugString() : "?", "\n");
   return out;
 }
 
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index eb157577a3f..9f8825eaac8 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/core/platform/platform.h"
 // clang-format on
 
+#include "absl/types/variant.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
 #include "tensorflow/core/common_runtime/eager/context.h"
@@ -43,6 +44,7 @@ limitations under the License.
 #endif  // IS_MOBILE_PLATFORM
 #include "tensorflow/core/framework/rendezvous.h"
 #include "tensorflow/core/framework/tensor.h"
+
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/lib/gtl/map_util.h"
@@ -60,15 +62,20 @@ namespace tensorflow {
 // of the TFE_TensorHandle struct and the python EagerTensor class
 // (unrelated to python TensorHandle).
 class TensorHandle : public core::RefCounted {
+  // Custom devices do many of the same things as physical Devices, but have a
+  // much more restricted interface. We pass around ambiguous pointers since
+  // TensorHandles may be placed either on custom or physical devices.
+  using VariantDevice = absl::variant<Device*, CustomDevice*>;
+
   // TensorHandle for dtype != DT_RESOURCE
   TensorHandle(std::unique_ptr<LocalTensorHandleData> t, DataType dtype,
-               Device* d, Device* op_device, EagerContext* ctx);
+               VariantDevice d, Device* op_device, EagerContext* ctx);
   // TensorHandle for dtype == DT_RESOURCE
   TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
-               const ResourceHandle& resource_handle, Device* d,
+               const ResourceHandle& resource_handle, VariantDevice d,
                Device* op_device, EagerContext* ctx);
   TensorHandle(std::unique_ptr<EmptyLocalTensorHandleData> t, bool async,
-               Device* d, Device* op_device, Device* resource_device,
+               VariantDevice d, Device* op_device, Device* resource_device,
                DataType dtype, EagerContext* ctx);
 
 #if !defined(IS_MOBILE_PLATFORM)
@@ -82,9 +89,9 @@ class TensorHandle : public core::RefCounted {
   // TensorHandle with no assigned device
   static Status CreateLocalHandle(const class Tensor& t, TensorHandle** h);
   // TensorHandle with device == op_device
-  static Status CreateLocalHandle(const class Tensor& t, Device* d,
+  static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
                                   EagerContext* ctx, TensorHandle** h);
-  static Status CreateLocalHandle(const class Tensor& t, Device* d,
+  static Status CreateLocalHandle(const class Tensor& t, VariantDevice d,
                                   Device* op_device, EagerContext* ctx,
                                   TensorHandle** h);
   static Status CreateEmptyLocalHandle(bool async, Device* d, Device* op_device,
@@ -117,11 +124,11 @@ class TensorHandle : public core::RefCounted {
 
   Status TensorValue(tensorflow::TensorValue* t);
 
-  Device* device() const { return device_; }
+  VariantDevice device() const { return device_; }
   Device* op_device() const { return op_device_; }
   Device* resource_device() const { return resource_device_; }
 
-  Device* DeviceOrHostCPU(const EagerContext& ctx) const;
+  VariantDevice DeviceOrHostCPU(const EagerContext& ctx) const;
 
   Status Shape(tensorflow::TensorShape* shape);
   Status NumDims(int* num_dims) const;
@@ -188,8 +195,10 @@ class TensorHandle : public core::RefCounted {
 
   // TODO(b/136608821): Move away from nullptr
   bool OnHostCPU() const {
-    return device_ == nullptr ||
-           (ctx_ != nullptr && ctx_->HostCPU() == device_);
+    return (
+        device_.index() == 0 &&
+        (absl::get<Device*>(device_) == nullptr ||
+         (ctx_ != nullptr && ctx_->HostCPU() == absl::get<Device*>(device_))));
   }
 
   bool IsRemote() const { return is_remote_; }
@@ -216,7 +225,7 @@ class TensorHandle : public core::RefCounted {
   // done and the handle is "ready".
   Status WaitReady(const char* caller) const;
 
-  // TODO(b/136608821): device_ == nullptr iff Host CPU:0
+  // TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
   // This was expedient, but perhaps worth revisiting ('device_' should always
   // be a valid pointer?)
   // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are
@@ -224,7 +233,7 @@ class TensorHandle : public core::RefCounted {
   //
   // TODO(ashankar): Reference count TFE_Context to ensure that 'device_' of a
   // TFE_TensorHandle does not outlive the TFE_Context from which it came?
-  tensorflow::Device* const device_;
+  VariantDevice const device_;
 
   // Device in which the op producing this tensor was executed. Equals to
   // device_ for constant tensors.
@@ -286,6 +295,12 @@ class TensorHandle : public core::RefCounted {
   PartialTensorShape inference_shape_;
 };
 
+// Checks whether a VariantDevice contains a custom device.
+bool VariantDeviceIsCustom(absl::variant<Device*, CustomDevice*> device);
+
+// Wraps device->DebugString() or CustomDevice->name().
+string VariantDeviceDebugString(absl::variant<Device*, CustomDevice*> device);
+
 // Returns the device backing the resource. Else, returns nullptr.
 Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx);
 
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index 1fd063f617b..87459f4bb39 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -685,7 +685,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
       context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
   TF_ASSERT_OK(tensor_handle->Tensor(&t));
 
-  Device* device = tensor_handle->device();
+  Device* device = absl::get<Device*>(tensor_handle->device());
   EXPECT_EQ(device, nullptr);
 
   auto actual = t->flat<float>();
diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
index 01348026e0a..81c3a52e4d7 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc
@@ -77,7 +77,7 @@ RemoteCopyNode::RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor,
       src_(src),
       ctx_(ctx),
       executor_(executor),
-      send_device_(src->DeviceOrHostCPU(*ctx)),
+      send_device_(absl::get<Device*>(src->DeviceOrHostCPU(*ctx))),
       recv_device_(recv_device),
       wire_id_(GetUniqueWireID()),
       recv_op_id_(recv_op_id),
@@ -145,8 +145,8 @@ void RemoteCopyNode::StartSend() {
     request.set_context_id(ctx_->GetContextId());
     auto* remote_op = request.add_queue()->mutable_operation();
     status = ctx_->RemoteMgr()->SerializeRemoteTensorHandle(
-        src_, remote_op->add_inputs(), src_->device(),
-        src_->DeviceOrHostCPU(*ctx_)->name());
+        src_, remote_op->add_inputs(), absl::get<Device*>(src_->device()),
+        absl::get<Device*>(src_->DeviceOrHostCPU(*ctx_))->name());
     if (!status.ok()) {
       captured_state_->SetSendStatus(status);
       return;
diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
index a77d0cf41d9..aefe86c654d 100644
--- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
+++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc
@@ -75,8 +75,15 @@ Status RemoteMgr::GetMirroredResourceShape(
 
 Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle,
                                         int64* op_id, int32* output_num) {
+  // TODO(allenl): Consider supporting remote handles on custom devices.
+  absl::variant<Device*, CustomDevice*> device = handle->device();
+  if (VariantDeviceIsCustom(device)) {
+    return errors::Unimplemented(
+        "Custom devices and remote execution are currently not supported "
+        "together.");
+  }
   TF_RETURN_IF_ERROR(
-      handle->RemoteAddress(handle->device(), op_id, output_num));
+      handle->RemoteAddress(absl::get<Device*>(device), op_id, output_num));
   tensorflow::TensorHandle* h;
   TF_RETURN_IF_ERROR(
       GetTensorHandleImpl(RemoteTensorHandleInternal(*op_id, *output_num), &h));
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 3a47cd35cbf..a7cc9f59b69 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -1244,6 +1244,11 @@ StringPiece Tensor::tensor_data() const {
   return StringPiece(static_cast<char*>(buf_->data()), TotalBytes());
 }
 
+void* Tensor::data() const {
+  if (buf_ == nullptr) return nullptr;  // Don't die for empty tensors
+  return static_cast<void*>(buf_->data());
+}
+
 bool Tensor::SharesBufferWith(const Tensor& b) const {
   return buf_ != nullptr && b.buf_ != nullptr &&
          buf_->root_buffer() == b.buf_->root_buffer();
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 24f4a744d6f..11910766ba8 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -601,6 +601,7 @@ class Tensor {
   ///
   /// REQUIRES: `DataTypeCanUseMemcpy(dtype())`.
   StringPiece tensor_data() const;
+  void* data() const;
 
   /// Copy the other tensor into this tensor, reshape it and reinterpret the
   /// buffer's datatype. If Status::OK() is returned, the two tensors now share
diff --git a/tensorflow/python/framework/device_spec.py b/tensorflow/python/framework/device_spec.py
index 7ceef783364..08875ad9452 100644
--- a/tensorflow/python/framework/device_spec.py
+++ b/tensorflow/python/framework/device_spec.py
@@ -21,7 +21,7 @@ from __future__ import print_function
 from tensorflow.python.util.tf_export import tf_export
 
 
-_VALID_DEVICE_TYPES = {"CPU", "GPU", "TPU"}
+_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM"})
 
 
 # ==============================================================================
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index fd54938de57..59c68f5983d 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -95,7 +95,7 @@ Status MakeArgTuple(const PyCall* call, EagerContext* ctx, PyObject** tuple) {
     if (call->eager) {
       TensorHandle* handle;
       TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
-          t, ctx->CanonicalDevice(device), ctx, &handle));
+          t, ctx->CanonicalDevice(device), nullptr, ctx, &handle));
       arg = EagerTensorFromHandle(new TFE_TensorHandle{
           std::make_unique<tensorflow::TensorHandleInterface>(handle)});
       if (arg == nullptr) {
@@ -149,7 +149,11 @@ tensorflow::Status ExtractTensorFromEagerTensor(const PyObject* eager_tensor,
   auto handle = down_cast<tensorflow::TensorHandleInterface*>(
                     EagerTensor_Handle(eager_tensor)->handle.get())
                     ->Handle();
-  Device* actual_device = handle->device();
+  if (VariantDeviceIsCustom(handle->device())) {
+    return errors::Unimplemented(
+        "Custom devices are currently not supported with PyFuncs.");
+  }
+  Device* actual_device = absl::get<Device*>(handle->device());
   TF_RETURN_IF_ERROR(handle->Tensor(output_tensor));
   // actual_device may be nullptr, which implies local CPU.
   if (expected_device == actual_device) return Status::OK();
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 6bbf901a2d8..75a7b840db9 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -294,7 +294,8 @@ struct Converter {
     }
     tensorflow::TensorHandle* handle = nullptr;
     auto status = tensorflow::TensorHandle::CreateLocalHandle(
-        result, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
+        result, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
+        ctx->context, &handle);
     if (!status.ok()) {
       return status;
     }
@@ -609,7 +610,8 @@ TFE_TensorHandle* NumpyToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj) {
   auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
   if (cppstatus.ok()) {
     cppstatus = tensorflow::TensorHandle::CreateLocalHandle(
-        t, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &handle);
+        t, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr, ctx->context,
+        &handle);
   }
   if (!cppstatus.ok()) {
     PyErr_SetString(PyExc_ValueError,
@@ -806,7 +808,8 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj,
       Tensor tensor(requested_dtype == DT_INVALID ? DT_FLOAT : requested_dtype,
                     TensorShape(state.inferred_shape));
       status = tensorflow::TensorHandle::CreateLocalHandle(
-          tensor, /*d=*/nullptr, /*op_device=*/nullptr, ctx->context, &h);
+          tensor, /*d=*/ctx->context->HostCPU(), /*op_device=*/nullptr,
+          ctx->context, &h);
       if (!status.ok()) {
         PyErr_SetString(PyExc_ValueError, status.error_message().c_str());
         return nullptr;