Make GrpcEagerClientCache::GetClient thread safe.
PiperOrigin-RevId: 313211894 Change-Id: I3195db70af77816183cf041d024f694c32613164
This commit is contained in:
parent
51504ec873
commit
956278ab3d
@ -1,4 +1,5 @@
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -57,3 +58,21 @@ cc_library(
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "grpc_eager_client_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"grpc_eager_client_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":grpc_eager_client",
|
||||
"//tensorflow/c:tf_status_headers",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/platform:blocking_counter",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:strcat",
|
||||
],
|
||||
)
|
||||
|
@ -240,6 +240,7 @@ class GrpcEagerClientCache : public EagerClientCache {
|
||||
|
||||
Status GetClient(const string& target,
|
||||
core::RefCountPtr<EagerClient>* client) override {
|
||||
mutex_lock l(clients_mu_);
|
||||
auto it = clients_.find(target);
|
||||
if (it == clients_.end()) {
|
||||
tensorflow::SharedGrpcChannelPtr shared =
|
||||
@ -281,7 +282,9 @@ class GrpcEagerClientCache : public EagerClientCache {
|
||||
}
|
||||
|
||||
std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
|
||||
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_;
|
||||
mutable mutex clients_mu_;
|
||||
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
|
||||
TF_GUARDED_BY(clients_mu_);
|
||||
std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,58 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
|
||||
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/blocking_counter.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace eager {
|
||||
|
||||
TEST(GrpcEagerClientCache, TestGetClientThreadSafety) {
|
||||
GrpcChannelSpec spec;
|
||||
TF_ASSERT_OK(spec.AddHostPortsJob(
|
||||
"worker", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}));
|
||||
ChannelCreationFunction channel_func =
|
||||
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
|
||||
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
|
||||
NewGrpcChannelCache(spec, channel_func));
|
||||
std::unique_ptr<EagerClientCache> client_cache(
|
||||
NewGrpcEagerClientCache(channel_cache));
|
||||
const int num_calls = 10;
|
||||
BlockingCounter counter(num_calls);
|
||||
|
||||
for (int i = 0; i < num_calls; i++) {
|
||||
Env::Default()->SchedClosure([&client_cache, i, &counter]() {
|
||||
string target = strings::StrCat("/job:worker/replica:0/task:", i);
|
||||
core::RefCountPtr<EagerClient> eager_client;
|
||||
Status s = client_cache->GetClient(target, &eager_client);
|
||||
// With 6 tasks added to the job, querying client for 0--5 should be OK,
|
||||
// and querying client for 6+ should give invalid argument error.
|
||||
error::Code expected_code = i <= 5 ? error::OK : error::INVALID_ARGUMENT;
|
||||
EXPECT_EQ(expected_code, s.code());
|
||||
counter.DecrementCount();
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
}
|
||||
|
||||
} // namespace eager
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user