TF/gRPC: Use explicit request from receiver to clean RPC cache

This simplifies the response cache management on the sender side and avoids the
need for large response cache size.

PiperOrigin-RevId: 239638065
This commit is contained in:
Jing Dong 2019-03-21 11:36:47 -07:00 committed by TensorFlower Gardener
parent 90085135c2
commit a6d233a581
14 changed files with 368 additions and 306 deletions

View File

@ -128,7 +128,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
StatusCallback copy_ready = [response, done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
grpc::EncodeTensorToByteBuffer(is_dead, *copy, false, response);
done(s);
delete copy;
};
@ -136,7 +136,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
send_dev_context->CopyDeviceTensorToCPU(
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
} else {
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
grpc::EncodeTensorToByteBuffer(is_dead, val, false, response);
done(Status::OK());
}
}

View File

@ -165,8 +165,9 @@ cc_library(
srcs = ["grpc_response_cache.cc"],
hdrs = ["grpc_response_cache.h"],
deps = [
":grpc_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -66,6 +66,7 @@ class GrpcRemoteWorker : public WorkerInterface {
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
logger_(logger) {}
~GrpcRemoteWorker() override {}
@ -130,12 +131,10 @@ class GrpcRemoteWorker : public WorkerInterface {
int64 start_usec = Env::Default()->NowMicros();
// Type-specialized logging for this method.
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
StatusCallback wrapper_done;
const StatusCallback* cb_to_use;
if (!logging_active) {
cb_to_use = &done; // No additional work to do, so just use done directly
} else {
wrapper_done = [this, request, response, done, start_usec](Status s) {
auto callback = [this, request, response, done, start_usec,
logging_active](Status s) {
if (logging_active) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
@ -159,12 +158,17 @@ class GrpcRemoteWorker : public WorkerInterface {
}
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->DebugString();
done(s);
};
cb_to_use = &wrapper_done;
}
}
IssueRequest(request, response, recvbuf_, *cb_to_use, call_opts);
// Note done() can delete this worker object, so we need to call done()
// last.
if (response->require_ack()) {
IssueMarkRecvFinishedRequest(request->request_id());
}
done(s);
};
IssueRequest(request, response, recvbuf_, callback, call_opts);
}
void CompleteGroupAsync(CallOptions* call_opts,
@ -194,12 +198,10 @@ class GrpcRemoteWorker : public WorkerInterface {
int64 start_usec = Env::Default()->NowMicros();
// Type-specialized logging for this method.
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
StatusCallback wrapper_done;
const StatusCallback* cb_to_use;
if (!logging_active) {
cb_to_use = &done; // No additional work to do, so just use done directly
} else {
wrapper_done = [this, request, response, done, start_usec](Status s) {
auto callback = [this, request, response, done, start_usec,
logging_active](Status s) {
if (logging_active) {
if (logger_->LoggingActive()) {
int64 end_usec = Env::Default()->NowMicros();
int64 step_id = request->step_id();
@ -238,12 +240,17 @@ class GrpcRemoteWorker : public WorkerInterface {
}
VLOG(2) << "done callback, req: " << request->DebugString()
<< " response " << response->metadata().DebugString();
done(s);
};
cb_to_use = &wrapper_done;
}
}
IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
// Note done() can delete this worker object, so we need to call done()
// last.
if (response->metadata().require_ack()) {
IssueMarkRecvFinishedRequest(request->request_id());
}
done(s);
};
IssueRequest(request, response, recvtensor_, callback, call_opts);
}
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
@ -276,6 +283,16 @@ class GrpcRemoteWorker : public WorkerInterface {
callback_threadpool_, max_retries);
}
void IssueMarkRecvFinishedRequest(int64 request_id) {
VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
MarkRecvFinishedRequest request;
request.set_request_id(request_id);
MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
auto done = [response](Status status) { delete response; };
IssueRequest(&request, response, markrecvfinished_, done);
}
// Helper function for initializing the RpcMethod objects below.
const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
@ -299,6 +316,7 @@ class GrpcRemoteWorker : public WorkerInterface {
const ::grpc::string completegroup_;
const ::grpc::string instancesource_;
const ::grpc::string getstepsequence_;
const ::grpc::string markrecvfinished_;
// Support for logging.
WorkerCacheLogger* logger_;

View File

@ -14,170 +14,102 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "absl/types/optional.h"
#include "tensorflow/core/platform/env.h"
namespace tensorflow {
struct WorkerCacheEntry {
enum class State {
PENDING = 0,
ACTIVE = 1,
FINISHED = 2,
};
bool GrpcResponseCache::QueueRequest(int64 request_id, int64 step_id,
const FinishResponseCB& cb) {
VLOG(1) << "GrpcResponseCache Lookup " << request_id;
State state = State::PENDING;
int64 expires_seconds;
mu_.lock();
::grpc::ByteBuffer response_buf;
Status response_status;
ResponseCacheEntry& entry = response_cache_[request_id];
// Additional retries may arrive while a request is still executing. The
// callbacks for these calls are queued in `callbacks` and evaluated after
// the original request is completed.
std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
};
if (entry.state == ResponseCacheEntry::State::FINISHED) {
VLOG(1) << "Reuse cached response for " << request_id;
// Make a copy of the ResponseCacheEntry so that we can run FinishResponse
// outside the critical section. FinishResponse can be potentially
// expensive.
auto entry_copy = entry;
void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
if (buf_ != nullptr) {
*tgt = *buf_;
mu_.unlock();
entry_copy.FinishResponse(cb);
return true;
}
entry.callbacks.emplace_back(cb);
if (entry.state == ResponseCacheEntry::State::ACTIVE) {
VLOG(1) << "Found active request for " << request_id
<< ". Adding entry to response queue.";
mu_.unlock();
return true;
} else {
CHECK(msg_ != nullptr);
::grpc::Slice slice(msg_->ByteSizeLong());
msg_->SerializeWithCachedSizesToArray(
const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
::grpc::ByteBuffer tmp(&slice, 1);
tgt->Swap(&tmp);
VLOG(2) << "No cache entry for " << request_id
<< ", running user computation.";
entry.step_id = step_id;
entry.state = ResponseCacheEntry::State::ACTIVE;
mu_.unlock();
return false;
}
}
void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
if (buf_ != nullptr) {
*buf_ = src;
return;
}
void GrpcResponseCache::OnRequestFinished(int64 request_id,
const Tensor& tensor, bool is_dead,
const Status& status) {
absl::optional<ResponseCacheEntry> entry_copy;
CHECK(msg_ != nullptr);
// We create a single slice when encoding protocol messages.
std::vector<::grpc::Slice> slices;
if (src.Dump(&slices).ok()) {
msg_->ParseFromArray(slices[0].begin(), slices[0].size());
} else {
LOG(ERROR) << "Failed to decode cached buffer.";
}
}
void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
ComputeFunc compute_func,
StatusCallback done_cb) {
VLOG(1) << "Lookup " << key;
std::shared_ptr<WorkerCacheEntry> req;
MaybeCleanup();
{
mutex_lock m(mu_);
if (requests_.find(key) != requests_.end()) {
req = requests_[key];
} else {
req.reset(new WorkerCacheEntry);
requests_[key] = req;
}
if (req->state == WorkerCacheEntry::State::FINISHED) {
if (req->expires_seconds > Env::Default()->NowSeconds()) {
VLOG(1) << "Reuse cached response for " << key;
response.CopyFrom(req->response_buf);
done_cb(req->response_status);
return;
}
VLOG(1) << "Found expired cache entry for " << key;
req->state = WorkerCacheEntry::State::PENDING;
req->response_buf.Clear();
}
req->callbacks.push_back(std::make_pair(response, done_cb));
if (req->state == WorkerCacheEntry::State::ACTIVE) {
VLOG(1) << "Found active request for " << key
<< ". Adding entry to response queue.";
auto it = response_cache_.find(request_id);
if (it == response_cache_.end()) {
LOG(ERROR) << "Unexpected missing response cache entry for request "
<< request_id;
return;
}
ResponseCacheEntry& entry = it->second;
VLOG(2) << "No cache entry for " << key << ", running user computation.";
req->state = WorkerCacheEntry::State::ACTIVE;
req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
VLOG(1) << "Operation for " << request_id << " finished. "
<< "Status: " << status << ", tensor size " << tensor.TotalBytes()
<< " bytes, " << entry.callbacks.size() << " pending callbacks.";
entry.tensor = tensor;
entry.is_dead = is_dead;
entry.response_status = status;
entry.state = ResponseCacheEntry::State::FINISHED;
// We copy the extra work out of the critical section in order to avoid
// serializing the work for sending response.
entry_copy = entry;
entry.callbacks.clear();
}
compute_func([this, key, req, response](Status status) {
mutex_lock m(mu_);
response.Encode(&req->response_buf);
current_bytes_ += req->response_buf.Length();
req->response_status = status;
req->state = WorkerCacheEntry::State::FINISHED;
VLOG(1) << "Operation for " << key << " finished. "
<< "Status: " << status << ", " << req->response_buf.Length()
<< " response bytes, " << req->callbacks.size()
<< " pending callbacks.";
for (auto& cb : req->callbacks) {
cb.first.CopyFrom(req->response_buf);
cb.second(req->response_status);
}
req->callbacks.clear();
});
for (auto& cb : entry_copy->callbacks) {
entry_copy->FinishResponse(cb);
}
}
// Remove all stale or expired cache entries if the cache is full.
void GrpcResponseCache::MaybeCleanup() {
void GrpcResponseCache::EraseRequestId(int64 request_id) {
mutex_lock m(mu_);
if (current_bytes_ < max_bytes_) {
return;
}
response_cache_.erase(request_id);
}
VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
ordered_entries;
ordered_entries.reserve(requests_.size());
for (const auto& p : requests_) {
ordered_entries.push_back(std::make_pair(p.first, p.second));
}
std::sort(ordered_entries.begin(), ordered_entries.end(),
[](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
return a.second->expires_seconds > b.second->expires_seconds;
});
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
int64 now = Env::Default()->NowSeconds();
int64 bytes_used = 0;
// Always keep active requests.
for (auto& pair : ordered_entries) {
if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
kept.insert(pair);
void GrpcResponseCache::CleanEntriesForStep(int64 step_id) {
mutex_lock m(mu_);
// Remove all cache entries whose step id is the given step_id
for (auto it = response_cache_.begin(), last = response_cache_.end();
it != last;) {
if (it->second.step_id == step_id) {
VLOG(1) << "Erase stale GrpcResponseCache entry " << it->first;
it = response_cache_.erase(it);
} else {
++it;
}
}
// Keep unexpired, finished requests up to half of max_bytes_. This reduces
// chances of overfilling the cache when active requests complete and
// amortizes cache cleanup cost.
for (auto& pair : ordered_entries) {
if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
break;
}
if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
kept.insert(pair);
bytes_used += pair.second->response_buf.Length();
}
}
VLOG(1) << "Cleaned cache. Bytes used: " << current_bytes_ << " -> "
<< bytes_used << ". Cache size: " << requests_.size() << " -> "
<< kept.size();
current_bytes_ = bytes_used;
std::swap(requests_, kept);
}
} // namespace tensorflow

View File

@ -19,71 +19,74 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
// gRPC response caching. Most WorkerService methods cannot be retried directly
// as they will fail or deadlock. To enable retrying, we can instead cache
// responses for a short period of time and reply to duplicate requests from the
// cache.
// responses and reply to duplicate requests from the cache. The cache will be
// cleaned when the MarkRecvFinishedRequest is received from the receiver or the
// session step is completed.
namespace tensorflow {
// Union type to aid caching of either raw buffers (for RecvTensor RPCs) and
// protocol buffer messages (for all other RPCs).
class RPCResponse {
public:
explicit RPCResponse() : buf_(nullptr), msg_(nullptr) {}
explicit RPCResponse(::grpc::ByteBuffer* b) : buf_(b), msg_(nullptr) {}
explicit RPCResponse(protobuf::Message* m) : buf_(nullptr), msg_(m) {}
// Encode this response into the target buffer.
void Encode(::grpc::ByteBuffer* tgt) const;
// Copy from `src`: if this is a buffer, make a shallow copy.
// For protocol messages, parse the response from `src`.
void CopyFrom(const ::grpc::ByteBuffer& src);
private:
::grpc::ByteBuffer* buf_;
protobuf::Message* msg_;
};
typedef std::function<void(StatusCallback)> ComputeFunc;
struct WorkerCacheEntry;
// Track and cache the state of worker service RPCs. An RPC can be in 3 states:
//
// * PENDING: this is the first call of the RPC, and it will transition to
// * ACTIVE: another thread is active processing this RPC
// * FINISHED: the worker has finished processing the method
//
// The response from completed RPCs are LRU cached until either `max_bytes`
// bytes are in use by the cache or they expire (according to `expire_time`).
class GrpcResponseCache {
public:
GrpcResponseCache(int64 max_bytes, int64 expire_time_seconds)
: max_bytes_(max_bytes), expire_time_seconds_(expire_time_seconds) {}
using FinishResponseCB = std::function<void(
const Tensor& tensor, bool is_dead, const Status& status)>;
// Lookup the result for key.
// If it is finished, invoke `done_cb` immediately after filling `response`.
// If active, done_db will be invoked when the current call completes.
// Otherwise, invoke `compute_func` to fill the cache and invoke done_cb.
void LookupOrCompute(const string& key, RPCResponse response,
ComputeFunc compute_func, StatusCallback done_cb);
// Add the given request to the cache.
// If the request is in the cache,
// If it is finished, invoke `cb` immediately
// If active, cb will be invoked when the current call completes.
// In either case, return true.
// Otherwise, store the request and cb in the cache, and return false.
// Note FinishResponseCB is assumed to be thread-safe.
bool QueueRequest(int64 request_id, int64 step_id,
const FinishResponseCB& cb);
// Remove all stale or expired cache entries if the cache is full.
void MaybeCleanup();
// Fill the response cache for the given request_id and respond to all
// pending request.
void OnRequestFinished(int64 request_id, const Tensor& tensor, bool is_dead,
const Status& status);
// Erase the cache entry with the given request_id
void EraseRequestId(int64 request_id);
// Erase cache entries with the given step_id
void CleanEntriesForStep(int64 step_id);
private:
int64 current_bytes_ GUARDED_BY(mu_) = 0;
const int64 max_bytes_;
const int64 expire_time_seconds_;
struct ResponseCacheEntry {
enum class State {
PENDING = 0,
ACTIVE = 1,
FINISHED = 2,
};
State state = State::PENDING;
int64 step_id = -1;
Tensor tensor;
bool is_dead = false;
Status response_status;
void FinishResponse(const FinishResponseCB& cb) const {
cb(tensor, is_dead, response_status);
}
std::vector<FinishResponseCB> callbacks;
};
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> requests_
GUARDED_BY(mu_);
mutex mu_;
// response_cache_ is expected to be small, as entries are cleared immediately
// on ack from the receiver.
gtl::FlatMap<int64, ResponseCacheEntry> response_cache_ GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -135,13 +135,14 @@ static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
#endif
}
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
::grpc::ByteBuffer* result) {
const int kLargeTensorBytes = 1024;
RecvTensorResponse response;
if (is_dead) {
response.set_is_dead(is_dead);
}
response.set_require_ack(require_ack);
response.set_send_start_micros(Env::Default()->NowMicros());
if (!DataTypeCanUseMemcpy(val.dtype())) {
// Straightforward but slow path for complicated kinds of tensor data

View File

@ -46,7 +46,7 @@ void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
// "val" holds the tensor value to be encoded.
//
// Discards original contents of *result.
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
::grpc::ByteBuffer* result);
} // namespace grpc

View File

@ -31,7 +31,7 @@ class GrpcTensorCodingTest : public ::testing::Test {
void Validate(const Tensor& t, bool is_dead) {
// Check by encoding to a ByteBuffer
::grpc::ByteBuffer buf;
grpc::EncodeTensorToByteBuffer(is_dead, t, &buf);
grpc::EncodeTensorToByteBuffer(is_dead, t, false, &buf);
// Make a string
std::vector<::grpc::Slice> slices;

View File

@ -146,6 +146,7 @@ class GrpcWorkerServiceThread {
SETUP_FOR_REQUEST(RecvBuf, 500, true);
SETUP_FOR_REQUEST(RunGraph, 100, true);
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
// TODO(ncteisen): Determine a better policy for enqueuing the
// appropriate number of each request type.
@ -221,6 +222,14 @@ class GrpcWorkerServiceThread {
ENQUEUE_REQUEST(GetStepSequence, true);
}
void MarkRecvFinishedHandler(
WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
VLOG(1) << "Clean cache entry for request " << call->request.request_id();
worker_->RemoveCacheEntryForId(call->request.request_id());
call->SendResponse(::grpc::Status::OK);
ENQUEUE_REQUEST(MarkRecvFinished, false);
}
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
@ -229,32 +238,19 @@ class GrpcWorkerServiceThread {
NonOwnedProtoRunGraphResponse* wrapped_response =
new NonOwnedProtoRunGraphResponse(&call->response);
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
auto done_cb = [call, call_opts, wrapped_request,
wrapped_response](const Status& s) {
VLOG(1) << "RunGraph::Done";
if (!s.ok()) {
VLOG(1) << "Bad response from RunGraph:" << s;
}
call->ClearCancelCallback();
delete call_opts;
delete wrapped_request;
delete wrapped_response;
call->SendResponse(ToGrpcStatus(s));
};
auto compute_fn = [this, call_opts, wrapped_request,
wrapped_response](StatusCallback done) {
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
done);
};
if (cache_) {
string request_key = call->request.ShortDebugString();
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
compute_fn, done_cb);
} else {
compute_fn(done_cb);
}
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
[call, call_opts, wrapped_request,
wrapped_response](const Status& s) {
VLOG(1) << "RunGraph::Done";
if (!s.ok()) {
VLOG(1) << "Bad response from RunGraph:" << s;
}
call->ClearCancelCallback();
delete call_opts;
delete wrapped_request;
delete wrapped_response;
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(RunGraph, true);
}
@ -265,27 +261,16 @@ class GrpcWorkerServiceThread {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
auto done_cb = [call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(1) << "Bad response from RecvTensor:" << s;
}
call->SendResponse(ToGrpcStatus(s));
};
auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
done);
};
if (cache_) {
string request_key = call->request.ShortDebugString();
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
compute_fn, done_cb);
} else {
compute_fn(done_cb);
}
worker_->GrpcRecvTensorAsync(
call_opts, &call->request, &call->response,
[call, call_opts](const Status& s) {
call->ClearCancelCallback();
delete call_opts;
if (!s.ok()) {
VLOG(1) << "Bad response from RecvTensor:" << s;
}
call->SendResponse(ToGrpcStatus(s));
});
});
EnqueueRecvTensorRequestRaw();
}
@ -377,10 +362,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
GrpcWorkerServiceOptions options)
: is_shutdown_(false) {
builder->RegisterService(&worker_service_);
if (options.response_cache_bytes > 0) {
cache_.reset(
new GrpcResponseCache(options.response_cache_bytes,
options.response_cache_expires_seconds));
// TODO(jingdong): it would be cleaner to move this option to GrpcWorker
// since the cache is maintained by GrpcWorker now.
if (options.cache_rpc_response) {
worker->EnableResponseCache();
}
for (int i = 0; i < options.num_serving_threads; i++) {
@ -437,6 +422,11 @@ GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
? config.experimental().recv_buf_max_chunk()
: (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
void GrpcWorker::EnableResponseCache() {
VLOG(1) << "Enabling gRPC tensor response cache.";
response_cache_ = absl::make_unique<GrpcResponseCache>();
}
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
// buffers for a response object, to avoid extra protocol buffer serialization
// overhead we generate our response directly into a ::grpc::ByteBuffer object
@ -444,14 +434,49 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done) {
Status s = recent_request_ids_.TrackUnique(
request->request_id(), "RecvTensor (GrpcWorker)", *request);
if (!s.ok()) {
done(s);
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
const Status& status) {
if (status.ok()) {
bool require_ack = (response_cache_ != nullptr);
grpc::EncodeTensorToByteBuffer(is_dead, tensor, require_ack, response);
}
done(status);
};
const int64 request_id = request->request_id();
const int64 step_id = request->step_id();
// If response cache is enabled and the response cache already contains the
// request, we delegate this retry request to the response cache. Otherwise,
// we add the request to the response cache and start the computation to
// retrieve the requested data.
if (response_cache_ &&
response_cache_->QueueRequest(request_id, step_id, do_response)) {
return;
}
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
bool is_dead,
const Status& status) {
if (response_cache_) {
// Data is ready. Process all pending requests in the response cache.
response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
} else {
do_response(tensor, is_dead, status);
}
};
auto fail = [&rendezvous_done](const Status& status) {
rendezvous_done(Tensor(), false, status);
};
Status s = recent_request_ids_.TrackUnique(
request_id, "RecvTensor (GrpcWorker)", *request);
if (!s.ok()) {
fail(s);
return;
}
const int64 step_id = request->step_id();
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
@ -461,7 +486,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
s = PrepareRecvTensor(parsed, &src_dev);
}
if (!s.ok()) {
done(s);
fail(s);
return;
}
@ -475,7 +500,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
[step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, response, done, src_dev, request](
[opts, rendezvous_done, src_dev, request](
const Status& status, const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args, const Tensor& val,
const bool is_dead) {
@ -502,25 +527,21 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on an accelerator device. Uses the device_context to
// fill the copy on host.
StatusCallback copy_ready = [response, done, copy,
StatusCallback copy_ready = [rendezvous_done, copy,
is_dead](const Status& s) {
// The value is now ready to be returned on the wire.
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
done(s);
rendezvous_done(*copy, is_dead, s);
delete copy;
};
send_dev_context->CopyDeviceTensorToCPU(
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
} else {
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
done(Status::OK());
return;
}
}
} else {
// !s.ok()
done(status);
}
rendezvous_done(val, is_dead, status);
});
}
@ -537,8 +558,9 @@ namespace {
// RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
// and remove this function.
void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
int64 num_bytes, RecvBufResponse* response) {
RecvBufResponse* response) {
RecvBufRespExtra extra;
int64 num_bytes = tensor->TotalBytes();
const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
while (num_bytes > 0) {
int64 bytes =
@ -553,20 +575,56 @@ void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) {
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
const Status& status) {
if (status.ok()) {
SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
}
response->set_send_start_micros(env_->env->NowMicros());
response->set_require_ack(response_cache_ != nullptr);
done(status);
};
const int64 request_id = request->request_id();
const int64 step_id = request->step_id();
// If response cache is enabled and the response cache already contains the
// request, we delegate this retry request to the response cache. Otherwise,
// we add the request to the response cache and start the computation to
// retrieve the requested data.
if (response_cache_ &&
response_cache_->QueueRequest(request_id, step_id, do_response)) {
return;
}
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
const Status& status) {
if (response_cache_) {
// Data is ready. Process all pending requests in the response cache.
response_cache_->OnRequestFinished(request_id, tensor, false, status);
} else {
do_response(tensor, false, status);
}
};
auto fail = [&rendezvous_done](const Status& status) {
rendezvous_done(Tensor(), status);
};
// This is a generic, low performance implementation appropriate for grpc.
Status s = recent_request_ids_.TrackUnique(request->request_id(),
"RecvBuf (GrpcWorker)", *request);
Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
*request);
if (!s.ok()) {
done(s);
fail(s);
return;
}
CollectiveExecutor::Handle ce_handle(
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
env_->collective_executor_mgr->FindOrCreate(step_id), true);
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
rma->buf_rendezvous()->ConsumeBuf(
request->buf_rendezvous_key(),
[this, request, response, done](const Status& status,
BufRendezvous::Hook* hook) {
[this, request, rendezvous_done](const Status& status,
BufRendezvous::Hook* hook) {
Status s = status;
if (s.ok()) {
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
@ -594,27 +652,17 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
hook->prod_value->shape());
hook->prod_ctx->CopyDeviceTensorToCPU(
hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
[this, num_bytes, response, done, hook,
cpu_tensor](const Status& s) {
if (s.ok()) {
SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor,
num_bytes, response);
}
response->set_send_start_micros(env_->env->NowMicros());
done(s);
[hook, cpu_tensor, rendezvous_done](const Status& s) {
rendezvous_done(*cpu_tensor, s);
BufRendezvous::DoneWithHook(hook);
delete cpu_tensor;
});
return;
}
} else {
// Tensor is on CPU.
SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value,
num_bytes, response);
}
}
response->set_send_start_micros(env_->env->NowMicros());
done(s);
rendezvous_done(*hook->prod_value, s);
BufRendezvous::DoneWithHook(hook);
});
}
@ -646,8 +694,25 @@ void GrpcWorker::LoggingAsync(const LoggingRequest* request,
done(Status::OK());
}
void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
if (response_cache_) {
// Cleanup any stale response cache entries for this step. This can occur if
// a worker crashes before acking a request.
response_cache_->CleanEntriesForStep(request->step_id());
}
Worker::CleanupGraphAsync(request, response, done);
}
WorkerEnv* GrpcWorker::env() { return env_; }
void GrpcWorker::RemoveCacheEntryForId(int64 request_id) {
if (response_cache_) {
response_cache_->EraseRequestId(request_id);
}
}
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
const ConfigProto& config) {
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace grpc {
class ByteBuffer;
@ -33,6 +34,7 @@ class AsyncServiceInterface;
class ConfigProto;
struct WorkerEnv;
struct WorkerSession;
class GrpcResponseCache;
class GrpcWorker : public Worker {
public:
@ -44,15 +46,24 @@ class GrpcWorker : public Worker {
::grpc::ByteBuffer* response,
StatusCallback done);
virtual void LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done);
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
StatusCallback done) override;
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done);
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) override;
void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) override;
WorkerEnv* env();
void EnableResponseCache();
void RemoveCacheEntryForId(int64 request_id);
private:
std::unique_ptr<GrpcResponseCache> response_cache_;
const int32 recv_buf_max_chunk_;
};
@ -64,8 +75,14 @@ struct GrpcWorkerServiceOptions {
// default queue depth for a method.
std::unordered_map<int, int> queue_depth;
int num_serving_threads = 8;
int64 response_cache_bytes = 0;
int64 response_cache_expires_seconds = 0;
// Setting cache_rpc_response to true will enable sender side caching of
// response for RecvTensorAsync and RecvBufAsync to allow receiver to retry
// requests . This is only necessary when the network fabric is experiencing a
// significant error rate. Without it we'll fail a step on an network error,
// while with it we'll be able to complete long steps (like complex
// initializations) in the face of some network errors during RecvTensor.
bool cache_rpc_response = false;
};
// Returns an implementation of WorkerService rpc service.

