diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index ba988e83eef..527c1901e3f 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -36,9 +36,11 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { TF_CHECK_OK(spec.AddHostPortsJob("localhost", cluster_->targets())); ChannelCreationFunction channel_func = ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + grpc_worker_env_.reset(CreateGrpcWorkerEnv()); + std::shared_ptr channel_cache( + NewGrpcChannelCache(spec, channel_func)); std::unique_ptr worker_cache( - NewGrpcWorkerCache(std::shared_ptr( - NewGrpcChannelCache(spec, channel_func)))); + NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get())); worker_session_.reset(new WorkerSession( "cluster_test_session", "/job:localhost/replica:0/task:0", @@ -110,6 +112,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr cluster_; std::unique_ptr worker_session_; std::unique_ptr cluster_flr_; + std::unique_ptr grpc_worker_env_; }; TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) { diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc index 62082aa6d59..ff1acd9c015 100644 --- a/tensorflow/core/distributed_runtime/remote_device_test.cc +++ b/tensorflow/core/distributed_runtime/remote_device_test.cc @@ -39,6 +39,7 @@ class RemoteDeviceTest : public ::testing::Test { WorkerInterface* wi_; std::vector devices_; std::unique_ptr cluster_; + std::unique_ptr grpc_worker_env_; RemoteDeviceTest() { SessionOptions options; @@ -51,7 +52,9 @@ class RemoteDeviceTest : public ::testing::Test { ConvertToChannelCreationFunction(NewHostPortGrpcChannel); std::shared_ptr channel_cache( NewGrpcChannelCache(spec, channel_func)); - worker_cache_.reset(NewGrpcWorkerCache(channel_cache)); + grpc_worker_env_.reset(CreateGrpcWorkerEnv()); + worker_cache_.reset( + NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get())); remote_name_ = "/job:localhost/replica:0/task:0"; wi_ = worker_cache_->GetOrCreateWorker(remote_name_); } @@ -82,11 +85,13 @@ class RemoteDeviceTest : public ::testing::Test { TEST_F(RemoteDeviceTest, GetStatus) { // We know what the testlib's fake server does. - EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0")); + EXPECT_EQ(devices_[0]->name(), + strings::StrCat(remote_name_, "/device:CPU:0")); EXPECT_EQ(devices_[0]->attributes().device_type(), DeviceType(DEVICE_CPU).type()); EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20); - EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1")); + EXPECT_EQ(devices_[1]->name(), + strings::StrCat(remote_name_, "/device:CPU:1")); EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20); } diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 60d7172c2fc..02dcfa86dd8 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -165,6 +165,26 @@ cc_library( "//tensorflow/core/distributed_runtime:worker_cache_partial", "//tensorflow/core/distributed_runtime:worker_interface", "//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client", + "//tensorflow/core/util:env_var", + ], +) + +tf_cc_test( + name = "grpc_worker_cache_test", + size = "small", + srcs = [ + "grpc_worker_cache_test.cc", + ], + deps = [ + ":grpc_worker_cache", + "//tensorflow/c:tf_status_headers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/distributed_runtime:test_utils", + "//tensorflow/core/distributed_runtime/rpc:grpc_channel", + "//tensorflow/core/platform:env", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 6523d2fb4dd..c0b4d0ef6ec 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -87,23 +87,6 @@ RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) { return new RpcRendezvousMgr(env); } -std::unique_ptr CreateGrpcWorkerEnv() { - int num_cpus = port::NumSchedulableCPUs(); - int64 num_completion_queues; - Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64, - &num_completion_queues); - if (!status.ok()) { - LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status; - } - int64 num_threads; - status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus, - &num_threads); - if (!status.ok()) { - LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status; - } - return absl::make_unique(num_completion_queues, num_threads); -} - } // namespace GrpcServer::GrpcServer(const ServerDef& server_def, Env* env) @@ -275,7 +258,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { return errors::Unknown("Could not start gRPC server"); } // Create the execution environment for the GRPC workers cache. - grpc_worker_env_ = CreateGrpcWorkerEnv(); + grpc_worker_env_.reset(CreateGrpcWorkerEnv()); WorkerCacheInterface* worker_cache; WorkerCacheFactoryOptions worker_cache_factory_options(server_def_); @@ -401,7 +384,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, " differs from expected port ", bound_port_); } *worker_cache = NewGrpcWorkerCacheWithLocalWorker( - channel_cache, worker_impl(), name_prefix, grpc_worker_env_.get()); + channel_cache, grpc_worker_env(), worker_impl(), name_prefix); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 0474c5a517f..1e8fe35b5b4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -137,6 +137,7 @@ class GrpcServer : public ServerInterface { const ServerDef& server_def() const { return server_def_; } GrpcWorker* worker_impl() const { return worker_impl_.get(); } + GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } private: Env* env_; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 1d75728ddd2..cb4458a2ca4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -15,24 +15,20 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" -#include - #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/worker_cache_logger.h" #include "tensorflow/core/distributed_runtime/worker_cache_partial.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/env_var.h" namespace tensorflow { namespace { -// TODO(ncteisen): consider adding a config var or flag for this -static const size_t kGrpcWorkerCacheThreadCount = 8; -static const size_t kNumCallbackThreads = 10; - class GrpcWorkerCache : public WorkerCachePartial { public: explicit GrpcWorkerCache(std::shared_ptr channel_cache, @@ -43,13 +39,7 @@ class GrpcWorkerCache : public WorkerCachePartial { local_worker_(local_worker), channel_cache_(channel_cache), worker_env_(worker_env), - next_round_robin_assignment_(0) { - if (worker_env_ == nullptr) { - worker_env_ptr_ = absl::make_unique( - kGrpcWorkerCacheThreadCount, kNumCallbackThreads); - worker_env_ = worker_env_ptr_.get(); - } - } + next_round_robin_assignment_(0) {} void ListWorkers(std::vector* workers) const override { channel_cache_->ListWorkers(workers); @@ -118,8 +108,7 @@ class GrpcWorkerCache : public WorkerCachePartial { WorkerInterface* const local_worker_; // Not owned. std::shared_ptr channel_cache_; WorkerCacheLogger logger_; - GrpcWorkerEnv* worker_env_; // Not owned, if worker_env_ptr_ is nullptr. - std::unique_ptr worker_env_ptr_; + GrpcWorkerEnv* worker_env_; // Not owned mutex assignment_mu_; std::unordered_map target_assignments_ @@ -154,13 +143,32 @@ GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() { thread_.reset(); } -WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr cc) { - return new GrpcWorkerCache(cc, nullptr, "", nullptr); +GrpcWorkerEnv* CreateGrpcWorkerEnv() { + int num_cpus = port::NumSchedulableCPUs(); + int64 num_completion_queues; + Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64, + &num_completion_queues); + if (!status.ok()) { + LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status; + } + int64 num_threads; + status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus, + &num_threads); + if (!status.ok()) { + LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status; + } + return new GrpcWorkerEnv(num_completion_queues, num_threads); +} + +WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr cc, + GrpcWorkerEnv* worker_env) { + return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"", + worker_env); } WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( - std::shared_ptr cc, WorkerInterface* local_worker, - const string& local_target, GrpcWorkerEnv* worker_env) { + std::shared_ptr cc, GrpcWorkerEnv* worker_env, + WorkerInterface* local_worker, const string& local_target) { return new GrpcWorkerCache(cc, local_worker, local_target, worker_env); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h index ef54f3e9b97..2dfbc79aa8a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h @@ -16,8 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ -#include - #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" @@ -62,12 +60,17 @@ class GrpcWorkerEnv { std::vector threads_; }; +// Create a GrpcWorkerEnv instance that can be used as argument to create +// gRPC worker cache. Caller should take the ownership of the returned instance. +GrpcWorkerEnv* CreateGrpcWorkerEnv(); + // The returned WorkerCacheInterface object takes the ownership of "cc". -WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr cc); +WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr cc, + GrpcWorkerEnv* worker_env); WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker( - std::shared_ptr cc, WorkerInterface* local_worker, - const string& local_target, GrpcWorkerEnv* worker_env); + std::shared_ptr cc, GrpcWorkerEnv* worker_env, + WorkerInterface* local_worker, const string& local_target); } // namespace tensorflow #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache_test.cc new file mode 100644 index 00000000000..ff32fa91205 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" + +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" +#include "tensorflow/core/distributed_runtime/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/threadpool.h" + +namespace tensorflow { + +TEST(GrpcWorkerCacheTest, NewGrpcWorkerCache) { + GrpcChannelSpec spec; + TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:0", "b:1", "c:2"})); + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + auto channel_cache = std::shared_ptr( + NewGrpcChannelCache(spec, channel_func)); + std::unique_ptr grpc_worker_env(CreateGrpcWorkerEnv()); + + // We created a job with 3 tasks. Getting the task 0, 1, 2 should return valid + // worker interfaces, and getting other tasks should return nullptr. + std::unique_ptr worker_cache( + NewGrpcWorkerCache(channel_cache, grpc_worker_env.get())); + WorkerInterface* wi; + wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:0"); + EXPECT_NE(wi, nullptr); + worker_cache->ReleaseWorker("/job:worker/replica:0/task:0", wi); + wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:1"); + EXPECT_NE(wi, nullptr); + worker_cache->ReleaseWorker("/job:worker/replica:0/task:1", wi); + wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:2"); + EXPECT_NE(wi, nullptr); + worker_cache->ReleaseWorker("/job:worker/replica:0/task:2", wi); + wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:3"); + EXPECT_EQ(wi, nullptr); + + // Test creating a worker cache instance with local worker, and getting the + // worker instance with the specified local target. + std::unique_ptr local_wi; + worker_cache.reset(NewGrpcWorkerCacheWithLocalWorker( + channel_cache, grpc_worker_env.get(), local_wi.get(), "local_target")); + wi = worker_cache->GetOrCreateWorker("local_target"); + EXPECT_EQ(wi, local_wi.get()); +} + +TEST(GrpcWorkerCacheTest, DestructWorkerCacheInThreadPool) { + GrpcChannelSpec spec; + TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:1", "b:2", "c:3"})); + ChannelCreationFunction channel_func = + ConvertToChannelCreationFunction(NewHostPortGrpcChannel); + auto channel_cache = std::shared_ptr( + NewGrpcChannelCache(spec, channel_func)); + std::unique_ptr grpc_worker_env(CreateGrpcWorkerEnv()); + + // The GrpcWorkerEnv threadpool is used for worker interfaces for gRPC + // completion queue callbacks. Test worker cache destruction inside the + // callbacks that runs in the GrpcWorkerEnv threadpool. + WorkerCacheInterface* worker_cache = + NewGrpcWorkerCache(channel_cache, grpc_worker_env.get()); + thread::ThreadPool* tp = grpc_worker_env->GetThreadPool(); + Notification n; + tp->Schedule([worker_cache, &n] { + delete worker_cache; + n.Notify(); + }); + n.WaitForNotification(); +} + +} // namespace tensorflow