From 550581f6bd90ade2e72bdd6af843873c5b92861d Mon Sep 17 00:00:00 2001 From: Bramandia Ramadhana Date: Wed, 20 May 2020 08:50:21 -0700 Subject: [PATCH] When calling connect_to_cluser, if the options are identical and there is no renaming of local device, reuse existing local DeviceManager, otherwise we keep the old DeviceManager around to allow the old Tensor created to be usable. PiperOrigin-RevId: 312489501 Change-Id: Id392d0324aba7e7f9e92f8efeaf33683157470e1 --- tensorflow/c/eager/c_api.cc | 18 ++++- tensorflow/c/eager/c_api_cluster_test.cc | 68 ++++++++++++++++++- tensorflow/c/experimental/network.cc | 2 +- .../core/common_runtime/eager/context.cc | 17 +++-- .../core/common_runtime/eager/context.h | 11 ++- .../eager/eager_service_impl.cc | 4 +- .../core/distributed_runtime/graph_mgr.cc | 2 +- .../core/distributed_runtime/graph_mgr.h | 4 +- .../rpc/grpc_server_lib.cc | 43 +++++++----- .../distributed_runtime/rpc/grpc_server_lib.h | 8 +++ .../core/distributed_runtime/server_lib.cc | 12 +++- .../core/distributed_runtime/server_lib.h | 11 ++- .../distributed_runtime/server_lib_test.cc | 2 +- .../core/distributed_runtime/session_mgr.cc | 2 +- tensorflow/core/distributed_runtime/worker.cc | 2 +- .../core/distributed_runtime/worker_env.h | 2 +- .../distributed_runtime/worker_session.cc | 4 +- .../core/distributed_runtime/worker_session.h | 10 +-- 18 files changed, 174 insertions(+), 48 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 912cd184b77..5a39c17e1d9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) { } #if !defined(IS_MOBILE_PLATFORM) +bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, + const tensorflow::ServerDef& server_def) { + if (server_def.job_name() != context->HostCPU()->parsed_name().job) { + return false; + } + return server_def.default_session_config().SerializeAsString() == + context->session_options().config.SerializeAsString(); +} + tensorflow::Status AddRemoteDevicesToMgr( const std::vector& added_remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); + const tensorflow::DeviceMgr* device_mgr = + AreLocalDevicesCompatible(context, server_def) + ? context->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( + server_def, {device_mgr}, &new_server)); grpc_server = dynamic_cast(new_server.get()); LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); + ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); } else { LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, &curr_remote_workers)); diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 252a0408758..f8c702d592a 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); + {i, tensorflow::strings::StrCat("localhost", ":", port)}); } return server_def; } @@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { TestRemoteExecuteUpdateServerDefWithFailures(true); } +void TestConnectToCluster(bool keep_localhost_for_first_connect) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + const string first_name = + keep_localhost_for_first_connect ? "localhost" : "abc"; + tensorflow::ServerDef server_def = GetServerDef(first_name, 1); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + + tensorflow::Status status2; + EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name); + + // Rename local device + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev1_name = + absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0"); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name); + + // Another renaming of local device + const string second_name = "def"; + server_def.set_job_name(second_name); + server_def.mutable_cluster()->mutable_job(0)->set_name(second_name); + (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] = + absl::StrCat(second_name, ":", + tensorflow::testing::PickUnusedPortOrDie()); + + serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle2, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + TFE_DeleteTensorHandle(var_handle2); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); } + +TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); } + } // namespace diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc index 94375cf9983..97e63ec6259 100644 --- a/tensorflow/c/experimental/network.cc +++ b/tensorflow/c/experimental/network.cc @@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory { delete_function_(delete_function), rendezvous_builder_(rendezvous_builder) {} - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { TF_RETURN_IF_ERROR(CGrpcServer::Create( server_def, init_function_, start_function_, stop_function_, diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 207c6a02d5b..1024f3caabd 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -81,7 +81,8 @@ EagerContext::EagerContext( bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, DistributedFunctionLibraryRuntime* cluster_flr) - : default_device_placement_policy_(default_device_placement_policy), + : opts_(opts), + default_device_placement_policy_(default_device_placement_policy), default_mirroring_policy_(default_mirroring_policy), local_device_manager_(device_mgr, device_mgr_owned), host_cpu_device_(device_mgr->HostCPU()), @@ -1051,7 +1052,7 @@ void EagerContext::IncrementContextViewId() { // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. Status EagerContext::StoreCollectiveOpsServer( - std::unique_ptr new_server, DeviceMgr* device_mgr, + std::unique_ptr new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { collective_executor_mgr_.Reset(rpc_collective_executor_mgr); @@ -1176,7 +1177,7 @@ Status EagerContext::InitializeRemoteMaster( std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, const std::vector& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr) { @@ -1275,7 +1276,7 @@ Status EagerContext::SetMasterContextState( std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, uint64 context_id, - uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr, + uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr) { @@ -1287,7 +1288,13 @@ Status EagerContext::SetMasterContextState( use_send_tensor_rpc_ = ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true); - local_device_manager_.Reset(local_device_mgr); + if (local_device_mgr != local_device_manager_.Get()) { + if (local_device_manager_.Owned()) { + old_local_device_managers_.push_back( + std::move(local_device_manager_.owned_object)); + } + local_device_manager_.Reset(local_device_mgr); + } host_cpu_device_ = local_device_manager_.Get()->HostCPU(); if (rendezvous_ != nullptr) rendezvous_->Unref(); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d03a91c817a..cceb883a965 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -399,7 +399,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, const std::vector& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr); @@ -436,7 +436,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { const std::vector& remote_contexts, uint64 context_id); Status StoreCollectiveOpsServer( - std::unique_ptr new_server, DeviceMgr* device_mgr, + std::unique_ptr new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); // For the specified remote worker, preprocess and set its device filters. @@ -510,6 +510,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // Gets the CPU device on the task of device. Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; + const SessionOptions& session_options() const { return opts_; } + private: ~EagerContext() override; @@ -563,6 +565,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { T* unowned_object_ptr = nullptr; }; + SessionOptions opts_; const ContextDevicePlacementPolicy default_device_placement_policy_; const ContextMirroringPolicy default_mirroring_policy_; @@ -575,6 +578,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { TF_GUARDED_BY(policy_map_mu_); OwnedOrUnownedHelper local_device_manager_; + // Maintain copy of all previously created local device managers. + std::vector> old_local_device_managers_; // Unowned DynamicDeviceMgr is set on remote worker to allow running // multi-device function on remote worker. @@ -662,7 +667,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, uint64 context_id, uint64 context_view_id, Rendezvous* r, - DeviceMgr* local_device_mgr, int keep_alive_secs, + const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 6dc03cbc527..5327cbb6480 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -238,7 +238,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); // Initialize remote tensor communication based on worker session. TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); @@ -355,7 +355,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); std::vector remote_workers; worker_session->worker_cache()->ListWorkers(&remote_workers); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8b363e66d87..fe353d7d76c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -55,7 +55,7 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) +GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr) : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 50190ab337e..e768c0907b6 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -69,7 +69,7 @@ class WorkerSession; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle". The registered graph retains a @@ -145,7 +145,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - DeviceMgr* device_mgr_; + const DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 754209082fd..6523d2fb4dd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -130,9 +130,6 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. - } else { - // Note: session_mgr's legacy_session_ deletes device_mgr now. - delete worker_env_.device_mgr; } // Do not delete (as these are not owned by the server): @@ -204,12 +201,18 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { string name_prefix = strings::StrCat("/job:", server_def_.job_name(), "/replica:0", "/task:", server_def_.task_index()); - std::vector> devices; - TF_RETURN_IF_ERROR( - DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); - worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); - master_env_.local_devices = worker_env_.device_mgr->ListDevices(); + if (opts.local_device_mgr == nullptr) { + std::vector> devices; + TF_RETURN_IF_ERROR( + DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); + worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); + owned_device_manager_.reset(worker_env_.device_mgr); + } else { + worker_env_.device_mgr = opts.local_device_mgr; + owned_device_manager_.reset(nullptr); + } worker_env_.local_devices = worker_env_.device_mgr->ListDevices(); + master_env_.local_devices = worker_env_.device_mgr->ListDevices(); worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr ? new RpcRendezvousMgr(&worker_env_) : opts.rendezvous_mgr_func(&worker_env_); @@ -527,12 +530,13 @@ std::unique_ptr GrpcServer::CreateMaster(MasterEnv* master_env) { /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, std::unique_ptr* out_server) { std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - ServiceInitFunction service_func = nullptr; GrpcServerOptions options; options.rendezvous_mgr_func = NewRpcRendezvousMgr; + options.local_device_mgr = local_device_mgr; Status s = ret->Init(options); if (!s.ok()) { LOG(ERROR) << s; @@ -542,19 +546,21 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { + return Create(server_def, env, nullptr, out_server); +} + /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server) { - std::unique_ptr ret( - new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - GrpcServerOptions options; - options.rendezvous_mgr_func = NewRpcRendezvousMgr; - Status s = ret->Init(options); + std::unique_ptr server; + Status s = Create(server_def, env, nullptr, &server); if (!s.ok()) { - LOG(ERROR) << s; return s; } - *out_server = std::move(ret); + out_server->reset(dynamic_cast(server.release())); return Status::OK(); } @@ -566,9 +572,10 @@ class GrpcServerFactory : public ServerFactory { return server_def.protocol() == "grpc"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { - return GrpcServer::Create(server_def, Env::Default(), out_server); + return GrpcServer::Create(server_def, Env::Default(), + options.local_device_mgr, out_server); } }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index b3fa7d1f303..0474c5a517f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -68,11 +68,14 @@ struct GrpcServerOptions { WorkerCreationFunction worker_func = nullptr; StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; GrpcWorkerServiceOptions worker_service_options; + const DeviceMgr* local_device_mgr = nullptr; }; class GrpcServer : public ServerInterface { protected: GrpcServer(const ServerDef& server_def, Env* env); + GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, + Env* env); // Allow children classes to override this and provide custom args to the // server before it is constructed. Default behavior is to do nothing. virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); @@ -82,6 +85,10 @@ class GrpcServer : public ServerInterface { std::unique_ptr* out_server); static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server); + // Reuse the local_device_mgr. + static Status Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, + std::unique_ptr* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -163,6 +170,7 @@ class GrpcServer : public ServerInterface { // Implementation of a TensorFlow worker, and RPC polling thread. WorkerEnv worker_env_; + std::unique_ptr owned_device_manager_; std::unique_ptr worker_impl_; AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr worker_thread_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc index 62a2011db39..12baa75976a 100644 --- a/tensorflow/core/distributed_runtime/server_lib.cc +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -73,7 +73,17 @@ Status NewServer(const ServerDef& server_def, std::unique_ptr* out_server) { ServerFactory* factory; TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); - return factory->NewServer(server_def, out_server); + return factory->NewServer(server_def, ServerFactory::Options(), out_server); +} + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server) { + ServerFactory* factory; + TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); + return factory->NewServer(server_def, options, out_server); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h index 275f526d311..7b4b4892848 100644 --- a/tensorflow/core/distributed_runtime/server_lib.h +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { +class DeviceMgr; + // This library supports a registration/factory-based mechanism for // creating TensorFlow server objects. Each server implementation must // have an accompanying implementation of ServerFactory, and create a @@ -63,10 +65,14 @@ class ServerInterface { class ServerFactory { public: + struct Options { + // Local DeviceMgr to use. + const tensorflow::DeviceMgr* local_device_mgr; + }; // Creates a new server based on the given `server_def`, and stores // it in `*out_server`. Returns OK on success, otherwise returns an // error. - virtual Status NewServer(const ServerDef& server_def, + virtual Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) = 0; // Returns true if and only if this factory can create a server @@ -92,6 +98,9 @@ class ServerFactory { // `*out_server`. Returns OK on success, otherwise returns an error. Status NewServer(const ServerDef& server_def, std::unique_ptr* out_server); +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib_test.cc b/tensorflow/core/distributed_runtime/server_lib_test.cc index 77048c24b47..2152ff986d6 100644 --- a/tensorflow/core/distributed_runtime/server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/server_lib_test.cc @@ -26,7 +26,7 @@ class TestServerFactory : public ServerFactory { return server_def.protocol() == "test_protocol"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2151e068f6..1d9a22a5817 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -171,7 +171,7 @@ Status SessionMgr::UpdateSession( std::vector> cluster_devices; - DeviceMgr* local_device_mgr = worker_session->device_mgr(); + const DeviceMgr* local_device_mgr = worker_session->device_mgr(); DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr(); std::vector curr_remote_devices = remote_device_mgr->ListDevices(); std::vector> added_remote_devices; diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 7850ecc46b2..f857a63e64d 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -38,7 +38,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { void Worker::GetStatusAsync(const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { - DeviceMgr* dm = env_->device_mgr; + const DeviceMgr* dm = env_->device_mgr; std::vector devices; dm->ListDeviceAttributes(&devices); response->mutable_device_attributes()->Reserve(devices.size()); diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 93d933bfa60..ecc3313d0ce 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -53,7 +53,7 @@ struct WorkerEnv { // Note: Please use the device_mgr associated with your session if appropriate // instead of this one. Using this device_mgr does not support ClusterSpec // propagated sessions. - DeviceMgr* device_mgr = nullptr; + const DeviceMgr* device_mgr = nullptr; // A set of rendezvous keyed by step ids. RendezvousMgrInterface* rendezvous_mgr = nullptr; diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index ca4f25f08f5..3aed73fa358 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -144,7 +144,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices( std::shared_ptr WorkerSession::CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) { return std::shared_ptr(new WorkerSession( session_name, worker_name, std::move(worker_cache), borrowed_device_mgr, @@ -154,7 +154,7 @@ std::shared_ptr WorkerSession::CreateWithBorrowedDeviceMgr( WorkerSession::WorkerSession( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) : session_name_(session_name), worker_name_(worker_name), diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 3b2d1122558..f870a8c064b 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -37,7 +37,7 @@ class WorkerSession { // sessions created with `isolate_session_state == false`. In the // those cases, this method returns a pointer to a borrowed // DeviceMgr (typically the `worker_env.device_mgr`). - DeviceMgr* device_mgr() { + const DeviceMgr* device_mgr() { return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; } @@ -65,7 +65,7 @@ class WorkerSession { static std::shared_ptr CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); // In the eager runtime we allow WorkerSession to be updated, where the @@ -90,7 +90,7 @@ class WorkerSession { private: WorkerSession(const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); @@ -113,8 +113,8 @@ class WorkerSession { std::unique_ptr cluster_flr_; - const std::unique_ptr device_mgr_; - DeviceMgr* const borrowed_device_mgr_; // Not owned. + const std::unique_ptr device_mgr_; + const DeviceMgr* const borrowed_device_mgr_; // Not owned. std::unique_ptr remote_device_mgr_; };