From 731984bfd0f642f91c24781eb5fd09fa05268a10 Mon Sep 17 00:00:00 2001
From: Yujing Zhang <yujingzhang@google.com>
Date: Thu, 12 Sep 2019 14:47:10 -0700
Subject: [PATCH] Introduce EagerClusterFunctionLibraryRuntime. - Allow inputs
 on remote devices. - Run remote functions through eager service instead of
 worker service.

This cl has no behavior change since EagerClusterFunctionLibraryRuntime is not in use.

PiperOrigin-RevId: 268770768
---
 .../core/common_runtime/eager/context.cc      |   7 +-
 .../core/common_runtime/eager/context.h       |   3 +-
 .../core/distributed_runtime/eager/BUILD      |  26 +++
 .../eager/cluster_function_library_runtime.cc | 160 +++++++++++++++++
 .../eager/cluster_function_library_runtime.h  |  87 ++++++++++
 .../eager/eager_service_impl.cc               |   5 +-
 .../eager/eager_service_impl_test.cc          | 161 +++++++++++++++++-
 tensorflow/core/framework/function.h          |  20 +++
 tensorflow/core/protobuf/eager_service.proto  |   4 +
 9 files changed, 466 insertions(+), 7 deletions(-)
 create mode 100644 tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
 create mode 100644 tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h

diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 5515a3ab09b..b4fda6bda50 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -413,7 +413,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
   return Status::OK();
 }
 
-Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
+Status EagerContext::AddFunctionDef(const FunctionDef& fdef,
+                                    const bool add_to_local_only) {
   bool is_first_ref = false;
   {
     mutex_lock l(cache_mu_);
@@ -432,7 +433,9 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) {
   }
   if (is_first_ref) {
     TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef));
