Pass in GrpcWorkerEnv when creating GrpcWorkerCache.
PiperOrigin-RevId: 314769356 Change-Id: I154786dbba5eeb69d151baed83fc89a9dcd6a989
This commit is contained in:
parent
4815e7515f
commit
48678a1e2d
@ -36,9 +36,11 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
TF_CHECK_OK(spec.AddHostPortsJob("localhost", cluster_->targets()));
|
||||
ChannelCreationFunction channel_func =
|
||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache(
|
||||
NewGrpcChannelCache(spec, channel_func));
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache(
|
||||
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache>(
|
||||
NewGrpcChannelCache(spec, channel_func))));
|
||||
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
|
||||
|
||||
worker_session_.reset(new WorkerSession(
|
||||
"cluster_test_session", "/job:localhost/replica:0/task:0",
|
||||
@ -110,6 +112,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::unique_ptr<test::TestCluster> cluster_;
|
||||
std::unique_ptr<WorkerSession> worker_session_;
|
||||
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
|
||||
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
|
||||
};
|
||||
|
||||
TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {
|
||||
|
@ -39,6 +39,7 @@ class RemoteDeviceTest : public ::testing::Test {
|
||||
WorkerInterface* wi_;
|
||||
std::vector<Device*> devices_;
|
||||
std::unique_ptr<test::TestCluster> cluster_;
|
||||
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
|
||||
|
||||
RemoteDeviceTest() {
|
||||
SessionOptions options;
|
||||
@ -51,7 +52,9 @@ class RemoteDeviceTest : public ::testing::Test {
|
||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache(
|
||||
NewGrpcChannelCache(spec, channel_func));
|
||||
worker_cache_.reset(NewGrpcWorkerCache(channel_cache));
|
||||
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||
worker_cache_.reset(
|
||||
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
|
||||
remote_name_ = "/job:localhost/replica:0/task:0";
|
||||
wi_ = worker_cache_->GetOrCreateWorker(remote_name_);
|
||||
}
|
||||
@ -82,11 +85,13 @@ class RemoteDeviceTest : public ::testing::Test {
|
||||
|
||||
TEST_F(RemoteDeviceTest, GetStatus) {
|
||||
// We know what the testlib's fake server does.
|
||||
EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
|
||||
EXPECT_EQ(devices_[0]->name(),
|
||||
strings::StrCat(remote_name_, "/device:CPU:0"));
|
||||
EXPECT_EQ(devices_[0]->attributes().device_type(),
|
||||
DeviceType(DEVICE_CPU).type());
|
||||
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
|
||||
EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
|
||||
EXPECT_EQ(devices_[1]->name(),
|
||||
strings::StrCat(remote_name_, "/device:CPU:1"));
|
||||
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
|
||||
}
|
||||
|
||||
|
@ -165,6 +165,26 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_partial",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||
"//tensorflow/core/util:env_var",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "grpc_worker_cache_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"grpc_worker_cache_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":grpc_worker_cache",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime:test_utils",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/platform:env",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -87,23 +87,6 @@ RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
|
||||
return new RpcRendezvousMgr(env);
|
||||
}
|
||||
|
||||
std::unique_ptr<GrpcWorkerEnv> CreateGrpcWorkerEnv() {
|
||||
int num_cpus = port::NumSchedulableCPUs();
|
||||
int64 num_completion_queues;
|
||||
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
|
||||
&num_completion_queues);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
|
||||
}
|
||||
int64 num_threads;
|
||||
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
|
||||
&num_threads);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
|
||||
}
|
||||
return absl::make_unique<GrpcWorkerEnv>(num_completion_queues, num_threads);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
|
||||
@ -275,7 +258,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
||||
return errors::Unknown("Could not start gRPC server");
|
||||
}
|
||||
// Create the execution environment for the GRPC workers cache.
|
||||
grpc_worker_env_ = CreateGrpcWorkerEnv();
|
||||
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||
|
||||
WorkerCacheInterface* worker_cache;
|
||||
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
|
||||
@ -401,7 +384,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
||||
" differs from expected port ", bound_port_);
|
||||
}
|
||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
|
||||
channel_cache, worker_impl(), name_prefix, grpc_worker_env_.get());
|
||||
channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -137,6 +137,7 @@ class GrpcServer : public ServerInterface {
|
||||
|
||||
const ServerDef& server_def() const { return server_def_; }
|
||||
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
|
||||
GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
|
||||
|
||||
private:
|
||||
Env* env_;
|
||||
|
@ -15,24 +15,20 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(ncteisen): consider adding a config var or flag for this
|
||||
static const size_t kGrpcWorkerCacheThreadCount = 8;
|
||||
static const size_t kNumCallbackThreads = 10;
|
||||
|
||||
class GrpcWorkerCache : public WorkerCachePartial {
|
||||
public:
|
||||
explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
|
||||
@ -43,13 +39,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
||||
local_worker_(local_worker),
|
||||
channel_cache_(channel_cache),
|
||||
worker_env_(worker_env),
|
||||
next_round_robin_assignment_(0) {
|
||||
if (worker_env_ == nullptr) {
|
||||
worker_env_ptr_ = absl::make_unique<GrpcWorkerEnv>(
|
||||
kGrpcWorkerCacheThreadCount, kNumCallbackThreads);
|
||||
worker_env_ = worker_env_ptr_.get();
|
||||
}
|
||||
}
|
||||
next_round_robin_assignment_(0) {}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) const override {
|
||||
channel_cache_->ListWorkers(workers);
|
||||
@ -118,8 +108,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
||||
WorkerInterface* const local_worker_; // Not owned.
|
||||
std::shared_ptr<GrpcChannelCache> channel_cache_;
|
||||
WorkerCacheLogger logger_;
|
||||
GrpcWorkerEnv* worker_env_; // Not owned, if worker_env_ptr_ is nullptr.
|
||||
std::unique_ptr<GrpcWorkerEnv> worker_env_ptr_;
|
||||
GrpcWorkerEnv* worker_env_; // Not owned
|
||||
|
||||
mutex assignment_mu_;
|
||||
std::unordered_map<std::string, size_t> target_assignments_
|
||||
@ -154,13 +143,32 @@ GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
|
||||
thread_.reset();
|
||||
}
|
||||
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
|
||||
return new GrpcWorkerCache(cc, nullptr, "", nullptr);
|
||||
GrpcWorkerEnv* CreateGrpcWorkerEnv() {
|
||||
int num_cpus = port::NumSchedulableCPUs();
|
||||
int64 num_completion_queues;
|
||||
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
|
||||
&num_completion_queues);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
|
||||
}
|
||||
int64 num_threads;
|
||||
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
|
||||
&num_threads);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
|
||||
}
|
||||
return new GrpcWorkerEnv(num_completion_queues, num_threads);
|
||||
}
|
||||
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
|
||||
GrpcWorkerEnv* worker_env) {
|
||||
return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
|
||||
worker_env);
|
||||
}
|
||||
|
||||
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
||||
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
|
||||
const string& local_target, GrpcWorkerEnv* worker_env) {
|
||||
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
|
||||
WorkerInterface* local_worker, const string& local_target) {
|
||||
return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
|
||||
}
|
||||
|
||||
|
@ -16,8 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
@ -62,12 +60,17 @@ class GrpcWorkerEnv {
|
||||
std::vector<GrpcWorkerCacheThread> threads_;
|
||||
};
|
||||
|
||||
// Create a GrpcWorkerEnv instance that can be used as argument to create
|
||||
// gRPC worker cache. Caller should take the ownership of the returned instance.
|
||||
GrpcWorkerEnv* CreateGrpcWorkerEnv();
|
||||
|
||||
// The returned WorkerCacheInterface object takes the ownership of "cc".
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc);
|
||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
|
||||
GrpcWorkerEnv* worker_env);
|
||||
|
||||
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
||||
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
|
||||
const string& local_target, GrpcWorkerEnv* worker_env);
|
||||
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
|
||||
WorkerInterface* local_worker, const string& local_target);
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(GrpcWorkerCacheTest, NewGrpcWorkerCache) {
|
||||
GrpcChannelSpec spec;
|
||||
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:0", "b:1", "c:2"}));
|
||||
ChannelCreationFunction channel_func =
|
||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||
NewGrpcChannelCache(spec, channel_func));
|
||||
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
|
||||
|
||||
// We created a job with 3 tasks. Getting the task 0, 1, 2 should return valid
|
||||
// worker interfaces, and getting other tasks should return nullptr.
|
||||
std::unique_ptr<WorkerCacheInterface> worker_cache(
|
||||
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get()));
|
||||
WorkerInterface* wi;
|
||||
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:0");
|
||||
EXPECT_NE(wi, nullptr);
|
||||
worker_cache->ReleaseWorker("/job:worker/replica:0/task:0", wi);
|
||||
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:1");
|
||||
EXPECT_NE(wi, nullptr);
|
||||
worker_cache->ReleaseWorker("/job:worker/replica:0/task:1", wi);
|
||||
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:2");
|
||||
EXPECT_NE(wi, nullptr);
|
||||
worker_cache->ReleaseWorker("/job:worker/replica:0/task:2", wi);
|
||||
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:3");
|
||||
EXPECT_EQ(wi, nullptr);
|
||||
|
||||
// Test creating a worker cache instance with local worker, and getting the
|
||||
// worker instance with the specified local target.
|
||||
std::unique_ptr<TestWorkerInterface> local_wi;
|
||||
worker_cache.reset(NewGrpcWorkerCacheWithLocalWorker(
|
||||
channel_cache, grpc_worker_env.get(), local_wi.get(), "local_target"));
|
||||
wi = worker_cache->GetOrCreateWorker("local_target");
|
||||
EXPECT_EQ(wi, local_wi.get());
|
||||
}
|
||||
|
||||
TEST(GrpcWorkerCacheTest, DestructWorkerCacheInThreadPool) {
|
||||
GrpcChannelSpec spec;
|
||||
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:1", "b:2", "c:3"}));
|
||||
ChannelCreationFunction channel_func =
|
||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||
NewGrpcChannelCache(spec, channel_func));
|
||||
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
|
||||
|
||||
// The GrpcWorkerEnv threadpool is used for worker interfaces for gRPC
|
||||
// completion queue callbacks. Test worker cache destruction inside the
|
||||
// callbacks that runs in the GrpcWorkerEnv threadpool.
|
||||
WorkerCacheInterface* worker_cache =
|
||||
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get());
|
||||
thread::ThreadPool* tp = grpc_worker_env->GetThreadPool();
|
||||
Notification n;
|
||||
tp->Schedule([worker_cache, &n] {
|
||||
delete worker_cache;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user