diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index bea3512fef3..ec72c435e74 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -135,17 +135,16 @@ tensorflow::Status GetAllRemoteDevices( } tensorflow::Status CreateRemoteContexts( - const std::vector& remote_workers, int64 rendezvous_id, + const std::vector& remote_workers, tensorflow::uint64 context_id, int keep_alive_secs, const tensorflow::ServerDef& server_def, tensorflow::eager::EagerClientCache* remote_eager_workers, bool async, - const tensorflow::eager::CreateContextRequest& base_request, - tensorflow::gtl::FlatMap* remote_contexts) { + const tensorflow::eager::CreateContextRequest& base_request) { for (int i = 0; i < remote_workers.size(); i++) { const string& remote_worker = remote_workers[i]; tensorflow::eager::CreateContextRequest request(base_request); tensorflow::eager::CreateContextResponse response; - request.set_rendezvous_id(rendezvous_id); + request.set_context_id(context_id); tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker, &parsed_name)) { @@ -174,8 +173,6 @@ tensorflow::Status CreateRemoteContexts( }); n.WaitForNotification(); TF_RETURN_IF_ERROR(status); - - remote_contexts->emplace(remote_worker, response.context_id()); } return tensorflow::Status::OK(); } @@ -212,7 +209,7 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( LOG_AND_RETURN_IF_ERROR(grpc_server->Start()); - int64 rendezvous_id = tensorflow::random::New64(); + tensorflow::uint64 context_id = tensorflow::random::New64(); std::vector remote_workers; grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers); @@ -242,22 +239,20 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( *base_request.add_cluster_device_attributes() = da; } - std::shared_ptr channel_cache = - grpc_server->channel_cache(); - std::unique_ptr remote_eager_workers( - tensorflow::eager::NewGrpcEagerClientCache(channel_cache)); + std::unique_ptr remote_eager_workers; + tensorflow::WorkerCacheFactoryOptions options(server_def); + LOG_AND_RETURN_IF_ERROR( + grpc_server->EagerClientCacheFactory(options, &remote_eager_workers)); // Initialize remote eager workers. - tensorflow::gtl::FlatMap remote_contexts; LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( - remote_workers, rendezvous_id, keep_alive_secs, server_def, - remote_eager_workers.get(), ctx->context->Async(), base_request, - &remote_contexts)); + remote_workers, context_id, keep_alive_secs, server_def, + remote_eager_workers.get(), ctx->context->Async(), base_request)); tensorflow::RemoteRendezvous* r = - grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id); + grpc_server->worker_env()->rendezvous_mgr->Find(context_id); - auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id); + auto session_name = tensorflow::strings::StrCat("eager_", context_id); TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( session_name, server_def, base_request.cluster_device_attributes(), true)); @@ -272,10 +267,10 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( auto* device_mgr = grpc_server->worker_env()->device_mgr; - return ctx->context->InitializeRemote( + return ctx->context->InitializeRemoteMaster( std::move(server), grpc_server->worker_env(), worker_session, std::move(remote_eager_workers), std::move(remote_device_mgr), - remote_contexts, r, device_mgr, keep_alive_secs, + remote_workers, context_id, r, device_mgr, keep_alive_secs, worker_session->cluster_flr.get()); #undef LOG_AND_RETURN_IF_ERROR } diff --git a/tensorflow/compiler/xrt/client/xrt_client_test.cc b/tensorflow/compiler/xrt/client/xrt_client_test.cc index d9e94b01d2c..4e0cebeb6f0 100644 --- a/tensorflow/compiler/xrt/client/xrt_client_test.cc +++ b/tensorflow/compiler/xrt/client/xrt_client_test.cc @@ -76,19 +76,20 @@ TEST_F(XrtClientTest, XrtGrpcEagerClientWorks) { // Create and destroy a context to verify we can make RPCs. eager::CreateContextRequest request; + uint64 context_id = random::New64(); ServerDef* server_def = request.mutable_server_def(); *server_def->mutable_cluster() = cluster_def_; server_def->set_job_name("localhost"); server_def->set_protocol("grpc"); request.set_keep_alive_secs(60); - request.set_rendezvous_id(random::New64()); + request.set_context_id(context_id); eager::CreateContextResponse create_response; TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CreateContextAsync, &request, &create_response)); eager::CloseContextRequest close_request; - close_request.set_context_id(create_response.context_id()); + close_request.set_context_id(context_id); eager::CloseContextResponse close_response; TF_ASSERT_OK(client->SyncCall(&XrtGrpcEagerClient::CloseContextAsync, diff --git a/tensorflow/compiler/xrt/client/xrt_tf_client.cc b/tensorflow/compiler/xrt/client/xrt_tf_client.cc index 5388338fd36..88d0d25f84a 100644 --- a/tensorflow/compiler/xrt/client/xrt_tf_client.cc +++ b/tensorflow/compiler/xrt/client/xrt_tf_client.cc @@ -45,7 +45,7 @@ XrtTfClient::XrtTfClient(ClusterDef cluster_def, xla::StatusOr> XrtTfContext::Create( const XrtTfContext::Options& options, std::shared_ptr tf_client, const std::string& job, int task) { - int64 rendezvous_id = random::New64(); + int64 context_id = random::New64(); eager::CreateContextRequest request; ServerDef* server_def = request.mutable_server_def(); @@ -53,7 +53,7 @@ xla::StatusOr> XrtTfContext::Create( server_def->set_job_name(job); server_def->set_protocol("grpc"); request.set_keep_alive_secs(60); - request.set_rendezvous_id(rendezvous_id); + request.set_context_id(context_id); request.set_async(options.async); eager::CreateContextResponse response; @@ -98,8 +98,9 @@ xla::StatusOr> XrtTfContext::Create( return a.name() < b.name(); }); return std::make_shared(options, tf_client, eager_client, - rendezvous_id, response.context_id(), - std::move(devices), cpu_device_id); + /*rendezvous_id=*/context_id, + context_id, std::move(devices), + cpu_device_id); } XrtTfContext::XrtTfContext(const XrtTfContext::Options& options, diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 12f92b46db7..011bbbc28a1 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -64,15 +64,11 @@ EagerContext::EagerContext( ContextMirroringPolicy default_mirroring_policy, bool async, const DeviceMgr* device_mgr, bool device_mgr_owned, Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, - DistributedFunctionLibraryRuntime* cluster_flr, - std::function rendezvous_creator, - const DeviceMgr* remote_device_mgr) + DistributedFunctionLibraryRuntime* cluster_flr) : default_device_placement_policy_(default_device_placement_policy), default_mirroring_policy_(default_mirroring_policy), - remote_unowned_device_manager_(remote_device_mgr), devices_(device_mgr->ListDevices()), rendezvous_(rendezvous), - rendezvous_creator_(std::move(rendezvous_creator)), thread_pool_(NewThreadPoolFromSessionOptions(opts)), pflr_(new ProcessFunctionLibraryRuntime( device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, @@ -209,25 +205,22 @@ bool EagerContext::MirrorTensors() const { #if !defined(IS_MOBILE_PLATFORM) void EagerContext::CloseRemoteContexts() { // Close all remote contexts. - std::vector requests(remote_contexts_.size()); + eager::CloseContextRequest request; + request.set_context_id(context_id_); std::vector responses(remote_contexts_.size()); BlockingCounter counter(static_cast(remote_contexts_.size())); int i = 0; - for (const auto& worker_and_context_id : remote_contexts_) { + for (const auto& worker : remote_contexts_) { eager::EagerClient* client; - Status s = - remote_eager_workers_->GetClient(worker_and_context_id.first, &client); + Status s = remote_eager_workers_->GetClient(worker, &client); - requests[i].set_context_id(worker_and_context_id.second); client->CloseContextAsync( - &requests[i], &responses[i], - [&worker_and_context_id, &counter](const Status& s) { + &request, &responses[i], [this, &worker, &counter](const Status& s) { if (!s.ok()) { LOG(ERROR) << "Unable to close remote context with ID " - << worker_and_context_id.second - << " for worker: " << worker_and_context_id.first - << " due to " << s.error_message(); + << context_id_ << " for worker: " << worker << " due to " + << s.error_message(); } counter.DecrementCount(); }); @@ -261,8 +254,9 @@ EagerContext::~EagerContext() { keep_alive_thread_cv_.notify_all(); } keep_alive_thread_.reset(); - - CloseRemoteContexts(); + if (!remote_contexts_.empty() && keep_alive_thread_ != nullptr) { + CloseRemoteContexts(); + } #endif // !IS_MOBILE_PLATFORM executor_.WaitForAllPendingNodes().IgnoreError(); @@ -368,22 +362,20 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { #if !defined(IS_MOBILE_PLATFORM) BlockingCounter blocking_counter(static_cast(remote_contexts_.size())); - std::vector requests(remote_contexts_.size()); + eager::RegisterFunctionRequest request; + request.set_context_id(context_id_); + *request.mutable_function_def() = fdef; std::vector responses( remote_contexts_.size()); std::vector statuses(remote_contexts_.size()); int i = 0; - for (const auto& target_and_context_id : remote_contexts_) { - requests[i].set_context_id(target_and_context_id.second); - *requests[i].mutable_function_def() = fdef; - + for (const auto& target : remote_contexts_) { eager::EagerClient* eager_client; - TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient( - target_and_context_id.first, &eager_client)); + TF_RETURN_IF_ERROR(remote_eager_workers_->GetClient(target, &eager_client)); eager_client->RegisterFunctionAsync( - &requests[i], &responses[i], + &request, &responses[i], [i, &statuses, &blocking_counter](const Status& status) { statuses[i] = status; blocking_counter.DecrementCount(); @@ -559,17 +551,15 @@ Status GetTaskName(Device* d, string* task_name) { } // namespace #if !defined(IS_MOBILE_PLATFORM) -Status EagerContext::GetClientAndContextID(Device* device, - eager::EagerClient** client, - uint64* context_id) { +Status EagerContext::GetClient(Device* device, eager::EagerClient** client) { if (remote_eager_workers_ == nullptr) { return errors::Internal( "Haven't set up remote eager worker in this eager context yet."); } auto it = device_to_client_cache_.find(device); if (it != device_to_client_cache_.end()) { - *client = it->second.first; - *context_id = it->second.second; + *client = it->second; + return Status::OK(); } string device_task_name; TF_RETURN_IF_ERROR(GetTaskName(device, &device_task_name)); @@ -582,18 +572,19 @@ Status EagerContext::GetClientAndContextID(Device* device, "Unable to find eager client corresponding to device ", device->name()); } - auto context_iterator = remote_contexts_.find(device_task_name); - if (context_iterator == remote_contexts_.end()) { + if (std::find(remote_contexts_.begin(), remote_contexts_.end(), + device_task_name) == remote_contexts_.end()) { return errors::Internal("Unable to find a context for handle on task: ", device_task_name, ". This should not be possible"); } - *context_id = context_iterator->second; - device_to_client_cache_.insert({device, {*client, *context_id}}); + device_to_client_cache_.insert({device, *client}); return Status::OK(); } +uint64 EagerContext::GetContextId() { return context_id_; } + Status EagerContext::StoreCollectiveOpsServer( std::unique_ptr server, DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { @@ -624,13 +615,13 @@ Status EagerContext::StoreCollectiveOpsServer( return Status::OK(); } -Status EagerContext::InitializeRemote( +Status EagerContext::InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, - const gtl::FlatMap& remote_contexts, Rendezvous* r, - DeviceMgr* local_device_mgr, int keep_alive_secs, + const std::vector& remote_contexts, uint64 context_id, + Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr) { mutex_lock l(remote_state_mu_); @@ -638,6 +629,7 @@ Status EagerContext::InitializeRemote( CloseRemoteContexts(); } remote_contexts_ = remote_contexts; + context_id_ = context_id; use_send_tensor_rpc_ = ReadBoolFromEnvVar("TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC", false); @@ -666,11 +658,6 @@ Status EagerContext::InitializeRemote( worker_session_ = worker_session; remote_eager_workers_ = std::move(remote_eager_workers); - active_remote_contexts_.clear(); - for (const auto& remote_context : remote_contexts_) { - active_remote_contexts_.insert(remote_context.second); - } - device_to_client_cache_.clear(); remote_device_manager_ = std::move(remote_device_manager); @@ -705,16 +692,15 @@ Status EagerContext::InitializeRemote( mutex_lock l(remote_state_mu_); if (keep_alive_secs_ > 0) { { - for (const auto& worker_and_context_id : remote_contexts_) { + for (const auto& worker : remote_contexts_) { eager::EagerClient* client; - Status s = remote_eager_workers_->GetClient( - worker_and_context_id.first, &client); + Status s = + remote_eager_workers_->GetClient(worker, &client); if (!s.ok()) { LOG(WARNING) << "Keep-alive thread was unable to find " "a client for target " - << worker_and_context_id.first - << ". Got error: " << s; + << worker << ". Got error: " << s; continue; } @@ -723,7 +709,7 @@ Status EagerContext::InitializeRemote( eager::KeepAliveResponse* response = new eager::KeepAliveResponse; - request->set_context_id(worker_and_context_id.second); + request->set_context_id(context_id_); client->KeepAliveAsync( request, response, [request, response](const Status& s) { @@ -740,6 +726,36 @@ Status EagerContext::InitializeRemote( } return Status::OK(); } + +Status EagerContext::InitializeRemoteWorker( + std::unique_ptr remote_eager_workers, + const DeviceMgr* remote_device_mgr, + const std::vector& remote_contexts, uint64 context_id, + std::function rendezvous_creator) { + mutex_lock l(remote_state_mu_); + + if (remote_device_manager_ != nullptr || server_ != nullptr || + keep_alive_thread_ != nullptr) { + return errors::FailedPrecondition( + "EagerContext::InitializeRemoteWorker Failed. ", + "Already initialized remote as a master context."); + } + + remote_contexts_ = remote_contexts; + context_id_ = context_id; + + rendezvous_creator_ = std::move(rendezvous_creator); + remote_eager_workers_ = std::move(remote_eager_workers); + + device_to_client_cache_.clear(); + remote_unowned_device_manager_ = remote_device_mgr; + InitDeviceMapAndAsync(); + + ClearCaches(); + executor_.ClearError(); + + return Status::OK(); +} #endif // !IS_MOBILE_PLATFORM } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 45f2b685124..2d01e5ea9ba 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -25,6 +25,7 @@ limitations under the License. // clang-format off // Required for IS_MOBILE_PLATFORM +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/platform.h" // clang-format on @@ -97,15 +98,13 @@ class RunMetadataListener { class EagerContext : public core::RefCounted { public: - EagerContext( - const SessionOptions& opts, - ContextDevicePlacementPolicy default_device_placement_policy, - ContextMirroringPolicy default_mirroring_policy, bool async, - const DeviceMgr* device_mgr, bool device_mgr_owned, - Rendezvous* rendezvous, const CustomKernelCreator* custom_kernel_creator, - DistributedFunctionLibraryRuntime* cluster_flr = nullptr, - std::function rendezvous_creator = nullptr, - const DeviceMgr* remote_device_mgr = nullptr); + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_device_placement_policy, + ContextMirroringPolicy default_mirroring_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous, + const CustomKernelCreator* custom_kernel_creator, + DistributedFunctionLibraryRuntime* cluster_flr = nullptr); ~EagerContext() override; @@ -254,13 +253,16 @@ class EagerContext : public core::RefCounted { FunctionLibraryDefinition* FuncLibDef() { return &func_lib_def_; } #if !defined(IS_MOBILE_PLATFORM) - Status GetClientAndContextID(Device* device, eager::EagerClient** client, - uint64* context_id); + Status GetClient(Device* device, eager::EagerClient** client); + + uint64 GetContextId(); // TODO(nareshmodi): Encapsulate remote state into a separate // class/struct. // - // Enables the eager context to communicate with remote devices. + // Enables the eager context to communicate with remote devices. When + // initializing with this method, this context will be the master context, + // which will kill all its slaves in shutdown. // // - server: A ServerInterface that exports the tensorflow.WorkerService. // Note that this class expects the server to already have been started. @@ -268,20 +270,23 @@ class EagerContext : public core::RefCounted { // communicate with remote eager services. // - remote_device_mgr: A DeviceMgr* which contains all remote devices // (should contain no local devices). - // - remote_contexts: A map containing task name to remote context ID. - Status InitializeRemote( + // - remote_contexts: A vector containing task names. + Status InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, std::unique_ptr remote_device_manager, - const gtl::FlatMap& remote_contexts, Rendezvous* r, - DeviceMgr* local_device_mgr, int keep_alive_secs, + const std::vector& remote_contexts, uint64 context_id, + Rendezvous* r, DeviceMgr* local_device_mgr, int keep_alive_secs, DistributedFunctionLibraryRuntime* cluster_flr); - bool HasActiveRemoteContext(uint64 context_id) { - return active_remote_contexts_.find(context_id) != - active_remote_contexts_.end(); - } + // Similar with InitializeRemoteMaster but this context will not kill remote + // contexts in shutdown. + Status InitializeRemoteWorker( + std::unique_ptr remote_eager_workers, + const DeviceMgr* remote_device_mgr, + const std::vector& remote_contexts, uint64 context_id, + std::function rendezvous_creator); Status StoreCollectiveOpsServer( std::unique_ptr server, DeviceMgr* device_mgr, @@ -328,7 +333,7 @@ class EagerContext : public core::RefCounted { // Only one of the below is set. remote_unowned_device_manager_ is set on // remote worker to allow running multi-device function on remote worker. std::unique_ptr remote_device_manager_; - const DeviceMgr* remote_unowned_device_manager_; + const DeviceMgr* remote_unowned_device_manager_ = nullptr; // Devices owned by device_manager std::vector devices_; @@ -405,10 +410,9 @@ class EagerContext : public core::RefCounted { mutex remote_state_mu_; - gtl::FlatMap remote_contexts_; - gtl::FlatSet active_remote_contexts_; - gtl::FlatMap> - device_to_client_cache_; + uint64 context_id_; + std::vector remote_contexts_; + gtl::FlatMap device_to_client_cache_; int keep_alive_secs_ GUARDED_BY(remote_state_mu_); std::atomic sleep_for_secs_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 8fa096616dc..89b70bdd700 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -659,9 +659,8 @@ Status EagerRemoteSendTensor(EagerContext* ctx, TensorHandle* h, } eager::EagerClient* eager_client; - uint64 context_id; - TF_RETURN_IF_ERROR( - ctx->GetClientAndContextID(recv_device, &eager_client, &context_id)); + uint64 context_id = ctx->GetContextId(); + TF_RETURN_IF_ERROR(ctx->GetClient(recv_device, &eager_client)); eager::SendTensorRequest request; eager::SendTensorResponse response; @@ -732,9 +731,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, EagerContext* ctx = op->EagerContext(); eager::EagerClient* eager_client; - uint64 context_id; - TF_RETURN_IF_ERROR( - ctx->GetClientAndContextID(op->Device(), &eager_client, &context_id)); + uint64 context_id = ctx->GetContextId(); + TF_RETURN_IF_ERROR(ctx->GetClient(op->Device(), &eager_client)); std::unique_ptr request(new eager::EnqueueRequest); eager::EnqueueResponse response; diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 280facc0f3e..cda6e248641 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -112,6 +112,7 @@ tf_cc_test( deps = [ ":worker_session", "//tensorflow/core:framework_internal", + "//tensorflow/core:lib_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index 8d943a64023..b3ea176ef90 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -97,8 +97,9 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, cluster_device_attributes.push_back(cluster_device); } - auto* r = env_->rendezvous_mgr->Find(request->rendezvous_id()); - auto session_name = strings::StrCat("eager_", request->rendezvous_id()); + auto* r = env_->rendezvous_mgr->Find(request->context_id()); + auto session_name = + tensorflow::strings::StrCat("eager_", request->context_id()); TF_RETURN_IF_ERROR(env_->session_mgr->CreateSession( session_name, request->server_def(), request->cluster_device_attributes(), true)); @@ -123,8 +124,30 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, SessionOptions(), tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, tensorflow::ContextMirroringPolicy::MIRRORING_NONE, request->async(), - device_mgr, false, r, nullptr, worker_session->cluster_flr.get(), - std::move(rendezvous_creator), worker_session->remote_device_mgr()); + device_mgr, false, r, nullptr, worker_session->cluster_flr.get()); + + Status s; + std::vector remote_workers; + worker_session->worker_cache->ListWorkers(&remote_workers); + remote_workers.erase(std::remove(remote_workers.begin(), remote_workers.end(), + worker_session->worker_name), + remote_workers.end()); + + std::unique_ptr remote_eager_workers; + s = env_->eager_client_cache_factory(request->server_def(), + &remote_eager_workers); + if (!s.ok()) { + delete ctx; + return s; + } + + s = ctx->InitializeRemoteWorker( + std::move(remote_eager_workers), worker_session->remote_device_mgr(), + remote_workers, request->context_id(), std::move(rendezvous_creator)); + if (!s.ok()) { + delete ctx; + return s; + } std::vector device_attributes; device_mgr->ListDeviceAttributes(&device_attributes); @@ -132,17 +155,17 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, for (const auto& da : device_attributes) { *response->add_device_attributes() = da; } - - uint64 context_id; { mutex_lock l(contexts_mu_); - do { - context_id = random::New64(); - } while (contexts_.find(context_id) != contexts_.end()); - contexts_.emplace(context_id, + if (contexts_.find(request->context_id()) != contexts_.end()) { + delete ctx; + return errors::InvalidArgument("EagerService:CreateContext failed. ", + "Context id: <", request->context_id(), + "> already exists."); + } + contexts_.emplace(request->context_id(), new ServerContext(ctx, request->keep_alive_secs(), env_)); } - response->set_context_id(context_id); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 7a1463e8f04..28df7736510 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/session_mgr.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -51,22 +52,49 @@ class TestEagerServiceImpl : public EagerServiceImpl { } }; +class DummyWorkerCache : public WorkerCacheInterface { + void ListWorkers(std::vector* workers) const override {} + void ListWorkersInJob(const string& job_name, + std::vector* workers) const override {} + WorkerInterface* GetOrCreateWorker(const string& target) override { + return nullptr; + } + bool GetDeviceLocalityNonBlocking(const string& device, + DeviceLocality* locality) override { + return false; + } + void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality, + StatusCallback done) override {} +}; + +class DummyEagerClientCache : public EagerClientCache { + Status GetClient(const string& target, EagerClient** client) override { + return errors::Unimplemented(""); + } +}; + class EagerServiceImplTest : public ::testing::Test { public: EagerServiceImplTest() : rendezvous_mgr_(&worker_env_), session_mgr_(new SessionMgr( &worker_env_, "/job:localhost/replica:0/task:0/device:CPU:0", - std::unique_ptr(), + absl::make_unique(), [](const ServerDef& server_def, WorkerCacheInterface** worker_cache) { - *worker_cache = nullptr; + *worker_cache = new DummyWorkerCache(); return Status::OK(); })) { worker_env_.env = Env::Default(); worker_env_.rendezvous_mgr = &rendezvous_mgr_; worker_env_.session_mgr = session_mgr_.get(); + worker_env_.eager_client_cache_factory = + [](const ServerDef& server_def, + std::unique_ptr* eager_client_cache) { + eager_client_cache->reset(new DummyEagerClientCache()); + return Status::OK(); + }; device_mgr_ = absl::make_unique(DeviceFactory::NewDevice( "CPU", {}, "/job:localhost/replica:0/task:0/device:CPU:0")); @@ -76,6 +104,7 @@ class EagerServiceImplTest : public ::testing::Test { protected: WorkerEnv worker_env_; + std::unique_ptr worker_cache_; tensorflow::RpcRendezvousMgr rendezvous_mgr_; std::unique_ptr session_mgr_; std::unique_ptr device_mgr_; @@ -153,15 +182,16 @@ tensorflow::FunctionDef MatMulFunction() { TEST_F(EagerServiceImplTest, BasicTest) { TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); + CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); - request.set_rendezvous_id(random::New64()); + request.set_context_id(context_id); CreateContextResponse response; TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); - uint64 context_id = response.context_id(); EnqueueRequest remote_enqueue_request; remote_enqueue_request.set_context_id(context_id); @@ -202,7 +232,7 @@ TEST_F(EagerServiceImplTest, BasicTest) { tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( - response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); + context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); // This should be OK to do since we've placed all computation on the CPU // device. @@ -229,16 +259,16 @@ TEST_F(EagerServiceImplTest, BasicTest) { TEST_F(EagerServiceImplTest, BasicFunctionTest) { TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); + CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); - request.set_rendezvous_id(random::New64()); + request.set_context_id(context_id); CreateContextResponse response; TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); - uint64 context_id = response.context_id(); - RegisterFunctionRequest register_function_request; register_function_request.set_context_id(context_id); *register_function_request.mutable_function_def() = MatMulFunction(); @@ -273,7 +303,7 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { const tensorflow::Tensor* t = nullptr; tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( - response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); + context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&t)); auto actual = t->flat(); @@ -296,15 +326,16 @@ TEST_F(EagerServiceImplTest, BasicFunctionTest) { TEST_F(EagerServiceImplTest, SendTensorTest) { TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); + CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); - request.set_rendezvous_id(random::New64()); + request.set_context_id(context_id); CreateContextResponse response; TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); - uint64 context_id = response.context_id(); SendTensorRequest send_tensor_request; send_tensor_request.set_context_id(context_id); @@ -339,7 +370,7 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { const tensorflow::Tensor* t = nullptr; tensorflow::TensorHandle* tensor_handle; TF_ASSERT_OK(eager_service_impl.GetTensorHandle( - response.context_id(), RemoteTensorHandleInternal(2, 0), &tensor_handle)); + context_id, RemoteTensorHandleInternal(2, 0), &tensor_handle)); TF_ASSERT_OK(tensor_handle->Tensor(&t)); Device* device = tensor_handle->device(); @@ -364,10 +395,11 @@ TEST_F(EagerServiceImplTest, SendTensorTest) { TEST_F(EagerServiceImplTest, KeepAliveTest) { TestEagerServiceImpl eager_service_impl(&worker_env_); + uint64 context_id = random::New64(); CreateContextRequest request; request.mutable_server_def()->set_job_name("localhost"); request.mutable_server_def()->set_task_index(0); - request.set_rendezvous_id(random::New64()); + request.set_context_id(context_id); request.set_keep_alive_secs(3); CreateContextResponse response; @@ -379,7 +411,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) { KeepAliveRequest keep_alive_request; KeepAliveResponse keep_alive_response; - keep_alive_request.set_context_id(response.context_id()); + keep_alive_request.set_context_id(context_id); Status status = eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); @@ -388,15 +420,16 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) { EXPECT_PRED_FORMAT2(::testing::IsSubstring, "Unable to find a context_id", status.error_message()); + uint64 new_context_id = random::New64(); // Create a new context. - request.set_rendezvous_id(random::New64()); + request.set_context_id(new_context_id); TF_ASSERT_OK(eager_service_impl.CreateContext(&request, &response)); // The context should not be GC'd. worker_env_.env->SleepForMicroseconds(1 * tensorflow::EnvTime::kSecondsToMicros); - keep_alive_request.set_context_id(response.context_id()); + keep_alive_request.set_context_id(new_context_id); TF_ASSERT_OK( eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response)); diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index a8df0e2128d..290ceb50d2a 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -29,7 +29,7 @@ void DestoryRemoteTensorHandle(EagerContext* ctx, int output_num) { auto cleanup = gtl::MakeCleanup([ctx]() { ctx->Unref(); }); - if (!ctx->HasActiveRemoteContext(context_id)) { + if (ctx->GetContextId() != context_id) { // This means that this tensor was pointing to a remote device, which // has been changed out from under us. Simply return since there is // nothing we can do. diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 107541d1ff8..d47a7f0e913 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -296,6 +296,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:collective_param_resolver_distributed", "//tensorflow/core/distributed_runtime:device_resolver_distributed", "//tensorflow/core/distributed_runtime:graph_mgr", @@ -308,6 +309,8 @@ cc_library( "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:worker_cache_wrapper", "//tensorflow/core/distributed_runtime:worker_env", + "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service_impl", ], alwayslink = 1, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index c0407af29ba..66eabe2b643 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -24,12 +24,12 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "grpcpp/security/credentials.h" #include "grpcpp/server_builder.h" - #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" +#include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/graph_mgr.h" #include "tensorflow/core/distributed_runtime/local_master.h" #include "tensorflow/core/distributed_runtime/master.h" @@ -49,6 +49,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/public/session_options.h" @@ -253,6 +254,12 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { return WorkerCacheFactory(options, worker_cache); }); worker_env_.compute_pool = ComputePool(sess_opts); + worker_env_.eager_client_cache_factory = + [this](const ServerDef& server_def, + std::unique_ptr* eager_client_cahce) { + WorkerCacheFactoryOptions options(server_def); + return EagerClientCacheFactory(options, eager_client_cahce); + }; // Finish setting up master environment. master_env_.ops = OpRegistry::Global(); @@ -310,25 +317,13 @@ Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, WorkerCacheInterface** worker_cache) { - if (options.job_name == nullptr || options.job_name->empty()) { - Status s = errors::InvalidArgument( - "The master (current machine) is not included in the provided " - "cluster_def. ", - options.cluster_def->DebugString()); - LOG(WARNING) << s; - return s; - } - - GrpcChannelSpec channel_spec; - TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); - - channel_cache_.reset( - NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); + std::shared_ptr channel_cache; + TF_RETURN_IF_ERROR(FindOrCreateChannelCache(options, &channel_cache)); string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0", "/task:", options.task_index); - const string host_port = channel_cache_->TranslateTask(name_prefix); + const string host_port = channel_cache->TranslateTask(name_prefix); int requested_port; auto colon_index = host_port.find_last_of(':'); @@ -343,11 +338,50 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, " differs from expected port ", bound_port_); } - *worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache_, + *worker_cache = NewGrpcWorkerCacheWithLocalWorker(channel_cache, worker_impl(), name_prefix); return Status::OK(); } +Status GrpcServer::EagerClientCacheFactory( + const WorkerCacheFactoryOptions& options, + std::unique_ptr* eager_client_cache) { + std::shared_ptr channel_cache; + TF_RETURN_IF_ERROR(FindOrCreateChannelCache(options, &channel_cache)); + + eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache)); + return Status::OK(); +} + +Status GrpcServer::FindOrCreateChannelCache( + const WorkerCacheFactoryOptions& options, + std::shared_ptr* cache) { + if (options.job_name == nullptr || options.job_name->empty()) { + Status s = errors::InvalidArgument( + "The master (current machine) is not included in the provided " + "cluster_def. ", + options.cluster_def->DebugString()); + LOG(ERROR) << s; + return s; + } + string cluster = ""; + if (options.cluster_def != nullptr) { + options.cluster_def->SerializeToString(&cluster); + } + Fprint128 cache_key = Fingerprint128(cluster); + mutex_lock l(channel_mu_); + *cache = gtl::FindPtrOrNull(channel_caches_, cache_key); + if (*cache == nullptr) { + GrpcChannelSpec channel_spec; + TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec)); + + *cache = std::shared_ptr( + NewGrpcChannelCache(channel_spec, GetChannelCreationFunction())); + channel_caches_.emplace(cache_key, *cache); + } + return Status::OK(); +} + Status GrpcServer::Start() { mutex_lock l(mu_); switch (state_) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 17bc93588c3..a390095f1a4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -22,11 +22,11 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "grpcpp/security/credentials.h" - #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" +#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" #include "tensorflow/core/distributed_runtime/server_lib.h" @@ -34,7 +34,10 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/fingerprint.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -95,7 +98,9 @@ class GrpcServer : public ServerInterface { WorkerEnv* worker_env() { return &worker_env_; } MasterEnv* master_env() { return &master_env_; } - std::shared_ptr channel_cache() { return channel_cache_; } + virtual Status EagerClientCacheFactory( + const WorkerCacheFactoryOptions& options, + std::unique_ptr* eager_client_cache); protected: virtual Status GetPort(int* port) const; @@ -124,11 +129,9 @@ class GrpcServer : public ServerInterface { const ServerDef& server_def() const { return server_def_; } GrpcWorker* worker_impl() const { return worker_impl_.get(); } - void set_channel_cache(GrpcChannelCache* channel_cache) { - channel_cache_.reset(channel_cache); - } - private: + Status FindOrCreateChannelCache(const WorkerCacheFactoryOptions& options, + std::shared_ptr* cache); // The overall server configuration. const ServerDef server_def_; Env* env_; @@ -156,7 +159,12 @@ class GrpcServer : public ServerInterface { std::unique_ptr master_impl_; AsyncServiceInterface* master_service_ = nullptr; std::unique_ptr master_thread_ GUARDED_BY(mu_); - std::shared_ptr channel_cache_; + + mutex channel_mu_; + // TODO(fishx): Cleanup channel caches. + std::unordered_map, + Fprint128Hasher> + channel_caches_ GUARDED_BY(channel_mu_); // Implementation of a TensorFlow worker, and RPC polling thread. WorkerEnv worker_env_; diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h index 93d933bfa60..5181cc6435a 100644 --- a/tensorflow/core/distributed_runtime/worker_env.h +++ b/tensorflow/core/distributed_runtime/worker_env.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_ #include + +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -25,12 +27,21 @@ namespace thread { class ThreadPool; } // namespace thread +namespace eager { +class EagerClientCache; +} // namespace eager + class CollectiveExecutorMgrInterface; class Device; class DeviceMgr; class Env; class RendezvousMgrInterface; class SessionMgr; +class ServerDef; + +typedef std::function*)> + EagerClientCacheFactory; // The worker environment class, which holds a bag of pointers to // per-worker singletons. @@ -64,6 +75,15 @@ struct WorkerEnv { // A pool of threads for scheduling compute work. thread::ThreadPool* compute_pool = nullptr; + + // A factory function to create eager client cache. + EagerClientCacheFactory eager_client_cache_factory = + [](const ServerDef& s, std::unique_ptr* c) { + return errors::Unimplemented( + "EagerClientCacheFactory unimplemented. " + "It is probably because you didn't use GRPC. Right now " + "EagerClient only supports GRPC."); + }; }; } // end namespace tensorflow diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto index 295eb460ad9..0941f5affb3 100644 --- a/tensorflow/core/protobuf/eager_service.proto +++ b/tensorflow/core/protobuf/eager_service.proto @@ -67,21 +67,19 @@ message CreateContextRequest { // This is the version for all the ops that will be enqueued by the client. VersionDef version_def = 4; - // This ID will be used for all future communications. It is essential that - // both ends use this ID for selecting a rendezvous to get everything to - // match. - int64 rendezvous_id = 5; - // Device attributes in the cluster repeated DeviceAttributes cluster_device_attributes = 6; -} -message CreateContextResponse { // The ID of the created context. This is usually a randomly generated number, // that will be used to identify the context in future requests to the // service. Contexts are not persisted through server restarts. - fixed64 context_id = 1; + // This ID will be used for all future communications as well. It is essential + // that both ends use this ID for selecting a rendezvous to get everything to + // match. + fixed64 context_id = 7; +} +message CreateContextResponse { // List of devices that are locally accessible to the worker. repeated DeviceAttributes device_attributes = 2; }