When calling connect_to_cluser, if the options are identical and there is no renaming of local device, reuse existing local DeviceManager, otherwise we keep the old DeviceManager around to allow the old Tensor created to be usable.
PiperOrigin-RevId: 312489501 Change-Id: Id392d0324aba7e7f9e92f8efeaf33683157470e1
This commit is contained in:
parent
6e509432c0
commit
550581f6bd
|
@ -102,6 +102,15 @@ string DeviceName(const tensorflow::Device* d) {
|
|||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
bool AreLocalDevicesCompatible(const tensorflow::EagerContext* context,
|
||||
const tensorflow::ServerDef& server_def) {
|
||||
if (server_def.job_name() != context->HostCPU()->parsed_name().job) {
|
||||
return false;
|
||||
}
|
||||
return server_def.default_session_config().SerializeAsString() ==
|
||||
context->session_options().config.SerializeAsString();
|
||||
}
|
||||
|
||||
tensorflow::Status AddRemoteDevicesToMgr(
|
||||
const std::vector<string>& added_remote_workers,
|
||||
tensorflow::WorkerCacheInterface* worker_cache,
|
||||
|
@ -469,10 +478,15 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
|
||||
tensorflow::GrpcServer* grpc_server;
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &new_server));
|
||||
const tensorflow::DeviceMgr* device_mgr =
|
||||
AreLocalDevicesCompatible(context, server_def)
|
||||
? context->local_device_mgr()
|
||||
: nullptr;
|
||||
LOG_AND_RETURN_IF_ERROR(tensorflow::NewServerWithOptions(
|
||||
server_def, {device_mgr}, &new_server));
|
||||
grpc_server = dynamic_cast<tensorflow::GrpcServer*>(new_server.get());
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
ListRemoteWorkers(grpc_server, worker_name, &remote_workers));
|
||||
ListRemoteWorkers(new_server.get(), worker_name, &remote_workers));
|
||||
} else {
|
||||
LOG_AND_RETURN_IF_ERROR(ListRemoteWorkers(context->GetServer(), worker_name,
|
||||
&curr_remote_workers));
|
||||
|
|
|
@ -41,7 +41,7 @@ tensorflow::ServerDef GetServerDef(const string& job_name, int num_tasks) {
|
|||
for (int i = 0; i < num_tasks; i++) {
|
||||
int port = tensorflow::testing::PickUnusedPortOrDie();
|
||||
job_def->mutable_tasks()->insert(
|
||||
{i, tensorflow::strings::StrCat("localhost:", port)});
|
||||
{i, tensorflow::strings::StrCat("localhost", ":", port)});
|
||||
}
|
||||
return server_def;
|
||||
}
|
||||
|
@ -430,4 +430,70 @@ TEST(CAPI, RemoteExecuteUpdateServerDefWithFailuresAsync) {
|
|||
TestRemoteExecuteUpdateServerDefWithFailures(true);
|
||||
}
|
||||
|
||||
void TestConnectToCluster(bool keep_localhost_for_first_connect) {
|
||||
// Fail fast on GetStatus requests so we can get errors instead of timeout
|
||||
// when updating cluster with non-exsitent worker
|
||||
tensorflow::setenv("GRPC_FAIL_FAST", "TRUE", /*overwrite=*/1);
|
||||
|
||||
const string first_name =
|
||||
keep_localhost_for_first_connect ? "localhost" : "abc";
|
||||
tensorflow::ServerDef server_def = GetServerDef(first_name, 1);
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
const string dev0_name = "/job:localhost/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name);
|
||||
EXPECT_NE(var_handle0, nullptr);
|
||||
|
||||
tensorflow::Status status2;
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name);
|
||||
|
||||
// Rename local device
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev1_name =
|
||||
absl::StrCat("/job:", first_name, "/replica:0/task:0/device:CPU:0");
|
||||
TFE_TensorHandle* var_handle1 = TestVariable(ctx, 2.0, dev1_name);
|
||||
EXPECT_NE(var_handle1, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle1)->DeviceName(&status2), dev1_name);
|
||||
|
||||
// Another renaming of local device
|
||||
const string second_name = "def";
|
||||
server_def.set_job_name(second_name);
|
||||
server_def.mutable_cluster()->mutable_job(0)->set_name(second_name);
|
||||
(*server_def.mutable_cluster()->mutable_job(0)->mutable_tasks())[0] =
|
||||
absl::StrCat(second_name, ":",
|
||||
tensorflow::testing::PickUnusedPortOrDie());
|
||||
|
||||
serialized = server_def.SerializeAsString();
|
||||
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
const string dev2_name = "/job:def/replica:0/task:0/device:CPU:0";
|
||||
TFE_TensorHandle* var_handle2 = TestVariable(ctx, 2.0, dev2_name);
|
||||
EXPECT_NE(var_handle2, nullptr);
|
||||
EXPECT_EQ(tensorflow::unwrap(var_handle2)->DeviceName(&status2), dev2_name);
|
||||
|
||||
TFE_DeleteTensorHandle(var_handle0);
|
||||
TFE_DeleteTensorHandle(var_handle1);
|
||||
TFE_DeleteTensorHandle(var_handle2);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
tensorflow::unsetenv("GRPC_FAIL_FAST");
|
||||
}
|
||||
|
||||
TEST(CAPI, ConnectToClusterLocalhostFirst) { TestConnectToCluster(false); }
|
||||
|
||||
TEST(CAPI, ConnectToClusterRenameFirst) { TestConnectToCluster(true); }
|
||||
|
||||
} // namespace
|
||||
|
|
|
@ -108,7 +108,7 @@ class CServerFactory : public ServerFactory {
|
|||
delete_function_(delete_function),
|
||||
rendezvous_builder_(rendezvous_builder) {}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
TF_RETURN_IF_ERROR(CGrpcServer::Create(
|
||||
server_def, init_function_, start_function_, stop_function_,
|
||||
|
|
|
@ -81,7 +81,8 @@ EagerContext::EagerContext(
|
|||
bool device_mgr_owned, Rendezvous* rendezvous,
|
||||
const CustomKernelCreator* custom_kernel_creator,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr)
|
||||
: default_device_placement_policy_(default_device_placement_policy),
|
||||
: opts_(opts),
|
||||
default_device_placement_policy_(default_device_placement_policy),
|
||||
default_mirroring_policy_(default_mirroring_policy),
|
||||
local_device_manager_(device_mgr, device_mgr_owned),
|
||||
host_cpu_device_(device_mgr->HostCPU()),
|
||||
|
@ -1051,7 +1052,7 @@ void EagerContext::IncrementContextViewId() {
|
|||
// Set collective ops related state in the context. Passing nullptr to
|
||||
// `new_server` will reuse the existing GRPC server in context.
|
||||
Status EagerContext::StoreCollectiveOpsServer(
|
||||
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
|
||||
std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
|
||||
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
|
||||
collective_executor_mgr_.Reset(rpc_collective_executor_mgr);
|
||||
|
||||
|
@ -1176,7 +1177,7 @@ Status EagerContext::InitializeRemoteMaster(
|
|||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
|
||||
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||
Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
|
||||
remote_mgr) {
|
||||
|
@ -1275,7 +1276,7 @@ Status EagerContext::SetMasterContextState(
|
|||
std::shared_ptr<WorkerSession> worker_session,
|
||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_manager, uint64 context_id,
|
||||
uint64 context_view_id, Rendezvous* r, DeviceMgr* local_device_mgr,
|
||||
uint64 context_view_id, Rendezvous* r, const DeviceMgr* local_device_mgr,
|
||||
int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
|
||||
remote_mgr) {
|
||||
|
@ -1287,7 +1288,13 @@ Status EagerContext::SetMasterContextState(
|
|||
use_send_tensor_rpc_ =
|
||||
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", true);
|
||||
|
||||
local_device_manager_.Reset(local_device_mgr);
|
||||
if (local_device_mgr != local_device_manager_.Get()) {
|
||||
if (local_device_manager_.Owned()) {
|
||||
old_local_device_managers_.push_back(
|
||||
std::move(local_device_manager_.owned_object));
|
||||
}
|
||||
local_device_manager_.Reset(local_device_mgr);
|
||||
}
|
||||
host_cpu_device_ = local_device_manager_.Get()->HostCPU();
|
||||
|
||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||
|
|
|
@ -399,7 +399,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
|
||||
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||
Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
Rendezvous* r, const DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
|
||||
remote_mgr);
|
||||
|
@ -436,7 +436,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
const std::vector<string>& remote_contexts, uint64 context_id);
|
||||
|
||||
Status StoreCollectiveOpsServer(
|
||||
std::unique_ptr<ServerInterface> new_server, DeviceMgr* device_mgr,
|
||||
std::unique_ptr<ServerInterface> new_server, const DeviceMgr* device_mgr,
|
||||
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr);
|
||||
|
||||
// For the specified remote worker, preprocess and set its device filters.
|
||||
|
@ -510,6 +510,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
// Gets the CPU device on the task of device.
|
||||
Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const;
|
||||
|
||||
const SessionOptions& session_options() const { return opts_; }
|
||||
|
||||
private:
|
||||
~EagerContext() override;
|
||||
|
||||
|
@ -563,6 +565,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
T* unowned_object_ptr = nullptr;
|
||||
};
|
||||
|
||||
SessionOptions opts_;
|
||||
const ContextDevicePlacementPolicy default_device_placement_policy_;
|
||||
const ContextMirroringPolicy default_mirroring_policy_;
|
||||
|
||||
|
@ -575,6 +578,8 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
TF_GUARDED_BY(policy_map_mu_);
|
||||
|
||||
OwnedOrUnownedHelper<const DeviceMgr> local_device_manager_;
|
||||
// Maintain copy of all previously created local device managers.
|
||||
std::vector<std::unique_ptr<const DeviceMgr>> old_local_device_managers_;
|
||||
|
||||
// Unowned DynamicDeviceMgr is set on remote worker to allow running
|
||||
// multi-device function on remote worker.
|
||||
|
@ -662,7 +667,7 @@ class EagerContext : public AbstractContextInterface, public core::RefCounted {
|
|||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_manager,
|
||||
uint64 context_id, uint64 context_view_id, Rendezvous* r,
|
||||
DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
const DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
std::unique_ptr<eager::RemoteMgr, std::function<void(eager::RemoteMgr*)>>
|
||||
remote_mgr);
|
||||
|
|
|
@ -238,7 +238,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||
TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
|
||||
const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
@ -355,7 +355,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request,
|
|||
TF_RETURN_IF_ERROR(env_->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
|
||||
const tensorflow::DeviceMgr* device_mgr = worker_session->device_mgr();
|
||||
|
||||
std::vector<string> remote_workers;
|
||||
worker_session->worker_cache()->ListWorkers(&remote_workers);
|
||||
|
|
|
@ -55,7 +55,7 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
|
||||
GraphMgr::GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr)
|
||||
: worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
|
||||
// The default value of sync_on_finish will be flipped soon and this
|
||||
// environment variable will be removed as well.
|
||||
|
|
|
@ -69,7 +69,7 @@ class WorkerSession;
|
|||
// EXPECT_EQ(out["c"], Tensor({4, 6}));
|
||||
class GraphMgr {
|
||||
public:
|
||||
explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
|
||||
explicit GraphMgr(const WorkerEnv* worker_env, const DeviceMgr* device_mgr);
|
||||
~GraphMgr();
|
||||
|
||||
// Registers a graph. Fills in "handle". The registered graph retains a
|
||||
|
@ -145,7 +145,7 @@ class GraphMgr {
|
|||
};
|
||||
|
||||
const WorkerEnv* worker_env_; // Not owned.
|
||||
DeviceMgr* device_mgr_;
|
||||
const DeviceMgr* device_mgr_;
|
||||
|
||||
CostModelManager cost_model_manager_;
|
||||
|
||||
|
|
|
@ -130,9 +130,6 @@ GrpcServer::~GrpcServer() {
|
|||
// OpSegments.)
|
||||
if (worker_env_.session_mgr != nullptr) {
|
||||
delete worker_env_.session_mgr; // Deletes graph_mgr's.
|
||||
} else {
|
||||
// Note: session_mgr's legacy_session_ deletes device_mgr now.
|
||||
delete worker_env_.device_mgr;
|
||||
}
|
||||
|
||||
// Do not delete (as these are not owned by the server):
|
||||
|
@ -204,12 +201,18 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||
string name_prefix =
|
||||
strings::StrCat("/job:", server_def_.job_name(), "/replica:0",
|
||||
"/task:", server_def_.task_index());
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
||||
worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||
if (opts.local_device_mgr == nullptr) {
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
TF_RETURN_IF_ERROR(
|
||||
DeviceFactory::AddDevices(sess_opts, name_prefix, &devices));
|
||||
worker_env_.device_mgr = new StaticDeviceMgr(std::move(devices));
|
||||
owned_device_manager_.reset(worker_env_.device_mgr);
|
||||
} else {
|
||||
worker_env_.device_mgr = opts.local_device_mgr;
|
||||
owned_device_manager_.reset(nullptr);
|
||||
}
|
||||
worker_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||
master_env_.local_devices = worker_env_.device_mgr->ListDevices();
|
||||
worker_env_.rendezvous_mgr = opts.rendezvous_mgr_func == nullptr
|
||||
? new RpcRendezvousMgr(&worker_env_)
|
||||
: opts.rendezvous_mgr_func(&worker_env_);
|
||||
|
@ -527,12 +530,13 @@ std::unique_ptr<Master> GrpcServer::CreateMaster(MasterEnv* master_env) {
|
|||
|
||||
/* static */
|
||||
Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
||||
const DeviceMgr* local_device_mgr,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<GrpcServer> ret(
|
||||
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
||||
ServiceInitFunction service_func = nullptr;
|
||||
GrpcServerOptions options;
|
||||
options.rendezvous_mgr_func = NewRpcRendezvousMgr;
|
||||
options.local_device_mgr = local_device_mgr;
|
||||
Status s = ret->Init(options);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
|
@ -542,19 +546,21 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
return Create(server_def, env, nullptr, out_server);
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<GrpcServer>* out_server) {
|
||||
std::unique_ptr<GrpcServer> ret(
|
||||
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
||||
GrpcServerOptions options;
|
||||
options.rendezvous_mgr_func = NewRpcRendezvousMgr;
|
||||
Status s = ret->Init(options);
|
||||
std::unique_ptr<ServerInterface> server;
|
||||
Status s = Create(server_def, env, nullptr, &server);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
return s;
|
||||
}
|
||||
*out_server = std::move(ret);
|
||||
out_server->reset(dynamic_cast<GrpcServer*>(server.release()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -566,9 +572,10 @@ class GrpcServerFactory : public ServerFactory {
|
|||
return server_def.protocol() == "grpc";
|
||||
}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
return GrpcServer::Create(server_def, Env::Default(), out_server);
|
||||
return GrpcServer::Create(server_def, Env::Default(),
|
||||
options.local_device_mgr, out_server);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -68,11 +68,14 @@ struct GrpcServerOptions {
|
|||
WorkerCreationFunction worker_func = nullptr;
|
||||
StatsPublisherFactory stats_factory = CreateNoOpStatsPublisher;
|
||||
GrpcWorkerServiceOptions worker_service_options;
|
||||
const DeviceMgr* local_device_mgr = nullptr;
|
||||
};
|
||||
|
||||
class GrpcServer : public ServerInterface {
|
||||
protected:
|
||||
GrpcServer(const ServerDef& server_def, Env* env);
|
||||
GrpcServer(const ServerDef& server_def, DeviceMgr* local_device_mgr,
|
||||
Env* env);
|
||||
// Allow children classes to override this and provide custom args to the
|
||||
// server before it is constructed. Default behavior is to do nothing.
|
||||
virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder);
|
||||
|
@ -82,6 +85,10 @@ class GrpcServer : public ServerInterface {
|
|||
std::unique_ptr<ServerInterface>* out_server);
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<GrpcServer>* out_server);
|
||||
// Reuse the local_device_mgr.
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
const DeviceMgr* local_device_mgr,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
// Destruction is only supported in the factory method. Clean
|
||||
// shutdown is not currently implemented for this server type.
|
||||
|
@ -163,6 +170,7 @@ class GrpcServer : public ServerInterface {
|
|||
|
||||
// Implementation of a TensorFlow worker, and RPC polling thread.
|
||||
WorkerEnv worker_env_;
|
||||
std::unique_ptr<const DeviceMgr> owned_device_manager_;
|
||||
std::unique_ptr<GrpcWorker> worker_impl_;
|
||||
AsyncServiceInterface* worker_service_ = nullptr;
|
||||
std::unique_ptr<Thread> worker_thread_ TF_GUARDED_BY(mu_);
|
||||
|
|
|
@ -73,7 +73,17 @@ Status NewServer(const ServerDef& server_def,
|
|||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
ServerFactory* factory;
|
||||
TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory));
|
||||
return factory->NewServer(server_def, out_server);
|
||||
return factory->NewServer(server_def, ServerFactory::Options(), out_server);
|
||||
}
|
||||
|
||||
// Creates a server based on the given `server_def`, and stores it in
|
||||
// `*out_server`. Returns OK on success, otherwise returns an error.
|
||||
Status NewServerWithOptions(const ServerDef& server_def,
|
||||
const ServerFactory::Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
ServerFactory* factory;
|
||||
TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory));
|
||||
return factory->NewServer(server_def, options, out_server);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -24,6 +24,8 @@ limitations under the License.
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
class DeviceMgr;
|
||||
|
||||
// This library supports a registration/factory-based mechanism for
|
||||
// creating TensorFlow server objects. Each server implementation must
|
||||
// have an accompanying implementation of ServerFactory, and create a
|
||||
|
@ -63,10 +65,14 @@ class ServerInterface {
|
|||
|
||||
class ServerFactory {
|
||||
public:
|
||||
struct Options {
|
||||
// Local DeviceMgr to use.
|
||||
const tensorflow::DeviceMgr* local_device_mgr;
|
||||
};
|
||||
// Creates a new server based on the given `server_def`, and stores
|
||||
// it in `*out_server`. Returns OK on success, otherwise returns an
|
||||
// error.
|
||||
virtual Status NewServer(const ServerDef& server_def,
|
||||
virtual Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) = 0;
|
||||
|
||||
// Returns true if and only if this factory can create a server
|
||||
|
@ -92,6 +98,9 @@ class ServerFactory {
|
|||
// `*out_server`. Returns OK on success, otherwise returns an error.
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
Status NewServerWithOptions(const ServerDef& server_def,
|
||||
const ServerFactory::Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ class TestServerFactory : public ServerFactory {
|
|||
return server_def.protocol() == "test_protocol";
|
||||
}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
Status NewServer(const ServerDef& server_def, const Options& options,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -171,7 +171,7 @@ Status SessionMgr::UpdateSession(
|
|||
|
||||
std::vector<std::unique_ptr<Device>> cluster_devices;
|
||||
|
||||
DeviceMgr* local_device_mgr = worker_session->device_mgr();
|
||||
const DeviceMgr* local_device_mgr = worker_session->device_mgr();
|
||||
DeviceMgr* remote_device_mgr = worker_session->remote_device_mgr();
|
||||
std::vector<Device*> curr_remote_devices = remote_device_mgr->ListDevices();
|
||||
std::vector<std::unique_ptr<Device>> added_remote_devices;
|
||||
|
|
|
@ -38,7 +38,7 @@ Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {
|
|||
void Worker::GetStatusAsync(const GetStatusRequest* request,
|
||||
GetStatusResponse* response, bool fail_fast,
|
||||
StatusCallback done) {
|
||||
DeviceMgr* dm = env_->device_mgr;
|
||||
const DeviceMgr* dm = env_->device_mgr;
|
||||
std::vector<DeviceAttributes> devices;
|
||||
dm->ListDeviceAttributes(&devices);
|
||||
response->mutable_device_attributes()->Reserve(devices.size());
|
||||
|
|
|
@ -53,7 +53,7 @@ struct WorkerEnv {
|
|||
// Note: Please use the device_mgr associated with your session if appropriate
|
||||
// instead of this one. Using this device_mgr does not support ClusterSpec
|
||||
// propagated sessions.
|
||||
DeviceMgr* device_mgr = nullptr;
|
||||
const DeviceMgr* device_mgr = nullptr;
|
||||
|
||||
// A set of rendezvous keyed by step ids.
|
||||
RendezvousMgrInterface* rendezvous_mgr = nullptr;
|
||||
|
|
|
@ -144,7 +144,7 @@ Status WorkerSession::UpdateWorkerCacheAndDevices(
|
|||
std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
|
||||
const string& session_name, const string& worker_name,
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||
DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr) {
|
||||
return std::shared_ptr<WorkerSession>(new WorkerSession(
|
||||
session_name, worker_name, std::move(worker_cache), borrowed_device_mgr,
|
||||
|
@ -154,7 +154,7 @@ std::shared_ptr<WorkerSession> WorkerSession::CreateWithBorrowedDeviceMgr(
|
|||
WorkerSession::WorkerSession(
|
||||
const string& session_name, const string& worker_name,
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||
DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr)
|
||||
: session_name_(session_name),
|
||||
worker_name_(worker_name),
|
||||
|
|
|
@ -37,7 +37,7 @@ class WorkerSession {
|
|||
// sessions created with `isolate_session_state == false`. In the
|
||||
// those cases, this method returns a pointer to a borrowed
|
||||
// DeviceMgr (typically the `worker_env.device_mgr`).
|
||||
DeviceMgr* device_mgr() {
|
||||
const DeviceMgr* device_mgr() {
|
||||
return device_mgr_ ? device_mgr_.get() : borrowed_device_mgr_;
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ class WorkerSession {
|
|||
static std::shared_ptr<WorkerSession> CreateWithBorrowedDeviceMgr(
|
||||
const string& session_name, const string& worker_name,
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||
DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
const DeviceMgr* borrowed_device_mgr, std::unique_ptr<GraphMgr> graph_mgr,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);
|
||||
|
||||
// In the eager runtime we allow WorkerSession to be updated, where the
|
||||
|
@ -90,7 +90,7 @@ class WorkerSession {
|
|||
private:
|
||||
WorkerSession(const string& session_name, const string& worker_name,
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache,
|
||||
DeviceMgr* borrowed_device_mgr,
|
||||
const DeviceMgr* borrowed_device_mgr,
|
||||
std::unique_ptr<GraphMgr> graph_mgr,
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr);
|
||||
|
||||
|
@ -113,8 +113,8 @@ class WorkerSession {
|
|||
|
||||
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
|
||||
|
||||
const std::unique_ptr<DeviceMgr> device_mgr_;
|
||||
DeviceMgr* const borrowed_device_mgr_; // Not owned.
|
||||
const std::unique_ptr<const DeviceMgr> device_mgr_;
|
||||
const DeviceMgr* const borrowed_device_mgr_; // Not owned.
|
||||
std::unique_ptr<DynamicDeviceMgr> remote_device_mgr_;
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue