Make GrpcEagerClientCache::GetClient thread safe.
PiperOrigin-RevId: 313211894 Change-Id: I3195db70af77816183cf041d024f694c32613164
This commit is contained in:
parent
51504ec873
commit
956278ab3d
@ -1,4 +1,5 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -57,3 +58,21 @@ cc_library(
|
|||||||
tf_grpc_cc_dependency(),
|
tf_grpc_cc_dependency(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "grpc_eager_client_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"grpc_eager_client_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":grpc_eager_client",
|
||||||
|
"//tensorflow/c:tf_status_headers",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||||
|
"//tensorflow/core/platform:blocking_counter",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
|
"//tensorflow/core/platform:strcat",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -240,6 +240,7 @@ class GrpcEagerClientCache : public EagerClientCache {
|
|||||||
|
|
||||||
Status GetClient(const string& target,
|
Status GetClient(const string& target,
|
||||||
core::RefCountPtr<EagerClient>* client) override {
|
core::RefCountPtr<EagerClient>* client) override {
|
||||||
|
mutex_lock l(clients_mu_);
|
||||||
auto it = clients_.find(target);
|
auto it = clients_.find(target);
|
||||||
if (it == clients_.end()) {
|
if (it == clients_.end()) {
|
||||||
tensorflow::SharedGrpcChannelPtr shared =
|
tensorflow::SharedGrpcChannelPtr shared =
|
||||||
@ -281,7 +282,9 @@ class GrpcEagerClientCache : public EagerClientCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
|
std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
|
||||||
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_;
|
mutable mutex clients_mu_;
|
||||||
|
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
|
||||||
|
TF_GUARDED_BY(clients_mu_);
|
||||||
std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
|
std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -0,0 +1,58 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||||
|
|
||||||
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/blocking_counter.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
#include "tensorflow/core/platform/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace eager {
|
||||||
|
|
||||||
|
TEST(GrpcEagerClientCache, TestGetClientThreadSafety) {
|
||||||
|
GrpcChannelSpec spec;
|
||||||
|
TF_ASSERT_OK(spec.AddHostPortsJob(
|
||||||
|
"worker", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}));
|
||||||
|
ChannelCreationFunction channel_func =
|
||||||
|
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||||
|
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||||
|
NewGrpcChannelCache(spec, channel_func));
|
||||||
|
std::unique_ptr<EagerClientCache> client_cache(
|
||||||
|
NewGrpcEagerClientCache(channel_cache));
|
||||||
|
const int num_calls = 10;
|
||||||
|
BlockingCounter counter(num_calls);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_calls; i++) {
|
||||||
|
Env::Default()->SchedClosure([&client_cache, i, &counter]() {
|
||||||
|
string target = strings::StrCat("/job:worker/replica:0/task:", i);
|
||||||
|
core::RefCountPtr<EagerClient> eager_client;
|
||||||
|
Status s = client_cache->GetClient(target, &eager_client);
|
||||||
|
// With 6 tasks added to the job, querying client for 0--5 should be OK,
|
||||||
|
// and querying client for 6+ should give invalid argument error.
|
||||||
|
error::Code expected_code = i <= 5 ? error::OK : error::INVALID_ARGUMENT;
|
||||||
|
EXPECT_EQ(expected_code, s.code());
|
||||||
|
counter.DecrementCount();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
counter.Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace eager
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user