View File

@ -58,6 +58,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
return "/tensorflow.WorkerService/CompleteInstance";
case GrpcWorkerMethod::kGetStepSequence:
return "/tensorflow.WorkerService/GetStepSequence";
case GrpcWorkerMethod::kMarkRecvFinished:
return "/tensorflow.WorkerService/MarkRecvFinished";
}
// Shouldn't be reached.
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";

View File

@ -85,10 +85,11 @@ enum class GrpcWorkerMethod {
kCompleteGroup,
kCompleteInstance,
kGetStepSequence,
kMarkRecvFinished,
};
static const int kGrpcNumWorkerMethods =
static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
static_cast<int>(GrpcWorkerMethod::kMarkRecvFinished) + 1;
const char* GrpcWorkerMethodName(GrpcWorkerMethod id);

View File

@ -246,7 +246,7 @@ bool TensorResponse::ParseFast(Source* source) {
case RecvTensorResponse::kIsDeadFieldNumber: {
uint32 v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
meta_.set_is_dead((v != 0) ? true : false);
meta_.set_is_dead(v != 0);
break;
}
case RecvTensorResponse::kSendStartMicrosFieldNumber: {
@ -261,6 +261,12 @@ bool TensorResponse::ParseFast(Source* source) {
return false;
break;
}
case RecvTensorResponse::kRequireAckFieldNumber: {
uint32 v;
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
meta_.set_require_ack(v != 0);
break;
}
default: {
// Unknown tag, so don't handle we can't handle on the fast path
return false;

View File

@ -362,8 +362,20 @@ message RecvTensorResponse {
// Optional additional information about how to receive the tensor,
// e.g. in the event that `RecvTensorRequest.dma_ok` was true.
google.protobuf.Any transport_options = 4;
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
// to ack the message.
bool require_ack = 5;
}
// Message for managing the response cache maintained on the sender side.
// Currently only used by the gRPC worker service.
message MarkRecvFinishedRequest {
int64 request_id = 1;
}
message MarkRecvFinishedResponse {}
////////////////////////////////////////////////////////////////////////////////
//
// Logging method request/response messages
@ -490,6 +502,10 @@ message RecvBufResponse {
google.protobuf.Any transport_options = 4;
// Optional, for timeline.
int64 send_start_micros = 5;
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
// to ack the message.
bool require_ack = 6;
}
////////////////////////////////////////////////////////////////////////////////