Pass in GrpcWorkerEnv when creating GrpcWorkerCache.
PiperOrigin-RevId: 314769356 Change-Id: I154786dbba5eeb69d151baed83fc89a9dcd6a989
This commit is contained in:
parent
4815e7515f
commit
48678a1e2d
@ -36,9 +36,11 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
TF_CHECK_OK(spec.AddHostPortsJob("localhost", cluster_->targets()));
|
TF_CHECK_OK(spec.AddHostPortsJob("localhost", cluster_->targets()));
|
||||||
ChannelCreationFunction channel_func =
|
ChannelCreationFunction channel_func =
|
||||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
|
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||||
|
std::shared_ptr<GrpcChannelCache> channel_cache(
|
||||||
|
NewGrpcChannelCache(spec, channel_func));
|
||||||
std::unique_ptr<WorkerCacheInterface> worker_cache(
|
std::unique_ptr<WorkerCacheInterface> worker_cache(
|
||||||
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache>(
|
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
|
||||||
NewGrpcChannelCache(spec, channel_func))));
|
|
||||||
|
|
||||||
worker_session_.reset(new WorkerSession(
|
worker_session_.reset(new WorkerSession(
|
||||||
"cluster_test_session", "/job:localhost/replica:0/task:0",
|
"cluster_test_session", "/job:localhost/replica:0/task:0",
|
||||||
@ -110,6 +112,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
std::unique_ptr<test::TestCluster> cluster_;
|
std::unique_ptr<test::TestCluster> cluster_;
|
||||||
std::unique_ptr<WorkerSession> worker_session_;
|
std::unique_ptr<WorkerSession> worker_session_;
|
||||||
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
|
std::unique_ptr<ClusterFunctionLibraryRuntime> cluster_flr_;
|
||||||
|
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {
|
TEST_F(ClusterFunctionLibraryRuntimeTest, ConstructFunctionGraph) {
|
||||||
|
@ -39,6 +39,7 @@ class RemoteDeviceTest : public ::testing::Test {
|
|||||||
WorkerInterface* wi_;
|
WorkerInterface* wi_;
|
||||||
std::vector<Device*> devices_;
|
std::vector<Device*> devices_;
|
||||||
std::unique_ptr<test::TestCluster> cluster_;
|
std::unique_ptr<test::TestCluster> cluster_;
|
||||||
|
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env_;
|
||||||
|
|
||||||
RemoteDeviceTest() {
|
RemoteDeviceTest() {
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
@ -51,7 +52,9 @@ class RemoteDeviceTest : public ::testing::Test {
|
|||||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
std::shared_ptr<GrpcChannelCache> channel_cache(
|
std::shared_ptr<GrpcChannelCache> channel_cache(
|
||||||
NewGrpcChannelCache(spec, channel_func));
|
NewGrpcChannelCache(spec, channel_func));
|
||||||
worker_cache_.reset(NewGrpcWorkerCache(channel_cache));
|
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||||
|
worker_cache_.reset(
|
||||||
|
NewGrpcWorkerCache(channel_cache, grpc_worker_env_.get()));
|
||||||
remote_name_ = "/job:localhost/replica:0/task:0";
|
remote_name_ = "/job:localhost/replica:0/task:0";
|
||||||
wi_ = worker_cache_->GetOrCreateWorker(remote_name_);
|
wi_ = worker_cache_->GetOrCreateWorker(remote_name_);
|
||||||
}
|
}
|
||||||
@ -82,11 +85,13 @@ class RemoteDeviceTest : public ::testing::Test {
|
|||||||
|
|
||||||
TEST_F(RemoteDeviceTest, GetStatus) {
|
TEST_F(RemoteDeviceTest, GetStatus) {
|
||||||
// We know what the testlib's fake server does.
|
// We know what the testlib's fake server does.
|
||||||
EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
|
EXPECT_EQ(devices_[0]->name(),
|
||||||
|
strings::StrCat(remote_name_, "/device:CPU:0"));
|
||||||
EXPECT_EQ(devices_[0]->attributes().device_type(),
|
EXPECT_EQ(devices_[0]->attributes().device_type(),
|
||||||
DeviceType(DEVICE_CPU).type());
|
DeviceType(DEVICE_CPU).type());
|
||||||
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
|
EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
|
||||||
EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
|
EXPECT_EQ(devices_[1]->name(),
|
||||||
|
strings::StrCat(remote_name_, "/device:CPU:1"));
|
||||||
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
|
EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -165,6 +165,26 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime:worker_cache_partial",
|
"//tensorflow/core/distributed_runtime:worker_cache_partial",
|
||||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||||
|
"//tensorflow/core/util:env_var",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "grpc_worker_cache_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"grpc_worker_cache_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":grpc_worker_cache",
|
||||||
|
"//tensorflow/c:tf_status_headers",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/distributed_runtime:test_utils",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||||
|
"//tensorflow/core/platform:env",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/core/platform:strcat",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -87,23 +87,6 @@ RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
|
|||||||
return new RpcRendezvousMgr(env);
|
return new RpcRendezvousMgr(env);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<GrpcWorkerEnv> CreateGrpcWorkerEnv() {
|
|
||||||
int num_cpus = port::NumSchedulableCPUs();
|
|
||||||
int64 num_completion_queues;
|
|
||||||
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
|
|
||||||
&num_completion_queues);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
|
|
||||||
}
|
|
||||||
int64 num_threads;
|
|
||||||
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
|
|
||||||
&num_threads);
|
|
||||||
if (!status.ok()) {
|
|
||||||
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
|
|
||||||
}
|
|
||||||
return absl::make_unique<GrpcWorkerEnv>(num_completion_queues, num_threads);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
|
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
|
||||||
@ -275,7 +258,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) {
|
|||||||
return errors::Unknown("Could not start gRPC server");
|
return errors::Unknown("Could not start gRPC server");
|
||||||
}
|
}
|
||||||
// Create the execution environment for the GRPC workers cache.
|
// Create the execution environment for the GRPC workers cache.
|
||||||
grpc_worker_env_ = CreateGrpcWorkerEnv();
|
grpc_worker_env_.reset(CreateGrpcWorkerEnv());
|
||||||
|
|
||||||
WorkerCacheInterface* worker_cache;
|
WorkerCacheInterface* worker_cache;
|
||||||
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
|
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
|
||||||
@ -401,7 +384,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
|
|||||||
" differs from expected port ", bound_port_);
|
" differs from expected port ", bound_port_);
|
||||||
}
|
}
|
||||||
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
|
*worker_cache = NewGrpcWorkerCacheWithLocalWorker(
|
||||||
channel_cache, worker_impl(), name_prefix, grpc_worker_env_.get());
|
channel_cache, grpc_worker_env(), worker_impl(), name_prefix);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +137,7 @@ class GrpcServer : public ServerInterface {
|
|||||||
|
|
||||||
const ServerDef& server_def() const { return server_def_; }
|
const ServerDef& server_def() const { return server_def_; }
|
||||||
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
|
GrpcWorker* worker_impl() const { return worker_impl_.get(); }
|
||||||
|
GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Env* env_;
|
Env* env_;
|
||||||
|
@ -15,24 +15,20 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||||
|
#include "tensorflow/core/platform/cpu_info.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// TODO(ncteisen): consider adding a config var or flag for this
|
|
||||||
static const size_t kGrpcWorkerCacheThreadCount = 8;
|
|
||||||
static const size_t kNumCallbackThreads = 10;
|
|
||||||
|
|
||||||
class GrpcWorkerCache : public WorkerCachePartial {
|
class GrpcWorkerCache : public WorkerCachePartial {
|
||||||
public:
|
public:
|
||||||
explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
|
explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
|
||||||
@ -43,13 +39,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
|||||||
local_worker_(local_worker),
|
local_worker_(local_worker),
|
||||||
channel_cache_(channel_cache),
|
channel_cache_(channel_cache),
|
||||||
worker_env_(worker_env),
|
worker_env_(worker_env),
|
||||||
next_round_robin_assignment_(0) {
|
next_round_robin_assignment_(0) {}
|
||||||
if (worker_env_ == nullptr) {
|
|
||||||
worker_env_ptr_ = absl::make_unique<GrpcWorkerEnv>(
|
|
||||||
kGrpcWorkerCacheThreadCount, kNumCallbackThreads);
|
|
||||||
worker_env_ = worker_env_ptr_.get();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ListWorkers(std::vector<string>* workers) const override {
|
void ListWorkers(std::vector<string>* workers) const override {
|
||||||
channel_cache_->ListWorkers(workers);
|
channel_cache_->ListWorkers(workers);
|
||||||
@ -118,8 +108,7 @@ class GrpcWorkerCache : public WorkerCachePartial {
|
|||||||
WorkerInterface* const local_worker_; // Not owned.
|
WorkerInterface* const local_worker_; // Not owned.
|
||||||
std::shared_ptr<GrpcChannelCache> channel_cache_;
|
std::shared_ptr<GrpcChannelCache> channel_cache_;
|
||||||
WorkerCacheLogger logger_;
|
WorkerCacheLogger logger_;
|
||||||
GrpcWorkerEnv* worker_env_; // Not owned, if worker_env_ptr_ is nullptr.
|
GrpcWorkerEnv* worker_env_; // Not owned
|
||||||
std::unique_ptr<GrpcWorkerEnv> worker_env_ptr_;
|
|
||||||
|
|
||||||
mutex assignment_mu_;
|
mutex assignment_mu_;
|
||||||
std::unordered_map<std::string, size_t> target_assignments_
|
std::unordered_map<std::string, size_t> target_assignments_
|
||||||
@ -154,13 +143,32 @@ GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
|
|||||||
thread_.reset();
|
thread_.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
|
GrpcWorkerEnv* CreateGrpcWorkerEnv() {
|
||||||
return new GrpcWorkerCache(cc, nullptr, "", nullptr);
|
int num_cpus = port::NumSchedulableCPUs();
|
||||||
|
int64 num_completion_queues;
|
||||||
|
Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
|
||||||
|
&num_completion_queues);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
|
||||||
|
}
|
||||||
|
int64 num_threads;
|
||||||
|
status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
|
||||||
|
&num_threads);
|
||||||
|
if (!status.ok()) {
|
||||||
|
LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
|
||||||
|
}
|
||||||
|
return new GrpcWorkerEnv(num_completion_queues, num_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
|
||||||
|
GrpcWorkerEnv* worker_env) {
|
||||||
|
return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
|
||||||
|
worker_env);
|
||||||
}
|
}
|
||||||
|
|
||||||
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
||||||
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
|
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
|
||||||
const string& local_target, GrpcWorkerEnv* worker_env) {
|
WorkerInterface* local_worker, const string& local_target) {
|
||||||
return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
|
return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,8 +16,6 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||||
@ -62,12 +60,17 @@ class GrpcWorkerEnv {
|
|||||||
std::vector<GrpcWorkerCacheThread> threads_;
|
std::vector<GrpcWorkerCacheThread> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create a GrpcWorkerEnv instance that can be used as argument to create
|
||||||
|
// gRPC worker cache. Caller should take the ownership of the returned instance.
|
||||||
|
GrpcWorkerEnv* CreateGrpcWorkerEnv();
|
||||||
|
|
||||||
// The returned WorkerCacheInterface object takes the ownership of "cc".
|
// The returned WorkerCacheInterface object takes the ownership of "cc".
|
||||||
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc);
|
WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
|
||||||
|
GrpcWorkerEnv* worker_env);
|
||||||
|
|
||||||
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
|
||||||
std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
|
std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
|
||||||
const string& local_target, GrpcWorkerEnv* worker_env);
|
WorkerInterface* local_worker, const string& local_target);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
|
||||||
|
@ -0,0 +1,87 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/test_utils.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TEST(GrpcWorkerCacheTest, NewGrpcWorkerCache) {
|
||||||
|
GrpcChannelSpec spec;
|
||||||
|
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:0", "b:1", "c:2"}));
|
||||||
|
ChannelCreationFunction channel_func =
|
||||||
|
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
|
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||||
|
NewGrpcChannelCache(spec, channel_func));
|
||||||
|
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
|
||||||
|
|
||||||
|
// We created a job with 3 tasks. Getting the task 0, 1, 2 should return valid
|
||||||
|
// worker interfaces, and getting other tasks should return nullptr.
|
||||||
|
std::unique_ptr<WorkerCacheInterface> worker_cache(
|
||||||
|
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get()));
|
||||||
|
WorkerInterface* wi;
|
||||||
|
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:0");
|
||||||
|
EXPECT_NE(wi, nullptr);
|
||||||
|
worker_cache->ReleaseWorker("/job:worker/replica:0/task:0", wi);
|
||||||
|
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:1");
|
||||||
|
EXPECT_NE(wi, nullptr);
|
||||||
|
worker_cache->ReleaseWorker("/job:worker/replica:0/task:1", wi);
|
||||||
|
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:2");
|
||||||
|
EXPECT_NE(wi, nullptr);
|
||||||
|
worker_cache->ReleaseWorker("/job:worker/replica:0/task:2", wi);
|
||||||
|
wi = worker_cache->GetOrCreateWorker("/job:worker/replica:0/task:3");
|
||||||
|
EXPECT_EQ(wi, nullptr);
|
||||||
|
|
||||||
|
// Test creating a worker cache instance with local worker, and getting the
|
||||||
|
// worker instance with the specified local target.
|
||||||
|
std::unique_ptr<TestWorkerInterface> local_wi;
|
||||||
|
worker_cache.reset(NewGrpcWorkerCacheWithLocalWorker(
|
||||||
|
channel_cache, grpc_worker_env.get(), local_wi.get(), "local_target"));
|
||||||
|
wi = worker_cache->GetOrCreateWorker("local_target");
|
||||||
|
EXPECT_EQ(wi, local_wi.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GrpcWorkerCacheTest, DestructWorkerCacheInThreadPool) {
|
||||||
|
GrpcChannelSpec spec;
|
||||||
|
TF_ASSERT_OK(spec.AddHostPortsJob("worker", {"a:1", "b:2", "c:3"}));
|
||||||
|
ChannelCreationFunction channel_func =
|
||||||
|
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
|
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||||
|
NewGrpcChannelCache(spec, channel_func));
|
||||||
|
std::unique_ptr<GrpcWorkerEnv> grpc_worker_env(CreateGrpcWorkerEnv());
|
||||||
|
|
||||||
|
// The GrpcWorkerEnv threadpool is used for worker interfaces for gRPC
|
||||||
|
// completion queue callbacks. Test worker cache destruction inside the
|
||||||
|
// callbacks that runs in the GrpcWorkerEnv threadpool.
|
||||||
|
WorkerCacheInterface* worker_cache =
|
||||||
|
NewGrpcWorkerCache(channel_cache, grpc_worker_env.get());
|
||||||
|
thread::ThreadPool* tp = grpc_worker_env->GetThreadPool();
|
||||||
|
Notification n;
|
||||||
|
tp->Schedule([worker_cache, &n] {
|
||||||
|
delete worker_cache;
|
||||||
|
n.Notify();
|
||||||
|
});
|
||||||
|
n.WaitForNotification();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user