TF/gRPC: Use explicit request from receiver to clean RPC cache
This simplifies the response cache management on the sender side and avoids the need for large response cache size. PiperOrigin-RevId: 239638065
This commit is contained in:
parent
90085135c2
commit
a6d233a581
@ -128,7 +128,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
StatusCallback copy_ready = [response, done, copy,
|
StatusCallback copy_ready = [response, done, copy,
|
||||||
is_dead](const Status& s) {
|
is_dead](const Status& s) {
|
||||||
// The value is now ready to be returned on the wire.
|
// The value is now ready to be returned on the wire.
|
||||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
grpc::EncodeTensorToByteBuffer(is_dead, *copy, false, response);
|
||||||
done(s);
|
done(s);
|
||||||
delete copy;
|
delete copy;
|
||||||
};
|
};
|
||||||
@ -136,7 +136,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
send_dev_context->CopyDeviceTensorToCPU(
|
send_dev_context->CopyDeviceTensorToCPU(
|
||||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
||||||
} else {
|
} else {
|
||||||
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
|
grpc::EncodeTensorToByteBuffer(is_dead, val, false, response);
|
||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -165,8 +165,9 @@ cc_library(
|
|||||||
srcs = ["grpc_response_cache.cc"],
|
srcs = ["grpc_response_cache.cc"],
|
||||||
hdrs = ["grpc_response_cache.h"],
|
hdrs = ["grpc_response_cache.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":grpc_util",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -66,6 +66,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
|
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
|
||||||
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
||||||
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
||||||
|
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
|
||||||
logger_(logger) {}
|
logger_(logger) {}
|
||||||
|
|
||||||
~GrpcRemoteWorker() override {}
|
~GrpcRemoteWorker() override {}
|
||||||
@ -130,12 +131,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
int64 start_usec = Env::Default()->NowMicros();
|
int64 start_usec = Env::Default()->NowMicros();
|
||||||
// Type-specialized logging for this method.
|
// Type-specialized logging for this method.
|
||||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||||
StatusCallback wrapper_done;
|
|
||||||
const StatusCallback* cb_to_use;
|
auto callback = [this, request, response, done, start_usec,
|
||||||
if (!logging_active) {
|
logging_active](Status s) {
|
||||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
if (logging_active) {
|
||||||
} else {
|
|
||||||
wrapper_done = [this, request, response, done, start_usec](Status s) {
|
|
||||||
if (logger_->LoggingActive()) {
|
if (logger_->LoggingActive()) {
|
||||||
int64 end_usec = Env::Default()->NowMicros();
|
int64 end_usec = Env::Default()->NowMicros();
|
||||||
int64 step_id = request->step_id();
|
int64 step_id = request->step_id();
|
||||||
@ -159,12 +158,17 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
}
|
}
|
||||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||||
<< " response " << response->DebugString();
|
<< " response " << response->DebugString();
|
||||||
done(s);
|
|
||||||
};
|
|
||||||
cb_to_use = &wrapper_done;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
IssueRequest(request, response, recvbuf_, *cb_to_use, call_opts);
|
// Note done() can delete this worker object, so we need to call done()
|
||||||
|
// last.
|
||||||
|
if (response->require_ack()) {
|
||||||
|
IssueMarkRecvFinishedRequest(request->request_id());
|
||||||
|
}
|
||||||
|
done(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
IssueRequest(request, response, recvbuf_, callback, call_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CompleteGroupAsync(CallOptions* call_opts,
|
void CompleteGroupAsync(CallOptions* call_opts,
|
||||||
@ -194,12 +198,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
int64 start_usec = Env::Default()->NowMicros();
|
int64 start_usec = Env::Default()->NowMicros();
|
||||||
// Type-specialized logging for this method.
|
// Type-specialized logging for this method.
|
||||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||||
StatusCallback wrapper_done;
|
|
||||||
const StatusCallback* cb_to_use;
|
auto callback = [this, request, response, done, start_usec,
|
||||||
if (!logging_active) {
|
logging_active](Status s) {
|
||||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
if (logging_active) {
|
||||||
} else {
|
|
||||||
wrapper_done = [this, request, response, done, start_usec](Status s) {
|
|
||||||
if (logger_->LoggingActive()) {
|
if (logger_->LoggingActive()) {
|
||||||
int64 end_usec = Env::Default()->NowMicros();
|
int64 end_usec = Env::Default()->NowMicros();
|
||||||
int64 step_id = request->step_id();
|
int64 step_id = request->step_id();
|
||||||
@ -238,12 +240,17 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
}
|
}
|
||||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||||
<< " response " << response->metadata().DebugString();
|
<< " response " << response->metadata().DebugString();
|
||||||
done(s);
|
|
||||||
};
|
|
||||||
cb_to_use = &wrapper_done;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
|
// Note done() can delete this worker object, so we need to call done()
|
||||||
|
// last.
|
||||||
|
if (response->metadata().require_ack()) {
|
||||||
|
IssueMarkRecvFinishedRequest(request->request_id());
|
||||||
|
}
|
||||||
|
done(s);
|
||||||
|
};
|
||||||
|
|
||||||
|
IssueRequest(request, response, recvtensor_, callback, call_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||||
@ -276,6 +283,16 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
callback_threadpool_, max_retries);
|
callback_threadpool_, max_retries);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IssueMarkRecvFinishedRequest(int64 request_id) {
|
||||||
|
VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
|
||||||
|
MarkRecvFinishedRequest request;
|
||||||
|
request.set_request_id(request_id);
|
||||||
|
|
||||||
|
MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
|
||||||
|
auto done = [response](Status status) { delete response; };
|
||||||
|
IssueRequest(&request, response, markrecvfinished_, done);
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function for initializing the RpcMethod objects below.
|
// Helper function for initializing the RpcMethod objects below.
|
||||||
const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
|
const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
|
||||||
|
|
||||||
@ -299,6 +316,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
|||||||
const ::grpc::string completegroup_;
|
const ::grpc::string completegroup_;
|
||||||
const ::grpc::string instancesource_;
|
const ::grpc::string instancesource_;
|
||||||
const ::grpc::string getstepsequence_;
|
const ::grpc::string getstepsequence_;
|
||||||
|
const ::grpc::string markrecvfinished_;
|
||||||
|
|
||||||
// Support for logging.
|
// Support for logging.
|
||||||
WorkerCacheLogger* logger_;
|
WorkerCacheLogger* logger_;
|
||||||
|
@ -14,170 +14,102 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
struct WorkerCacheEntry {
|
bool GrpcResponseCache::QueueRequest(int64 request_id, int64 step_id,
|
||||||
enum class State {
|
const FinishResponseCB& cb) {
|
||||||
PENDING = 0,
|
VLOG(1) << "GrpcResponseCache Lookup " << request_id;
|
||||||
ACTIVE = 1,
|
|
||||||
FINISHED = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
State state = State::PENDING;
|
mu_.lock();
|
||||||
int64 expires_seconds;
|
|
||||||
|
|
||||||
::grpc::ByteBuffer response_buf;
|
ResponseCacheEntry& entry = response_cache_[request_id];
|
||||||
Status response_status;
|
|
||||||
|
|
||||||
// Additional retries may arrive while a request is still executing. The
|
if (entry.state == ResponseCacheEntry::State::FINISHED) {
|
||||||
// callbacks for these calls are queued in `callbacks` and evaluated after
|
VLOG(1) << "Reuse cached response for " << request_id;
|
||||||
// the original request is completed.
|
// Make a copy of the ResponseCacheEntry so that we can run FinishResponse
|
||||||
std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
|
// outside the critical section. FinishResponse can be potentially
|
||||||
};
|
// expensive.
|
||||||
|
auto entry_copy = entry;
|
||||||
|
|
||||||
void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
|
mu_.unlock();
|
||||||
if (buf_ != nullptr) {
|
entry_copy.FinishResponse(cb);
|
||||||
*tgt = *buf_;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.callbacks.emplace_back(cb);
|
||||||
|
|
||||||
|
if (entry.state == ResponseCacheEntry::State::ACTIVE) {
|
||||||
|
VLOG(1) << "Found active request for " << request_id
|
||||||
|
<< ". Adding entry to response queue.";
|
||||||
|
mu_.unlock();
|
||||||
|
return true;
|
||||||
} else {
|
} else {
|
||||||
CHECK(msg_ != nullptr);
|
VLOG(2) << "No cache entry for " << request_id
|
||||||
::grpc::Slice slice(msg_->ByteSizeLong());
|
<< ", running user computation.";
|
||||||
msg_->SerializeWithCachedSizesToArray(
|
entry.step_id = step_id;
|
||||||
const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
|
entry.state = ResponseCacheEntry::State::ACTIVE;
|
||||||
::grpc::ByteBuffer tmp(&slice, 1);
|
mu_.unlock();
|
||||||
tgt->Swap(&tmp);
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
|
void GrpcResponseCache::OnRequestFinished(int64 request_id,
|
||||||
if (buf_ != nullptr) {
|
const Tensor& tensor, bool is_dead,
|
||||||
*buf_ = src;
|
const Status& status) {
|
||||||
return;
|
absl::optional<ResponseCacheEntry> entry_copy;
|
||||||
}
|
|
||||||
|
|
||||||
CHECK(msg_ != nullptr);
|
|
||||||
// We create a single slice when encoding protocol messages.
|
|
||||||
std::vector<::grpc::Slice> slices;
|
|
||||||
if (src.Dump(&slices).ok()) {
|
|
||||||
msg_->ParseFromArray(slices[0].begin(), slices[0].size());
|
|
||||||
} else {
|
|
||||||
LOG(ERROR) << "Failed to decode cached buffer.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
|
|
||||||
ComputeFunc compute_func,
|
|
||||||
StatusCallback done_cb) {
|
|
||||||
VLOG(1) << "Lookup " << key;
|
|
||||||
std::shared_ptr<WorkerCacheEntry> req;
|
|
||||||
MaybeCleanup();
|
|
||||||
{
|
{
|
||||||
mutex_lock m(mu_);
|
mutex_lock m(mu_);
|
||||||
|
|
||||||
if (requests_.find(key) != requests_.end()) {
|
auto it = response_cache_.find(request_id);
|
||||||
req = requests_[key];
|
if (it == response_cache_.end()) {
|
||||||
|
LOG(ERROR) << "Unexpected missing response cache entry for request "
|
||||||
|
<< request_id;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ResponseCacheEntry& entry = it->second;
|
||||||
|
|
||||||
|
VLOG(1) << "Operation for " << request_id << " finished. "
|
||||||
|
<< "Status: " << status << ", tensor size " << tensor.TotalBytes()
|
||||||
|
<< " bytes, " << entry.callbacks.size() << " pending callbacks.";
|
||||||
|
|
||||||
|
entry.tensor = tensor;
|
||||||
|
entry.is_dead = is_dead;
|
||||||
|
entry.response_status = status;
|
||||||
|
entry.state = ResponseCacheEntry::State::FINISHED;
|
||||||
|
|
||||||
|
// We copy the extra work out of the critical section in order to avoid
|
||||||
|
// serializing the work for sending response.
|
||||||
|
entry_copy = entry;
|
||||||
|
|
||||||
|
entry.callbacks.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& cb : entry_copy->callbacks) {
|
||||||
|
entry_copy->FinishResponse(cb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GrpcResponseCache::EraseRequestId(int64 request_id) {
|
||||||
|
mutex_lock m(mu_);
|
||||||
|
response_cache_.erase(request_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GrpcResponseCache::CleanEntriesForStep(int64 step_id) {
|
||||||
|
mutex_lock m(mu_);
|
||||||
|
// Remove all cache entries whose step id is the given step_id
|
||||||
|
for (auto it = response_cache_.begin(), last = response_cache_.end();
|
||||||
|
it != last;) {
|
||||||
|
if (it->second.step_id == step_id) {
|
||||||
|
VLOG(1) << "Erase stale GrpcResponseCache entry " << it->first;
|
||||||
|
it = response_cache_.erase(it);
|
||||||
} else {
|
} else {
|
||||||
req.reset(new WorkerCacheEntry);
|
++it;
|
||||||
requests_[key] = req;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (req->state == WorkerCacheEntry::State::FINISHED) {
|
|
||||||
if (req->expires_seconds > Env::Default()->NowSeconds()) {
|
|
||||||
VLOG(1) << "Reuse cached response for " << key;
|
|
||||||
response.CopyFrom(req->response_buf);
|
|
||||||
done_cb(req->response_status);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
VLOG(1) << "Found expired cache entry for " << key;
|
|
||||||
req->state = WorkerCacheEntry::State::PENDING;
|
|
||||||
req->response_buf.Clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
req->callbacks.push_back(std::make_pair(response, done_cb));
|
|
||||||
|
|
||||||
if (req->state == WorkerCacheEntry::State::ACTIVE) {
|
|
||||||
VLOG(1) << "Found active request for " << key
|
|
||||||
<< ". Adding entry to response queue.";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(2) << "No cache entry for " << key << ", running user computation.";
|
|
||||||
req->state = WorkerCacheEntry::State::ACTIVE;
|
|
||||||
req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
|
|
||||||
}
|
|
||||||
|
|
||||||
compute_func([this, key, req, response](Status status) {
|
|
||||||
mutex_lock m(mu_);
|
|
||||||
response.Encode(&req->response_buf);
|
|
||||||
current_bytes_ += req->response_buf.Length();
|
|
||||||
|
|
||||||
req->response_status = status;
|
|
||||||
req->state = WorkerCacheEntry::State::FINISHED;
|
|
||||||
|
|
||||||
VLOG(1) << "Operation for " << key << " finished. "
|
|
||||||
<< "Status: " << status << ", " << req->response_buf.Length()
|
|
||||||
<< " response bytes, " << req->callbacks.size()
|
|
||||||
<< " pending callbacks.";
|
|
||||||
for (auto& cb : req->callbacks) {
|
|
||||||
cb.first.CopyFrom(req->response_buf);
|
|
||||||
cb.second(req->response_status);
|
|
||||||
}
|
|
||||||
req->callbacks.clear();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove all stale or expired cache entries if the cache is full.
|
|
||||||
void GrpcResponseCache::MaybeCleanup() {
|
|
||||||
mutex_lock m(mu_);
|
|
||||||
if (current_bytes_ < max_bytes_) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
|
|
||||||
std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
|
|
||||||
ordered_entries;
|
|
||||||
ordered_entries.reserve(requests_.size());
|
|
||||||
for (const auto& p : requests_) {
|
|
||||||
ordered_entries.push_back(std::make_pair(p.first, p.second));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::sort(ordered_entries.begin(), ordered_entries.end(),
|
|
||||||
[](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
|
|
||||||
const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
|
|
||||||
return a.second->expires_seconds > b.second->expires_seconds;
|
|
||||||
});
|
|
||||||
|
|
||||||
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
|
|
||||||
int64 now = Env::Default()->NowSeconds();
|
|
||||||
int64 bytes_used = 0;
|
|
||||||
|
|
||||||
// Always keep active requests.
|
|
||||||
for (auto& pair : ordered_entries) {
|
|
||||||
if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
|
|
||||||
kept.insert(pair);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep unexpired, finished requests up to half of max_bytes_. This reduces
|
|
||||||
// chances of overfilling the cache when active requests complete and
|
|
||||||
// amortizes cache cleanup cost.
|
|
||||||
for (auto& pair : ordered_entries) {
|
|
||||||
if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
|
|
||||||
kept.insert(pair);
|
|
||||||
bytes_used += pair.second->response_buf.Length();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "Cleaned cache. Bytes used: " << current_bytes_ << " -> "
|
|
||||||
<< bytes_used << ". Cache size: " << requests_.size() << " -> "
|
|
||||||
<< kept.size();
|
|
||||||
current_bytes_ = bytes_used;
|
|
||||||
std::swap(requests_, kept);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,71 +19,74 @@ limitations under the License.
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
|
|
||||||
// gRPC response caching. Most WorkerService methods cannot be retried directly
|
// gRPC response caching. Most WorkerService methods cannot be retried directly
|
||||||
// as they will fail or deadlock. To enable retrying, we can instead cache
|
// as they will fail or deadlock. To enable retrying, we can instead cache
|
||||||
// responses for a short period of time and reply to duplicate requests from the
|
// responses and reply to duplicate requests from the cache. The cache will be
|
||||||
// cache.
|
// cleaned when the MarkRecvFinishedRequest is received from the receiver or the
|
||||||
|
// session step is completed.
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Union type to aid caching of either raw buffers (for RecvTensor RPCs) and
|
|
||||||
// protocol buffer messages (for all other RPCs).
|
|
||||||
class RPCResponse {
|
|
||||||
public:
|
|
||||||
explicit RPCResponse() : buf_(nullptr), msg_(nullptr) {}
|
|
||||||
explicit RPCResponse(::grpc::ByteBuffer* b) : buf_(b), msg_(nullptr) {}
|
|
||||||
explicit RPCResponse(protobuf::Message* m) : buf_(nullptr), msg_(m) {}
|
|
||||||
|
|
||||||
// Encode this response into the target buffer.
|
|
||||||
void Encode(::grpc::ByteBuffer* tgt) const;
|
|
||||||
|
|
||||||
// Copy from `src`: if this is a buffer, make a shallow copy.
|
|
||||||
// For protocol messages, parse the response from `src`.
|
|
||||||
void CopyFrom(const ::grpc::ByteBuffer& src);
|
|
||||||
|
|
||||||
private:
|
|
||||||
::grpc::ByteBuffer* buf_;
|
|
||||||
protobuf::Message* msg_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::function<void(StatusCallback)> ComputeFunc;
|
|
||||||
struct WorkerCacheEntry;
|
|
||||||
|
|
||||||
// Track and cache the state of worker service RPCs. An RPC can be in 3 states:
|
// Track and cache the state of worker service RPCs. An RPC can be in 3 states:
|
||||||
//
|
//
|
||||||
// * PENDING: this is the first call of the RPC, and it will transition to
|
// * PENDING: this is the first call of the RPC, and it will transition to
|
||||||
// * ACTIVE: another thread is active processing this RPC
|
// * ACTIVE: another thread is active processing this RPC
|
||||||
// * FINISHED: the worker has finished processing the method
|
// * FINISHED: the worker has finished processing the method
|
||||||
//
|
|
||||||
// The response from completed RPCs are LRU cached until either `max_bytes`
|
|
||||||
// bytes are in use by the cache or they expire (according to `expire_time`).
|
|
||||||
class GrpcResponseCache {
|
class GrpcResponseCache {
|
||||||
public:
|
public:
|
||||||
GrpcResponseCache(int64 max_bytes, int64 expire_time_seconds)
|
using FinishResponseCB = std::function<void(
|
||||||
: max_bytes_(max_bytes), expire_time_seconds_(expire_time_seconds) {}
|
const Tensor& tensor, bool is_dead, const Status& status)>;
|
||||||
|
|
||||||
// Lookup the result for key.
|
// Add the given request to the cache.
|
||||||
// If it is finished, invoke `done_cb` immediately after filling `response`.
|
// If the request is in the cache,
|
||||||
// If active, done_db will be invoked when the current call completes.
|
// If it is finished, invoke `cb` immediately
|
||||||
// Otherwise, invoke `compute_func` to fill the cache and invoke done_cb.
|
// If active, cb will be invoked when the current call completes.
|
||||||
void LookupOrCompute(const string& key, RPCResponse response,
|
// In either case, return true.
|
||||||
ComputeFunc compute_func, StatusCallback done_cb);
|
// Otherwise, store the request and cb in the cache, and return false.
|
||||||
|
// Note FinishResponseCB is assumed to be thread-safe.
|
||||||
|
bool QueueRequest(int64 request_id, int64 step_id,
|
||||||
|
const FinishResponseCB& cb);
|
||||||
|
|
||||||
// Remove all stale or expired cache entries if the cache is full.
|
// Fill the response cache for the given request_id and respond to all
|
||||||
void MaybeCleanup();
|
// pending request.
|
||||||
|
void OnRequestFinished(int64 request_id, const Tensor& tensor, bool is_dead,
|
||||||
|
const Status& status);
|
||||||
|
|
||||||
|
// Erase the cache entry with the given request_id
|
||||||
|
void EraseRequestId(int64 request_id);
|
||||||
|
|
||||||
|
// Erase cache entries with the given step_id
|
||||||
|
void CleanEntriesForStep(int64 step_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64 current_bytes_ GUARDED_BY(mu_) = 0;
|
struct ResponseCacheEntry {
|
||||||
const int64 max_bytes_;
|
enum class State {
|
||||||
const int64 expire_time_seconds_;
|
PENDING = 0,
|
||||||
|
ACTIVE = 1,
|
||||||
|
FINISHED = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
State state = State::PENDING;
|
||||||
|
int64 step_id = -1;
|
||||||
|
Tensor tensor;
|
||||||
|
bool is_dead = false;
|
||||||
|
Status response_status;
|
||||||
|
|
||||||
|
void FinishResponse(const FinishResponseCB& cb) const {
|
||||||
|
cb(tensor, is_dead, response_status);
|
||||||
|
}
|
||||||
|
std::vector<FinishResponseCB> callbacks;
|
||||||
|
};
|
||||||
|
|
||||||
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> requests_
|
|
||||||
GUARDED_BY(mu_);
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
|
// response_cache_ is expected to be small, as entries are cleared immediately
|
||||||
|
// on ack from the receiver.
|
||||||
|
gtl::FlatMap<int64, ResponseCacheEntry> response_cache_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -135,13 +135,14 @@ static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
|
||||||
::grpc::ByteBuffer* result) {
|
::grpc::ByteBuffer* result) {
|
||||||
const int kLargeTensorBytes = 1024;
|
const int kLargeTensorBytes = 1024;
|
||||||
RecvTensorResponse response;
|
RecvTensorResponse response;
|
||||||
if (is_dead) {
|
if (is_dead) {
|
||||||
response.set_is_dead(is_dead);
|
response.set_is_dead(is_dead);
|
||||||
}
|
}
|
||||||
|
response.set_require_ack(require_ack);
|
||||||
response.set_send_start_micros(Env::Default()->NowMicros());
|
response.set_send_start_micros(Env::Default()->NowMicros());
|
||||||
if (!DataTypeCanUseMemcpy(val.dtype())) {
|
if (!DataTypeCanUseMemcpy(val.dtype())) {
|
||||||
// Straightforward but slow path for complicated kinds of tensor data
|
// Straightforward but slow path for complicated kinds of tensor data
|
||||||
|
@ -46,7 +46,7 @@ void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
|
|||||||
// "val" holds the tensor value to be encoded.
|
// "val" holds the tensor value to be encoded.
|
||||||
//
|
//
|
||||||
// Discards original contents of *result.
|
// Discards original contents of *result.
|
||||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
|
||||||
::grpc::ByteBuffer* result);
|
::grpc::ByteBuffer* result);
|
||||||
|
|
||||||
} // namespace grpc
|
} // namespace grpc
|
||||||
|
@ -31,7 +31,7 @@ class GrpcTensorCodingTest : public ::testing::Test {
|
|||||||
void Validate(const Tensor& t, bool is_dead) {
|
void Validate(const Tensor& t, bool is_dead) {
|
||||||
// Check by encoding to a ByteBuffer
|
// Check by encoding to a ByteBuffer
|
||||||
::grpc::ByteBuffer buf;
|
::grpc::ByteBuffer buf;
|
||||||
grpc::EncodeTensorToByteBuffer(is_dead, t, &buf);
|
grpc::EncodeTensorToByteBuffer(is_dead, t, false, &buf);
|
||||||
|
|
||||||
// Make a string
|
// Make a string
|
||||||
std::vector<::grpc::Slice> slices;
|
std::vector<::grpc::Slice> slices;
|
||||||
|
@ -146,6 +146,7 @@ class GrpcWorkerServiceThread {
|
|||||||
SETUP_FOR_REQUEST(RecvBuf, 500, true);
|
SETUP_FOR_REQUEST(RecvBuf, 500, true);
|
||||||
SETUP_FOR_REQUEST(RunGraph, 100, true);
|
SETUP_FOR_REQUEST(RunGraph, 100, true);
|
||||||
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
|
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
|
||||||
|
SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
|
||||||
|
|
||||||
// TODO(ncteisen): Determine a better policy for enqueuing the
|
// TODO(ncteisen): Determine a better policy for enqueuing the
|
||||||
// appropriate number of each request type.
|
// appropriate number of each request type.
|
||||||
@ -221,6 +222,14 @@ class GrpcWorkerServiceThread {
|
|||||||
ENQUEUE_REQUEST(GetStepSequence, true);
|
ENQUEUE_REQUEST(GetStepSequence, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MarkRecvFinishedHandler(
|
||||||
|
WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
|
||||||
|
VLOG(1) << "Clean cache entry for request " << call->request.request_id();
|
||||||
|
worker_->RemoveCacheEntryForId(call->request.request_id());
|
||||||
|
call->SendResponse(::grpc::Status::OK);
|
||||||
|
ENQUEUE_REQUEST(MarkRecvFinished, false);
|
||||||
|
}
|
||||||
|
|
||||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||||
Schedule([this, call]() {
|
Schedule([this, call]() {
|
||||||
CallOptions* call_opts = new CallOptions;
|
CallOptions* call_opts = new CallOptions;
|
||||||
@ -229,7 +238,8 @@ class GrpcWorkerServiceThread {
|
|||||||
NonOwnedProtoRunGraphResponse* wrapped_response =
|
NonOwnedProtoRunGraphResponse* wrapped_response =
|
||||||
new NonOwnedProtoRunGraphResponse(&call->response);
|
new NonOwnedProtoRunGraphResponse(&call->response);
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
auto done_cb = [call, call_opts, wrapped_request,
|
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
||||||
|
[call, call_opts, wrapped_request,
|
||||||
wrapped_response](const Status& s) {
|
wrapped_response](const Status& s) {
|
||||||
VLOG(1) << "RunGraph::Done";
|
VLOG(1) << "RunGraph::Done";
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
@ -240,21 +250,7 @@ class GrpcWorkerServiceThread {
|
|||||||
delete wrapped_request;
|
delete wrapped_request;
|
||||||
delete wrapped_response;
|
delete wrapped_response;
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
};
|
});
|
||||||
|
|
||||||
auto compute_fn = [this, call_opts, wrapped_request,
|
|
||||||
wrapped_response](StatusCallback done) {
|
|
||||||
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
|
||||||
done);
|
|
||||||
};
|
|
||||||
|
|
||||||
if (cache_) {
|
|
||||||
string request_key = call->request.ShortDebugString();
|
|
||||||
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
|
|
||||||
compute_fn, done_cb);
|
|
||||||
} else {
|
|
||||||
compute_fn(done_cb);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(RunGraph, true);
|
ENQUEUE_REQUEST(RunGraph, true);
|
||||||
}
|
}
|
||||||
@ -265,27 +261,16 @@ class GrpcWorkerServiceThread {
|
|||||||
CallOptions* call_opts = new CallOptions;
|
CallOptions* call_opts = new CallOptions;
|
||||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||||
|
|
||||||
auto done_cb = [call, call_opts](const Status& s) {
|
worker_->GrpcRecvTensorAsync(
|
||||||
|
call_opts, &call->request, &call->response,
|
||||||
|
[call, call_opts](const Status& s) {
|
||||||
call->ClearCancelCallback();
|
call->ClearCancelCallback();
|
||||||
delete call_opts;
|
delete call_opts;
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
VLOG(1) << "Bad response from RecvTensor:" << s;
|
VLOG(1) << "Bad response from RecvTensor:" << s;
|
||||||
}
|
}
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
};
|
});
|
||||||
|
|
||||||
auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
|
|
||||||
worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
|
|
||||||
done);
|
|
||||||
};
|
|
||||||
|
|
||||||
if (cache_) {
|
|
||||||
string request_key = call->request.ShortDebugString();
|
|
||||||
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
|
|
||||||
compute_fn, done_cb);
|
|
||||||
} else {
|
|
||||||
compute_fn(done_cb);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
EnqueueRecvTensorRequestRaw();
|
EnqueueRecvTensorRequestRaw();
|
||||||
}
|
}
|
||||||
@ -377,10 +362,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
GrpcWorkerServiceOptions options)
|
GrpcWorkerServiceOptions options)
|
||||||
: is_shutdown_(false) {
|
: is_shutdown_(false) {
|
||||||
builder->RegisterService(&worker_service_);
|
builder->RegisterService(&worker_service_);
|
||||||
if (options.response_cache_bytes > 0) {
|
// TODO(jingdong): it would be cleaner to move this option to GrpcWorker
|
||||||
cache_.reset(
|
// since the cache is maintained by GrpcWorker now.
|
||||||
new GrpcResponseCache(options.response_cache_bytes,
|
if (options.cache_rpc_response) {
|
||||||
options.response_cache_expires_seconds));
|
worker->EnableResponseCache();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < options.num_serving_threads; i++) {
|
for (int i = 0; i < options.num_serving_threads; i++) {
|
||||||
@ -437,6 +422,11 @@ GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
|
|||||||
? config.experimental().recv_buf_max_chunk()
|
? config.experimental().recv_buf_max_chunk()
|
||||||
: (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
|
: (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
|
||||||
|
|
||||||
|
void GrpcWorker::EnableResponseCache() {
|
||||||
|
VLOG(1) << "Enabling gRPC tensor response cache.";
|
||||||
|
response_cache_ = absl::make_unique<GrpcResponseCache>();
|
||||||
|
}
|
||||||
|
|
||||||
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
|
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
|
||||||
// buffers for a response object, to avoid extra protocol buffer serialization
|
// buffers for a response object, to avoid extra protocol buffer serialization
|
||||||
// overhead we generate our response directly into a ::grpc::ByteBuffer object
|
// overhead we generate our response directly into a ::grpc::ByteBuffer object
|
||||||
@ -444,14 +434,49 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
const RecvTensorRequest* request,
|
const RecvTensorRequest* request,
|
||||||
::grpc::ByteBuffer* response,
|
::grpc::ByteBuffer* response,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
Status s = recent_request_ids_.TrackUnique(
|
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
|
||||||
request->request_id(), "RecvTensor (GrpcWorker)", *request);
|
const Status& status) {
|
||||||
if (!s.ok()) {
|
if (status.ok()) {
|
||||||
done(s);
|
bool require_ack = (response_cache_ != nullptr);
|
||||||
|
grpc::EncodeTensorToByteBuffer(is_dead, tensor, require_ack, response);
|
||||||
|
}
|
||||||
|
done(status);
|
||||||
|
};
|
||||||
|
|
||||||
|
const int64 request_id = request->request_id();
|
||||||
|
const int64 step_id = request->step_id();
|
||||||
|
|
||||||
|
// If response cache is enabled and the response cache already contains the
|
||||||
|
// request, we delegate this retry request to the response cache. Otherwise,
|
||||||
|
// we add the request to the response cache and start the computation to
|
||||||
|
// retrieve the requested data.
|
||||||
|
if (response_cache_ &&
|
||||||
|
response_cache_->QueueRequest(request_id, step_id, do_response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
|
||||||
|
bool is_dead,
|
||||||
|
const Status& status) {
|
||||||
|
if (response_cache_) {
|
||||||
|
// Data is ready. Process all pending requests in the response cache.
|
||||||
|
response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
|
||||||
|
} else {
|
||||||
|
do_response(tensor, is_dead, status);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto fail = [&rendezvous_done](const Status& status) {
|
||||||
|
rendezvous_done(Tensor(), false, status);
|
||||||
|
};
|
||||||
|
|
||||||
|
Status s = recent_request_ids_.TrackUnique(
|
||||||
|
request_id, "RecvTensor (GrpcWorker)", *request);
|
||||||
|
if (!s.ok()) {
|
||||||
|
fail(s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 step_id = request->step_id();
|
|
||||||
const string& key = request->rendezvous_key();
|
const string& key = request->rendezvous_key();
|
||||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||||
Rendezvous::ParsedKey parsed;
|
Rendezvous::ParsedKey parsed;
|
||||||
@ -461,7 +486,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
s = PrepareRecvTensor(parsed, &src_dev);
|
s = PrepareRecvTensor(parsed, &src_dev);
|
||||||
}
|
}
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
fail(s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -475,7 +500,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
[step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
|
[step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
|
||||||
env_->rendezvous_mgr->RecvLocalAsync(
|
env_->rendezvous_mgr->RecvLocalAsync(
|
||||||
step_id, parsed,
|
step_id, parsed,
|
||||||
[opts, response, done, src_dev, request](
|
[opts, rendezvous_done, src_dev, request](
|
||||||
const Status& status, const Rendezvous::Args& send_args,
|
const Status& status, const Rendezvous::Args& send_args,
|
||||||
const Rendezvous::Args& recv_args, const Tensor& val,
|
const Rendezvous::Args& recv_args, const Tensor& val,
|
||||||
const bool is_dead) {
|
const bool is_dead) {
|
||||||
@ -502,25 +527,21 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
|||||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||||
// "val" is on an accelerator device. Uses the device_context to
|
// "val" is on an accelerator device. Uses the device_context to
|
||||||
// fill the copy on host.
|
// fill the copy on host.
|
||||||
StatusCallback copy_ready = [response, done, copy,
|
StatusCallback copy_ready = [rendezvous_done, copy,
|
||||||
is_dead](const Status& s) {
|
is_dead](const Status& s) {
|
||||||
// The value is now ready to be returned on the wire.
|
// The value is now ready to be returned on the wire.
|
||||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
rendezvous_done(*copy, is_dead, s);
|
||||||
done(s);
|
|
||||||
delete copy;
|
delete copy;
|
||||||
};
|
};
|
||||||
|
|
||||||
send_dev_context->CopyDeviceTensorToCPU(
|
send_dev_context->CopyDeviceTensorToCPU(
|
||||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
||||||
} else {
|
return;
|
||||||
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
|
|
||||||
done(Status::OK());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// !s.ok()
|
|
||||||
done(status);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rendezvous_done(val, is_dead, status);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -537,8 +558,9 @@ namespace {
|
|||||||
// RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
|
// RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
|
||||||
// and remove this function.
|
// and remove this function.
|
||||||
void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
|
void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
|
||||||
int64 num_bytes, RecvBufResponse* response) {
|
RecvBufResponse* response) {
|
||||||
RecvBufRespExtra extra;
|
RecvBufRespExtra extra;
|
||||||
|
int64 num_bytes = tensor->TotalBytes();
|
||||||
const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
|
const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
|
||||||
while (num_bytes > 0) {
|
while (num_bytes > 0) {
|
||||||
int64 bytes =
|
int64 bytes =
|
||||||
@ -553,19 +575,55 @@ void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
|
|||||||
|
|
||||||
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||||
RecvBufResponse* response, StatusCallback done) {
|
RecvBufResponse* response, StatusCallback done) {
|
||||||
|
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
|
||||||
|
const Status& status) {
|
||||||
|
if (status.ok()) {
|
||||||
|
SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
|
||||||
|
}
|
||||||
|
response->set_send_start_micros(env_->env->NowMicros());
|
||||||
|
response->set_require_ack(response_cache_ != nullptr);
|
||||||
|
done(status);
|
||||||
|
};
|
||||||
|
|
||||||
|
const int64 request_id = request->request_id();
|
||||||
|
const int64 step_id = request->step_id();
|
||||||
|
|
||||||
|
// If response cache is enabled and the response cache already contains the
|
||||||
|
// request, we delegate this retry request to the response cache. Otherwise,
|
||||||
|
// we add the request to the response cache and start the computation to
|
||||||
|
// retrieve the requested data.
|
||||||
|
if (response_cache_ &&
|
||||||
|
response_cache_->QueueRequest(request_id, step_id, do_response)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
|
||||||
|
const Status& status) {
|
||||||
|
if (response_cache_) {
|
||||||
|
// Data is ready. Process all pending requests in the response cache.
|
||||||
|
response_cache_->OnRequestFinished(request_id, tensor, false, status);
|
||||||
|
} else {
|
||||||
|
do_response(tensor, false, status);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto fail = [&rendezvous_done](const Status& status) {
|
||||||
|
rendezvous_done(Tensor(), status);
|
||||||
|
};
|
||||||
|
|
||||||
// This is a generic, low performance implementation appropriate for grpc.
|
// This is a generic, low performance implementation appropriate for grpc.
|
||||||
Status s = recent_request_ids_.TrackUnique(request->request_id(),
|
Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
|
||||||
"RecvBuf (GrpcWorker)", *request);
|
*request);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
fail(s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CollectiveExecutor::Handle ce_handle(
|
CollectiveExecutor::Handle ce_handle(
|
||||||
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
|
env_->collective_executor_mgr->FindOrCreate(step_id), true);
|
||||||
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
|
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
|
||||||
rma->buf_rendezvous()->ConsumeBuf(
|
rma->buf_rendezvous()->ConsumeBuf(
|
||||||
request->buf_rendezvous_key(),
|
request->buf_rendezvous_key(),
|
||||||
[this, request, response, done](const Status& status,
|
[this, request, rendezvous_done](const Status& status,
|
||||||
BufRendezvous::Hook* hook) {
|
BufRendezvous::Hook* hook) {
|
||||||
Status s = status;
|
Status s = status;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
@ -594,27 +652,17 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
|||||||
hook->prod_value->shape());
|
hook->prod_value->shape());
|
||||||
hook->prod_ctx->CopyDeviceTensorToCPU(
|
hook->prod_ctx->CopyDeviceTensorToCPU(
|
||||||
hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
|
hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
|
||||||
[this, num_bytes, response, done, hook,
|
[hook, cpu_tensor, rendezvous_done](const Status& s) {
|
||||||
cpu_tensor](const Status& s) {
|
rendezvous_done(*cpu_tensor, s);
|
||||||
if (s.ok()) {
|
|
||||||
SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor,
|
|
||||||
num_bytes, response);
|
|
||||||
}
|
|
||||||
response->set_send_start_micros(env_->env->NowMicros());
|
|
||||||
done(s);
|
|
||||||
BufRendezvous::DoneWithHook(hook);
|
BufRendezvous::DoneWithHook(hook);
|
||||||
delete cpu_tensor;
|
delete cpu_tensor;
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Tensor is on CPU.
|
|
||||||
SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value,
|
|
||||||
num_bytes, response);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
response->set_send_start_micros(env_->env->NowMicros());
|
|
||||||
done(s);
|
rendezvous_done(*hook->prod_value, s);
|
||||||
BufRendezvous::DoneWithHook(hook);
|
BufRendezvous::DoneWithHook(hook);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -646,8 +694,25 @@ void GrpcWorker::LoggingAsync(const LoggingRequest* request,
|
|||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||||
|
CleanupGraphResponse* response,
|
||||||
|
StatusCallback done) {
|
||||||
|
if (response_cache_) {
|
||||||
|
// Cleanup any stale response cache entries for this step. This can occur if
|
||||||
|
// a worker crashes before acking a request.
|
||||||
|
response_cache_->CleanEntriesForStep(request->step_id());
|
||||||
|
}
|
||||||
|
Worker::CleanupGraphAsync(request, response, done);
|
||||||
|
}
|
||||||
|
|
||||||
WorkerEnv* GrpcWorker::env() { return env_; }
|
WorkerEnv* GrpcWorker::env() { return env_; }
|
||||||
|
|
||||||
|
void GrpcWorker::RemoveCacheEntryForId(int64 request_id) {
|
||||||
|
if (response_cache_) {
|
||||||
|
response_cache_->EraseRequestId(request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
|
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
|
||||||
const ConfigProto& config) {
|
const ConfigProto& config) {
|
||||||
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
|
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
|
||||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||||
|
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||||
|
|
||||||
namespace grpc {
|
namespace grpc {
|
||||||
class ByteBuffer;
|
class ByteBuffer;
|
||||||
@ -33,6 +34,7 @@ class AsyncServiceInterface;
|
|||||||
class ConfigProto;
|
class ConfigProto;
|
||||||
struct WorkerEnv;
|
struct WorkerEnv;
|
||||||
struct WorkerSession;
|
struct WorkerSession;
|
||||||
|
class GrpcResponseCache;
|
||||||
|
|
||||||
class GrpcWorker : public Worker {
|
class GrpcWorker : public Worker {
|
||||||
public:
|
public:
|
||||||
@ -44,15 +46,24 @@ class GrpcWorker : public Worker {
|
|||||||
::grpc::ByteBuffer* response,
|
::grpc::ByteBuffer* response,
|
||||||
StatusCallback done);
|
StatusCallback done);
|
||||||
|
|
||||||
virtual void LoggingAsync(const LoggingRequest* request,
|
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||||
LoggingResponse* response, StatusCallback done);
|
StatusCallback done) override;
|
||||||
|
|
||||||
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||||
RecvBufResponse* response, StatusCallback done);
|
RecvBufResponse* response, StatusCallback done) override;
|
||||||
|
|
||||||
|
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||||
|
CleanupGraphResponse* response,
|
||||||
|
StatusCallback done) override;
|
||||||
|
|
||||||
WorkerEnv* env();
|
WorkerEnv* env();
|
||||||
|
|
||||||
|
void EnableResponseCache();
|
||||||
|
|
||||||
|
void RemoveCacheEntryForId(int64 request_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::unique_ptr<GrpcResponseCache> response_cache_;
|
||||||
const int32 recv_buf_max_chunk_;
|
const int32 recv_buf_max_chunk_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -64,8 +75,14 @@ struct GrpcWorkerServiceOptions {
|
|||||||
// default queue depth for a method.
|
// default queue depth for a method.
|
||||||
std::unordered_map<int, int> queue_depth;
|
std::unordered_map<int, int> queue_depth;
|
||||||
int num_serving_threads = 8;
|
int num_serving_threads = 8;
|
||||||
int64 response_cache_bytes = 0;
|
|
||||||
int64 response_cache_expires_seconds = 0;
|
// Setting cache_rpc_response to true will enable sender side caching of
|
||||||
|
// response for RecvTensorAsync and RecvBufAsync to allow receiver to retry
|
||||||
|
// requests . This is only necessary when the network fabric is experiencing a
|
||||||
|
// significant error rate. Without it we'll fail a step on an network error,
|
||||||
|
// while with it we'll be able to complete long steps (like complex
|
||||||
|
// initializations) in the face of some network errors during RecvTensor.
|
||||||
|
bool cache_rpc_response = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns an implementation of WorkerService rpc service.
|
// Returns an implementation of WorkerService rpc service.
|
||||||
|
@ -58,6 +58,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
|
|||||||
return "/tensorflow.WorkerService/CompleteInstance";
|
return "/tensorflow.WorkerService/CompleteInstance";
|
||||||
case GrpcWorkerMethod::kGetStepSequence:
|
case GrpcWorkerMethod::kGetStepSequence:
|
||||||
return "/tensorflow.WorkerService/GetStepSequence";
|
return "/tensorflow.WorkerService/GetStepSequence";
|
||||||
|
case GrpcWorkerMethod::kMarkRecvFinished:
|
||||||
|
return "/tensorflow.WorkerService/MarkRecvFinished";
|
||||||
}
|
}
|
||||||
// Shouldn't be reached.
|
// Shouldn't be reached.
|
||||||
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
|
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
|
||||||
|
@ -85,10 +85,11 @@ enum class GrpcWorkerMethod {
|
|||||||
kCompleteGroup,
|
kCompleteGroup,
|
||||||
kCompleteInstance,
|
kCompleteInstance,
|
||||||
kGetStepSequence,
|
kGetStepSequence,
|
||||||
|
kMarkRecvFinished,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const int kGrpcNumWorkerMethods =
|
static const int kGrpcNumWorkerMethods =
|
||||||
static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
|
static_cast<int>(GrpcWorkerMethod::kMarkRecvFinished) + 1;
|
||||||
|
|
||||||
const char* GrpcWorkerMethodName(GrpcWorkerMethod id);
|
const char* GrpcWorkerMethodName(GrpcWorkerMethod id);
|
||||||
|
|
||||||
|
@ -246,7 +246,7 @@ bool TensorResponse::ParseFast(Source* source) {
|
|||||||
case RecvTensorResponse::kIsDeadFieldNumber: {
|
case RecvTensorResponse::kIsDeadFieldNumber: {
|
||||||
uint32 v;
|
uint32 v;
|
||||||
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
|
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
|
||||||
meta_.set_is_dead((v != 0) ? true : false);
|
meta_.set_is_dead(v != 0);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case RecvTensorResponse::kSendStartMicrosFieldNumber: {
|
case RecvTensorResponse::kSendStartMicrosFieldNumber: {
|
||||||
@ -261,6 +261,12 @@ bool TensorResponse::ParseFast(Source* source) {
|
|||||||
return false;
|
return false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case RecvTensorResponse::kRequireAckFieldNumber: {
|
||||||
|
uint32 v;
|
||||||
|
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
|
||||||
|
meta_.set_require_ack(v != 0);
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
// Unknown tag, so don't handle we can't handle on the fast path
|
// Unknown tag, so don't handle we can't handle on the fast path
|
||||||
return false;
|
return false;
|
||||||
|
@ -362,8 +362,20 @@ message RecvTensorResponse {
|
|||||||
// Optional additional information about how to receive the tensor,
|
// Optional additional information about how to receive the tensor,
|
||||||
// e.g. in the event that `RecvTensorRequest.dma_ok` was true.
|
// e.g. in the event that `RecvTensorRequest.dma_ok` was true.
|
||||||
google.protobuf.Any transport_options = 4;
|
google.protobuf.Any transport_options = 4;
|
||||||
|
|
||||||
|
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
|
||||||
|
// to ack the message.
|
||||||
|
bool require_ack = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Message for managing the response cache maintained on the sender side.
|
||||||
|
// Currently only used by the gRPC worker service.
|
||||||
|
message MarkRecvFinishedRequest {
|
||||||
|
int64 request_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message MarkRecvFinishedResponse {}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
//
|
//
|
||||||
// Logging method request/response messages
|
// Logging method request/response messages
|
||||||
@ -490,6 +502,10 @@ message RecvBufResponse {
|
|||||||
google.protobuf.Any transport_options = 4;
|
google.protobuf.Any transport_options = 4;
|
||||||
// Optional, for timeline.
|
// Optional, for timeline.
|
||||||
int64 send_start_micros = 5;
|
int64 send_start_micros = 5;
|
||||||
|
|
||||||
|
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
|
||||||
|
// to ack the message.
|
||||||
|
bool require_ack = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
x
Reference in New Issue
Block a user