Pass in GrpcWorkerEnv when creating GrpcWorkerCache.

PiperOrigin-RevId: 314769356
Change-Id: I154786dbba5eeb69d151baed83fc89a9dcd6a989
This commit is contained in:
Haoyu Zhang 2020-06-04 11:30:48 -07:00 committed by TensorFlower Gardener
parent 4815e7515f
commit 48678a1e2d
8 changed files with 158 additions and 48 deletions

View File

@ -36,9 +36,11 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
TF_CHECK_OK(spec.AddHostPortsJob("localhost", cluster_->targets()));
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
std::shared_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(spec, channel_func));
std::unique_ptr<WorkerCacheInterface> worker_cache(
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache>(
NewGrpcChannelCache(spec, channel_func))));
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
worker_session_.reset(new WorkerSession(
"cluster_test_session", "/job:localhost/replica:0/task:0",
@ -110,6 +112,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<test::TestCluster> cluster_;
std::unique_ptr<WorkerSession> worker_session_;
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
};
TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {

View File

@ -39,6 +39,7 @@ class RemoteDeviceTest : public ::testing::Test {
WorkerInterface* wi_;
std::vector<Device*> devices_;
std::unique_ptr<test::TestCluster> cluster_;
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
RemoteDeviceTest() {
SessionOptions options;
@ -51,7 +52,9 @@ class RemoteDeviceTest : public ::testing::Test {
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
std::shared_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(spec, channel_func));
worker_cache_.reset(NewGrpcWorkerCache(channel_cache));
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
worker_cache_.reset(
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
remote_name_ = "/job:localhost/replica:0/task:0";
wi_ = worker_cache_->GetOrCreateWorker(remote_name_);
}
@ -82,11 +85,13 @@ class RemoteDeviceTest : public ::testing::Test {
TEST_F(RemoteDeviceTest, GetStatus) {
// We know what the testlib's fake server does.
EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
EXPECT_EQ(devices_[0]->name(),
strings::StrCat(remote_name_, "/device:CPU:0"));
EXPECT_EQ(devices_[0]->attributes().device_type(),
DeviceType(DEVICE_CPU).type());
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
EXPECT_EQ(devices_[1]->name(),
strings::StrCat(remote_name_, "/device:CPU:1"));
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
}

View File

@ -165,6 +165,26 @@ cc_library(
"//tensorflow/core/distributed_runtime:worker_cache_partial",
"//tensorflow/core/distributed_runtime:worker_interface",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
"//tensorflow/core/util:env_var",
],
)
tf_cc_test(
name = "grpc_worker_cache_test",
size = "small",
srcs = [
"grpc_worker_cache_test.cc",
],
deps = [
":grpc_worker_cache",
"//tensorflow/c:tf_status_headers",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime:test_utils",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/platform:env",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:strcat",
],
)

View File

@ -87,23 +87,6 @@ RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
return new RpcRendezvousMgr(env);
}
std::unique_ptr<GrpcWorkerEnv> CreateGrpcWorkerEnv() {
int num_cpus = port::NumSchedulableCPUs();
int64 num_completion_queues;
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
&num_completion_queues);
if (!status.ok()) {
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
}
int64 num_threads;
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
&num_threads);
if (!status.ok()) {
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
}
return absl::make_unique<GrpcWorkerEnv>(num_completion_queues, num_threads);
}
} // namespace
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
@ -275,7 +258,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
return errors::Unknown("Could not start gRPC server");
}
// Create the execution environment for the GRPC workers cache.
grpc_worker_env_ = CreateGrpcWorkerEnv();
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
WorkerCacheInterface* worker_cache;
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
@ -401,7 +384,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
" differs from expected port ", bound_port_);
}
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
channel_cache, worker_impl(), name_prefix, grpc_worker_env_.get());
channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
return Status::OK();
}

View File

