parent
dcbf7f7bc0
commit
d5e752fe18
@ -129,17 +129,16 @@ tensorflow::Status GetAllRemoteDevices(
|
|||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::Status CreateRemoteContexts(
|
tensorflow::Status CreateRemoteContexts(
|
||||||
const std::vector<string>& remote_workers, int64 rendezvous_id,
|
const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
|
||||||
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
int keep_alive_secs, const tensorflow::ServerDef& server_def,
|
||||||
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
|
||||||
const tensorflow::eager::CreateContextRequest& base_request,
|
const tensorflow::eager::CreateContextRequest& base_request) {
|
||||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
|
|
||||||
for (int i = 0; i < remote_workers.size(); i++) {
|
for (int i = 0; i < remote_workers.size(); i++) {
|
||||||
const string& remote_worker = remote_workers[i];
|
const string& remote_worker = remote_workers[i];
|
||||||
|
|
||||||
tensorflow::eager::CreateContextRequest request(base_request);
|
tensorflow::eager::CreateContextRequest request(base_request);
|
||||||
tensorflow::eager::CreateContextResponse response;
|
tensorflow::eager::CreateContextResponse response;
|
||||||
request.set_rendezvous_id(rendezvous_id);
|
request.set_context_id(context_id);
|
||||||
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
tensorflow::DeviceNameUtils::ParsedName parsed_name;
|
||||||
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
|
||||||
&parsed_name)) {
|
&parsed_name)) {
|
||||||
@ -168,8 +167,6 @@ tensorflow::Status CreateRemoteContexts(
|
|||||||
});
|
});
|
||||||
n.WaitForNotification();
|
n.WaitForNotification();
|
||||||
TF_RETURN_IF_ERROR(status);
|
TF_RETURN_IF_ERROR(status);
|
||||||
|
|
||||||
remote_contexts->emplace(remote_worker, response.context_id());
|
|
||||||
}
|
}
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
@ -206,7 +203,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
|
|
||||||
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
|
||||||
|
|
||||||
int64 rendezvous_id = tensorflow::random::New64();
|
tensorflow::uint64 context_id = tensorflow::random::New64();
|
||||||
|
|
||||||
std::vector<string> remote_workers;
|
std::vector<string> remote_workers;
|
||||||
grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
|
grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
|
||||||
@ -242,16 +239,14 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
&remote_eager_workers));
|
&remote_eager_workers));
|
||||||
|
|
||||||
// Initialize remote eager workers.
|
// Initialize remote eager workers.
|
||||||
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
|
|
||||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||||
remote_workers, rendezvous_id, keep_alive_secs, server_def,
|
remote_workers, context_id, keep_alive_secs, server_def,
|
||||||
remote_eager_workers.get(), ctx->context->Async(), base_request,
|
remote_eager_workers.get(), ctx->context->Async(), base_request));
|
||||||
&remote_contexts));
|
|
||||||
|
|
||||||
tensorflow::RemoteRendezvous* r =
|
tensorflow::RemoteRendezvous* r =
|
||||||
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
|
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||||
|
|
||||||
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
|
auto session_name = tensorflow::strings::StrCat("eager_", context_id);
|
||||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||||
session_name, server_def, base_request.cluster_device_attributes(),
|
session_name, server_def, base_request.cluster_device_attributes(),
|
||||||
true));
|
true));
|
||||||
@ -266,10 +261,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
|||||||
|
|
||||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||||
|
|
||||||
return ctx->context->InitializeRemote(
|
return ctx->context->InitializeRemoteMaster(
|
||||||
std::move(server), grpc_server->worker_env(), worker_session,
|
std::move(server), grpc_server->worker_env(), worker_session,
|
||||||
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
std::move(remote_eager_workers), std::move(remote_device_mgr),
|
||||||
remote_contexts, r, device_mgr, keep_alive_secs,
|
remote_workers, context_id, r, device_mgr, keep_alive_secs,
|
||||||
worker_session->cluster_flr.get());
|
worker_session->cluster_flr.get());
|
||||||
#undef LOG_AND_RETURN_IF_ERROR
|
#undef LOG_AND_RETURN_IF_ERROR
|
||||||
}
|
}
|
||||||
|
@ -76,19 +76,20 @@ TEST_F(XrtClientTest, XrtGrpcEagerClientWorks) {
|
|||||||
|
|
||||||
// Create and destroy a context to verify we can make RPCs.
|
// Create and destroy a context to verify we can make RPCs.
|
||||||
eager::CreateContextRequest request;
|
eager::CreateContextRequest request;
|
||||||
|
uint64 context_id = random::New64();
|
||||||
ServerDef* server_def = request.mutable_server_def();
|
ServerDef* server_def = request.mutable_server_def();
|
||||||
*server_def->mutable_cluster() = cluster_def_;
|
*server_def->mutable_cluster() = cluster_def_;
|
||||||
server_def->set_job_name("localhost");
|
server_def->set_job_name("localhost");
|
||||||
server_def->set_protocol("grpc");
|
server_def->set_protocol("grpc");
|
||||||
request.set_keep_alive_secs(60);
|
request.set_keep_alive_secs(60);
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(context_id);
|
||||||
|
|
||||||
eager::CreateContextResponse create_response;
|
eager::CreateContextResponse create_response;
|
||||||
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync,
|
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync,
|
||||||
&request, &create_response));
|
&request, &create_response));
|
||||||
|
|
||||||
eager::CloseContextRequest close_request;
|
eager::CloseContextRequest close_request;
|
||||||
close_request.set_context_id(create_response.context_id());
|
close_request.set_context_id(context_id);
|
||||||
|
|
||||||
eager::CloseContextResponse close_response;
|
eager::CloseContextResponse close_response;
|
||||||
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync,
|
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync,
|
||||||
|
@ -45,7 +45,7 @@ XrtTfClient::XrtTfClient(ClusterDef cluster_def,
|
|||||||
xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
|
xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
|
||||||
const XrtTfContext::Options& options,
|
const XrtTfContext::Options& options,
|
||||||
std::shared_ptr<XrtTfClient> tf_client, const std::string& job, int task) {
|
std::shared_ptr<XrtTfClient> tf_client, const std::string& job, int task) {
|
||||||
int64 rendezvous_id = random::New64();
|
int64 context_id = random::New64();
|
||||||
|
|
||||||
eager::CreateContextRequest request;
|
eager::CreateContextRequest request;
|
||||||
ServerDef* server_def = request.mutable_server_def();
|
ServerDef* server_def = request.mutable_server_def();
|
||||||
@ -53,7 +53,7 @@ xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
|
|||||||
server_def->set_job_name(job);
|
server_def->set_job_name(job);
|
||||||
server_def->set_protocol("grpc");
|
server_def->set_protocol("grpc");
|
||||||
request.set_keep_alive_secs(60);
|
request.set_keep_alive_secs(60);
|
||||||
request.set_rendezvous_id(rendezvous_id);
|
request.set_context_id(context_id);
|
||||||
request.set_async(options.async);
|
request.set_async(options.async);
|
||||||
|
|
||||||
eager::CreateContextResponse response;
|
eager::CreateContextResponse response;
|
||||||
@ -98,8 +98,9 @@ xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
|
|||||||
return a.name() < b.name();
|
return a.name() < b.name();
|
||||||
});
|
});
|
||||||
return std::make_shared<XrtTfContext>(options, tf_client, eager_client,
|
return std::make_shared<XrtTfContext>(options, tf_client, eager_client,
|
||||||
rendezvous_id, response.context_id(),
|
/*rendezvous_id=*/context_id,
|
||||||
std::move(devices), cpu_device_id);
|
context_id, std::move(devices),
|
||||||
|
cpu_device_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
XrtTfContext::XrtTfContext(const XrtTfContext::Options& options,
|
XrtTfContext::XrtTfContext(const XrtTfContext::Options& options,
|
||||||
|
@ -65,15 +65,11 @@ EagerContext::EagerContext(
|
|||||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||||
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
|
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
|
||||||
const CustomKernelCreator* custom_kernel_creator,
|
const CustomKernelCreator* custom_kernel_creator,
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
DistributedFunctionLibraryRuntime* cluster_flr)
|
||||||
std::function<Rendezvous*(const int64)> rendezvous_creator,
|
|
||||||
const DeviceMgr* remote_device_mgr)
|
|
||||||
: default_device_placement_policy_(default_device_placement_policy),
|
: default_device_placement_policy_(default_device_placement_policy),
|
||||||
default_mirroring_policy_(default_mirroring_policy),
|
default_mirroring_policy_(default_mirroring_policy),
|
||||||
remote_unowned_device_manager_(remote_device_mgr),
|
|
||||||
devices_(device_mgr->ListDevices()),
|
devices_(device_mgr->ListDevices()),
|
||||||
rendezvous_(rendezvous),
|
rendezvous_(rendezvous),
|
||||||
rendezvous_creator_(std::move(rendezvous_creator)),
|
|
||||||
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
thread_pool_(NewThreadPoolFromSessionOptions(opts)),
|
||||||
custom_kernel_creator_(custom_kernel_creator),
|
custom_kernel_creator_(custom_kernel_creator),
|
||||||
pflr_(new ProcessFunctionLibraryRuntime(
|
pflr_(new ProcessFunctionLibraryRuntime(
|
||||||
@ -211,25 +207,22 @@ bool EagerContext::MirrorTensors() const {
|
|||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
void EagerContext::CloseRemoteContexts() {
|
void EagerContext::CloseRemoteContexts() {
|
||||||
// Close all remote contexts.
|
// Close all remote contexts.
|
||||||
std::vector<eager::CloseContextRequest> requests(remote_contexts_.size());
|
eager::CloseContextRequest request;
|
||||||
|
request.set_context_id(context_id_);
|
||||||
std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
|
std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
|
||||||
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& worker_and_context_id : remote_contexts_) {
|
for (const auto& worker : remote_contexts_) {
|
||||||
eager::EagerClient* client;
|
eager::EagerClient* client;
|
||||||
Status s =
|
Status s = remote_eager_workers_->GetClient(worker, &client);
|
||||||
remote_eager_workers_->GetClient(worker_and_context_id.first, &client);
|
|
||||||
|
|
||||||
requests[i].set_context_id(worker_and_context_id.second);
|
|
||||||
client->CloseContextAsync(
|
client->CloseContextAsync(
|
||||||
&requests[i], &responses[i],
|
&request, &responses[i], [this, &worker, &counter](const Status& s) {
|
||||||
[&worker_and_context_id, &counter](const Status& s) {
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(ERROR) << "Unable to close remote context with ID "
|
LOG(ERROR) << "Unable to close remote context with ID "
|
||||||
<< worker_and_context_id.second
|
<< context_id_ << " for worker: " << worker << " due to "
|
||||||
<< " for worker: " << worker_and_context_id.first
|
<< s.error_message();
|
||||||
<< " due to " << s.error_message();
|
|
||||||
}
|
}
|
||||||
counter.DecrementCount();
|
counter.DecrementCount();
|
||||||
});
|
});
|
||||||
@ -263,8 +256,9 @@ EagerContext::~EagerContext() {
|
|||||||
keep_alive_thread_cv_.notify_all();
|
keep_alive_thread_cv_.notify_all();
|
||||||
}
|
}
|
||||||
keep_alive_thread_.reset();
|
keep_alive_thread_.reset();
|
||||||
|
if (!remote_contexts_.empty() && keep_alive_thread_ != nullptr) {
|
||||||
CloseRemoteContexts();
|
CloseRemoteContexts();
|
||||||
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
executor_.WaitForAllPendingNodes().IgnoreError();
|
executor_.WaitForAllPendingNodes().IgnoreError();
|
||||||
@ -370,22 +364,20 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
|
|||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
|
BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
|
||||||
|
|
||||||
std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size());
|
eager::RegisterFunctionRequest request;
|
||||||
|
request.set_context_id(context_id_);
|
||||||
|
*request.mutable_function_def() = fdef;
|
||||||
std::vector<eager::RegisterFunctionResponse> responses(
|
std::vector<eager::RegisterFunctionResponse> responses(
|
||||||
remote_contexts_.size());
|
remote_contexts_.size());
|
||||||
std::vector<Status> statuses(remote_contexts_.size());
|
std::vector<Status> statuses(remote_contexts_.size());
|
||||||
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for (const auto& target_and_context_id : remote_contexts_) {
|
for (const auto& target : remote_contexts_) {
|
||||||
requests[i].set_context_id(target_and_context_id.second);
|
|
||||||
*requests[i].mutable_function_def() = fdef;
|
|
||||||
|
|
||||||
eager::EagerClient* eager_client;
|
eager::EagerClient* eager_client;
|
||||||
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(
|
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
|
||||||
target_and_context_id.first, &eager_client));
|
|
||||||
|
|
||||||
eager_client->RegisterFunctionAsync(
|
eager_client->RegisterFunctionAsync(
|
||||||
&requests[i], &responses[i],
|
&request, &responses[i],
|
||||||
[i, &statuses, &blocking_counter](const Status& status) {
|
[i, &statuses, &blocking_counter](const Status& status) {
|
||||||
statuses[i] = status;
|
statuses[i] = status;
|
||||||
blocking_counter.DecrementCount();
|
blocking_counter.DecrementCount();
|
||||||
@ -561,17 +553,15 @@ Status GetTaskName(Device* d, string* task_name) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
Status EagerContext::GetClientAndContextID(Device* device,
|
Status EagerContext::GetClient(Device* device, eager::EagerClient** client) {
|
||||||
eager::EagerClient** client,
|
|
||||||
uint64* context_id) {
|
|
||||||
if (remote_eager_workers_ == nullptr) {
|
if (remote_eager_workers_ == nullptr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"Haven't set up remote eager worker in this eager context yet.");
|
"Haven't set up remote eager worker in this eager context yet.");
|
||||||
}
|
}
|
||||||
auto it = device_to_client_cache_.find(device);
|
auto it = device_to_client_cache_.find(device);
|
||||||
if (it != device_to_client_cache_.end()) {
|
if (it != device_to_client_cache_.end()) {
|
||||||
*client = it->second.first;
|
*client = it->second;
|
||||||
*context_id = it->second.second;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
string device_task_name;
|
string device_task_name;
|
||||||
TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
|
TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
|
||||||
@ -584,18 +574,19 @@ Status EagerContext::GetClientAndContextID(Device* device,
|
|||||||
"Unable to find eager client corresponding to device ", device->name());
|
"Unable to find eager client corresponding to device ", device->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto context_iterator = remote_contexts_.find(device_task_name);
|
if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
|
||||||
if (context_iterator == remote_contexts_.end()) {
|
device_task_name) == remote_contexts_.end()) {
|
||||||
return errors::Internal("Unable to find a context for handle on task: ",
|
return errors::Internal("Unable to find a context for handle on task: ",
|
||||||
device_task_name, ". This should not be possible");
|
device_task_name, ". This should not be possible");
|
||||||
}
|
}
|
||||||
*context_id = context_iterator->second;
|
|
||||||
|
|
||||||
device_to_client_cache_.insert({device, {*client, *context_id}});
|
device_to_client_cache_.insert({device, *client});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64 EagerContext::GetContextId() { return context_id_; }
|
||||||
|
|
||||||
Status EagerContext::StoreCollectiveOpsServer(
|
Status EagerContext::StoreCollectiveOpsServer(
|
||||||
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
|
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
|
||||||
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
|
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
|
||||||
@ -626,13 +617,13 @@ Status EagerContext::StoreCollectiveOpsServer(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status EagerContext::InitializeRemote(
|
Status EagerContext::InitializeRemoteMaster(
|
||||||
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
|
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
|
||||||
std::shared_ptr<WorkerSession> worker_session,
|
std::shared_ptr<WorkerSession> worker_session,
|
||||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||||
std::unique_ptr<DeviceMgr> remote_device_manager,
|
std::unique_ptr<DeviceMgr> remote_device_manager,
|
||||||
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
|
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||||
DeviceMgr* local_device_mgr, int keep_alive_secs,
|
Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr) {
|
DistributedFunctionLibraryRuntime* cluster_flr) {
|
||||||
mutex_lock l(remote_state_mu_);
|
mutex_lock l(remote_state_mu_);
|
||||||
|
|
||||||
@ -640,6 +631,7 @@ Status EagerContext::InitializeRemote(
|
|||||||
CloseRemoteContexts();
|
CloseRemoteContexts();
|
||||||
}
|
}
|
||||||
remote_contexts_ = remote_contexts;
|
remote_contexts_ = remote_contexts;
|
||||||
|
context_id_ = context_id;
|
||||||
|
|
||||||
use_send_tensor_rpc_ =
|
use_send_tensor_rpc_ =
|
||||||
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
|
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
|
||||||
@ -668,11 +660,6 @@ Status EagerContext::InitializeRemote(
|
|||||||
worker_session_ = worker_session;
|
worker_session_ = worker_session;
|
||||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||||
|
|
||||||
active_remote_contexts_.clear();
|
|
||||||
for (const auto& remote_context : remote_contexts_) {
|
|
||||||
active_remote_contexts_.insert(remote_context.second);
|
|
||||||
}
|
|
||||||
|
|
||||||
device_to_client_cache_.clear();
|
device_to_client_cache_.clear();
|
||||||
remote_device_manager_ = std::move(remote_device_manager);
|
remote_device_manager_ = std::move(remote_device_manager);
|
||||||
|
|
||||||
@ -707,16 +694,15 @@ Status EagerContext::InitializeRemote(
|
|||||||
mutex_lock l(remote_state_mu_);
|
mutex_lock l(remote_state_mu_);
|
||||||
if (keep_alive_secs_ > 0) {
|
if (keep_alive_secs_ > 0) {
|
||||||
{
|
{
|
||||||
for (const auto& worker_and_context_id : remote_contexts_) {
|
for (const auto& worker : remote_contexts_) {
|
||||||
eager::EagerClient* client;
|
eager::EagerClient* client;
|
||||||
Status s = remote_eager_workers_->GetClient(
|
Status s =
|
||||||
worker_and_context_id.first, &client);
|
remote_eager_workers_->GetClient(worker, &client);
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
LOG(WARNING) << "Keep-alive thread was unable to find "
|
LOG(WARNING) << "Keep-alive thread was unable to find "
|
||||||
"a client for target "
|
"a client for target "
|
||||||
<< worker_and_context_id.first
|
<< worker << ". Got error: " << s;
|
||||||
<< ". Got error: " << s;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -725,7 +711,7 @@ Status EagerContext::InitializeRemote(
|
|||||||
eager::KeepAliveResponse* response =
|
eager::KeepAliveResponse* response =
|
||||||
new eager::KeepAliveResponse;
|
new eager::KeepAliveResponse;
|
||||||
|
|
||||||
request->set_context_id(worker_and_context_id.second);
|
request->set_context_id(context_id_);
|
||||||
client->KeepAliveAsync(
|
client->KeepAliveAsync(
|
||||||
request, response,
|
request, response,
|
||||||
[request, response](const Status& s) {
|
[request, response](const Status& s) {
|
||||||
@ -742,6 +728,36 @@ Status EagerContext::InitializeRemote(
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status EagerContext::InitializeRemoteWorker(
|
||||||
|
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||||
|
const DeviceMgr* remote_device_mgr,
|
||||||
|
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||||
|
std::function<Rendezvous*(const int64)> rendezvous_creator) {
|
||||||
|
mutex_lock l(remote_state_mu_);
|
||||||
|
|
||||||
|
if (remote_device_manager_ != nullptr || server_ != nullptr ||
|
||||||
|
keep_alive_thread_ != nullptr) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"EagerContext::InitializeRemoteWorker Failed. ",
|
||||||
|
"Already initialized remote as a master context.");
|
||||||
|
}
|
||||||
|
|
||||||
|
remote_contexts_ = remote_contexts;
|
||||||
|
context_id_ = context_id;
|
||||||
|
|
||||||
|
rendezvous_creator_ = std::move(rendezvous_creator);
|
||||||
|
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||||
|
|
||||||
|
device_to_client_cache_.clear();
|
||||||
|
remote_unowned_device_manager_ = remote_device_mgr;
|
||||||
|
InitDeviceMapAndAsync();
|
||||||
|
|
||||||
|
ClearCaches();
|
||||||
|
executor_.ClearError();
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
#endif // !IS_MOBILE_PLATFORM
|
#endif // !IS_MOBILE_PLATFORM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
// Required for IS_MOBILE_PLATFORM
|
// Required for IS_MOBILE_PLATFORM
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/platform.h"
|
#include "tensorflow/core/platform/platform.h"
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
@ -97,15 +98,13 @@ class RunMetadataListener {
|
|||||||
|
|
||||||
class EagerContext : public core::RefCounted {
|
class EagerContext : public core::RefCounted {
|
||||||
public:
|
public:
|
||||||
EagerContext(
|
EagerContext(const SessionOptions& opts,
|
||||||
const SessionOptions& opts,
|
ContextDevicePlacementPolicy default_device_placement_policy,
|
||||||
ContextDevicePlacementPolicy default_device_placement_policy,
|
ContextMirroringPolicy default_mirroring_policy, bool async,
|
||||||
ContextMirroringPolicy default_mirroring_policy, bool async,
|
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
||||||
const DeviceMgr* device_mgr, bool device_mgr_owned,
|
Rendezvous* rendezvous,
|
||||||
Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator,
|
const CustomKernelCreator* custom_kernel_creator,
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr = nullptr,
|
DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
|
||||||
std::function<Rendezvous*(const int64)> rendezvous_creator = nullptr,
|
|
||||||
const DeviceMgr* remote_device_mgr = nullptr);
|
|
||||||
|
|
||||||
~EagerContext() override;
|
~EagerContext() override;
|
||||||
|
|
||||||
@ -254,13 +253,16 @@ class EagerContext : public core::RefCounted {
|
|||||||
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
|
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
|
||||||
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
#if !defined(IS_MOBILE_PLATFORM)
|
||||||
Status GetClientAndContextID(Device* device, eager::EagerClient** client,
|
Status GetClient(Device* device, eager::EagerClient** client);
|
||||||
uint64* context_id);
|
|
||||||
|
uint64 GetContextId();
|
||||||
|
|
||||||
// TODO(nareshmodi): Encapsulate remote state into a separate
|
// TODO(nareshmodi): Encapsulate remote state into a separate
|
||||||
// class/struct.
|
// class/struct.
|
||||||
//
|
//
|
||||||
// Enables the eager context to communicate with remote devices.
|
// Enables the eager context to communicate with remote devices. When
|
||||||
|
// initializing with this method, this context will be the master context,
|
||||||
|
// which will kill all its slaves in shutdown.
|
||||||
//
|
//
|
||||||
// - server: A ServerInterface that exports the tensorflow.WorkerService.
|
// - server: A ServerInterface that exports the tensorflow.WorkerService.
|
||||||
// Note that this class expects the server to already have been started.
|
// Note that this class expects the server to already have been started.
|
||||||
@ -268,20 +270,23 @@ class EagerContext : public core::RefCounted {
|
|||||||
// communicate with remote eager services.
|
// communicate with remote eager services.
|
||||||
// - remote_device_mgr: A DeviceMgr* which contains all remote devices
|
// - remote_device_mgr: A DeviceMgr* which contains all remote devices
|
||||||
// (should contain no local devices).
|
// (should contain no local devices).
|
||||||
// - remote_contexts: A map containing task name to remote context ID.
|
// - remote_contexts: A vector containing task names.
|
||||||
Status InitializeRemote(
|
Status InitializeRemoteMaster(
|
||||||
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
|
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
|
||||||
std::shared_ptr<WorkerSession> worker_session,
|
std::shared_ptr<WorkerSession> worker_session,
|
||||||
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||||
std::unique_ptr<DeviceMgr> remote_device_manager,
|
std::unique_ptr<DeviceMgr> remote_device_manager,
|
||||||
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r,
|
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||||
DeviceMgr* local_device_mgr, int keep_alive_secs,
|
Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
|
||||||
DistributedFunctionLibraryRuntime* cluster_flr);
|
DistributedFunctionLibraryRuntime* cluster_flr);
|
||||||
|
|
||||||
bool HasActiveRemoteContext(uint64 context_id) {
|
// Similar with InitializeRemoteMaster but this context will not kill remote
|
||||||
return active_remote_contexts_.find(context_id) !=
|
// contexts in shutdown.
|
||||||
active_remote_contexts_.end();
|
Status InitializeRemoteWorker(
|
||||||
}
|
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
|
||||||
|
const DeviceMgr* remote_device_mgr,
|
||||||
|
const std::vector<string>& remote_contexts, uint64 context_id,
|
||||||
|
std::function<Rendezvous*(const int64)> rendezvous_creator);
|
||||||
|
|
||||||
Status StoreCollectiveOpsServer(
|
Status StoreCollectiveOpsServer(
|
||||||
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
|
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
|
||||||
@ -328,7 +333,7 @@ class EagerContext : public core::RefCounted {
|
|||||||
// Only one of the below is set. remote_unowned_device_manager_ is set on
|
// Only one of the below is set. remote_unowned_device_manager_ is set on
|
||||||
// remote worker to allow running multi-device function on remote worker.
|
// remote worker to allow running multi-device function on remote worker.
|
||||||
std::unique_ptr<DeviceMgr> remote_device_manager_;
|
std::unique_ptr<DeviceMgr> remote_device_manager_;
|
||||||
const DeviceMgr* remote_unowned_device_manager_;
|
const DeviceMgr* remote_unowned_device_manager_ = nullptr;
|
||||||
|
|
||||||
// Devices owned by device_manager
|
// Devices owned by device_manager
|
||||||
std::vector<Device*> devices_;
|
std::vector<Device*> devices_;
|
||||||
@ -406,10 +411,9 @@ class EagerContext : public core::RefCounted {
|
|||||||
|
|
||||||
mutex remote_state_mu_;
|
mutex remote_state_mu_;
|
||||||
|
|
||||||
gtl::FlatMap<string, uint64> remote_contexts_;
|
uint64 context_id_;
|
||||||
gtl::FlatSet<uint64> active_remote_contexts_;
|
std::vector<string> remote_contexts_;
|
||||||
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>>
|
gtl::FlatMap<Device*, eager::EagerClient*> device_to_client_cache_;
|
||||||
device_to_client_cache_;
|
|
||||||
|
|
||||||
int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
|
int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
|
||||||
std::atomic<int> sleep_for_secs_;
|
std::atomic<int> sleep_for_secs_;
|
||||||
|
@ -644,9 +644,8 @@ Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
|
|||||||
Device* recv_device, bool mirror,
|
Device* recv_device, bool mirror,
|
||||||
TensorHandle** result) {
|
TensorHandle** result) {
|
||||||
eager::EagerClient* eager_client;
|
eager::EagerClient* eager_client;
|
||||||
uint64 context_id;
|
uint64 context_id = ctx->GetContextId();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ctx->GetClient(recv_device, &eager_client));
|
||||||
ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
|
|
||||||
|
|
||||||
eager::SendTensorRequest request;
|
eager::SendTensorRequest request;
|
||||||
eager::SendTensorResponse response;
|
eager::SendTensorResponse response;
|
||||||
@ -750,9 +749,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
|||||||
EagerContext* ctx = op->EagerContext();
|
EagerContext* ctx = op->EagerContext();
|
||||||
|
|
||||||
eager::EagerClient* eager_client;
|
eager::EagerClient* eager_client;
|
||||||
uint64 context_id;
|
uint64 context_id = ctx->GetContextId();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ctx->GetClient(op->Device(), &eager_client));
|
||||||
ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
|
|
||||||
|
|
||||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||||
eager::EnqueueResponse response;
|
eager::EnqueueResponse response;
|
||||||
@ -1207,9 +1205,9 @@ Status ExecuteSend(EagerContext* ctx, Device* device, TensorHandle* h,
|
|||||||
kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr));
|
kernel->Run(input_vector, nullptr, nullptr, nullptr, nullptr));
|
||||||
} else {
|
} else {
|
||||||
eager::EagerClient* eager_client;
|
eager::EagerClient* eager_client;
|
||||||
uint64 context_id;
|
uint64 context_id = ctx->GetContextId();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ctx->GetClientAndContextID(device, &eager_client, &context_id));
|
ctx->GetClient(device, &eager_client));
|
||||||
|
|
||||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||||
eager::EnqueueResponse response;
|
eager::EnqueueResponse response;
|
||||||
@ -1279,9 +1277,9 @@ Status ExecuteRecv(EagerContext* ctx, Device* device, DataType dtype,
|
|||||||
/* op_device= */ kernel->device(), ctx, result));
|
/* op_device= */ kernel->device(), ctx, result));
|
||||||
} else {
|
} else {
|
||||||
eager::EagerClient* eager_client;
|
eager::EagerClient* eager_client;
|
||||||
uint64 context_id;
|
uint64 context_id = ctx->GetContextId();
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ctx->GetClientAndContextID(device, &eager_client, &context_id));
|
ctx->GetClient(device, &eager_client));
|
||||||
|
|
||||||
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
|
||||||
eager::EnqueueResponse response;
|
eager::EnqueueResponse response;
|
||||||
|
@ -110,6 +110,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":worker_session",
|
":worker_session",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
@ -274,11 +275,7 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "worker_cache_wrapper",
|
name = "worker_cache_wrapper",
|
||||||
hdrs = ["worker_cache_wrapper.h"],
|
hdrs = ["worker_cache_wrapper.h"],
|
||||||
deps = [
|
deps = [":worker_cache"],
|
||||||
":worker_cache",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -411,7 +408,6 @@ cc_library(
|
|||||||
hdrs = ["master_env.h"],
|
hdrs = ["master_env.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:session_options",
|
"//tensorflow/core:session_options",
|
||||||
],
|
],
|
||||||
@ -489,16 +485,13 @@ cc_library(
|
|||||||
srcs = ["rpc_collective_executor_mgr.cc"],
|
srcs = ["rpc_collective_executor_mgr.cc"],
|
||||||
hdrs = ["rpc_collective_executor_mgr.h"],
|
hdrs = ["rpc_collective_executor_mgr.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":base_rendezvous_mgr",
|
|
||||||
":collective_param_resolver_distributed",
|
":collective_param_resolver_distributed",
|
||||||
":collective_rma_distributed",
|
":collective_rma_distributed",
|
||||||
":device_resolver_distributed",
|
":device_resolver_distributed",
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -530,7 +523,6 @@ cc_library(
|
|||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:lib_internal", # protobuf::Any
|
"//tensorflow/core:lib_internal", # protobuf::Any
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
@ -561,15 +553,11 @@ cc_library(
|
|||||||
srcs = ["collective_param_resolver_distributed.cc"],
|
srcs = ["collective_param_resolver_distributed.cc"],
|
||||||
hdrs = ["collective_param_resolver_distributed.h"],
|
hdrs = ["collective_param_resolver_distributed.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":call_options",
|
|
||||||
":cancellable_call",
|
":cancellable_call",
|
||||||
":device_resolver_distributed",
|
":device_resolver_distributed",
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -615,9 +603,7 @@ cc_library(
|
|||||||
":worker_cache",
|
":worker_cache",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -63,7 +63,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":remote_tensor_handle",
|
":remote_tensor_handle",
|
||||||
"//tensorflow:grpc",
|
|
||||||
"//tensorflow:grpc++",
|
"//tensorflow:grpc++",
|
||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
@ -104,6 +103,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||||
|
"//tensorflow/core/distributed_runtime:test_utils",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||||
],
|
],
|
||||||
|
@ -98,8 +98,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
cluster_device_attributes.push_back(cluster_device);
|
cluster_device_attributes.push_back(cluster_device);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id());
|
auto* r = env_->rendezvous_mgr->Find(request->context_id());
|
||||||
auto session_name = strings::StrCat("eager_", request->rendezvous_id());
|
auto session_name =
|
||||||
|
tensorflow::strings::StrCat("eager_", request->context_id());
|
||||||
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
|
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
|
||||||
session_name, request->server_def(), request->cluster_device_attributes(),
|
session_name, request->server_def(), request->cluster_device_attributes(),
|
||||||
true));
|
true));
|
||||||
@ -125,8 +126,29 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
|
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
|
||||||
device_mgr, false, r, GetDefaultCustomKernelCreator(),
|
device_mgr, false, r, GetDefaultCustomKernelCreator(),
|
||||||
worker_session->cluster_flr.get(), std::move(rendezvous_creator),
|
worker_session->cluster_flr.get());
|
||||||
worker_session->remote_device_mgr());
|
|
||||||
|
Status s;
|
||||||
|
std::vector<string> remote_workers;
|
||||||
|
worker_session->worker_cache->ListWorkers(&remote_workers);
|
||||||
|
remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
|
||||||
|
worker_session->worker_name),
|
||||||
|
remote_workers.end());
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
|
||||||
|
s = worker_session->worker_cache->GetEagerClientCache(&remote_eager_workers);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete ctx;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
s = ctx->InitializeRemoteWorker(
|
||||||
|
std::move(remote_eager_workers), worker_session->remote_device_mgr(),
|
||||||
|
remote_workers, request->context_id(), std::move(rendezvous_creator));
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete ctx;
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<DeviceAttributes> device_attributes;
|
std::vector<DeviceAttributes> device_attributes;
|
||||||
device_mgr->ListDeviceAttributes(&device_attributes);
|
device_mgr->ListDeviceAttributes(&device_attributes);
|
||||||
@ -134,17 +156,17 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
|
|||||||
for (const auto& da : device_attributes) {
|
for (const auto& da : device_attributes) {
|
||||||
*response->add_device_attributes() = da;
|
*response->add_device_attributes() = da;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64 context_id;
|
|
||||||
{
|
{
|
||||||
mutex_lock l(contexts_mu_);
|
mutex_lock l(contexts_mu_);
|
||||||
do {
|
if (contexts_.find(request->context_id()) != contexts_.end()) {
|
||||||
context_id = random::New64();
|
delete ctx;
|
||||||
} while (contexts_.find(context_id) != contexts_.end());
|
return errors::InvalidArgument("EagerService:CreateContext failed. ",
|
||||||
contexts_.emplace(context_id,
|
"Context id: <", request->context_id(),
|
||||||
|
"> already exists.");
|
||||||
|
}
|
||||||
|
contexts_.emplace(request->context_id(),
|
||||||
new ServerContext(ctx, request->keep_alive_secs(), env_));
|
new ServerContext(ctx, request->keep_alive_secs(), env_));
|
||||||
}
|
}
|
||||||
response->set_context_id(context_id);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -51,16 +52,30 @@ class TestEagerServiceImpl : public EagerServiceImpl {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DummyEagerClientCache : public EagerClientCache {
|
||||||
|
Status GetClient(const string& target, EagerClient** client) override {
|
||||||
|
return errors::Unimplemented("");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class FakeCache : public TestWorkerCache {
|
||||||
|
Status GetEagerClientCache(
|
||||||
|
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
|
||||||
|
eager_client_cache->reset(new DummyEagerClientCache);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class EagerServiceImplTest : public ::testing::Test {
|
class EagerServiceImplTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
EagerServiceImplTest()
|
EagerServiceImplTest()
|
||||||
: rendezvous_mgr_(&worker_env_),
|
: rendezvous_mgr_(&worker_env_),
|
||||||
session_mgr_(new SessionMgr(
|
session_mgr_(new SessionMgr(
|
||||||
&worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
|
&worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
|
||||||
std::unique_ptr<WorkerCacheInterface>(),
|
std::unique_ptr<WorkerCacheInterface>(new FakeCache),
|
||||||
[](const ServerDef& server_def,
|
[](const ServerDef& server_def,
|
||||||
WorkerCacheInterface** worker_cache) {
|
WorkerCacheInterface** worker_cache) {
|
||||||
*worker_cache = nullptr;
|
*worker_cache = new FakeCache;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})) {
|
})) {
|
||||||
worker_env_.env = Env::Default();
|
worker_env_.env = Env::Default();
|
||||||
@ -153,15 +168,16 @@ tensorflow::FunctionDef MatMulFunction() {
|
|||||||
TEST_F(EagerServiceImplTest, BasicTest) {
|
TEST_F(EagerServiceImplTest, BasicTest) {
|
||||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||||
|
|
||||||
|
uint64 context_id = random::New64();
|
||||||
|
|
||||||
CreateContextRequest request;
|
CreateContextRequest request;
|
||||||
request.mutable_server_def()->set_job_name("localhost");
|
request.mutable_server_def()->set_job_name("localhost");
|
||||||
request.mutable_server_def()->set_task_index(0);
|
request.mutable_server_def()->set_task_index(0);
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(context_id);
|
||||||
CreateContextResponse response;
|
CreateContextResponse response;
|
||||||
|
|
||||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||||
|
|
||||||
uint64 context_id = response.context_id();
|
|
||||||
|
|
||||||
EnqueueRequest remote_enqueue_request;
|
EnqueueRequest remote_enqueue_request;
|
||||||
remote_enqueue_request.set_context_id(context_id);
|
remote_enqueue_request.set_context_id(context_id);
|
||||||
@ -202,7 +218,7 @@ TEST_F(EagerServiceImplTest, BasicTest) {
|
|||||||
|
|
||||||
tensorflow::TensorHandle* tensor_handle;
|
tensorflow::TensorHandle* tensor_handle;
|
||||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||||
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||||
|
|
||||||
// This should be OK to do since we've placed all computation on the CPU
|
// This should be OK to do since we've placed all computation on the CPU
|
||||||
// device.
|
// device.
|
||||||
@ -229,16 +245,16 @@ TEST_F(EagerServiceImplTest, BasicTest) {
|
|||||||
TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
||||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||||
|
|
||||||
|
uint64 context_id = random::New64();
|
||||||
|
|
||||||
CreateContextRequest request;
|
CreateContextRequest request;
|
||||||
request.mutable_server_def()->set_job_name("localhost");
|
request.mutable_server_def()->set_job_name("localhost");
|
||||||
request.mutable_server_def()->set_task_index(0);
|
request.mutable_server_def()->set_task_index(0);
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(context_id);
|
||||||
CreateContextResponse response;
|
CreateContextResponse response;
|
||||||
|
|
||||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||||
|
|
||||||
uint64 context_id = response.context_id();
|
|
||||||
|
|
||||||
RegisterFunctionRequest register_function_request;
|
RegisterFunctionRequest register_function_request;
|
||||||
register_function_request.set_context_id(context_id);
|
register_function_request.set_context_id(context_id);
|
||||||
*register_function_request.mutable_function_def() = MatMulFunction();
|
*register_function_request.mutable_function_def() = MatMulFunction();
|
||||||
@ -273,7 +289,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
|||||||
const tensorflow::Tensor* t = nullptr;
|
const tensorflow::Tensor* t = nullptr;
|
||||||
tensorflow::TensorHandle* tensor_handle;
|
tensorflow::TensorHandle* tensor_handle;
|
||||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||||
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||||
|
|
||||||
auto actual = t->flat<float>();
|
auto actual = t->flat<float>();
|
||||||
@ -296,15 +312,16 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
|
|||||||
TEST_F(EagerServiceImplTest, SendTensorTest) {
|
TEST_F(EagerServiceImplTest, SendTensorTest) {
|
||||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||||
|
|
||||||
|
uint64 context_id = random::New64();
|
||||||
|
|
||||||
CreateContextRequest request;
|
CreateContextRequest request;
|
||||||
request.mutable_server_def()->set_job_name("localhost");
|
request.mutable_server_def()->set_job_name("localhost");
|
||||||
request.mutable_server_def()->set_task_index(0);
|
request.mutable_server_def()->set_task_index(0);
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(context_id);
|
||||||
CreateContextResponse response;
|
CreateContextResponse response;
|
||||||
|
|
||||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||||
|
|
||||||
uint64 context_id = response.context_id();
|
|
||||||
|
|
||||||
SendTensorRequest send_tensor_request;
|
SendTensorRequest send_tensor_request;
|
||||||
send_tensor_request.set_context_id(context_id);
|
send_tensor_request.set_context_id(context_id);
|
||||||
@ -339,7 +356,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
|||||||
const tensorflow::Tensor* t = nullptr;
|
const tensorflow::Tensor* t = nullptr;
|
||||||
tensorflow::TensorHandle* tensor_handle;
|
tensorflow::TensorHandle* tensor_handle;
|
||||||
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
|
||||||
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
|
||||||
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
TF_ASSERT_OK(tensor_handle->Tensor(&t));
|
||||||
|
|
||||||
Device* device = tensor_handle->device();
|
Device* device = tensor_handle->device();
|
||||||
@ -364,10 +381,11 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
|
|||||||
TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
||||||
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
TestEagerServiceImpl eager_service_impl(&worker_env_);
|
||||||
|
|
||||||
|
uint64 context_id = random::New64();
|
||||||
CreateContextRequest request;
|
CreateContextRequest request;
|
||||||
request.mutable_server_def()->set_job_name("localhost");
|
request.mutable_server_def()->set_job_name("localhost");
|
||||||
request.mutable_server_def()->set_task_index(0);
|
request.mutable_server_def()->set_task_index(0);
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(context_id);
|
||||||
request.set_keep_alive_secs(3);
|
request.set_keep_alive_secs(3);
|
||||||
CreateContextResponse response;
|
CreateContextResponse response;
|
||||||
|
|
||||||
@ -379,7 +397,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
|||||||
KeepAliveRequest keep_alive_request;
|
KeepAliveRequest keep_alive_request;
|
||||||
KeepAliveResponse keep_alive_response;
|
KeepAliveResponse keep_alive_response;
|
||||||
|
|
||||||
keep_alive_request.set_context_id(response.context_id());
|
keep_alive_request.set_context_id(context_id);
|
||||||
|
|
||||||
Status status =
|
Status status =
|
||||||
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
|
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
|
||||||
@ -388,15 +406,16 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
|
|||||||
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
|
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
|
||||||
status.error_message());
|
status.error_message());
|
||||||
|
|
||||||
|
uint64 new_context_id = random::New64();
|
||||||
// Create a new context.
|
// Create a new context.
|
||||||
request.set_rendezvous_id(random::New64());
|
request.set_context_id(new_context_id);
|
||||||
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
|
||||||
|
|
||||||
// The context should not be GC'd.
|
// The context should not be GC'd.
|
||||||
worker_env_.env->SleepForMicroseconds(1 *
|
worker_env_.env->SleepForMicroseconds(1 *
|
||||||
tensorflow::EnvTime::kSecondsToMicros);
|
tensorflow::EnvTime::kSecondsToMicros);
|
||||||
|
|
||||||
keep_alive_request.set_context_id(response.context_id());
|
keep_alive_request.set_context_id(new_context_id);
|
||||||
|
|
||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
|
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));
|
||||||
|
@ -29,7 +29,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
|
|||||||
int output_num) {
|
int output_num) {
|
||||||
auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); });
|
auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); });
|
||||||
|
|
||||||
if (!ctx->HasActiveRemoteContext(context_id)) {
|
if (ctx->GetContextId() != context_id) {
|
||||||
// This means that this tensor was pointing to a remote device, which
|
// This means that this tensor was pointing to a remote device, which
|
||||||
// has been changed out from under us. Simply return since there is
|
// has been changed out from under us. Simply return since there is
|
||||||
// nothing we can do.
|
// nothing we can do.
|
||||||
|
@ -67,23 +67,25 @@ message CreateContextRequest {
|
|||||||
// This is the version for all the ops that will be enqueued by the client.
|
// This is the version for all the ops that will be enqueued by the client.
|
||||||
VersionDef version_def = 4;
|
VersionDef version_def = 4;
|
||||||
|
|
||||||
// This ID will be used for all future communications. It is essential that
|
|
||||||
// both ends use this ID for selecting a rendezvous to get everything to
|
|
||||||
// match.
|
|
||||||
int64 rendezvous_id = 5;
|
|
||||||
|
|
||||||
// Device attributes in the cluster
|
// Device attributes in the cluster
|
||||||
repeated DeviceAttributes cluster_device_attributes = 6;
|
repeated DeviceAttributes cluster_device_attributes = 6;
|
||||||
}
|
|
||||||
|
|
||||||
message CreateContextResponse {
|
|
||||||
// The ID of the created context. This is usually a randomly generated number,
|
// The ID of the created context. This is usually a randomly generated number,
|
||||||
// that will be used to identify the context in future requests to the
|
// that will be used to identify the context in future requests to the
|
||||||
// service. Contexts are not persisted through server restarts.
|
// service. Contexts are not persisted through server restarts.
|
||||||
fixed64 context_id = 1;
|
// This ID will be used for all future communications as well. It is essential
|
||||||
|
// that both ends use this ID for selecting a rendezvous to get everything to
|
||||||
|
// match.
|
||||||
|
fixed64 context_id = 7;
|
||||||
|
|
||||||
|
reserved 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CreateContextResponse {
|
||||||
// List of devices that are locally accessible to the worker.
|
// List of devices that are locally accessible to the worker.
|
||||||
repeated DeviceAttributes device_attributes = 2;
|
repeated DeviceAttributes device_attributes = 2;
|
||||||
|
|
||||||
|
reserved 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message EnqueueRequest {
|
message EnqueueRequest {
|
||||||
|
Loading…
Reference in New Issue
Block a user