-    return MaybeRegisterFunctionRemotely(fdef);
+    if (!add_to_local_only) {
+      return MaybeRegisterFunctionRemotely(fdef);
+    }
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 4e744dd179e..47080742126 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -185,7 +185,8 @@ class EagerContext : public core::RefCounted {
 
   EagerExecutor& Executor();
 
-  Status AddFunctionDef(const FunctionDef& fdef);
+  Status AddFunctionDef(const FunctionDef& fdef,
+                        const bool add_to_local_only = false);
 
   Status RemoveFunction(const string& func);
 
diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD
index 34fc44af097..235a84c6d15 100644
--- a/tensorflow/core/distributed_runtime/eager/BUILD
+++ b/tensorflow/core/distributed_runtime/eager/BUILD
@@ -21,6 +21,29 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "cluster_function_library_runtime",
+    srcs = [
+        "cluster_function_library_runtime.cc",
+    ],
+    hdrs = [
+        "cluster_function_library_runtime.h",
+    ],
+    deps = [
+        ":eager_client",
+        ":remote_execute_node",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/common_runtime/eager:context",
+        "//tensorflow/core/common_runtime/eager:eager_operation",
+        "//tensorflow/core/common_runtime/eager:tensor_handle",
+        "@com_google_absl//absl/types:span",
+    ],
+)
+
 cc_library(
     name = "destroy_tensor_handle_node",
     hdrs = ["destroy_tensor_handle_node.h"],
@@ -64,6 +87,7 @@ cc_library(
         "eager_service_impl.h",
     ],
     deps = [
+        ":cluster_function_library_runtime",
         ":remote_mgr",
         ":remote_tensor_handle",
         "//tensorflow:grpc++",
@@ -95,6 +119,7 @@ tf_cc_test(
     name = "eager_service_impl_test",
     srcs = ["eager_service_impl_test.cc"],
     deps = [
+        ":cluster_function_library_runtime",
         ":eager_service_impl",
         ":remote_mgr",
         "//tensorflow/c:c_api",
@@ -111,6 +136,7 @@ tf_cc_test(
         "//tensorflow/core/distributed_runtime:test_utils",
         "//tensorflow/core/distributed_runtime:worker_env",
         "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
+        "@com_google_absl//absl/types:span",
     ],
 )
 
diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
new file mode 100644
index 00000000000..6c143e93fe2
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
+
+#include <map>
+
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
+#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace eager {
+
+Status EagerClusterFunctionLibraryRuntime::Instantiate(
+    const string& function_name, const FunctionLibraryDefinition& lib_def,
+    AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options,
+    FunctionLibraryRuntime::LocalHandle* handle) {
+  const tensorflow::AttrTypeMap* attr_types;
+  bool is_function = false;
+  TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(function_name.c_str(),
+                                                  &attr_types, &is_function));
+  if (!is_function) {
+    return errors::Internal(function_name, " is not a function.");
+  }
+  auto op = absl::make_unique<EagerOperation>(ctx_, function_name.c_str(),
+                                              is_function, attr_types);
+  TF_RETURN_IF_ERROR(op->SetDeviceName(options.target.c_str()));
+
+  VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target
+          << " (this: " << this << ")";
+  eager::EagerClient* eager_client = nullptr;
+  Device* device;
+  TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(options.target.c_str(), &device));
+  TF_RETURN_IF_ERROR(ctx_->GetClient(device, &eager_client));
+
+  if (eager_client == nullptr) {
+    return errors::InvalidArgument("Could not find eager client for target: ",
+                                   options.target);
+  }
+
+  const FunctionLibraryDefinition& func_lib_def =
+      options.lib_def ? *options.lib_def : lib_def;
+
+  RegisterFunctionRequest request;
+  const uint64 context_id = ctx_->GetContextId();
+  request.set_context_id(context_id);
+  // TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to
+  // support nested functions.
+  *request.mutable_function_def() = *func_lib_def.Find(function_name);
+  request.set_is_component_function(true);
+
+  Status status;
+  Notification done;
+  RegisterFunctionResponse response;
+  eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) {
+    status = s;
+    done.Notify();
+  });
+  done.WaitForNotification();
+  TF_RETURN_IF_ERROR(status);
+
+  mutex_lock l(mu_);
+  *handle = function_data_.size();
+  function_data_.emplace_back(options.target, context_id, eager_client,
+                              std::move(op));
+  return Status::OK();
+}
+
+void EagerClusterFunctionLibraryRuntime::Run(
+    const FunctionLibraryRuntime::Options& opts,
+    FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args,
+    std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
+  done(errors::Unimplemented("Not implemented"));
+}
+
+void EagerClusterFunctionLibraryRuntime::Run(
+    const FunctionLibraryRuntime::Options& opts,
+    FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
+    absl::Span<eager::RemoteTensorHandle* const> args,
+    FunctionLibraryRuntime::DoneCallback done) {
+  FunctionData* function_data = nullptr;
+  {
+    mutex_lock l(mu_);
+    DCHECK_LE(handle, function_data_.size());
+    function_data = &function_data_[handle];
+  }
+
+  EagerClient* eager_client = function_data->eager_client;
+  if (eager_client == nullptr) {
+    done(errors::Internal("Could not find eager client"));
+    return;
+  }
+
+  Device* device;
+  Status s = ctx_->FindDeviceFromName(function_data->target.c_str(), &device);
+  if (!s.ok()) {
+    done(errors::Internal("Failed to get device"));
+    return;
+  }
+
+  EagerOperation* op = function_data->op.get();
+
+  eager::EnqueueRequest* request = new eager::EnqueueRequest;
+  request->set_context_id(function_data->context_id);
+  eager::Operation* remote_op = request->add_queue()->mutable_operation();
+  for (size_t i = 0; i < args.size(); ++i) {
+    remote_op->add_inputs()->Swap(args[i]);
+  }
+  // TODO(yujingzhang): add step_id to eager::Operation to make sure that all
+  // component functions use the same step id.
+  // The remote component function should use the same op_id as its parent
+  // multi-device function's in order to get the global unqiue op_id generated
+  // by the master context.
+  remote_op->set_id(op_id);
+  remote_op->set_name(op->Name());
+  op->Attrs().FillAttrValueMap(remote_op->mutable_attrs());
+  remote_op->set_device(function_data->target);
+
+  for (auto handle : op->Inputs()) {
+    handle->Ref();
+  }
+
+  // TODO(yujingzhang): Use RemoteExecuteNode once we enable async execution.
+  EnqueueResponse* response = new EnqueueResponse;
+  eager_client->EnqueueAsync(request, response,
+                             [op, request, response, done](const Status& s) {
+                               for (auto handle : op->Inputs()) {
+                                 handle->Unref();
+                               }
+                               done(s);
+                               delete request;
+                               delete response;
+                             });
+}
+
+void EagerClusterFunctionLibraryRuntime::CleanUp(
+    uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
+    FunctionLibraryRuntime::DoneCallback done) {
+  done(Status::OK());
+}
+
+}  // namespace eager
+}  // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h
new file mode 100644
index 00000000000..56a7ee189fe
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h
@@ -0,0 +1,87 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
+
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
+
+namespace tensorflow {
+
+struct WorkerSession;
+
+namespace eager {
+
+// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run
+// functions across processes by making RPCs through eager service.
+class EagerClusterFunctionLibraryRuntime
+    : public DistributedFunctionLibraryRuntime {
+ public:
+  EagerClusterFunctionLibraryRuntime(EagerContext* ctx,
+                                     DeviceMgr* remote_device_mgr)
+      : ctx_(ctx), remote_device_mgr_(remote_device_mgr) {}
+
+  ~EagerClusterFunctionLibraryRuntime() override{};
+
+  Status Instantiate(const string& function_name,
+                     const FunctionLibraryDefinition& lib_def, AttrSlice attrs,
+                     const FunctionLibraryRuntime::InstantiateOptions& options,
+                     FunctionLibraryRuntime::LocalHandle* handle) override;
+
+  void Run(const FunctionLibraryRuntime::Options& opts,
+           FunctionLibraryRuntime::LocalHandle handle,
+           gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
+           FunctionLibraryRuntime::DoneCallback done) override;
+
+  void Run(const FunctionLibraryRuntime::Options& opts,
+           FunctionLibraryRuntime::LocalHandle handle, const int64 op_id,
+           absl::Span<eager::RemoteTensorHandle* const> args,
+           FunctionLibraryRuntime::DoneCallback done) override;
+
+  void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle,
+               FunctionLibraryRuntime::DoneCallback done) override;
+
+  DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; }
+
+ private:
+  EagerContext* ctx_;
+  DeviceMgr* remote_device_mgr_;  // not owned.
+
+  struct FunctionData {
+    const string target;
+    const uint64 context_id;
+    EagerClient* eager_client = nullptr;
+    std::unique_ptr<EagerOperation> op;
+
+    FunctionData(const string& target, const uint64 context_id,
+                 EagerClient* eager_client, std::unique_ptr<EagerOperation> op)
+        : target(target),
+          context_id(context_id),
+          eager_client(eager_client),
+          op(std::move(op)) {}
+  };
+
+  mutable mutex mu_;
+  std::vector<FunctionData> function_data_ GUARDED_BY(mu_);
+};
+
+}  // namespace eager
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
index 1c90ef96cc8..26e33634404 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc
@@ -367,7 +367,10 @@ Status EagerServiceImpl::RegisterFunction(
   TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context));
   core::ScopedUnref context_unref(context);
 