@ -137,6 +137,7 @@ class GrpcServer : public ServerInterface {
const ServerDef& server_def() const { return server_def_; }
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
private:
Env* env_;

View File

@ -15,24 +15,20 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/env_var.h"
namespace tensorflow {
namespace {
// TODO(ncteisen): consider adding a config var or flag for this
static const size_t kGrpcWorkerCacheThreadCount = 8;
static const size_t kNumCallbackThreads = 10;
class GrpcWorkerCache : public WorkerCachePartial {
public:
explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
@ -43,13 +39,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
local_worker_(local_worker),
channel_cache_(channel_cache),
worker_env_(worker_env),
next_round_robin_assignment_(0) {
if (worker_env_ == nullptr) {
worker_env_ptr_ = absl::make_unique<GrpcWorkerEnv>(
kGrpcWorkerCacheThreadCount, kNumCallbackThreads);
worker_env_ = worker_env_ptr_.get();
}
}
next_round_robin_assignment_(0) {}
void ListWorkers(std::vector<string>* workers) const override {
channel_cache_->ListWorkers(workers);
@ -118,8 +108,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
WorkerInterface* const local_worker_; // Not owned.
std::shared_ptr<GrpcChannelCache> channel_cache_;
WorkerCacheLogger logger_;
GrpcWorkerEnv* worker_env_; // Not owned, if worker_env_ptr_ is nullptr.
std::unique_ptr<GrpcWorkerEnv> worker_env_ptr_;
GrpcWorkerEnv* worker_env_; // Not owned
mutex assignment_mu_;
std::unordered_map<std::string, size_t> target_assignments_
@ -154,13 +143,32 @@ GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
thread_.reset();
}
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
return new GrpcWorkerCache(cc, nullptr, "", nullptr);
GrpcWorkerEnv* CreateGrpcWorkerEnv() {
int num_cpus = port::NumSchedulableCPUs();
int64 num_completion_queues;
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
&num_completion_queues);
if (!status.ok()) {
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
}
int64 num_threads;
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
&num_threads);
if (!status.ok()) {
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
}
return new GrpcWorkerEnv(num_completion_queues, num_threads);
}
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
GrpcWorkerEnv* worker_env) {
return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
worker_env);
}
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
const string& local_target, GrpcWorkerEnv* worker_env) {
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
WorkerInterface* local_worker, const string& local_target) {
return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
}

View File

@ -16,8 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
#include <memory>
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -62,12 +60,17 @@ class GrpcWorkerEnv {
std::vector<GrpcWorkerCacheThread> threads_;
};
// Create a GrpcWorkerEnv instance that can be used as argument to create
// gRPC worker cache. Caller should take the ownership of the returned instance.
GrpcWorkerEnv* CreateGrpcWorkerEnv();
// The returned WorkerCacheInterface object takes the ownership of "cc".
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc);
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
GrpcWorkerEnv* worker_env);
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
const string& local_target, GrpcWorkerEnv* worker_env);
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
WorkerInterface* local_worker, const string& local_target);
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/threadpool.h"
namespace tensorflow {
TEST(GrpcWorkerCacheTest, NewGrpcWorkerCache) {
GrpcChannelSpec spec;
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:0", "b:1", "c:2"}));
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
NewGrpcChannelCache(spec, channel_func));
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
// We created a job with 3 tasks. Getting the task 0, 1, 2 should return valid
// worker interfaces, and getting other tasks should return nullptr.
std::unique_ptr<WorkerCacheInterface> worker_cache(
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get()));
WorkerInterface* wi;
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:0");
EXPECT_NE(wi, nullptr);
worker_cache->ReleaseWorker("/job:worker/replica:0/task:0", wi);
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:1");
EXPECT_NE(wi, nullptr);
worker_cache->ReleaseWorker("/job:worker/replica:0/task:1", wi);
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:2");
EXPECT_NE(wi, nullptr);
worker_cache->ReleaseWorker("/job:worker/replica:0/task:2", wi);
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:3");
EXPECT_EQ(wi, nullptr);
// Test creating a worker cache instance with local worker, and getting the
// worker instance with the specified local target.
std::unique_ptr<TestWorkerInterface> local_wi;
worker_cache.reset(NewGrpcWorkerCacheWithLocalWorker(
channel_cache, grpc_worker_env.get(), local_wi.get(), "local_target"));
wi = worker_cache->GetOrCreateWorker("local_target");
EXPECT_EQ(wi, local_wi.get());
}
TEST(GrpcWorkerCacheTest, DestructWorkerCacheInThreadPool) {
GrpcChannelSpec spec;
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:1", "b:2", "c:3"}));
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
NewGrpcChannelCache(spec, channel_func));
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
// The GrpcWorkerEnv threadpool is used for worker interfaces for gRPC
// completion queue callbacks. Test worker cache destruction inside the
// callbacks that runs in the GrpcWorkerEnv threadpool.
WorkerCacheInterface* worker_cache =
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get());
thread::ThreadPool* tp = grpc_worker_env->GetThreadPool();
Notification n;
tp->Schedule([worker_cache, &n] {
delete worker_cache;
n.Notify();
});
n.WaitForNotification();
}
} // namespace tensorflow