Reject retried RecvTensor requests.
Retried RecvTensorRequests are problematic because a RecvTensor with no corresponding sender will wait forever, and the tensor may have been delivered to a previous retry. This change adds a unique request_id to each RecvTensor request, and we check these request_ids against a set of recent request_ids. If a request_id is in the recent set, we reject the RecvTensor request. PiperOrigin-RevId: 182863245
This commit is contained in:
parent
2968447d32
commit
6042b5d267
@ -82,6 +82,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:recent_request_ids",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
@ -103,6 +104,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:request_id",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
@ -47,6 +48,7 @@ class GdrRecvTensorCall : public BaseRecvTensorCall {
|
||||
recv_args_(recv_args) {
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key.data(), key.size());
|
||||
req_.set_request_id(GetUniqueRequestId());
|
||||
}
|
||||
|
||||
~GdrRecvTensorCall() override {}
|
||||
|
@ -41,17 +41,26 @@ namespace tensorflow {
|
||||
|
||||
GdrWorker::GdrWorker(WorkerEnv* worker_env,
|
||||
RemoteMemoryManager* remote_memory_manager)
|
||||
: GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {}
|
||||
: GrpcWorker(worker_env),
|
||||
remote_memory_manager_(remote_memory_manager),
|
||||
recv_tensor_recent_request_ids_(100000) {}
|
||||
|
||||
void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) {
|
||||
Status s = recv_tensor_recent_request_ids_.TrackUnique(
|
||||
request->request_id(), "RecvTensor (GdrWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
const int64 step_id = request->step_id();
|
||||
const string& key = request->rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
s = Rendezvous::ParseKey(key, &parsed);
|
||||
Device* src_dev = nullptr;
|
||||
if (s.ok()) {
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -38,6 +39,7 @@ class GdrWorker : public GrpcWorker {
|
||||
|
||||
private:
|
||||
RemoteMemoryManager* remote_memory_manager_; // Not owned
|
||||
RecentRequestIds recv_tensor_recent_request_ids_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -71,6 +71,8 @@ cc_library(
|
||||
"//tensorflow/core:protos_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:recent_request_ids",
|
||||
"//tensorflow/core/distributed_runtime:request_id",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
|
@ -33,8 +33,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
MPIRendezvousMgr::MPIRendezvousMgr(const WorkerEnv* env)
|
||||
: BaseRendezvousMgr(env), worker_env_2(env), use_optimal_transfer_(false) {
|
||||
|
||||
: BaseRendezvousMgr(env),
|
||||
worker_env_2(env),
|
||||
use_optimal_transfer_(false),
|
||||
recv_tensor_recent_request_ids_(100000) {
|
||||
const char* mpienv = getenv("MPI_OPTIMAL_PATH");
|
||||
if (mpienv && mpienv[0] == '1') {
|
||||
LOG(INFO) << "MPI Optimal copy path enabled (Requires CUDA-Aware MPI when "
|
||||
@ -149,6 +151,8 @@ MPIRemoteRendezvous::~MPIRemoteRendezvous() {}
|
||||
*/
|
||||
void MPIRendezvousMgr::AddRequest(RecvTensorRequest request,
|
||||
const int mpi_dst) {
|
||||
TF_CHECK_OK(recv_tensor_recent_request_ids_.TrackUnique(
|
||||
req.request_id(), "RecvTensor (MPIRendezvousMgr)", req));
|
||||
const int64 step_id = request.step_id();
|
||||
const std::string& key = request.rendezvous_key();
|
||||
Rendezvous::ParsedKey parsed;
|
||||
|
@ -30,10 +30,11 @@ limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
|
||||
#include "tensorflow/contrib/mpi/mpi_utils.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/contrib/mpi/mpi_msg.pb.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
#define TAG_REQTENSOR 1010
|
||||
@ -104,6 +105,7 @@ class MPIRequestTensorCall {
|
||||
void Init(const Rendezvous::ParsedKey& parsed, const int64 step_id) {
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(parsed.FullKey().data(), parsed.FullKey().size());
|
||||
req_.set_request_id(GetUniqueRequestId());
|
||||
request_buffer_size_ = req_.ByteSize();
|
||||
// request_buffer_ = new char[request_buffer_size_];
|
||||
// req_.SerializeToArray(request_buffer_, request_buffer_size_);
|
||||
@ -177,6 +179,8 @@ class MPIRendezvousMgr : public BaseRendezvousMgr {
|
||||
std::map<std::string, std::shared_ptr<MPIRequestTensorCall>> recv_tensor_map_
|
||||
GUARDED_BY(mrq_);
|
||||
|
||||
RecentRequestIds recv_tensor_recent_request_ids_;
|
||||
|
||||
void AddRequest(RecvTensorRequest, const int);
|
||||
void MPIBackgroundThread();
|
||||
|
||||
|
@ -556,3 +556,47 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:array",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "request_id",
|
||||
srcs = ["request_id.cc"],
|
||||
hdrs = ["request_id.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "request_id_test",
|
||||
size = "small",
|
||||
srcs = ["request_id_test.cc"],
|
||||
deps = [
|
||||
":request_id",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "recent_request_ids",
|
||||
srcs = ["recent_request_ids.cc"],
|
||||
hdrs = ["recent_request_ids.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "recent_request_ids_test",
|
||||
size = "small",
|
||||
srcs = ["recent_request_ids_test.cc"],
|
||||
deps = [
|
||||
":recent_request_ids",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
)
|
||||
|
57
tensorflow/core/distributed_runtime/recent_request_ids.cc
Normal file
57
tensorflow/core/distributed_runtime/recent_request_ids.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
|
||||
: circular_buffer_(num_tracked_request_ids) {
|
||||
set_.reserve(num_tracked_request_ids);
|
||||
}
|
||||
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const protobuf::Message& request) {
|
||||
mutex_lock l(mu_);
|
||||
if (request_id == 0) {
|
||||
// For backwards compatibility, allow all requests with request_id 0.
|
||||
return Status::OK();
|
||||
}
|
||||
if (set_.count(request_id) > 0) {
|
||||
// Note: RecentRequestIds is not strict LRU because we don't update
|
||||
// request_id's age in the circular_buffer_ if it's tracked again. Strict
|
||||
// LRU is not useful here because returning this error will close the
|
||||
// current Session.
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
request.ShortDebugString());
|
||||
}
|
||||
|
||||
// Remove the oldest request_id from the set_. circular_buffer_ is
|
||||
// zero-initialized, and zero is never tracked, so it's safe to do this even
|
||||
// when the buffer is not yet full.
|
||||
set_.erase(circular_buffer_[next_index_]);
|
||||
circular_buffer_[next_index_] = request_id;
|
||||
set_.insert(request_id);
|
||||
next_index_ = (next_index_ + 1) % circular_buffer_.size();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
72
tensorflow/core/distributed_runtime/recent_request_ids.h
Normal file
72
tensorflow/core/distributed_runtime/recent_request_ids.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RecentRequestIds tracks recent 64-bit request_ids. When maximum capacity is
|
||||
// reached, the oldest request_id is evicted. Thread safe.
|
||||
//
|
||||
// Some RPCs like RecvTensor are unsafe to retry. For example, RecvTensor pairs
|
||||
// one sender and one receiver, and the receiver waits for the sender's tensor.
|
||||
// Retried RecvTensor requests are problematic, because the original RecvTensor
|
||||
// request may have consumed the sender's tensor, so a retried request might
|
||||
// block forever. RecentRequestIds identifies retried requests, so we can fail
|
||||
// them instead of blocking forever.
|
||||
//
|
||||
// Internally, recent request_ids are stored in two data structures: a set and a
|
||||
// circular buffer. The set is used for efficient lookups, and the circular
|
||||
// buffer tracks the oldest request_id. When the buffer is full, the new
|
||||
// request_id replaces the oldest request_id in the circular buffer, and the
|
||||
// oldest request_id is removed from the set.
|
||||
class RecentRequestIds {
|
||||
public:
|
||||
// num_tracked_request_ids should be much larger than the number of RPCs that
|
||||
// can be received in a small time window. For example, we observed a peak RPC
|
||||
// rate of ~700 RecvTensor RPC/s when training inception v3 on TPUs, so we
|
||||
// currently set num_tracked_request_ids to 100,000 for RecvTensor.
|
||||
RecentRequestIds(int num_tracked_request_ids);
|
||||
|
||||
// Returns OK iff request_id has not been seen in the last
|
||||
// num_tracked_request_ids insertions. For backwards compatibility, this
|
||||
// always returns OK for request_id 0. The method_name and the request's
|
||||
// ShortDebugString are added to returned errors.
|
||||
Status TrackUnique(int64 request_id, const string& method_name,
|
||||
const protobuf::Message& request);
|
||||
|
||||
private:
|
||||
mutex mu_;
|
||||
// next_index_ indexes into circular_buffer_, and points to the next storage
|
||||
// space to use. When the buffer is full, next_index_ points at the oldest
|
||||
// request_id.
|
||||
int next_index_ GUARDED_BY(mu_) = 0;
|
||||
std::vector<int64> circular_buffer_ GUARDED_BY(mu_);
|
||||
gtl::FlatSet<int64> set_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
@ -0,0 +1,96 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status TrackUnique(int64 request_id, RecentRequestIds* recent_request_ids) {
|
||||
RecvTensorRequest request;
|
||||
request.set_request_id(request_id);
|
||||
return recent_request_ids->TrackUnique(request_id, "recent_request_ids_test",
|
||||
request);
|
||||
}
|
||||
|
||||
// request_id 0 is always valid.
|
||||
TEST(RecentRequestIds, Zero) {
|
||||
RecentRequestIds recent_request_ids(1);
|
||||
EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
|
||||
EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
|
||||
EXPECT_TRUE(TrackUnique(0, &recent_request_ids).ok());
|
||||
}
|
||||
|
||||
TEST(RecentRequestIds, Unordered) {
|
||||
// Capacity for 6 numbers.
|
||||
RecentRequestIds recent_request_ids(6);
|
||||
|
||||
// Some unordered numbers to insert into request_id_set.
|
||||
std::vector<int64> numbers = {53754, 23351, 164101, 7476,
|
||||
162432, 130761, 164102};
|
||||
|
||||
// Insert numbers[0..6) and check that all previously inserted numbers remain
|
||||
// in the set.
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
TF_EXPECT_OK(TrackUnique(numbers[i], &recent_request_ids));
|
||||
|
||||
for (int j = 0; j <= i; ++j) {
|
||||
EXPECT_FALSE(TrackUnique(numbers[j], &recent_request_ids).ok())
|
||||
<< "i=" << i << " j=" << j;
|
||||
}
|
||||
}
|
||||
|
||||
// Insert numbers[6]. Inserting this 7th number should evict the first number
|
||||
// from the set. The set should only contain numbers[1..7).
|
||||
TF_EXPECT_OK(TrackUnique(numbers[6], &recent_request_ids));
|
||||
for (int i = 1; i < 7; ++i) {
|
||||
EXPECT_FALSE(TrackUnique(numbers[i], &recent_request_ids).ok())
|
||||
<< "i=" << i;
|
||||
}
|
||||
|
||||
// Insert numbers[0] again. This should succeed because we just evicted it
|
||||
// from the set.
|
||||
TF_EXPECT_OK(TrackUnique(numbers[0], &recent_request_ids));
|
||||
}
|
||||
|
||||
// Check that the oldest request_id is evicted.
|
||||
void TestOrdered(int num_request_ids) {
|
||||
RecentRequestIds recent_request_ids(num_request_ids);
|
||||
|
||||
// Insert [1..101). The current number and the (num_request_ids - 1) preceding
|
||||
// numbers should still be in the set.
|
||||
for (int i = 1; i < 101; ++i) {
|
||||
TF_EXPECT_OK(TrackUnique(i, &recent_request_ids));
|
||||
|
||||
for (int j = std::max(1, i - num_request_ids + 1); j <= i; ++j) {
|
||||
EXPECT_FALSE(TrackUnique(j, &recent_request_ids).ok())
|
||||
<< "i=" << i << " j=" << j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test eviction with various numbers of buckets.
|
||||
TEST(RecentRequestIds, Ordered2) { TestOrdered(2); }
|
||||
TEST(RecentRequestIds, Ordered3) { TestOrdered(3); }
|
||||
TEST(RecentRequestIds, Ordered4) { TestOrdered(4); }
|
||||
TEST(RecentRequestIds, Ordered5) { TestOrdered(5); }
|
||||
|
||||
} // namespace tensorflow
|
30
tensorflow/core/distributed_runtime/request_id.cc
Normal file
30
tensorflow/core/distributed_runtime/request_id.cc
Normal file
@ -0,0 +1,30 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
int64 GetUniqueRequestId() {
|
||||
int64 request_id = 0;
|
||||
while (request_id == 0) {
|
||||
request_id = random::New64();
|
||||
}
|
||||
return request_id;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
31
tensorflow/core/distributed_runtime/request_id.h
Normal file
31
tensorflow/core/distributed_runtime/request_id.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
|
||||
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Returns a request_id for use with RecentRequestIds. This number will not be
|
||||
// zero, and must be unique over RecentRequestIds' window of
|
||||
// num_tracked_request_ids. See recent_request_ids.h for more details.
|
||||
int64 GetUniqueRequestId();
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REQUEST_ID_H_
|
29
tensorflow/core/distributed_runtime/request_id_test.cc
Normal file
29
tensorflow/core/distributed_runtime/request_id_test.cc
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Try requesting some request_ids and verify that none are zero.
|
||||
TEST(GetUniqueRequestId, Basic) {
|
||||
for (int i = 0; i < 1000000; ++i) {
|
||||
EXPECT_NE(GetUniqueRequestId(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -186,6 +186,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:recent_request_ids",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
@ -270,6 +271,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:request_id",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
|
@ -354,7 +354,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
} // namespace
|
||||
|
||||
GrpcWorker::GrpcWorker(WorkerEnv* worker_env) : Worker(worker_env) {}
|
||||
GrpcWorker::GrpcWorker(WorkerEnv* worker_env)
|
||||
: Worker(worker_env), recv_tensor_recent_request_ids_(100000) {}
|
||||
|
||||
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
|
||||
// buffers for a response object, to avoid extra protocol buffer serialization
|
||||
@ -363,11 +364,18 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) {
|
||||
Status s = recv_tensor_recent_request_ids_.TrackUnique(
|
||||
request->request_id(), "RecvTensor (GrpcWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
const int64 step_id = request->step_id();
|
||||
const string& key = request->rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
||||
s = Rendezvous::ParseKey(key, &parsed);
|
||||
Device* src_dev = nullptr;
|
||||
if (s.ok()) {
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||
|
||||
namespace grpc {
|
||||
@ -40,6 +41,9 @@ class GrpcWorker : public Worker {
|
||||
StatusCallback done);
|
||||
|
||||
WorkerEnv* env();
|
||||
|
||||
private:
|
||||
RecentRequestIds recv_tensor_recent_request_ids_;
|
||||
};
|
||||
|
||||
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
@ -67,6 +68,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
done_ = std::move(done);
|
||||
req_.set_step_id(step_id);
|
||||
req_.set_rendezvous_key(key.data(), key.size());
|
||||
req_.set_request_id(GetUniqueRequestId());
|
||||
}
|
||||
|
||||
void Reset(WorkerCacheInterface* wc) {
|
||||
|
@ -292,7 +292,10 @@ message RecvTensorRequest {
|
||||
// into a RunGraph call on the same WorkerService.
|
||||
int64 step_id = 1;
|
||||
|
||||
// A key that identifies the tensor to be received.
|
||||
// A key identifying the channel to receive tensors from. A RecvTensor request
|
||||
// retrieves one tensor from the channel, but multiple tensors can be sent and
|
||||
// received over the same channel with multiple RecvTensor requests. See
|
||||
// rendezvous.h for details.
|
||||
string rendezvous_key = 2;
|
||||
|
||||
// If true, use an out-of-band DMA mechanism to transfer the
|
||||
@ -307,6 +310,16 @@ message RecvTensorRequest {
|
||||
|
||||
// Optional information needed by the RPC subsystem.
|
||||
google.protobuf.Any transport_options = 6;
|
||||
|
||||
// Unique identifier for this request. Every RecvTensorRequest must have a
|
||||
// unique request_id, and retried RecvTensorRequests must have the same
|
||||
// request_id. If request_id is zero, retry detection is disabled.
|
||||
//
|
||||
// Retried RecvTensorRequests are problematic because a RecvTensor with no
|
||||
// corresponding sender will wait forever, and the tensor may have been
|
||||
// delivered to a previous retry. Workers use request_ids to reject retried
|
||||
// RecvTensor requests instead of waiting forever.
|
||||
int64 request_id = 7;
|
||||
}
|
||||
|
||||
message RecvTensorResponse {
|
||||
|
Loading…
Reference in New Issue
Block a user