diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 912cd184b77..5a39c17e1d9 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) { } #if !defined(IS_MOBILE_PLATFORM) +bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context, + const tensorflow::ServerDef& server_def) { + if (server_def.job_name() != context->HostCPU()->parsed_name().job) { + return false; + } + return server_def.default_session_config().SerializeAsString() == + context->session_options().config.SerializeAsString(); +} + tensorflow::Status AddRemoteDevicesToMgr( const std::vector& added_remote_workers, tensorflow::WorkerCacheInterface* worker_cache, @@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); tensorflow::GrpcServer* grpc_server; if (reset_context) { - LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server)); + const tensorflow::DeviceMgr* device_mgr = + AreLocalDevicesCompatible(context, server_def) + ? context->local_device_mgr() + : nullptr; + LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions( + server_def, {device_mgr}, &new_server)); grpc_server = dynamic_cast(new_server.get()); LOG_AND_RETURN_IF_ERROR( - ListRemoteWorkers(grpc_server, worker_name, &remote_workers)); + ListRemoteWorkers(new_server.get(), worker_name, &remote_workers)); } else { LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name, &curr_remote_workers)); diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index 252a0408758..f8c702d592a 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) { for (int i = 0; i < num_tasks; i++) { int port = tensorflow::testing::PickUnusedPortOrDie(); job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); + {i, tensorflow::strings::StrCat("localhost", ":", port)}); } return server_def; } @@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) { TestRemoteExecuteUpdateServerDefWithFailures(true); } +void TestConnectToCluster(bool keep_localhost_for_first_connect) { + // Fail fast on GetStatus requests so we can get errors instead of timeout + // when updating cluster with non-exsitent worker + tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1); + + const string first_name = + keep_localhost_for_first_connect ? "localhost" : "abc"; + tensorflow::ServerDef server_def = GetServerDef(first_name, 1); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); + EXPECT_NE(var_handle0, nullptr); + + tensorflow::Status status2; + EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name); + + // Rename local device + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev1_name = + absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0"); + TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name); + EXPECT_NE(var_handle1, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name); + + // Another renaming of local device + const string second_name = "def"; + server_def.set_job_name(second_name); + server_def.mutable_cluster()->mutable_job(0)->set_name(second_name); + (*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] = + absl::StrCat(second_name, ":", + tensorflow::testing::PickUnusedPortOrDie()); + + serialized = server_def.SerializeAsString(); + TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0"; + TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name); + EXPECT_NE(var_handle2, nullptr); + EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name); + + TFE_DeleteTensorHandle(var_handle0); + TFE_DeleteTensorHandle(var_handle1); + TFE_DeleteTensorHandle(var_handle2); + + TFE_DeleteContext(ctx); + TF_DeleteStatus(status); + + tensorflow::unsetenv("GRPC_FAIL_FAST"); +} + +TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); } + +TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); } + } // namespace diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc index 94375cf9983..97e63ec6259 100644 --- a/tensorflow/c/experimental/network.cc +++ b/tensorflow/c/experimental/network.cc @@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory { delete_function_(delete_function), rendezvous_builder_(rendezvous_builder) {} - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { TF_RETURN_IF_ERROR(CGrpcServer::Create( server_def, init_function_, start_function_, stop_function_, diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 207c6a02d5b..1024f3caabd 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -81,7 +81,8 @@ EagerContext::EagerContext( bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, DistributedFunctionLibraryRuntime* cluster_flr) - : default_device_placement_policy_(default_device_placement_policy), + : opts_(opts), + default_device_placement_policy_(default_device_placement_policy), default_mirroring_policy_(default_mirroring_policy), local_device_manager_(device_mgr, device_mgr_owned), host_cpu_device_(device_mgr->HostCPU()), @@ -1051,7 +1052,7 @@ void EagerContext::IncrementContextViewId() { // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. Status EagerContext::StoreCollectiveOpsServer( - std::unique_ptr new_server, DeviceMgr* device_mgr, + std::unique_ptr new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { collective_executor_mgr_.Reset(rpc_collective_executor_mgr); @@ -1176,7 +1177,7 @@ Status EagerContext::InitializeRemoteMaster( std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, const std::vector& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr) { @@ -1275,7 +1276,7 @@ Status EagerContext::SetMasterContextState( std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, uint64 context_id, - uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr, + uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr) { @@ -1287,7 +1288,13 @@ Status EagerContext::SetMasterContextState( use_send_tensor_rpc_ = ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true); - local_device_manager_.Reset(local_device_mgr); + if (local_device_mgr != local_device_manager_.Get()) { + if (local_device_manager_.Owned()) { + old_local_device_managers_.push_back( + std::move(local_device_manager_.owned_object)); + } + local_device_manager_.Reset(local_device_mgr); + } host_cpu_device_ = local_device_manager_.Get()->HostCPU(); if (rendezvous_ != nullptr) rendezvous_->Unref(); diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index d03a91c817a..cceb883a965 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -399,7 +399,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, const std::vector& remote_contexts, uint64 context_id, - Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, + Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr); @@ -436,7 +436,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { const std::vector& remote_contexts, uint64 context_id); Status StoreCollectiveOpsServer( - std::unique_ptr new_server, DeviceMgr* device_mgr, + std::unique_ptr new_server, const DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); // For the specified remote worker, preprocess and set its device filters. @@ -510,6 +510,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { // Gets the CPU device on the task of device. Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; + const SessionOptions& session_options() const { return opts_; } + private: ~EagerContext() override; @@ -563,6 +565,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { T* unowned_object_ptr = nullptr; }; + SessionOptions opts_; const ContextDevicePlacementPolicy default_device_placement_policy_; const ContextMirroringPolicy default_mirroring_policy_; @@ -575,6 +578,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { TF_GUARDED_BY(policy_map_mu_); OwnedOrUnownedHelper local_device_manager_; + // Maintain copy of all previously created local device managers. + std::vector> old_local_device_managers_; // Unowned DynamicDeviceMgr is set on remote worker to allow running // multi-device function on remote worker. @@ -662,7 +667,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted { std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, uint64 context_id, uint64 context_view_id, Rendezvous* r, - DeviceMgr* local_device_mgr, int keep_alive_secs, + const DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr, std::unique_ptr> remote_mgr); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 6dc03cbc527..5327cbb6480 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -238,7 +238,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); // Initialize remote tensor communication based on worker session. TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); @@ -355,7 +355,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession( session_name, &worker_session)); - tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); + const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr(); std::vector remote_workers; worker_session->worker_cache()->ListWorkers(&remote_workers); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 8b363e66d87..fe353d7d76c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -55,7 +55,7 @@ limitations under the License. namespace tensorflow { -GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr) +GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr) : worker_env_(worker_env), device_mgr_(device_mgr), table_(5) { // The default value of sync_on_finish will be flipped soon and this // environment variable will be removed as well. diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 50190ab337e..e768c0907b6 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -69,7 +69,7 @@ class WorkerSession; // EXPECT_EQ(out["c"], Tensor({4, 6})); class GraphMgr { public: - explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr); + explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr); ~GraphMgr(); // Registers a graph. Fills in "handle". The registered graph retains a @@ -145,7 +145,7 @@ class GraphMgr { }; const WorkerEnv* worker_env_; // Not owned. - DeviceMgr* device_mgr_; + const DeviceMgr* device_mgr_; CostModelManager cost_model_manager_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 754209082fd..6523d2fb4dd 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -130,9 +130,6 @@ GrpcServer::~GrpcServer() { // OpSegments.) if (worker_env_.session_mgr != nullptr) { delete worker_env_.session_mgr; // Deletes graph_mgr's. - } else { - // Note: session_mgr's legacy_session_ deletes device_mgr now. - delete worker_env_.device_mgr; } // Do not delete (as these are not owned by the server): @@ -204,12 +201,18 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { string name_prefix = strings::StrCat("/job:", server_def_.job_name(), "/replica:0", "/task:", server_def_.task_index()); - std::vector> devices; - TF_RETURN_IF_ERROR( - DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); - worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); - master_env_.local_devices = worker_env_.device_mgr->ListDevices(); + if (opts.local_device_mgr == nullptr) { + std::vector> devices; + TF_RETURN_IF_ERROR( + DeviceFactory::AddDevices(sess_opts, name_prefix, &devices)); + worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices)); + owned_device_manager_.reset(worker_env_.device_mgr); + } else { + worker_env_.device_mgr = opts.local_device_mgr; + owned_device_manager_.reset(nullptr); + } worker_env_.local_devices = worker_env_.device_mgr->ListDevices(); + master_env_.local_devices = worker_env_.device_mgr->ListDevices(); worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr ? new RpcRendezvousMgr(&worker_env_) : opts.rendezvous_mgr_func(&worker_env_); @@ -527,12 +530,13 @@ std::unique_ptr GrpcServer::CreateMaster(MasterEnv* master_env) { /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, std::unique_ptr* out_server) { std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - ServiceInitFunction service_func = nullptr; GrpcServerOptions options; options.rendezvous_mgr_func = NewRpcRendezvousMgr; + options.local_device_mgr = local_device_mgr; Status s = ret->Init(options); if (!s.ok()) { LOG(ERROR) << s; @@ -542,19 +546,21 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, return Status::OK(); } +/* static */ +Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { + return Create(server_def, env, nullptr, out_server); +} + /* static */ Status GrpcServer::Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server) { - std::unique_ptr ret( - new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); - GrpcServerOptions options; - options.rendezvous_mgr_func = NewRpcRendezvousMgr; - Status s = ret->Init(options); + std::unique_ptr server; + Status s = Create(server_def, env, nullptr, &server); if (!s.ok()) { - LOG(ERROR) << s; return s; } - *out_server = std::move(ret); + out_server->reset(dynamic_cast(server.release())); return Status::OK(); } @@ -566,9 +572,10 @@ class GrpcServerFactory : public ServerFactory { return server_def.protocol() == "grpc"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { - return GrpcServer::Create(server_def, Env::Default(), out_server); + return GrpcServer::Create(server_def, Env::Default(), + options.local_device_mgr, out_server); } }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index b3fa7d1f303..0474c5a517f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -68,11 +68,14 @@ struct GrpcServerOptions { WorkerCreationFunction worker_func = nullptr; StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher; GrpcWorkerServiceOptions worker_service_options; + const DeviceMgr* local_device_mgr = nullptr; }; class GrpcServer : public ServerInterface { protected: GrpcServer(const ServerDef& server_def, Env* env); + GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr, + Env* env); // Allow children classes to override this and provide custom args to the // server before it is constructed. Default behavior is to do nothing. virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder); @@ -82,6 +85,10 @@ class GrpcServer : public ServerInterface { std::unique_ptr* out_server); static Status Create(const ServerDef& server_def, Env* env, std::unique_ptr* out_server); + // Reuse the local_device_mgr. + static Status Create(const ServerDef& server_def, Env* env, + const DeviceMgr* local_device_mgr, + std::unique_ptr* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. @@ -163,6 +170,7 @@ class GrpcServer : public ServerInterface { // Implementation of a TensorFlow worker, and RPC polling thread. WorkerEnv worker_env_; + std::unique_ptr owned_device_manager_; std::unique_ptr worker_impl_; AsyncServiceInterface* worker_service_ = nullptr; std::unique_ptr worker_thread_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc index 62a2011db39..12baa75976a 100644 --- a/tensorflow/core/distributed_runtime/server_lib.cc +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -73,7 +73,17 @@ Status NewServer(const ServerDef& server_def, std::unique_ptr* out_server) { ServerFactory* factory; TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); - return factory->NewServer(server_def, out_server); + return factory->NewServer(server_def, ServerFactory::Options(), out_server); +} + +// Creates a server based on the given `server_def`, and stores it in +// `*out_server`. Returns OK on success, otherwise returns an error. +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server) { + ServerFactory* factory; + TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); + return factory->NewServer(server_def, options, out_server); } } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h index 275f526d311..7b4b4892848 100644 --- a/tensorflow/core/distributed_runtime/server_lib.h +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -24,6 +24,8 @@ limitations under the License. namespace tensorflow { +class DeviceMgr; + // This library supports a registration/factory-based mechanism for // creating TensorFlow server objects. Each server implementation must // have an accompanying implementation of ServerFactory, and create a @@ -63,10 +65,14 @@ class ServerInterface { class ServerFactory { public: + struct Options { + // Local DeviceMgr to use. + const tensorflow::DeviceMgr* local_device_mgr; + }; // Creates a new server based on the given `server_def`, and stores // it in `*out_server`. Returns OK on success, otherwise returns an // error. - virtual Status NewServer(const ServerDef& server_def, + virtual Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) = 0; // Returns true if and only if this factory can create a server @@ -92,6 +98,9 @@ class ServerFactory { // `*out_server`. Returns OK on success, otherwise returns an error. Status NewServer(const ServerDef& server_def, std::unique_ptr* out_server); +Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib_test.cc b/tensorflow/core/distributed_runtime/server_lib_test.cc index 77048c24b47..2152ff986d6 100644 --- a/tensorflow/core/distributed_runtime/server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/server_lib_test.cc @@ -26,7 +26,7 @@ class TestServerFactory : public ServerFactory { return server_def.protocol() == "test_protocol"; } - Status NewServer(const ServerDef& server_def, + Status NewServer(const ServerDef& server_def, const Options& options, std::unique_ptr* out_server) override { return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index e2151e068f6..1d9a22a5817 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -171,7 +171,7 @@ Status SessionMgr::UpdateSession( std::vector> cluster_devices; - DeviceMgr* local_device_mgr = worker_session->device_mgr(); + const DeviceMgr* local_device_mgr = worker_session->device_mgr(); DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr(); std::vector curr_remote_devices = remote_device_mgr->ListDevices(); std::vector> added_remote_devices; diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 7850ecc46b2..f857a63e64d 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -38,7 +38,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) { void Worker::GetStatusAsync(const GetStatusRequest* request, GetStatusResponse* response, bool fail_fast, StatusCallback done) { - DeviceMgr* dm = env_->device_mgr; + const DeviceMgr* dm = env_->device_mgr; std::vector devices; dm->ListDeviceAttributes(&devices); response->mutable_device_attributes()->Reserve(devices.size()); diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 93d933bfa60..ecc3313d0ce 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -53,7 +53,7 @@ struct WorkerEnv { // Note: Please use the device_mgr associated with your session if appropriate // instead of this one. Using this device_mgr does not support ClusterSpec // propagated sessions. - DeviceMgr* device_mgr = nullptr; + const DeviceMgr* device_mgr = nullptr; // A set of rendezvous keyed by step ids. RendezvousMgrInterface* rendezvous_mgr = nullptr; diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index ca4f25f08f5..3aed73fa358 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -144,7 +144,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices( std::shared_ptr WorkerSession::CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) { return std::shared_ptr(new WorkerSession( session_name, worker_name, std::move(worker_cache), borrowed_device_mgr, @@ -154,7 +154,7 @@ std::shared_ptr WorkerSession::CreateWithBorrowedDeviceMgr( WorkerSession::WorkerSession( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr) : session_name_(session_name), worker_name_(worker_name), diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index 3b2d1122558..f870a8c064b 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -37,7 +37,7 @@ class WorkerSession { // sessions created with `isolate_session_state == false`. In the // those cases, this method returns a pointer to a borrowed // DeviceMgr (typically the `worker_env.device_mgr`). - DeviceMgr* device_mgr() { + const DeviceMgr* device_mgr() { return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_; } @@ -65,7 +65,7 @@ class WorkerSession { static std::shared_ptr CreateWithBorrowedDeviceMgr( const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); // In the eager runtime we allow WorkerSession to be updated, where the @@ -90,7 +90,7 @@ class WorkerSession { private: WorkerSession(const string& session_name, const string& worker_name, std::unique_ptr worker_cache, - DeviceMgr* borrowed_device_mgr, + const DeviceMgr* borrowed_device_mgr, std::unique_ptr graph_mgr, std::unique_ptr remote_device_mgr); @@ -113,8 +113,8 @@ class WorkerSession { std::unique_ptr cluster_flr_; - const std::unique_ptr device_mgr_; - DeviceMgr* const borrowed_device_mgr_; // Not owned. + const std::unique_ptr device_mgr_; + const DeviceMgr* const borrowed_device_mgr_; // Not owned. std::unique_ptr remote_device_mgr_; };