-  return context->Context()->AddFunctionDef(request->function_def());
+  // If the function is a component of a multi-device function, we only need to
+  // register it locally.
+  return context->Context()->AddFunctionDef(request->function_def(),
+                                            request->is_component_function());
 }
 
 Status EagerServiceImpl::SendTensor(const SendTensorRequest* request,
diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
index d278f56b99c..7f0915a471f 100644
--- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
+++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc
@@ -17,8 +17,10 @@ limitations under the License.
 
 #include <string.h>
 
+#include "absl/types/span.h"
 #include "tensorflow/c/c_api_internal.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h"
 #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h"
 #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
 #include "tensorflow/core/distributed_runtime/session_mgr.h"
@@ -42,6 +44,13 @@ namespace {
 class TestEagerServiceImpl : public EagerServiceImpl {
  public:
   explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {}
+  Status GetEagerContext(const uint64 context_id, EagerContext** ctx) {
+    ServerContext* context = nullptr;
+    TF_RETURN_IF_ERROR(GetServerContext(context_id, &context));
+    core::ScopedUnref context_unref(context);
+    *ctx = context->Context();
+    return Status::OK();
+  }
   Status GetTensorHandle(const uint64 context_id,
                          const RemoteTensorHandleInternal& remote_handle,
                          tensorflow::TensorHandle** handle) {
@@ -54,10 +63,48 @@ class TestEagerServiceImpl : public EagerServiceImpl {
   }
 };
 
-class DummyEagerClientCache : public EagerClientCache {
-  Status GetClient(const string& target, EagerClient** client) override {
-    return errors::Unimplemented("");
+class FakeEagerClient : public EagerClient {
+ public:
+  FakeEagerClient() {}
+  ~FakeEagerClient() override {}
+
+  void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; }
+
+#define CLIENT_METHOD(method)                                         \
+  void method##Async(const method##Request* request,                  \
+                     method##Response* response, StatusCallback done) \
+      override {                                                      \
+    done(impl_->method(request, response));                           \
   }
+
+  CLIENT_METHOD(CreateContext);
+  CLIENT_METHOD(Enqueue);
+  CLIENT_METHOD(WaitQueueDone);
+  CLIENT_METHOD(KeepAlive);
+  CLIENT_METHOD(CloseContext);
+  CLIENT_METHOD(RegisterFunction);
+#undef CLIENT_METHOD
+
+  void StreamingEnqueueAsync(const EnqueueRequest* request,
+                             EnqueueResponse* response,
+                             StatusCallback done) override {
+    done(errors::Unimplemented(""));
+  }
+
+ private:
+  TestEagerServiceImpl* impl_;
+};
+
+class DummyEagerClientCache : public EagerClientCache {
+ public:
+  DummyEagerClientCache() : client_(new FakeEagerClient) {}
+  Status GetClient(const string& target, EagerClient** client) override {
+    *client = client_.get();
+    return Status::OK();
+  }
+
+ private:
+  std::unique_ptr<EagerClient> client_;
 };
 
 class FakeCache : public TestWorkerCache {
@@ -66,6 +113,10 @@ class FakeCache : public TestWorkerCache {
     eager_client_cache->reset(new DummyEagerClientCache);
     return Status::OK();
   }
+
+  void ListWorkers(std::vector<string>* workers) const override {
+    workers->push_back("/job:localhost/replica:0/task:0");
+  }
 };
 
 class EagerServiceImplTest : public ::testing::Test {
@@ -311,6 +362,110 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
                                                &close_context_response));
 }
 
+// Test executes a function through EagerClusterFunctionLibraryRuntime.
+TEST_F(EagerServiceImplTest, ClusterFLRTest) {
+  TestEagerServiceImpl eager_service_impl(&worker_env_);
+
+  uint64 context_id = random::New64();
+
+  CreateContextRequest request;
+  request.mutable_server_def()->set_job_name("localhost");
+  request.mutable_server_def()->set_task_index(0);
+  request.set_context_id(context_id);
+  CreateContextResponse response;
+  TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
+
+  const string target_device = "/job:localhost/replica:0/task:0/device:CPU:0";
+
+  // Make the fake EagerClient use the local eager_service_impl.
+  EagerContext* ctx = nullptr;
+  TF_ASSERT_OK(eager_service_impl.GetEagerContext(context_id, &ctx));
+  Device* device;
+  TF_ASSERT_OK(ctx->FindDeviceFromName(target_device.c_str(), &device));
+  EagerClient* client;
+  TF_ASSERT_OK(ctx->GetClient(device, &client));
+  FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client);
+  fake_client->SetServiceImpl(&eager_service_impl);
+
+  auto eager_cluster_flr =
+      absl::make_unique<EagerClusterFunctionLibraryRuntime>(ctx, nullptr);
+  tensorflow::FunctionDef fdef = MatMulFunction();
+
+  // Create the remote input for MatMulFunction.
+  EnqueueRequest remote_enqueue_request;
+  remote_enqueue_request.set_context_id(context_id);
+  EnqueueResponse remote_enqueue_response;
+  std::unordered_map<string, AttrValue> const_attrs;
+  AttrValue val;
+  val.set_type(tensorflow::DataType::DT_FLOAT);
+  const_attrs.insert({"dtype", val});
+  val.Clear();
+  SetTensorProto(val.mutable_tensor());
+  const_attrs.insert({"value", val});
+  AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, target_device,
+                               &remote_enqueue_request);
+  TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request,
+                                          &remote_enqueue_response));
+
+  // Instantiate MatMulFunction.
+  FunctionLibraryRuntime::InstantiateOptions options;
+  options.target = target_device;
+  options.is_multi_device_function = true;
+  options.input_devices.push_back(target_device);
+  FunctionLibraryRuntime::Handle handle;
+  FunctionLibraryDefinition func_lib_def{OpRegistry::Global(), {}};
+  TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef));
+  TF_ASSERT_OK(eager_cluster_flr->Instantiate(
+      fdef.signature().name(), func_lib_def, AttrSlice(&fdef.attr()), options,
+      &handle));
+
+  // Run MatMulFunction.
+  FunctionLibraryRuntime::Options opts;
+  const int64 step_id = opts.step_id;
+  Notification done;
+  Status status;
+  RemoteTensorHandle input;
+  input.set_op_id(1);
+  input.set_output_num(0);
+  input.set_op_device(target_device);
+  input.set_device(target_device);
+  eager_cluster_flr->Run(opts, handle, 2, {&input},
+                         [&status, &done](const Status& s) {
+                           status = s;
+                           done.Notify();
+                         });
+  done.WaitForNotification();
+  TF_ASSERT_OK(status);
+
+  const tensorflow::Tensor* t = nullptr;
+  tensorflow::TensorHandle* tensor_handle;
+  TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
+      context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
+  TF_ASSERT_OK(tensor_handle->Tensor(&t));
+  auto actual = t->flat<float>();
+  EXPECT_EQ(4, actual.size());
+  EXPECT_EQ(7, actual(0));
+  EXPECT_EQ(10, actual(1));
+  EXPECT_EQ(15, actual(2));
+  EXPECT_EQ(22, actual(3));
+
+  Status cleanup_status;
+  bool callback_is_called = false;
+  eager_cluster_flr->CleanUp(
+      step_id, handle, [&cleanup_status, &callback_is_called](const Status& s) {
+        callback_is_called = true;
+        cleanup_status.Update(s);
+      });
+  EXPECT_TRUE(callback_is_called);
+  TF_ASSERT_OK(cleanup_status);
+
+  CloseContextRequest close_context_request;
+  close_context_request.set_context_id(context_id);
+  CloseContextResponse close_context_response;
+  TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request,
+                                               &close_context_response));
+}
+
 // Test creates a context and attempts to send a tensor (using the RPC), and
 // then use the tensor.
 TEST_F(EagerServiceImplTest, SendTensorTest) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 476c6055801..a6efad47fa1 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -18,6 +18,11 @@ limitations under the License.
 
 #include <vector>
 
