Introduce a new method InitializeRemoteWorker in EagerContext which allow us to set up EagerClient cache on workers. This change is required to issue tensor copies lazily.

This CL also includes following changes:
 - Rename InitializeRemote -> InitializeRemoteMaster.
 - Use the same context_id in all workers and generate it on master eager context.
 - Remove rendezvous_id. Use context_id instead.

PiperOrigin-RevId: 254313879
This commit is contained in:
Xiao Yu 2019-06-20 18:34:45 -07:00 committed by TensorFlower Gardener
parent 155207e65f
commit 81cc1b91cc
15 changed files with 299 additions and 164 deletions

View File

@ -135,17 +135,16 @@ tensorflow::Status GetAllRemoteDevices(
} }
tensorflow::Status CreateRemoteContexts( tensorflow::Status CreateRemoteContexts(
const std::vector<string>& remote_workers, int64 rendezvous_id, const std::vector<string>& remote_workers, tensorflow::uint64 context_id,
int keep_alive_secs, const tensorflow::ServerDef& server_def, int keep_alive_secs, const tensorflow::ServerDef& server_def,
tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
const tensorflow::eager::CreateContextRequest& base_request, const tensorflow::eager::CreateContextRequest& base_request) {
tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
for (int i = 0; i < remote_workers.size(); i++) { for (int i = 0; i < remote_workers.size(); i++) {
const string& remote_worker = remote_workers[i]; const string& remote_worker = remote_workers[i];
tensorflow::eager::CreateContextRequest request(base_request); tensorflow::eager::CreateContextRequest request(base_request);
tensorflow::eager::CreateContextResponse response; tensorflow::eager::CreateContextResponse response;
request.set_rendezvous_id(rendezvous_id); request.set_context_id(context_id);
tensorflow::DeviceNameUtils::ParsedName parsed_name; tensorflow::DeviceNameUtils::ParsedName parsed_name;
if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
&parsed_name)) { &parsed_name)) {
@ -174,8 +173,6 @@ tensorflow::Status CreateRemoteContexts(
}); });
n.WaitForNotification(); n.WaitForNotification();
TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(status);
remote_contexts->emplace(remote_worker, response.context_id());
} }
return tensorflow::Status::OK(); return tensorflow::Status::OK();
} }
@ -212,7 +209,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
int64 rendezvous_id = tensorflow::random::New64(); tensorflow::uint64 context_id = tensorflow::random::New64();
std::vector<string> remote_workers; std::vector<string> remote_workers;
grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
@ -242,22 +239,20 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
*base_request.add_cluster_device_attributes() = da; *base_request.add_cluster_device_attributes() = da;
} }
std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache = std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
grpc_server->channel_cache(); tensorflow::WorkerCacheFactoryOptions options(server_def);
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers( LOG_AND_RETURN_IF_ERROR(
tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); grpc_server->EagerClientCacheFactory(options, &remote_eager_workers));
// Initialize remote eager workers. // Initialize remote eager workers.
tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
remote_workers, rendezvous_id, keep_alive_secs, server_def, remote_workers, context_id, keep_alive_secs, server_def,
remote_eager_workers.get(), ctx->context->Async(), base_request, remote_eager_workers.get(), ctx->context->Async(), base_request));
&remote_contexts));
tensorflow::RemoteRendezvous* r = tensorflow::RemoteRendezvous* r =
grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); auto session_name = tensorflow::strings::StrCat("eager_", context_id);
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(), session_name, server_def, base_request.cluster_device_attributes(),
true)); true));
@ -272,10 +267,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
auto* device_mgr = grpc_server->worker_env()->device_mgr; auto* device_mgr = grpc_server->worker_env()->device_mgr;
return ctx->context->InitializeRemote( return ctx->context->InitializeRemoteMaster(
std::move(server), grpc_server->worker_env(), worker_session, std::move(server), grpc_server->worker_env(), worker_session,
std::move(remote_eager_workers), std::move(remote_device_mgr), std::move(remote_eager_workers), std::move(remote_device_mgr),
remote_contexts, r, device_mgr, keep_alive_secs, remote_workers, context_id, r, device_mgr, keep_alive_secs,
worker_session->cluster_flr.get()); worker_session->cluster_flr.get());
#undef LOG_AND_RETURN_IF_ERROR #undef LOG_AND_RETURN_IF_ERROR
} }

View File

@ -76,19 +76,20 @@ TEST_F(XrtClientTest, XrtGrpcEagerClientWorks) {
// Create and destroy a context to verify we can make RPCs. // Create and destroy a context to verify we can make RPCs.
eager::CreateContextRequest request; eager::CreateContextRequest request;
uint64 context_id = random::New64();
ServerDef* server_def = request.mutable_server_def(); ServerDef* server_def = request.mutable_server_def();
*server_def->mutable_cluster() = cluster_def_; *server_def->mutable_cluster() = cluster_def_;
server_def->set_job_name("localhost"); server_def->set_job_name("localhost");
server_def->set_protocol("grpc"); server_def->set_protocol("grpc");
request.set_keep_alive_secs(60); request.set_keep_alive_secs(60);
request.set_rendezvous_id(random::New64()); request.set_context_id(context_id);
eager::CreateContextResponse create_response; eager::CreateContextResponse create_response;
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync, TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync,
&request, &create_response)); &request, &create_response));
eager::CloseContextRequest close_request; eager::CloseContextRequest close_request;
close_request.set_context_id(create_response.context_id()); close_request.set_context_id(context_id);
eager::CloseContextResponse close_response; eager::CloseContextResponse close_response;
TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync, TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync,

View File

