From 731984bfd0f642f91c24781eb5fd09fa05268a10 Mon Sep 17 00:00:00 2001 From: Yujing Zhang <yujingzhang@google.com> Date: Thu, 12 Sep 2019 14:47:10 -0700 Subject: [PATCH] Introduce EagerClusterFunctionLibraryRuntime. - Allow inputs on remote devices. - Run remote functions through eager service instead of worker service. This cl has no behavior change since EagerClusterFunctionLibraryRuntime is not in use. PiperOrigin-RevId: 268770768 --- .../core/common_runtime/eager/context.cc | 7 +- .../core/common_runtime/eager/context.h | 3 +- .../core/distributed_runtime/eager/BUILD | 26 +++ .../eager/cluster_function_library_runtime.cc | 160 +++++++++++++++++ .../eager/cluster_function_library_runtime.h | 87 ++++++++++ .../eager/eager_service_impl.cc | 5 +- .../eager/eager_service_impl_test.cc | 161 +++++++++++++++++- tensorflow/core/framework/function.h | 20 +++ tensorflow/core/protobuf/eager_service.proto | 4 + 9 files changed, 466 insertions(+), 7 deletions(-) create mode 100644 tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc create mode 100644 tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5515a3ab09b..b4fda6bda50 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -413,7 +413,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { return Status::OK(); } -Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { +Status EagerContext::AddFunctionDef(const FunctionDef& fdef, + const bool add_to_local_only) { bool is_first_ref = false; { mutex_lock l(cache_mu_); @@ -432,7 +433,9 @@ Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { } if (is_first_ref) { TF_RETURN_IF_ERROR(func_lib_def_.AddFunctionDef(fdef)); - return MaybeRegisterFunctionRemotely(fdef); + if (!add_to_local_only) { + return MaybeRegisterFunctionRemotely(fdef); + } } return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 4e744dd179e..47080742126 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -185,7 +185,8 @@ class EagerContext : public core::RefCounted { EagerExecutor& Executor(); - Status AddFunctionDef(const FunctionDef& fdef); + Status AddFunctionDef(const FunctionDef& fdef, + const bool add_to_local_only = false); Status RemoveFunction(const string& func); diff --git a/tensorflow/core/distributed_runtime/eager/BUILD b/tensorflow/core/distributed_runtime/eager/BUILD index 34fc44af097..235a84c6d15 100644 --- a/tensorflow/core/distributed_runtime/eager/BUILD +++ b/tensorflow/core/distributed_runtime/eager/BUILD @@ -21,6 +21,29 @@ cc_library( ], ) +cc_library( + name = "cluster_function_library_runtime", + srcs = [ + "cluster_function_library_runtime.cc", + ], + hdrs = [ + "cluster_function_library_runtime.h", + ], + deps = [ + ":eager_client", + ":remote_execute_node", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:eager_operation", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "destroy_tensor_handle_node", hdrs = ["destroy_tensor_handle_node.h"], @@ -64,6 +87,7 @@ cc_library( "eager_service_impl.h", ], deps = [ + ":cluster_function_library_runtime", ":remote_mgr", ":remote_tensor_handle", "//tensorflow:grpc++", @@ -95,6 +119,7 @@ tf_cc_test( name = "eager_service_impl_test", srcs = ["eager_service_impl_test.cc"], deps = [ + ":cluster_function_library_runtime", ":eager_service_impl", ":remote_mgr", "//tensorflow/c:c_api", @@ -111,6 +136,7 @@ tf_cc_test( "//tensorflow/core/distributed_runtime:test_utils", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc new file mode 100644 index 00000000000..6c143e93fe2 --- /dev/null +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" + +#include <map> + +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" +#include "tensorflow/core/distributed_runtime/eager/remote_execute_node.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" + +namespace tensorflow { +namespace eager { + +Status EagerClusterFunctionLibraryRuntime::Instantiate( + const string& function_name, const FunctionLibraryDefinition& lib_def, + AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle) { + const tensorflow::AttrTypeMap* attr_types; + bool is_function = false; + TF_RETURN_IF_ERROR(tensorflow::AttrTypeMapForOp(function_name.c_str(), + &attr_types, &is_function)); + if (!is_function) { + return errors::Internal(function_name, " is not a function."); + } + auto op = absl::make_unique<EagerOperation>(ctx_, function_name.c_str(), + is_function, attr_types); + TF_RETURN_IF_ERROR(op->SetDeviceName(options.target.c_str())); + + VLOG(1) << "CFLR::Instantiate: " << function_name << " on " << options.target + << " (this: " << this << ")"; + eager::EagerClient* eager_client = nullptr; + Device* device; + TF_RETURN_IF_ERROR(ctx_->FindDeviceFromName(options.target.c_str(), &device)); + TF_RETURN_IF_ERROR(ctx_->GetClient(device, &eager_client)); + + if (eager_client == nullptr) { + return errors::InvalidArgument("Could not find eager client for target: ", + options.target); + } + + const FunctionLibraryDefinition& func_lib_def = + options.lib_def ? *options.lib_def : lib_def; + + RegisterFunctionRequest request; + const uint64 context_id = ctx_->GetContextId(); + request.set_context_id(context_id); + // TODO(yujingzhang): add FunctionDefLibrary to RegisterFunctionRequest to + // support nested functions. + *request.mutable_function_def() = *func_lib_def.Find(function_name); + request.set_is_component_function(true); + + Status status; + Notification done; + RegisterFunctionResponse response; + eager_client->RegisterFunctionAsync(&request, &response, [&](Status s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + TF_RETURN_IF_ERROR(status); + + mutex_lock l(mu_); + *handle = function_data_.size(); + function_data_.emplace_back(options.target, context_id, eager_client, + std::move(op)); + return Status::OK(); +} + +void EagerClusterFunctionLibraryRuntime::Run( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args, + std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) { + done(errors::Unimplemented("Not implemented")); +} + +void EagerClusterFunctionLibraryRuntime::Run( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, const int64 op_id, + absl::Span<eager::RemoteTensorHandle* const> args, + FunctionLibraryRuntime::DoneCallback done) { + FunctionData* function_data = nullptr; + { + mutex_lock l(mu_); + DCHECK_LE(handle, function_data_.size()); + function_data = &function_data_[handle]; + } + + EagerClient* eager_client = function_data->eager_client; + if (eager_client == nullptr) { + done(errors::Internal("Could not find eager client")); + return; + } + + Device* device; + Status s = ctx_->FindDeviceFromName(function_data->target.c_str(), &device); + if (!s.ok()) { + done(errors::Internal("Failed to get device")); + return; + } + + EagerOperation* op = function_data->op.get(); + + eager::EnqueueRequest* request = new eager::EnqueueRequest; + request->set_context_id(function_data->context_id); + eager::Operation* remote_op = request->add_queue()->mutable_operation(); + for (size_t i = 0; i < args.size(); ++i) { + remote_op->add_inputs()->Swap(args[i]); + } + // TODO(yujingzhang): add step_id to eager::Operation to make sure that all + // component functions use the same step id. + // The remote component function should use the same op_id as its parent + // multi-device function's in order to get the global unqiue op_id generated + // by the master context. + remote_op->set_id(op_id); + remote_op->set_name(op->Name()); + op->Attrs().FillAttrValueMap(remote_op->mutable_attrs()); + remote_op->set_device(function_data->target); + + for (auto handle : op->Inputs()) { + handle->Ref(); + } + + // TODO(yujingzhang): Use RemoteExecuteNode once we enable async execution. + EnqueueResponse* response = new EnqueueResponse; + eager_client->EnqueueAsync(request, response, + [op, request, response, done](const Status& s) { + for (auto handle : op->Inputs()) { + handle->Unref(); + } + done(s); + delete request; + delete response; + }); +} + +void EagerClusterFunctionLibraryRuntime::CleanUp( + uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, + FunctionLibraryRuntime::DoneCallback done) { + done(Status::OK()); +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h new file mode 100644 index 00000000000..56a7ee189fe --- /dev/null +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h @@ -0,0 +1,87 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ + +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/eager_operation.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" + +namespace tensorflow { + +struct WorkerSession; + +namespace eager { + +// EagerClusterFunctionLibraryRuntime contains methods to Instantiate and Run +// functions across processes by making RPCs through eager service. +class EagerClusterFunctionLibraryRuntime + : public DistributedFunctionLibraryRuntime { + public: + EagerClusterFunctionLibraryRuntime(EagerContext* ctx, + DeviceMgr* remote_device_mgr) + : ctx_(ctx), remote_device_mgr_(remote_device_mgr) {} + + ~EagerClusterFunctionLibraryRuntime() override{}; + + Status Instantiate(const string& function_name, + const FunctionLibraryDefinition& lib_def, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::LocalHandle* handle) override; + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, + FunctionLibraryRuntime::DoneCallback done) override; + + void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, const int64 op_id, + absl::Span<eager::RemoteTensorHandle* const> args, + FunctionLibraryRuntime::DoneCallback done) override; + + void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, + FunctionLibraryRuntime::DoneCallback done) override; + + DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } + + private: + EagerContext* ctx_; + DeviceMgr* remote_device_mgr_; // not owned. + + struct FunctionData { + const string target; + const uint64 context_id; + EagerClient* eager_client = nullptr; + std::unique_ptr<EagerOperation> op; + + FunctionData(const string& target, const uint64 context_id, + EagerClient* eager_client, std::unique_ptr<EagerOperation> op) + : target(target), + context_id(context_id), + eager_client(eager_client), + op(std::move(op)) {} + }; + + mutable mutex mu_; + std::vector<FunctionData> function_data_ GUARDED_BY(mu_); +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_CLUSTER_FUNCTION_LIBRARY_RUNTIME_H_ diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 1c90ef96cc8..26e33634404 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -367,7 +367,10 @@ Status EagerServiceImpl::RegisterFunction( TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); core::ScopedUnref context_unref(context); - return context->Context()->AddFunctionDef(request->function_def()); + // If the function is a component of a multi-device function, we only need to + // register it locally. + return context->Context()->AddFunctionDef(request->function_def(), + request->is_component_function()); } Status EagerServiceImpl::SendTensor(const SendTensorRequest* request, diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index d278f56b99c..7f0915a471f 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -17,8 +17,10 @@ limitations under the License. #include <string.h> +#include "absl/types/span.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.h" #include "tensorflow/core/distributed_runtime/eager/remote_mgr.h" #include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h" #include "tensorflow/core/distributed_runtime/session_mgr.h" @@ -42,6 +44,13 @@ namespace { class TestEagerServiceImpl : public EagerServiceImpl { public: explicit TestEagerServiceImpl(const WorkerEnv* env) : EagerServiceImpl(env) {} + Status GetEagerContext(const uint64 context_id, EagerContext** ctx) { + ServerContext* context = nullptr; + TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); + core::ScopedUnref context_unref(context); + *ctx = context->Context(); + return Status::OK(); + } Status GetTensorHandle(const uint64 context_id, const RemoteTensorHandleInternal& remote_handle, tensorflow::TensorHandle** handle) { @@ -54,10 +63,48 @@ class TestEagerServiceImpl : public EagerServiceImpl { } }; -class DummyEagerClientCache : public EagerClientCache { - Status GetClient(const string& target, EagerClient** client) override { - return errors::Unimplemented(""); +class FakeEagerClient : public EagerClient { + public: + FakeEagerClient() {} + ~FakeEagerClient() override {} + + void SetServiceImpl(TestEagerServiceImpl* impl) { impl_ = impl; } + +#define CLIENT_METHOD(method) \ + void method##Async(const method##Request* request, \ + method##Response* response, StatusCallback done) \ + override { \ + done(impl_->method(request, response)); \ } + + CLIENT_METHOD(CreateContext); + CLIENT_METHOD(Enqueue); + CLIENT_METHOD(WaitQueueDone); + CLIENT_METHOD(KeepAlive); + CLIENT_METHOD(CloseContext); + CLIENT_METHOD(RegisterFunction); +#undef CLIENT_METHOD + + void StreamingEnqueueAsync(const EnqueueRequest* request, + EnqueueResponse* response, + StatusCallback done) override { + done(errors::Unimplemented("")); + } + + private: + TestEagerServiceImpl* impl_; +}; + +class DummyEagerClientCache : public EagerClientCache { + public: + DummyEagerClientCache() : client_(new FakeEagerClient) {} + Status GetClient(const string& target, EagerClient** client) override { + *client = client_.get(); + return Status::OK(); + } + + private: + std::unique_ptr<EagerClient> client_; }; class FakeCache : public TestWorkerCache { @@ -66,6 +113,10 @@ class FakeCache : public TestWorkerCache { eager_client_cache->reset(new DummyEagerClientCache); return Status::OK(); } + + void ListWorkers(std::vector<string>* workers) const override { + workers->push_back("/job:localhost/replica:0/task:0"); + } }; class EagerServiceImplTest : public ::testing::Test { @@ -311,6 +362,110 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { &close_context_response)); } +// Test executes a function through EagerClusterFunctionLibraryRuntime. +TEST_F(EagerServiceImplTest, ClusterFLRTest) { + TestEagerServiceImpl eager_service_impl(&worker_env_); + + uint64 context_id = random::New64(); + + CreateContextRequest request; + request.mutable_server_def()->set_job_name("localhost"); + request.mutable_server_def()->set_task_index(0); + request.set_context_id(context_id); + CreateContextResponse response; + TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); + + const string target_device = "/job:localhost/replica:0/task:0/device:CPU:0"; + + // Make the fake EagerClient use the local eager_service_impl. + EagerContext* ctx = nullptr; + TF_ASSERT_OK(eager_service_impl.GetEagerContext(context_id, &ctx)); + Device* device; + TF_ASSERT_OK(ctx->FindDeviceFromName(target_device.c_str(), &device)); + EagerClient* client; + TF_ASSERT_OK(ctx->GetClient(device, &client)); + FakeEagerClient* fake_client = static_cast<FakeEagerClient*>(client); + fake_client->SetServiceImpl(&eager_service_impl); + + auto eager_cluster_flr = + absl::make_unique<EagerClusterFunctionLibraryRuntime>(ctx, nullptr); + tensorflow::FunctionDef fdef = MatMulFunction(); + + // Create the remote input for MatMulFunction. + EnqueueRequest remote_enqueue_request; + remote_enqueue_request.set_context_id(context_id); + EnqueueResponse remote_enqueue_response; + std::unordered_map<string, AttrValue> const_attrs; + AttrValue val; + val.set_type(tensorflow::DataType::DT_FLOAT); + const_attrs.insert({"dtype", val}); + val.Clear(); + SetTensorProto(val.mutable_tensor()); + const_attrs.insert({"value", val}); + AddOperationToEnqueueRequest(1, "Const", {}, const_attrs, target_device, + &remote_enqueue_request); + TF_ASSERT_OK(eager_service_impl.Enqueue(&remote_enqueue_request, + &remote_enqueue_response)); + + // Instantiate MatMulFunction. + FunctionLibraryRuntime::InstantiateOptions options; + options.target = target_device; + options.is_multi_device_function = true; + options.input_devices.push_back(target_device); + FunctionLibraryRuntime::Handle handle; + FunctionLibraryDefinition func_lib_def{OpRegistry::Global(), {}}; + TF_ASSERT_OK(func_lib_def.AddFunctionDef(fdef)); + TF_ASSERT_OK(eager_cluster_flr->Instantiate( + fdef.signature().name(), func_lib_def, AttrSlice(&fdef.attr()), options, + &handle)); + + // Run MatMulFunction. + FunctionLibraryRuntime::Options opts; + const int64 step_id = opts.step_id; + Notification done; + Status status; + RemoteTensorHandle input; + input.set_op_id(1); + input.set_output_num(0); + input.set_op_device(target_device); + input.set_device(target_device); + eager_cluster_flr->Run(opts, handle, 2, {&input}, + [&status, &done](const Status& s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + TF_ASSERT_OK(status); + + const tensorflow::Tensor* t = nullptr; + tensorflow::TensorHandle* tensor_handle; + TF_ASSERT_OK(eager_service_impl.GetTensorHandle( + context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); + TF_ASSERT_OK(tensor_handle->Tensor(&t)); + auto actual = t->flat<float>(); + EXPECT_EQ(4, actual.size()); + EXPECT_EQ(7, actual(0)); + EXPECT_EQ(10, actual(1)); + EXPECT_EQ(15, actual(2)); + EXPECT_EQ(22, actual(3)); + + Status cleanup_status; + bool callback_is_called = false; + eager_cluster_flr->CleanUp( + step_id, handle, [&cleanup_status, &callback_is_called](const Status& s) { + callback_is_called = true; + cleanup_status.Update(s); + }); + EXPECT_TRUE(callback_is_called); + TF_ASSERT_OK(cleanup_status); + + CloseContextRequest close_context_request; + close_context_request.set_context_id(context_id); + CloseContextResponse close_context_response; + TF_ASSERT_OK(eager_service_impl.CloseContext(&close_context_request, + &close_context_response)); +} + // Test creates a context and attempts to send a tensor (using the RPC), and // then use the tensor. TEST_F(EagerServiceImplTest, SendTensorTest) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 476c6055801..a6efad47fa1 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -18,6 +18,11 @@ limitations under the License. #include <vector> +// clang-format off +// Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/platform/platform.h" +// clang-format on + #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function.pb.h" @@ -34,6 +39,9 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/protobuf/config.pb.h" +#if !defined(IS_MOBILE_PLATFORM) +#include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" +#endif // IS_MOBILE_PLATFORM namespace tensorflow { @@ -816,6 +824,18 @@ class DistributedFunctionLibraryRuntime { FunctionLibraryRuntime::LocalHandle handle, gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) = 0; + +#if !defined(IS_MOBILE_PLATFORM) + // TODO(yujingzhang): Support outputting tensors on remote devices. + virtual void Run(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::LocalHandle handle, + const int64 op_id, + absl::Span<eager::RemoteTensorHandle* const> args, + FunctionLibraryRuntime::DoneCallback done) { + done(errors::Unimplemented("Unimplemented.")); + } +#endif // IS_MOBILE_PLATFORM + virtual void CleanUp(uint64 step_id, FunctionLibraryRuntime::LocalHandle handle, FunctionLibraryRuntime::DoneCallback done) = 0; diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 99534a1fa96..60b39a2b97a 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -122,6 +122,10 @@ message RegisterFunctionRequest { fixed64 context_id = 1; FunctionDef function_def = 2; + + // If true, it means that function_def is produced by graph partition during + // multi-device function instantiation. + bool is_component_function = 3; } message RegisterFunctionResponse {}