+// clang-format off
+// Required for IS_MOBILE_PLATFORM
+#include "tensorflow/core/platform/platform.h"
+// clang-format on
+
 #include "tensorflow/core/framework/attr_value.pb.h"
 #include "tensorflow/core/framework/attr_value_util.h"
 #include "tensorflow/core/framework/function.pb.h"
@@ -34,6 +39,9 @@ limitations under the License.
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/protobuf.h"
 #include "tensorflow/core/protobuf/config.pb.h"
+#if !defined(IS_MOBILE_PLATFORM)
+#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h"
+#endif  // IS_MOBILE_PLATFORM
 
 namespace tensorflow {
 
@@ -816,6 +824,18 @@ class DistributedFunctionLibraryRuntime {
                    FunctionLibraryRuntime::LocalHandle handle,
                    gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
                    FunctionLibraryRuntime::DoneCallback done) = 0;
+
+#if !defined(IS_MOBILE_PLATFORM)
+  // TODO(yujingzhang): Support outputting tensors on remote devices.
+  virtual void Run(const FunctionLibraryRuntime::Options& opts,
+                   FunctionLibraryRuntime::LocalHandle handle,
+                   const int64 op_id,
+                   absl::Span<eager::RemoteTensorHandle* const> args,
+                   FunctionLibraryRuntime::DoneCallback done) {
+    done(errors::Unimplemented("Unimplemented."));
+  }
+#endif  // IS_MOBILE_PLATFORM
+
   virtual void CleanUp(uint64 step_id,
                        FunctionLibraryRuntime::LocalHandle handle,
                        FunctionLibraryRuntime::DoneCallback done) = 0;
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
index 99534a1fa96..60b39a2b97a 100644
--- a/tensorflow/core/protobuf/eager_service.proto
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -122,6 +122,10 @@ message RegisterFunctionRequest {
   fixed64 context_id = 1;
 
   FunctionDef function_def = 2;
+
+  // If true, it means that function_def is produced by graph partition during
+  // multi-device function instantiation.
+  bool is_component_function = 3;
 }
 
 message RegisterFunctionResponse {}