@ -45,7 +45,7 @@ XrtTfClient::XrtTfClient(ClusterDef cluster_def,
xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create( xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
const XrtTfContext::Options& options, const XrtTfContext::Options& options,
std::shared_ptr<XrtTfClient> tf_client, const std::string& job, int task) { std::shared_ptr<XrtTfClient> tf_client, const std::string& job, int task) {
int64 rendezvous_id = random::New64(); int64 context_id = random::New64();
eager::CreateContextRequest request; eager::CreateContextRequest request;
ServerDef* server_def = request.mutable_server_def(); ServerDef* server_def = request.mutable_server_def();
@ -53,7 +53,7 @@ xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
server_def->set_job_name(job); server_def->set_job_name(job);
server_def->set_protocol("grpc"); server_def->set_protocol("grpc");
request.set_keep_alive_secs(60); request.set_keep_alive_secs(60);
request.set_rendezvous_id(rendezvous_id); request.set_context_id(context_id);
request.set_async(options.async); request.set_async(options.async);
eager::CreateContextResponse response; eager::CreateContextResponse response;
@ -98,8 +98,9 @@ xla::StatusOr<std::shared_ptr<XrtTfContext>> XrtTfContext::Create(
return a.name() < b.name(); return a.name() < b.name();
}); });
return std::make_shared<XrtTfContext>(options, tf_client, eager_client, return std::make_shared<XrtTfContext>(options, tf_client, eager_client,
rendezvous_id, response.context_id(), /*rendezvous_id=*/context_id,
std::move(devices), cpu_device_id); context_id, std::move(devices),
cpu_device_id);
} }
XrtTfContext::XrtTfContext(const XrtTfContext::Options& options, XrtTfContext::XrtTfContext(const XrtTfContext::Options& options,

View File

@ -64,15 +64,11 @@ EagerContext::EagerContext(
ContextMirroringPolicy default_mirroring_policy, bool async, ContextMirroringPolicy default_mirroring_policy, bool async,
const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous, const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous,
const CustomKernelCreator* custom_kernel_creator, const CustomKernelCreator* custom_kernel_creator,
DistributedFunctionLibraryRuntime* cluster_flr, DistributedFunctionLibraryRuntime* cluster_flr)
std::function<Rendezvous*(const int64)> rendezvous_creator,
const DeviceMgr* remote_device_mgr)
: default_device_placement_policy_(default_device_placement_policy), : default_device_placement_policy_(default_device_placement_policy),
default_mirroring_policy_(default_mirroring_policy), default_mirroring_policy_(default_mirroring_policy),
remote_unowned_device_manager_(remote_device_mgr),
devices_(device_mgr->ListDevices()), devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous), rendezvous_(rendezvous),
rendezvous_creator_(std::move(rendezvous_creator)),
thread_pool_(NewThreadPoolFromSessionOptions(opts)), thread_pool_(NewThreadPoolFromSessionOptions(opts)),
pflr_(new ProcessFunctionLibraryRuntime( pflr_(new ProcessFunctionLibraryRuntime(
device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_,
@ -209,25 +205,22 @@ bool EagerContext::MirrorTensors() const {
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
void EagerContext::CloseRemoteContexts() { void EagerContext::CloseRemoteContexts() {
// Close all remote contexts. // Close all remote contexts.
std::vector<eager::CloseContextRequest> requests(remote_contexts_.size()); eager::CloseContextRequest request;
request.set_context_id(context_id_);
std::vector<eager::CloseContextResponse> responses(remote_contexts_.size()); std::vector<eager::CloseContextResponse> responses(remote_contexts_.size());
BlockingCounter counter(static_cast<int>(remote_contexts_.size())); BlockingCounter counter(static_cast<int>(remote_contexts_.size()));
int i = 0; int i = 0;
for (const auto& worker_and_context_id : remote_contexts_) { for (const auto& worker : remote_contexts_) {
eager::EagerClient* client; eager::EagerClient* client;
Status s = Status s = remote_eager_workers_->GetClient(worker, &client);
remote_eager_workers_->GetClient(worker_and_context_id.first, &client);
requests[i].set_context_id(worker_and_context_id.second);
client->CloseContextAsync( client->CloseContextAsync(
&requests[i], &responses[i], &request, &responses[i], [this, &worker, &counter](const Status& s) {
[&worker_and_context_id, &counter](const Status& s) {
if (!s.ok()) { if (!s.ok()) {
LOG(ERROR) << "Unable to close remote context with ID " LOG(ERROR) << "Unable to close remote context with ID "
<< worker_and_context_id.second << context_id_ << " for worker: " << worker << " due to "
<< " for worker: " << worker_and_context_id.first << s.error_message();
<< " due to " << s.error_message();
} }
counter.DecrementCount(); counter.DecrementCount();
}); });
@ -261,8 +254,9 @@ EagerContext::~EagerContext() {
keep_alive_thread_cv_.notify_all(); keep_alive_thread_cv_.notify_all();
} }
keep_alive_thread_.reset(); keep_alive_thread_.reset();
if (!remote_contexts_.empty() && keep_alive_thread_ != nullptr) {
CloseRemoteContexts(); CloseRemoteContexts();
}
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
executor_.WaitForAllPendingNodes().IgnoreError(); executor_.WaitForAllPendingNodes().IgnoreError();
@ -368,22 +362,20 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) {
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size())); BlockingCounter blocking_counter(static_cast<int>(remote_contexts_.size()));
std::vector<eager::RegisterFunctionRequest> requests(remote_contexts_.size()); eager::RegisterFunctionRequest request;
request.set_context_id(context_id_);
*request.mutable_function_def() = fdef;
std::vector<eager::RegisterFunctionResponse> responses( std::vector<eager::RegisterFunctionResponse> responses(
remote_contexts_.size()); remote_contexts_.size());
std::vector<Status> statuses(remote_contexts_.size()); std::vector<Status> statuses(remote_contexts_.size());
int i = 0; int i = 0;
for (const auto& target_and_context_id : remote_contexts_) { for (const auto& target : remote_contexts_) {
requests[i].set_context_id(target_and_context_id.second);
*requests[i].mutable_function_def() = fdef;
eager::EagerClient* eager_client; eager::EagerClient* eager_client;
TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient( TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client));
target_and_context_id.first, &eager_client));
eager_client->RegisterFunctionAsync( eager_client->RegisterFunctionAsync(
&requests[i], &responses[i], &request, &responses[i],
[i, &statuses, &blocking_counter](const Status& status) { [i, &statuses, &blocking_counter](const Status& status) {
statuses[i] = status; statuses[i] = status;
blocking_counter.DecrementCount(); blocking_counter.DecrementCount();
@ -559,17 +551,15 @@ Status GetTaskName(Device* d, string* task_name) {
} // namespace } // namespace
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
Status EagerContext::GetClientAndContextID(Device* device, Status EagerContext::GetClient(Device* device, eager::EagerClient** client) {
eager::EagerClient** client,
uint64* context_id) {
if (remote_eager_workers_ == nullptr) { if (remote_eager_workers_ == nullptr) {
return errors::Internal( return errors::Internal(
"Haven't set up remote eager worker in this eager context yet."); "Haven't set up remote eager worker in this eager context yet.");
} }
auto it = device_to_client_cache_.find(device); auto it = device_to_client_cache_.find(device);
if (it != device_to_client_cache_.end()) { if (it != device_to_client_cache_.end()) {
*client = it->second.first; *client = it->second;
*context_id = it->second.second; return Status::OK();
} }
string device_task_name; string device_task_name;
TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name)); TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name));
@ -582,18 +572,19 @@ Status EagerContext::GetClientAndContextID(Device* device,
"Unable to find eager client corresponding to device ", device->name()); "Unable to find eager client corresponding to device ", device->name());
} }
auto context_iterator = remote_contexts_.find(device_task_name); if (std::find(remote_contexts_.begin(), remote_contexts_.end(),
if (context_iterator == remote_contexts_.end()) { device_task_name) == remote_contexts_.end()) {
return errors::Internal("Unable to find a context for handle on task: ", return errors::Internal("Unable to find a context for handle on task: ",
device_task_name, ". This should not be possible"); device_task_name, ". This should not be possible");
} }
*context_id = context_iterator->second;
device_to_client_cache_.insert({device, {*client, *context_id}}); device_to_client_cache_.insert({device, *client});
return Status::OK(); return Status::OK();
} }
uint64 EagerContext::GetContextId() { return context_id_; }
Status EagerContext::StoreCollectiveOpsServer( Status EagerContext::StoreCollectiveOpsServer(
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr, std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) {
@ -624,13 +615,13 @@ Status EagerContext::StoreCollectiveOpsServer(
return Status::OK(); return Status::OK();
} }
Status EagerContext::InitializeRemote( Status EagerContext::InitializeRemoteMaster(
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
std::shared_ptr<WorkerSession> worker_session, std::shared_ptr<WorkerSession> worker_session,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager, std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, const std::vector<string>& remote_contexts, uint64 context_id,
DeviceMgr* local_device_mgr, int keep_alive_secs, Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
DistributedFunctionLibraryRuntime* cluster_flr) { DistributedFunctionLibraryRuntime* cluster_flr) {
mutex_lock l(remote_state_mu_); mutex_lock l(remote_state_mu_);
@ -638,6 +629,7 @@ Status EagerContext::InitializeRemote(
CloseRemoteContexts(); CloseRemoteContexts();
} }
remote_contexts_ = remote_contexts; remote_contexts_ = remote_contexts;
context_id_ = context_id;
use_send_tensor_rpc_ = use_send_tensor_rpc_ =
ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false); ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false);
@ -666,11 +658,6 @@ Status EagerContext::InitializeRemote(
worker_session_ = worker_session; worker_session_ = worker_session;
remote_eager_workers_ = std::move(remote_eager_workers); remote_eager_workers_ = std::move(remote_eager_workers);
active_remote_contexts_.clear();
for (const auto& remote_context : remote_contexts_) {
active_remote_contexts_.insert(remote_context.second);
}
device_to_client_cache_.clear(); device_to_client_cache_.clear();
remote_device_manager_ = std::move(remote_device_manager); remote_device_manager_ = std::move(remote_device_manager);
@ -705,16 +692,15 @@ Status EagerContext::InitializeRemote(
mutex_lock l(remote_state_mu_); mutex_lock l(remote_state_mu_);
if (keep_alive_secs_ > 0) { if (keep_alive_secs_ > 0) {
{ {
for (const auto& worker_and_context_id : remote_contexts_) { for (const auto& worker : remote_contexts_) {
eager::EagerClient* client; eager::EagerClient* client;
Status s = remote_eager_workers_->GetClient( Status s =
worker_and_context_id.first, &client); remote_eager_workers_->GetClient(worker, &client);
if (!s.ok()) { if (!s.ok()) {
LOG(WARNING) << "Keep-alive thread was unable to find " LOG(WARNING) << "Keep-alive thread was unable to find "
"a client for target " "a client for target "
<< worker_and_context_id.first << worker << ". Got error: " << s;
<< ". Got error: " << s;
continue; continue;
} }
@ -723,7 +709,7 @@ Status EagerContext::InitializeRemote(
eager::KeepAliveResponse* response = eager::KeepAliveResponse* response =
new eager::KeepAliveResponse; new eager::KeepAliveResponse;
request->set_context_id(worker_and_context_id.second); request->set_context_id(context_id_);
client->KeepAliveAsync( client->KeepAliveAsync(
request, response, request, response,
[request, response](const Status& s) { [request, response](const Status& s) {
@ -740,6 +726,36 @@ Status EagerContext::InitializeRemote(
} }
return Status::OK(); return Status::OK();
} }
Status EagerContext::InitializeRemoteWorker(
std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
const DeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
std::function<Rendezvous*(const int64)> rendezvous_creator) {
mutex_lock l(remote_state_mu_);
if (remote_device_manager_ != nullptr || server_ != nullptr ||
keep_alive_thread_ != nullptr) {
return errors::FailedPrecondition(
"EagerContext::InitializeRemoteWorker Failed. ",
"Already initialized remote as a master context.");
}
remote_contexts_ = remote_contexts;
context_id_ = context_id;
rendezvous_creator_ = std::move(rendezvous_creator);
remote_eager_workers_ = std::move(remote_eager_workers);
device_to_client_cache_.clear();
remote_unowned_device_manager_ = remote_device_mgr;
InitDeviceMapAndAsync();
ClearCaches();
executor_.ClearError();
return Status::OK();
}
#endif // !IS_MOBILE_PLATFORM #endif // !IS_MOBILE_PLATFORM
} // namespace tensorflow } // namespace tensorflow

View File

@ -25,6 +25,7 @@ limitations under the License.
// clang-format off // clang-format off
// Required for IS_MOBILE_PLATFORM // Required for IS_MOBILE_PLATFORM
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/platform.h"
// clang-format on // clang-format on
@ -97,15 +98,13 @@ class RunMetadataListener {
class EagerContext : public core::RefCounted { class EagerContext : public core::RefCounted {
public: public:
EagerContext( EagerContext(const SessionOptions& opts,
const SessionOptions& opts,
ContextDevicePlacementPolicy default_device_placement_policy, ContextDevicePlacementPolicy default_device_placement_policy,
ContextMirroringPolicy default_mirroring_policy, bool async, ContextMirroringPolicy default_mirroring_policy, bool async,
const DeviceMgr* device_mgr, bool device_mgr_owned, const DeviceMgr* device_mgr, bool device_mgr_owned,
Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, Rendezvous* rendezvous,
DistributedFunctionLibraryRuntime* cluster_flr = nullptr, const CustomKernelCreator* custom_kernel_creator,
std::function<Rendezvous*(const int64)> rendezvous_creator = nullptr, DistributedFunctionLibraryRuntime* cluster_flr = nullptr);
const DeviceMgr* remote_device_mgr = nullptr);
~EagerContext() override; ~EagerContext() override;
@ -254,13 +253,16 @@ class EagerContext : public core::RefCounted {
FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; }
#if !defined(IS_MOBILE_PLATFORM) #if !defined(IS_MOBILE_PLATFORM)
Status GetClientAndContextID(Device* device, eager::EagerClient** client, Status GetClient(Device* device, eager::EagerClient** client);
uint64* context_id);
uint64 GetContextId();
// TODO(nareshmodi): Encapsulate remote state into a separate // TODO(nareshmodi): Encapsulate remote state into a separate
// class/struct. // class/struct.
// //
// Enables the eager context to communicate with remote devices. // Enables the eager context to communicate with remote devices. When
// initializing with this method, this context will be the master context,
// which will kill all its slaves in shutdown.
// //
// - server: A ServerInterface that exports the tensorflow.WorkerService. // - server: A ServerInterface that exports the tensorflow.WorkerService.
// Note that this class expects the server to already have been started. // Note that this class expects the server to already have been started.
@ -268,20 +270,23 @@ class EagerContext : public core::RefCounted {
// communicate with remote eager services. // communicate with remote eager services.
// - remote_device_mgr: A DeviceMgr* which contains all remote devices // - remote_device_mgr: A DeviceMgr* which contains all remote devices
// (should contain no local devices). // (should contain no local devices).
// - remote_contexts: A map containing task name to remote context ID. // - remote_contexts: A vector containing task names.
Status InitializeRemote( Status InitializeRemoteMaster(
std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env, std::unique_ptr<ServerInterface> server, WorkerEnv* worker_env,
std::shared_ptr<WorkerSession> worker_session, std::shared_ptr<WorkerSession> worker_session,
std::unique_ptr<eager::EagerClientCache> remote_eager_workers, std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
std::unique_ptr<DeviceMgr> remote_device_manager, std::unique_ptr<DeviceMgr> remote_device_manager,
const gtl::FlatMap<string, uint64>& remote_contexts, Rendezvous* r, const std::vector<string>& remote_contexts, uint64 context_id,
DeviceMgr* local_device_mgr, int keep_alive_secs, Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs,
DistributedFunctionLibraryRuntime* cluster_flr); DistributedFunctionLibraryRuntime* cluster_flr);
bool HasActiveRemoteContext(uint64 context_id) { // Similar with InitializeRemoteMaster but this context will not kill remote
return active_remote_contexts_.find(context_id) != // contexts in shutdown.
active_remote_contexts_.end(); Status InitializeRemoteWorker(
} std::unique_ptr<eager::EagerClientCache> remote_eager_workers,
const DeviceMgr* remote_device_mgr,
const std::vector<string>& remote_contexts, uint64 context_id,
std::function<Rendezvous*(const int64)> rendezvous_creator);
Status StoreCollectiveOpsServer( Status StoreCollectiveOpsServer(
std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr, std::unique_ptr<ServerInterface> server, DeviceMgr* device_mgr,
@ -328,7 +333,7 @@ class EagerContext : public core::RefCounted {
// Only one of the below is set. remote_unowned_device_manager_ is set on // Only one of the below is set. remote_unowned_device_manager_ is set on
// remote worker to allow running multi-device function on remote worker. // remote worker to allow running multi-device function on remote worker.
std::unique_ptr<DeviceMgr> remote_device_manager_; std::unique_ptr<DeviceMgr> remote_device_manager_;
const DeviceMgr* remote_unowned_device_manager_; const DeviceMgr* remote_unowned_device_manager_ = nullptr;
// Devices owned by device_manager // Devices owned by device_manager
std::vector<Device*> devices_; std::vector<Device*> devices_;
@ -405,10 +410,9 @@ class EagerContext : public core::RefCounted {
mutex remote_state_mu_; mutex remote_state_mu_;
gtl::FlatMap<string, uint64> remote_contexts_; uint64 context_id_;
gtl::FlatSet<uint64> active_remote_contexts_; std::vector<string> remote_contexts_;
gtl::FlatMap<Device*, std::pair<eager::EagerClient*, uint64>> gtl::FlatMap<Device*, eager::EagerClient*> device_to_client_cache_;
device_to_client_cache_;
int keep_alive_secs_ GUARDED_BY(remote_state_mu_); int keep_alive_secs_ GUARDED_BY(remote_state_mu_);
std::atomic<int> sleep_for_secs_; std::atomic<int> sleep_for_secs_;

View File

@ -659,9 +659,8 @@ Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h,
} }
eager::EagerClient* eager_client; eager::EagerClient* eager_client;
uint64 context_id; uint64 context_id = ctx->GetContextId();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ctx->GetClient(recv_device, &eager_client));
ctx->GetClientAndContextID(recv_device, &eager_client, &context_id));
eager::SendTensorRequest request; eager::SendTensorRequest request;
eager::SendTensorResponse response; eager::SendTensorResponse response;
@ -732,9 +731,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
EagerContext* ctx = op->EagerContext(); EagerContext* ctx = op->EagerContext();
eager::EagerClient* eager_client; eager::EagerClient* eager_client;
uint64 context_id; uint64 context_id = ctx->GetContextId();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ctx->GetClient(op->Device(), &eager_client));
ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id));
std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest); std::unique_ptr<eager::EnqueueRequest> request(new eager::EnqueueRequest);
eager::EnqueueResponse response; eager::EnqueueResponse response;

View File

@ -112,6 +112,7 @@ tf_cc_test(
deps = [ deps = [
":worker_session", ":worker_session",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",

View File

@ -97,8 +97,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
cluster_device_attributes.push_back(cluster_device); cluster_device_attributes.push_back(cluster_device);
} }
auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id()); auto* r = env_->rendezvous_mgr->Find(request->context_id());
auto session_name = strings::StrCat("eager_", request->rendezvous_id()); auto session_name =
tensorflow::strings::StrCat("eager_", request->context_id());
TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession(
session_name, request->server_def(), request->cluster_device_attributes(), session_name, request->server_def(), request->cluster_device_attributes(),
true)); true));
@ -123,8 +124,30 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
SessionOptions(), SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(),
device_mgr, false, r, nullptr, worker_session->cluster_flr.get(), device_mgr, false, r, nullptr, worker_session->cluster_flr.get());
std::move(rendezvous_creator), worker_session->remote_device_mgr());
Status s;
std::vector<string> remote_workers;
worker_session->worker_cache->ListWorkers(&remote_workers);
remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(),
worker_session->worker_name),
remote_workers.end());
std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers;
s = env_->eager_client_cache_factory(request->server_def(),
&remote_eager_workers);
if (!s.ok()) {
delete ctx;
return s;
}
s = ctx->InitializeRemoteWorker(
std::move(remote_eager_workers), worker_session->remote_device_mgr(),
remote_workers, request->context_id(), std::move(rendezvous_creator));
if (!s.ok()) {
delete ctx;
return s;
}
std::vector<DeviceAttributes> device_attributes; std::vector<DeviceAttributes> device_attributes;
device_mgr->ListDeviceAttributes(&device_attributes); device_mgr->ListDeviceAttributes(&device_attributes);
@ -132,17 +155,17 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request,
for (const auto& da : device_attributes) { for (const auto& da : device_attributes) {
*response->add_device_attributes() = da; *response->add_device_attributes() = da;
} }
uint64 context_id;
{ {
mutex_lock l(contexts_mu_); mutex_lock l(contexts_mu_);
do { if (contexts_.find(request->context_id()) != contexts_.end()) {
context_id = random::New64(); delete ctx;
} while (contexts_.find(context_id) != contexts_.end()); return errors::InvalidArgument("EagerService:CreateContext failed. ",
contexts_.emplace(context_id, "Context id: <", request->context_id(),
"> already exists.");
}
contexts_.emplace(request->context_id(),
new ServerContext(ctx, request->keep_alive_secs(), env_)); new ServerContext(ctx, request->keep_alive_secs(), env_));
} }
response->set_context_id(context_id);
return Status::OK(); return Status::OK();
} }

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
@ -51,22 +52,49 @@ class TestEagerServiceImpl : public EagerServiceImpl {
} }
}; };
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
void ListWorkersInJob(const string& job_name,
std::vector<string>* workers) const override {}
WorkerInterface* GetOrCreateWorker(const string& target) override {
return nullptr;
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return false;
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {}
};
class DummyEagerClientCache : public EagerClientCache {
Status GetClient(const string& target, EagerClient** client) override {
return errors::Unimplemented("");
}
};
class EagerServiceImplTest : public ::testing::Test { class EagerServiceImplTest : public ::testing::Test {
public: public:
EagerServiceImplTest() EagerServiceImplTest()
: rendezvous_mgr_(&worker_env_), : rendezvous_mgr_(&worker_env_),
session_mgr_(new SessionMgr( session_mgr_(new SessionMgr(
&worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0", &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0",
std::unique_ptr<WorkerCacheInterface>(), absl::make_unique<DummyWorkerCache>(),
[](const ServerDef& server_def, [](const ServerDef& server_def,
WorkerCacheInterface** worker_cache) { WorkerCacheInterface** worker_cache) {
*worker_cache = nullptr; *worker_cache = new DummyWorkerCache();
return Status::OK(); return Status::OK();
})) { })) {
worker_env_.env = Env::Default(); worker_env_.env = Env::Default();
worker_env_.rendezvous_mgr = &rendezvous_mgr_; worker_env_.rendezvous_mgr = &rendezvous_mgr_;
worker_env_.session_mgr = session_mgr_.get(); worker_env_.session_mgr = session_mgr_.get();
worker_env_.eager_client_cache_factory =
[](const ServerDef& server_def,
std::unique_ptr<EagerClientCache>* eager_client_cache) {
eager_client_cache->reset(new DummyEagerClientCache());
return Status::OK();
};
device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice( device_mgr_ = absl::make_unique<DeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0"));
@ -76,6 +104,7 @@ class EagerServiceImplTest : public ::testing::Test {
protected: protected:
WorkerEnv worker_env_; WorkerEnv worker_env_;
std::unique_ptr<DummyWorkerCache> worker_cache_;
tensorflow::RpcRendezvousMgr rendezvous_mgr_; tensorflow::RpcRendezvousMgr rendezvous_mgr_;
std::unique_ptr<SessionMgr> session_mgr_; std::unique_ptr<SessionMgr> session_mgr_;
std::unique_ptr<DeviceMgr> device_mgr_; std::unique_ptr<DeviceMgr> device_mgr_;
@ -153,15 +182,16 @@ tensorflow::FunctionDef MatMulFunction() {
TEST_F(EagerServiceImplTest, BasicTest) { TEST_F(EagerServiceImplTest, BasicTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_); TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
CreateContextRequest request; CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0); request.mutable_server_def()->set_task_index(0);
request.set_rendezvous_id(random::New64()); request.set_context_id(context_id);
CreateContextResponse response; CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
uint64 context_id = response.context_id();
EnqueueRequest remote_enqueue_request; EnqueueRequest remote_enqueue_request;
remote_enqueue_request.set_context_id(context_id); remote_enqueue_request.set_context_id(context_id);
@ -202,7 +232,7 @@ TEST_F(EagerServiceImplTest, BasicTest) {
tensorflow::TensorHandle* tensor_handle; tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle( TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
// This should be OK to do since we've placed all computation on the CPU // This should be OK to do since we've placed all computation on the CPU
// device. // device.
@ -229,16 +259,16 @@ TEST_F(EagerServiceImplTest, BasicTest) {
TEST_F(EagerServiceImplTest, BasicFunctionTest) { TEST_F(EagerServiceImplTest, BasicFunctionTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_); TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
CreateContextRequest request; CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0); request.mutable_server_def()->set_task_index(0);
request.set_rendezvous_id(random::New64()); request.set_context_id(context_id);
CreateContextResponse response; CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
uint64 context_id = response.context_id();
RegisterFunctionRequest register_function_request; RegisterFunctionRequest register_function_request;
register_function_request.set_context_id(context_id); register_function_request.set_context_id(context_id);
*register_function_request.mutable_function_def() = MatMulFunction(); *register_function_request.mutable_function_def() = MatMulFunction();
@ -273,7 +303,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle; tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle( TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t)); TF_ASSERT_OK(tensor_handle->Tensor(&t));
auto actual = t->flat<float>(); auto actual = t->flat<float>();
@ -296,15 +326,16 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) {
TEST_F(EagerServiceImplTest, SendTensorTest) { TEST_F(EagerServiceImplTest, SendTensorTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_); TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
CreateContextRequest request; CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0); request.mutable_server_def()->set_task_index(0);
request.set_rendezvous_id(random::New64()); request.set_context_id(context_id);
CreateContextResponse response; CreateContextResponse response;
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
uint64 context_id = response.context_id();
SendTensorRequest send_tensor_request; SendTensorRequest send_tensor_request;
send_tensor_request.set_context_id(context_id); send_tensor_request.set_context_id(context_id);
@ -339,7 +370,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
const tensorflow::Tensor* t = nullptr; const tensorflow::Tensor* t = nullptr;
tensorflow::TensorHandle* tensor_handle; tensorflow::TensorHandle* tensor_handle;
TF_ASSERT_OK(eager_service_impl.GetTensorHandle( TF_ASSERT_OK(eager_service_impl.GetTensorHandle(
response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle));
TF_ASSERT_OK(tensor_handle->Tensor(&t)); TF_ASSERT_OK(tensor_handle->Tensor(&t));
Device* device = tensor_handle->device(); Device* device = tensor_handle->device();
@ -364,10 +395,11 @@ TEST_F(EagerServiceImplTest, SendTensorTest) {
TEST_F(EagerServiceImplTest, KeepAliveTest) { TEST_F(EagerServiceImplTest, KeepAliveTest) {
TestEagerServiceImpl eager_service_impl(&worker_env_); TestEagerServiceImpl eager_service_impl(&worker_env_);
uint64 context_id = random::New64();
CreateContextRequest request; CreateContextRequest request;
request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_job_name("localhost");
request.mutable_server_def()->set_task_index(0); request.mutable_server_def()->set_task_index(0);
request.set_rendezvous_id(random::New64()); request.set_context_id(context_id);
request.set_keep_alive_secs(3); request.set_keep_alive_secs(3);
CreateContextResponse response; CreateContextResponse response;
@ -379,7 +411,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
KeepAliveRequest keep_alive_request; KeepAliveRequest keep_alive_request;
KeepAliveResponse keep_alive_response; KeepAliveResponse keep_alive_response;
keep_alive_request.set_context_id(response.context_id()); keep_alive_request.set_context_id(context_id);
Status status = Status status =
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response);
@ -388,15 +420,16 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) {
EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id",
status.error_message()); status.error_message());
uint64 new_context_id = random::New64();
// Create a new context. // Create a new context.
request.set_rendezvous_id(random::New64()); request.set_context_id(new_context_id);
TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response));
// The context should not be GC'd. // The context should not be GC'd.
worker_env_.env->SleepForMicroseconds(1 * worker_env_.env->SleepForMicroseconds(1 *
tensorflow::EnvTime::kSecondsToMicros); tensorflow::EnvTime::kSecondsToMicros);
keep_alive_request.set_context_id(response.context_id()); keep_alive_request.set_context_id(new_context_id);
TF_ASSERT_OK( TF_ASSERT_OK(
eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response)); eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response));

View File

@ -29,7 +29,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx,
int output_num) { int output_num) {
auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); }); auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); });
if (!ctx->HasActiveRemoteContext(context_id)) { if (ctx->GetContextId() != context_id) {
// This means that this tensor was pointing to a remote device, which // This means that this tensor was pointing to a remote device, which
// has been changed out from under us. Simply return since there is // has been changed out from under us. Simply return since there is
// nothing we can do. // nothing we can do.

View File

@ -296,6 +296,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed",
"//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime:device_resolver_distributed",
"//tensorflow/core/distributed_runtime:graph_mgr", "//tensorflow/core/distributed_runtime:graph_mgr",
@ -308,6 +309,8 @@ cc_library(
"//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_cache_wrapper",
"//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime/eager:eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -24,12 +24,12 @@ limitations under the License.
#include "grpcpp/grpcpp.h" #include "grpcpp/grpcpp.h"
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
#include "grpcpp/server_builder.h" #include "grpcpp/server_builder.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/local_master.h" #include "tensorflow/core/distributed_runtime/local_master.h"
#include "tensorflow/core/distributed_runtime/master.h" #include "tensorflow/core/distributed_runtime/master.h"
@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/session_options.h"
@ -253,6 +254,12 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
return WorkerCacheFactory(options, worker_cache); return WorkerCacheFactory(options, worker_cache);
}); });
worker_env_.compute_pool = ComputePool(sess_opts); worker_env_.compute_pool = ComputePool(sess_opts);
worker_env_.eager_client_cache_factory =
[this](const ServerDef& server_def,
std::unique_ptr<eager::EagerClientCache>* eager_client_cahce) {
WorkerCacheFactoryOptions options(server_def);
return EagerClientCacheFactory(options, eager_client_cahce);
};
// Finish setting up master environment. // Finish setting up master environment.
master_env_.ops = OpRegistry::Global(); master_env_.ops = OpRegistry::Global();
@ -310,25 +317,13 @@ Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) { WorkerCacheInterface** worker_cache) {
if (options.job_name == nullptr || options.job_name->empty()) { std::shared_ptr<GrpcChannelCache> channel_cache;
Status s = errors::InvalidArgument( TF_RETURN_IF_ERROR(FindOrCreateChannelCache(options, &channel_cache));
"The master (current machine) is not included in the provided "
"cluster_def. ",
options.cluster_def->DebugString());
LOG(WARNING) << s;
return s;
}
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
channel_cache_.reset(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index); "/task:", options.task_index);
const string host_port = channel_cache_->TranslateTask(name_prefix); const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port; int requested_port;
auto colon_index = host_port.find_last_of(':'); auto colon_index = host_port.find_last_of(':');
@ -343,11 +338,50 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
" differs from expected port ", bound_port_); " differs from expected port ", bound_port_);
} }
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_, *worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache,
worker_impl(), name_prefix); worker_impl(), name_prefix);
return Status::OK(); return Status::OK();
} }
Status GrpcServer::EagerClientCacheFactory(
const WorkerCacheFactoryOptions& options,
std::unique_ptr<eager::EagerClientCache>* eager_client_cache) {
std::shared_ptr<GrpcChannelCache> channel_cache;
TF_RETURN_IF_ERROR(FindOrCreateChannelCache(options, &channel_cache));
eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache));
return Status::OK();
}
Status GrpcServer::FindOrCreateChannelCache(
const WorkerCacheFactoryOptions& options,
std::shared_ptr<GrpcChannelCache>* cache) {
if (options.job_name == nullptr || options.job_name->empty()) {
Status s = errors::InvalidArgument(
"The master (current machine) is not included in the provided "
"cluster_def. ",
options.cluster_def->DebugString());
LOG(ERROR) << s;
return s;
}
string cluster = "";
if (options.cluster_def != nullptr) {
options.cluster_def->SerializeToString(&cluster);
}
Fprint128 cache_key = Fingerprint128(cluster);
mutex_lock l(channel_mu_);
*cache = gtl::FindPtrOrNull(channel_caches_, cache_key);
if (*cache == nullptr) {
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
*cache = std::shared_ptr<GrpcChannelCache>(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
channel_caches_.emplace(cache_key, *cache);
}
return Status::OK();
}
Status GrpcServer::Start() { Status GrpcServer::Start() {
mutex_lock l(mu_); mutex_lock l(mu_);
switch (state_) { switch (state_) {

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "grpcpp/grpcpp.h" #include "grpcpp/grpcpp.h"
#include "grpcpp/security/credentials.h" #include "grpcpp/security/credentials.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h"
#include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/distributed_runtime/server_lib.h"
@ -34,7 +34,10 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -95,7 +98,9 @@ class GrpcServer : public ServerInterface {
WorkerEnv* worker_env() { return &worker_env_; } WorkerEnv* worker_env() { return &worker_env_; }
MasterEnv* master_env() { return &master_env_; } MasterEnv* master_env() { return &master_env_; }
std::shared_ptr<GrpcChannelCache> channel_cache() { return channel_cache_; } virtual Status EagerClientCacheFactory(
const WorkerCacheFactoryOptions& options,
std::unique_ptr<eager::EagerClientCache>* eager_client_cache);
protected: protected:
virtual Status GetPort(int* port) const; virtual Status GetPort(int* port) const;
@ -124,11 +129,9 @@ class GrpcServer : public ServerInterface {
const ServerDef& server_def() const { return server_def_; } const ServerDef& server_def() const { return server_def_; }
GrpcWorker* worker_impl() const { return worker_impl_.get(); } GrpcWorker* worker_impl() const { return worker_impl_.get(); }
void set_channel_cache(GrpcChannelCache* channel_cache) {
channel_cache_.reset(channel_cache);
}
private: private:
Status FindOrCreateChannelCache(const WorkerCacheFactoryOptions& options,
std::shared_ptr<GrpcChannelCache>* cache);
// The overall server configuration. // The overall server configuration.
const ServerDef server_def_; const ServerDef server_def_;
Env* env_; Env* env_;
@ -156,7 +159,12 @@ class GrpcServer : public ServerInterface {
std::unique_ptr<Master> master_impl_; std::unique_ptr<Master> master_impl_;
AsyncServiceInterface* master_service_ = nullptr; AsyncServiceInterface* master_service_ = nullptr;
std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_); std::unique_ptr<Thread> master_thread_ GUARDED_BY(mu_);
std::shared_ptr<GrpcChannelCache> channel_cache_;
mutex channel_mu_;
// TODO(fishx): Cleanup channel caches.
std::unordered_map<Fprint128, std::shared_ptr<GrpcChannelCache>,
Fprint128Hasher>
channel_caches_ GUARDED_BY(channel_mu_);
// Implementation of a TensorFlow worker, and RPC polling thread. // Implementation of a TensorFlow worker, and RPC polling thread.
WorkerEnv worker_env_; WorkerEnv worker_env_;

View File

@ -17,6 +17,8 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#include <vector> #include <vector>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -25,12 +27,21 @@ namespace thread {
class ThreadPool; class ThreadPool;
} // namespace thread } // namespace thread
namespace eager {
class EagerClientCache;
} // namespace eager
class CollectiveExecutorMgrInterface; class CollectiveExecutorMgrInterface;
class Device; class Device;
class DeviceMgr; class DeviceMgr;
class Env; class Env;
class RendezvousMgrInterface; class RendezvousMgrInterface;
class SessionMgr; class SessionMgr;
class ServerDef;
typedef std::function<Status(const ServerDef&,
std::unique_ptr<eager::EagerClientCache>*)>
EagerClientCacheFactory;
// The worker environment class, which holds a bag of pointers to // The worker environment class, which holds a bag of pointers to
// per-worker singletons. // per-worker singletons.
@ -64,6 +75,15 @@ struct WorkerEnv {
// A pool of threads for scheduling compute work. // A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr; thread::ThreadPool* compute_pool = nullptr;
// A factory function to create eager client cache.
EagerClientCacheFactory eager_client_cache_factory =
[](const ServerDef& s, std::unique_ptr<eager::EagerClientCache>* c) {
return errors::Unimplemented(
"EagerClientCacheFactory unimplemented. "
"It is probably because you didn't use GRPC. Right now "
"EagerClient only supports GRPC.");
};
}; };
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -67,21 +67,19 @@ message CreateContextRequest {
// This is the version for all the ops that will be enqueued by the client. // This is the version for all the ops that will be enqueued by the client.
VersionDef version_def = 4; VersionDef version_def = 4;
// This ID will be used for all future communications. It is essential that
// both ends use this ID for selecting a rendezvous to get everything to
// match.
int64 rendezvous_id = 5;
// Device attributes in the cluster // Device attributes in the cluster
repeated DeviceAttributes cluster_device_attributes = 6; repeated DeviceAttributes cluster_device_attributes = 6;
}
message CreateContextResponse {
// The ID of the created context. This is usually a randomly generated number, // The ID of the created context. This is usually a randomly generated number,
// that will be used to identify the context in future requests to the // that will be used to identify the context in future requests to the
// service. Contexts are not persisted through server restarts. // service. Contexts are not persisted through server restarts.
fixed64 context_id = 1; // This ID will be used for all future communications as well. It is essential
// that both ends use this ID for selecting a rendezvous to get everything to
// match.
fixed64 context_id = 7;
}
message CreateContextResponse {
// List of devices that are locally accessible to the worker. // List of devices that are locally accessible to the worker.
repeated DeviceAttributes device_attributes = 2; repeated DeviceAttributes device_attributes = 2;
} }