Make GrpcEagerClientCache::GetClient thread safe.

PiperOrigin-RevId: 313211894
Change-Id: I3195db70af77816183cf041d024f694c32613164
This commit is contained in:
Haoyu Zhang 2020-05-26 09:59:28 -07:00 committed by TensorFlower Gardener
parent 51504ec873
commit 956278ab3d
3 changed files with 81 additions and 1 deletions

View File

@ -1,4 +1,5 @@
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency") load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package( package(
default_visibility = [ default_visibility = [
@ -57,3 +58,21 @@ cc_library(
tf_grpc_cc_dependency(), tf_grpc_cc_dependency(),
], ],
) )
tf_cc_test(
name = "grpc_eager_client_test",
size = "small",
srcs = [
"grpc_eager_client_test.cc",
],
deps = [
":grpc_eager_client",
"//tensorflow/c:tf_status_headers",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:strcat",
],
)

View File

@ -240,6 +240,7 @@ class GrpcEagerClientCache : public EagerClientCache {
Status GetClient(const string& target, Status GetClient(const string& target,
core::RefCountPtr<EagerClient>* client) override { core::RefCountPtr<EagerClient>* client) override {
mutex_lock l(clients_mu_);
auto it = clients_.find(target); auto it = clients_.find(target);
if (it == clients_.end()) { if (it == clients_.end()) {
tensorflow::SharedGrpcChannelPtr shared = tensorflow::SharedGrpcChannelPtr shared =
@ -281,7 +282,9 @@ class GrpcEagerClientCache : public EagerClientCache {
} }
std::shared_ptr<tensorflow::GrpcChannelCache> cache_; std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_; mutable mutex clients_mu_;
std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
TF_GUARDED_BY(clients_mu_);
std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_; std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
}; };

View File

@ -0,0 +1,58 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/blocking_counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace eager {
TEST(GrpcEagerClientCache, TestGetClientThreadSafety) {
GrpcChannelSpec spec;
TF_ASSERT_OK(spec.AddHostPortsJob(
"worker", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}));
ChannelCreationFunction channel_func =
ConvertToChannelCreationFunction(NewHostPortGrpcChannel);
auto channel_cache = std::shared_ptr<GrpcChannelCache>(
NewGrpcChannelCache(spec, channel_func));
std::unique_ptr<EagerClientCache> client_cache(
NewGrpcEagerClientCache(channel_cache));
const int num_calls = 10;
BlockingCounter counter(num_calls);
for (int i = 0; i < num_calls; i++) {
Env::Default()->SchedClosure([&client_cache, i, &counter]() {
string target = strings::StrCat("/job:worker/replica:0/task:", i);
core::RefCountPtr<EagerClient> eager_client;
Status s = client_cache->GetClient(target, &eager_client);
// With 6 tasks added to the job, querying client for 0--5 should be OK,
// and querying client for 6+ should give invalid argument error.
error::Code expected_code = i <= 5 ? error::OK : error::INVALID_ARGUMENT;
EXPECT_EQ(expected_code, s.code());
counter.DecrementCount();
});
}
counter.Wait();
}
} // namespace eager
} // namespace tensorflow