TF/gRPC: Use explicit request from receiver to clean RPC cache
This simplifies the response cache management on the sender side and avoids the need for large response cache size. PiperOrigin-RevId: 239638065
This commit is contained in:
parent
90085135c2
commit
a6d233a581
tensorflow
contrib/gdr
core
@ -128,7 +128,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
StatusCallback copy_ready = [response, done, copy,
|
||||
is_dead](const Status& s) {
|
||||
// The value is now ready to be returned on the wire.
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, false, response);
|
||||
done(s);
|
||||
delete copy;
|
||||
};
|
||||
@ -136,7 +136,7 @@ void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
send_dev_context->CopyDeviceTensorToCPU(
|
||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
||||
} else {
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, val, false, response);
|
||||
done(Status::OK());
|
||||
}
|
||||
}
|
||||
|
@ -165,8 +165,9 @@ cc_library(
|
||||
srcs = ["grpc_response_cache.cc"],
|
||||
hdrs = ["grpc_response_cache.h"],
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -66,6 +66,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
|
||||
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
|
||||
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
|
||||
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
|
||||
logger_(logger) {}
|
||||
|
||||
~GrpcRemoteWorker() override {}
|
||||
@ -130,12 +131,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
int64 start_usec = Env::Default()->NowMicros();
|
||||
// Type-specialized logging for this method.
|
||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||
StatusCallback wrapper_done;
|
||||
const StatusCallback* cb_to_use;
|
||||
if (!logging_active) {
|
||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
||||
} else {
|
||||
wrapper_done = [this, request, response, done, start_usec](Status s) {
|
||||
|
||||
auto callback = [this, request, response, done, start_usec,
|
||||
logging_active](Status s) {
|
||||
if (logging_active) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
@ -159,12 +158,17 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->DebugString();
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
}
|
||||
}
|
||||
|
||||
IssueRequest(request, response, recvbuf_, *cb_to_use, call_opts);
|
||||
// Note done() can delete this worker object, so we need to call done()
|
||||
// last.
|
||||
if (response->require_ack()) {
|
||||
IssueMarkRecvFinishedRequest(request->request_id());
|
||||
}
|
||||
done(s);
|
||||
};
|
||||
|
||||
IssueRequest(request, response, recvbuf_, callback, call_opts);
|
||||
}
|
||||
|
||||
void CompleteGroupAsync(CallOptions* call_opts,
|
||||
@ -194,12 +198,10 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
int64 start_usec = Env::Default()->NowMicros();
|
||||
// Type-specialized logging for this method.
|
||||
bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
|
||||
StatusCallback wrapper_done;
|
||||
const StatusCallback* cb_to_use;
|
||||
if (!logging_active) {
|
||||
cb_to_use = &done; // No additional work to do, so just use done directly
|
||||
} else {
|
||||
wrapper_done = [this, request, response, done, start_usec](Status s) {
|
||||
|
||||
auto callback = [this, request, response, done, start_usec,
|
||||
logging_active](Status s) {
|
||||
if (logging_active) {
|
||||
if (logger_->LoggingActive()) {
|
||||
int64 end_usec = Env::Default()->NowMicros();
|
||||
int64 step_id = request->step_id();
|
||||
@ -238,12 +240,17 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
}
|
||||
VLOG(2) << "done callback, req: " << request->DebugString()
|
||||
<< " response " << response->metadata().DebugString();
|
||||
done(s);
|
||||
};
|
||||
cb_to_use = &wrapper_done;
|
||||
}
|
||||
}
|
||||
|
||||
IssueRequest(request, response, recvtensor_, *cb_to_use, call_opts);
|
||||
// Note done() can delete this worker object, so we need to call done()
|
||||
// last.
|
||||
if (response->metadata().require_ack()) {
|
||||
IssueMarkRecvFinishedRequest(request->request_id());
|
||||
}
|
||||
done(s);
|
||||
};
|
||||
|
||||
IssueRequest(request, response, recvtensor_, callback, call_opts);
|
||||
}
|
||||
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
@ -276,6 +283,16 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
callback_threadpool_, max_retries);
|
||||
}
|
||||
|
||||
void IssueMarkRecvFinishedRequest(int64 request_id) {
|
||||
VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
|
||||
MarkRecvFinishedRequest request;
|
||||
request.set_request_id(request_id);
|
||||
|
||||
MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
|
||||
auto done = [response](Status status) { delete response; };
|
||||
IssueRequest(&request, response, markrecvfinished_, done);
|
||||
}
|
||||
|
||||
// Helper function for initializing the RpcMethod objects below.
|
||||
const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
|
||||
|
||||
@ -299,6 +316,7 @@ class GrpcRemoteWorker : public WorkerInterface {
|
||||
const ::grpc::string completegroup_;
|
||||
const ::grpc::string instancesource_;
|
||||
const ::grpc::string getstepsequence_;
|
||||
const ::grpc::string markrecvfinished_;
|
||||
|
||||
// Support for logging.
|
||||
WorkerCacheLogger* logger_;
|
||||
|
@ -14,170 +14,102 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct WorkerCacheEntry {
|
||||
enum class State {
|
||||
PENDING = 0,
|
||||
ACTIVE = 1,
|
||||
FINISHED = 2,
|
||||
};
|
||||
bool GrpcResponseCache::QueueRequest(int64 request_id, int64 step_id,
|
||||
const FinishResponseCB& cb) {
|
||||
VLOG(1) << "GrpcResponseCache Lookup " << request_id;
|
||||
|
||||
State state = State::PENDING;
|
||||
int64 expires_seconds;
|
||||
mu_.lock();
|
||||
|
||||
::grpc::ByteBuffer response_buf;
|
||||
Status response_status;
|
||||
ResponseCacheEntry& entry = response_cache_[request_id];
|
||||
|
||||
// Additional retries may arrive while a request is still executing. The
|
||||
// callbacks for these calls are queued in `callbacks` and evaluated after
|
||||
// the original request is completed.
|
||||
std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
|
||||
};
|
||||
if (entry.state == ResponseCacheEntry::State::FINISHED) {
|
||||
VLOG(1) << "Reuse cached response for " << request_id;
|
||||
// Make a copy of the ResponseCacheEntry so that we can run FinishResponse
|
||||
// outside the critical section. FinishResponse can be potentially
|
||||
// expensive.
|
||||
auto entry_copy = entry;
|
||||
|
||||
void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
|
||||
if (buf_ != nullptr) {
|
||||
*tgt = *buf_;
|
||||
mu_.unlock();
|
||||
entry_copy.FinishResponse(cb);
|
||||
return true;
|
||||
}
|
||||
|
||||
entry.callbacks.emplace_back(cb);
|
||||
|
||||
if (entry.state == ResponseCacheEntry::State::ACTIVE) {
|
||||
VLOG(1) << "Found active request for " << request_id
|
||||
<< ". Adding entry to response queue.";
|
||||
mu_.unlock();
|
||||
return true;
|
||||
} else {
|
||||
CHECK(msg_ != nullptr);
|
||||
::grpc::Slice slice(msg_->ByteSizeLong());
|
||||
msg_->SerializeWithCachedSizesToArray(
|
||||
const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
|
||||
::grpc::ByteBuffer tmp(&slice, 1);
|
||||
tgt->Swap(&tmp);
|
||||
VLOG(2) << "No cache entry for " << request_id
|
||||
<< ", running user computation.";
|
||||
entry.step_id = step_id;
|
||||
entry.state = ResponseCacheEntry::State::ACTIVE;
|
||||
mu_.unlock();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
|
||||
if (buf_ != nullptr) {
|
||||
*buf_ = src;
|
||||
return;
|
||||
}
|
||||
void GrpcResponseCache::OnRequestFinished(int64 request_id,
|
||||
const Tensor& tensor, bool is_dead,
|
||||
const Status& status) {
|
||||
absl::optional<ResponseCacheEntry> entry_copy;
|
||||
|
||||
CHECK(msg_ != nullptr);
|
||||
// We create a single slice when encoding protocol messages.
|
||||
std::vector<::grpc::Slice> slices;
|
||||
if (src.Dump(&slices).ok()) {
|
||||
msg_->ParseFromArray(slices[0].begin(), slices[0].size());
|
||||
} else {
|
||||
LOG(ERROR) << "Failed to decode cached buffer.";
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
|
||||
ComputeFunc compute_func,
|
||||
StatusCallback done_cb) {
|
||||
VLOG(1) << "Lookup " << key;
|
||||
std::shared_ptr<WorkerCacheEntry> req;
|
||||
MaybeCleanup();
|
||||
{
|
||||
mutex_lock m(mu_);
|
||||
|
||||
if (requests_.find(key) != requests_.end()) {
|
||||
req = requests_[key];
|
||||
} else {
|
||||
req.reset(new WorkerCacheEntry);
|
||||
requests_[key] = req;
|
||||
}
|
||||
|
||||
if (req->state == WorkerCacheEntry::State::FINISHED) {
|
||||
if (req->expires_seconds > Env::Default()->NowSeconds()) {
|
||||
VLOG(1) << "Reuse cached response for " << key;
|
||||
response.CopyFrom(req->response_buf);
|
||||
done_cb(req->response_status);
|
||||
return;
|
||||
}
|
||||
VLOG(1) << "Found expired cache entry for " << key;
|
||||
req->state = WorkerCacheEntry::State::PENDING;
|
||||
req->response_buf.Clear();
|
||||
}
|
||||
|
||||
req->callbacks.push_back(std::make_pair(response, done_cb));
|
||||
|
||||
if (req->state == WorkerCacheEntry::State::ACTIVE) {
|
||||
VLOG(1) << "Found active request for " << key
|
||||
<< ". Adding entry to response queue.";
|
||||
auto it = response_cache_.find(request_id);
|
||||
if (it == response_cache_.end()) {
|
||||
LOG(ERROR) << "Unexpected missing response cache entry for request "
|
||||
<< request_id;
|
||||
return;
|
||||
}
|
||||
ResponseCacheEntry& entry = it->second;
|
||||
|
||||
VLOG(2) << "No cache entry for " << key << ", running user computation.";
|
||||
req->state = WorkerCacheEntry::State::ACTIVE;
|
||||
req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
|
||||
VLOG(1) << "Operation for " << request_id << " finished. "
|
||||
<< "Status: " << status << ", tensor size " << tensor.TotalBytes()
|
||||
<< " bytes, " << entry.callbacks.size() << " pending callbacks.";
|
||||
|
||||
entry.tensor = tensor;
|
||||
entry.is_dead = is_dead;
|
||||
entry.response_status = status;
|
||||
entry.state = ResponseCacheEntry::State::FINISHED;
|
||||
|
||||
// We copy the extra work out of the critical section in order to avoid
|
||||
// serializing the work for sending response.
|
||||
entry_copy = entry;
|
||||
|
||||
entry.callbacks.clear();
|
||||
}
|
||||
|
||||
compute_func([this, key, req, response](Status status) {
|
||||
mutex_lock m(mu_);
|
||||
response.Encode(&req->response_buf);
|
||||
current_bytes_ += req->response_buf.Length();
|
||||
|
||||
req->response_status = status;
|
||||
req->state = WorkerCacheEntry::State::FINISHED;
|
||||
|
||||
VLOG(1) << "Operation for " << key << " finished. "
|
||||
<< "Status: " << status << ", " << req->response_buf.Length()
|
||||
<< " response bytes, " << req->callbacks.size()
|
||||
<< " pending callbacks.";
|
||||
for (auto& cb : req->callbacks) {
|
||||
cb.first.CopyFrom(req->response_buf);
|
||||
cb.second(req->response_status);
|
||||
}
|
||||
req->callbacks.clear();
|
||||
});
|
||||
for (auto& cb : entry_copy->callbacks) {
|
||||
entry_copy->FinishResponse(cb);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove all stale or expired cache entries if the cache is full.
|
||||
void GrpcResponseCache::MaybeCleanup() {
|
||||
void GrpcResponseCache::EraseRequestId(int64 request_id) {
|
||||
mutex_lock m(mu_);
|
||||
if (current_bytes_ < max_bytes_) {
|
||||
return;
|
||||
}
|
||||
response_cache_.erase(request_id);
|
||||
}
|
||||
|
||||
VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
|
||||
std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
|
||||
ordered_entries;
|
||||
ordered_entries.reserve(requests_.size());
|
||||
for (const auto& p : requests_) {
|
||||
ordered_entries.push_back(std::make_pair(p.first, p.second));
|
||||
}
|
||||
|
||||
std::sort(ordered_entries.begin(), ordered_entries.end(),
|
||||
[](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
|
||||
const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
|
||||
return a.second->expires_seconds > b.second->expires_seconds;
|
||||
});
|
||||
|
||||
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
|
||||
int64 now = Env::Default()->NowSeconds();
|
||||
int64 bytes_used = 0;
|
||||
|
||||
// Always keep active requests.
|
||||
for (auto& pair : ordered_entries) {
|
||||
if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
|
||||
kept.insert(pair);
|
||||
void GrpcResponseCache::CleanEntriesForStep(int64 step_id) {
|
||||
mutex_lock m(mu_);
|
||||
// Remove all cache entries whose step id is the given step_id
|
||||
for (auto it = response_cache_.begin(), last = response_cache_.end();
|
||||
it != last;) {
|
||||
if (it->second.step_id == step_id) {
|
||||
VLOG(1) << "Erase stale GrpcResponseCache entry " << it->first;
|
||||
it = response_cache_.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// Keep unexpired, finished requests up to half of max_bytes_. This reduces
|
||||
// chances of overfilling the cache when active requests complete and
|
||||
// amortizes cache cleanup cost.
|
||||
for (auto& pair : ordered_entries) {
|
||||
if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
|
||||
kept.insert(pair);
|
||||
bytes_used += pair.second->response_buf.Length();
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "Cleaned cache. Bytes used: " << current_bytes_ << " -> "
|
||||
<< bytes_used << ". Cache size: " << requests_.size() << " -> "
|
||||
<< kept.size();
|
||||
current_bytes_ = bytes_used;
|
||||
std::swap(requests_, kept);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -19,71 +19,74 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
|
||||
// gRPC response caching. Most WorkerService methods cannot be retried directly
|
||||
// as they will fail or deadlock. To enable retrying, we can instead cache
|
||||
// responses for a short period of time and reply to duplicate requests from the
|
||||
// cache.
|
||||
// responses and reply to duplicate requests from the cache. The cache will be
|
||||
// cleaned when the MarkRecvFinishedRequest is received from the receiver or the
|
||||
// session step is completed.
|
||||
namespace tensorflow {
|
||||
|
||||
// Union type to aid caching of either raw buffers (for RecvTensor RPCs) and
|
||||
// protocol buffer messages (for all other RPCs).
|
||||
class RPCResponse {
|
||||
public:
|
||||
explicit RPCResponse() : buf_(nullptr), msg_(nullptr) {}
|
||||
explicit RPCResponse(::grpc::ByteBuffer* b) : buf_(b), msg_(nullptr) {}
|
||||
explicit RPCResponse(protobuf::Message* m) : buf_(nullptr), msg_(m) {}
|
||||
|
||||
// Encode this response into the target buffer.
|
||||
void Encode(::grpc::ByteBuffer* tgt) const;
|
||||
|
||||
// Copy from `src`: if this is a buffer, make a shallow copy.
|
||||
// For protocol messages, parse the response from `src`.
|
||||
void CopyFrom(const ::grpc::ByteBuffer& src);
|
||||
|
||||
private:
|
||||
::grpc::ByteBuffer* buf_;
|
||||
protobuf::Message* msg_;
|
||||
};
|
||||
|
||||
typedef std::function<void(StatusCallback)> ComputeFunc;
|
||||
struct WorkerCacheEntry;
|
||||
|
||||
// Track and cache the state of worker service RPCs. An RPC can be in 3 states:
|
||||
//
|
||||
// * PENDING: this is the first call of the RPC, and it will transition to
|
||||
// * ACTIVE: another thread is active processing this RPC
|
||||
// * FINISHED: the worker has finished processing the method
|
||||
//
|
||||
// The response from completed RPCs are LRU cached until either `max_bytes`
|
||||
// bytes are in use by the cache or they expire (according to `expire_time`).
|
||||
|
||||
class GrpcResponseCache {
|
||||
public:
|
||||
GrpcResponseCache(int64 max_bytes, int64 expire_time_seconds)
|
||||
: max_bytes_(max_bytes), expire_time_seconds_(expire_time_seconds) {}
|
||||
using FinishResponseCB = std::function<void(
|
||||
const Tensor& tensor, bool is_dead, const Status& status)>;
|
||||
|
||||
// Lookup the result for key.
|
||||
// If it is finished, invoke `done_cb` immediately after filling `response`.
|
||||
// If active, done_db will be invoked when the current call completes.
|
||||
// Otherwise, invoke `compute_func` to fill the cache and invoke done_cb.
|
||||
void LookupOrCompute(const string& key, RPCResponse response,
|
||||
ComputeFunc compute_func, StatusCallback done_cb);
|
||||
// Add the given request to the cache.
|
||||
// If the request is in the cache,
|
||||
// If it is finished, invoke `cb` immediately
|
||||
// If active, cb will be invoked when the current call completes.
|
||||
// In either case, return true.
|
||||
// Otherwise, store the request and cb in the cache, and return false.
|
||||
// Note FinishResponseCB is assumed to be thread-safe.
|
||||
bool QueueRequest(int64 request_id, int64 step_id,
|
||||
const FinishResponseCB& cb);
|
||||
|
||||
// Remove all stale or expired cache entries if the cache is full.
|
||||
void MaybeCleanup();
|
||||
// Fill the response cache for the given request_id and respond to all
|
||||
// pending request.
|
||||
void OnRequestFinished(int64 request_id, const Tensor& tensor, bool is_dead,
|
||||
const Status& status);
|
||||
|
||||
// Erase the cache entry with the given request_id
|
||||
void EraseRequestId(int64 request_id);
|
||||
|
||||
// Erase cache entries with the given step_id
|
||||
void CleanEntriesForStep(int64 step_id);
|
||||
|
||||
private:
|
||||
int64 current_bytes_ GUARDED_BY(mu_) = 0;
|
||||
const int64 max_bytes_;
|
||||
const int64 expire_time_seconds_;
|
||||
struct ResponseCacheEntry {
|
||||
enum class State {
|
||||
PENDING = 0,
|
||||
ACTIVE = 1,
|
||||
FINISHED = 2,
|
||||
};
|
||||
|
||||
State state = State::PENDING;
|
||||
int64 step_id = -1;
|
||||
Tensor tensor;
|
||||
bool is_dead = false;
|
||||
Status response_status;
|
||||
|
||||
void FinishResponse(const FinishResponseCB& cb) const {
|
||||
cb(tensor, is_dead, response_status);
|
||||
}
|
||||
std::vector<FinishResponseCB> callbacks;
|
||||
};
|
||||
|
||||
std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> requests_
|
||||
GUARDED_BY(mu_);
|
||||
mutex mu_;
|
||||
// response_cache_ is expected to be small, as entries are cleared immediately
|
||||
// on ack from the receiver.
|
||||
gtl::FlatMap<int64, ResponseCacheEntry> response_cache_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -135,13 +135,14 @@ static void EncodeSkeleton(const Tensor& val, io::ProtoEncodeHelper* e) {
|
||||
#endif
|
||||
}
|
||||
|
||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
|
||||
::grpc::ByteBuffer* result) {
|
||||
const int kLargeTensorBytes = 1024;
|
||||
RecvTensorResponse response;
|
||||
if (is_dead) {
|
||||
response.set_is_dead(is_dead);
|
||||
}
|
||||
response.set_require_ack(require_ack);
|
||||
response.set_send_start_micros(Env::Default()->NowMicros());
|
||||
if (!DataTypeCanUseMemcpy(val.dtype())) {
|
||||
// Straightforward but slow path for complicated kinds of tensor data
|
||||
|
@ -46,7 +46,7 @@ void EncodeRecvTensorResponseToByteBuffer(const RecvTensorResponse& proto,
|
||||
// "val" holds the tensor value to be encoded.
|
||||
//
|
||||
// Discards original contents of *result.
|
||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
||||
void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack,
|
||||
::grpc::ByteBuffer* result);
|
||||
|
||||
} // namespace grpc
|
||||
|
@ -31,7 +31,7 @@ class GrpcTensorCodingTest : public ::testing::Test {
|
||||
void Validate(const Tensor& t, bool is_dead) {
|
||||
// Check by encoding to a ByteBuffer
|
||||
::grpc::ByteBuffer buf;
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, t, &buf);
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, t, false, &buf);
|
||||
|
||||
// Make a string
|
||||
std::vector<::grpc::Slice> slices;
|
||||
|
@ -146,6 +146,7 @@ class GrpcWorkerServiceThread {
|
||||
SETUP_FOR_REQUEST(RecvBuf, 500, true);
|
||||
SETUP_FOR_REQUEST(RunGraph, 100, true);
|
||||
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
|
||||
SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
|
||||
|
||||
// TODO(ncteisen): Determine a better policy for enqueuing the
|
||||
// appropriate number of each request type.
|
||||
@ -221,6 +222,14 @@ class GrpcWorkerServiceThread {
|
||||
ENQUEUE_REQUEST(GetStepSequence, true);
|
||||
}
|
||||
|
||||
void MarkRecvFinishedHandler(
|
||||
WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
|
||||
VLOG(1) << "Clean cache entry for request " << call->request.request_id();
|
||||
worker_->RemoveCacheEntryForId(call->request.request_id());
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
ENQUEUE_REQUEST(MarkRecvFinished, false);
|
||||
}
|
||||
|
||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
Schedule([this, call]() {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
@ -229,32 +238,19 @@ class GrpcWorkerServiceThread {
|
||||
NonOwnedProtoRunGraphResponse* wrapped_response =
|
||||
new NonOwnedProtoRunGraphResponse(&call->response);
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
auto done_cb = [call, call_opts, wrapped_request,
|
||||
wrapped_response](const Status& s) {
|
||||
VLOG(1) << "RunGraph::Done";
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Bad response from RunGraph:" << s;
|
||||
}
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
delete wrapped_response;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
};
|
||||
|
||||
auto compute_fn = [this, call_opts, wrapped_request,
|
||||
wrapped_response](StatusCallback done) {
|
||||
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
||||
done);
|
||||
};
|
||||
|
||||
if (cache_) {
|
||||
string request_key = call->request.ShortDebugString();
|
||||
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
|
||||
compute_fn, done_cb);
|
||||
} else {
|
||||
compute_fn(done_cb);
|
||||
}
|
||||
worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
|
||||
[call, call_opts, wrapped_request,
|
||||
wrapped_response](const Status& s) {
|
||||
VLOG(1) << "RunGraph::Done";
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Bad response from RunGraph:" << s;
|
||||
}
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
delete wrapped_request;
|
||||
delete wrapped_response;
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
});
|
||||
ENQUEUE_REQUEST(RunGraph, true);
|
||||
}
|
||||
@ -265,27 +261,16 @@ class GrpcWorkerServiceThread {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
|
||||
|
||||
auto done_cb = [call, call_opts](const Status& s) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Bad response from RecvTensor:" << s;
|
||||
}
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
};
|
||||
|
||||
auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
|
||||
worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
|
||||
done);
|
||||
};
|
||||
|
||||
if (cache_) {
|
||||
string request_key = call->request.ShortDebugString();
|
||||
cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
|
||||
compute_fn, done_cb);
|
||||
} else {
|
||||
compute_fn(done_cb);
|
||||
}
|
||||
worker_->GrpcRecvTensorAsync(
|
||||
call_opts, &call->request, &call->response,
|
||||
[call, call_opts](const Status& s) {
|
||||
call->ClearCancelCallback();
|
||||
delete call_opts;
|
||||
if (!s.ok()) {
|
||||
VLOG(1) << "Bad response from RecvTensor:" << s;
|
||||
}
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
});
|
||||
EnqueueRecvTensorRequestRaw();
|
||||
}
|
||||
@ -377,10 +362,10 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
GrpcWorkerServiceOptions options)
|
||||
: is_shutdown_(false) {
|
||||
builder->RegisterService(&worker_service_);
|
||||
if (options.response_cache_bytes > 0) {
|
||||
cache_.reset(
|
||||
new GrpcResponseCache(options.response_cache_bytes,
|
||||
options.response_cache_expires_seconds));
|
||||
// TODO(jingdong): it would be cleaner to move this option to GrpcWorker
|
||||
// since the cache is maintained by GrpcWorker now.
|
||||
if (options.cache_rpc_response) {
|
||||
worker->EnableResponseCache();
|
||||
}
|
||||
|
||||
for (int i = 0; i < options.num_serving_threads; i++) {
|
||||
@ -437,6 +422,11 @@ GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
|
||||
? config.experimental().recv_buf_max_chunk()
|
||||
: (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
|
||||
|
||||
void GrpcWorker::EnableResponseCache() {
|
||||
VLOG(1) << "Enabling gRPC tensor response cache.";
|
||||
response_cache_ = absl::make_unique<GrpcResponseCache>();
|
||||
}
|
||||
|
||||
// GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
|
||||
// buffers for a response object, to avoid extra protocol buffer serialization
|
||||
// overhead we generate our response directly into a ::grpc::ByteBuffer object
|
||||
@ -444,14 +434,49 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
const RecvTensorRequest* request,
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done) {
|
||||
Status s = recent_request_ids_.TrackUnique(
|
||||
request->request_id(), "RecvTensor (GrpcWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
|
||||
const Status& status) {
|
||||
if (status.ok()) {
|
||||
bool require_ack = (response_cache_ != nullptr);
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, tensor, require_ack, response);
|
||||
}
|
||||
done(status);
|
||||
};
|
||||
|
||||
const int64 request_id = request->request_id();
|
||||
const int64 step_id = request->step_id();
|
||||
|
||||
// If response cache is enabled and the response cache already contains the
|
||||
// request, we delegate this retry request to the response cache. Otherwise,
|
||||
// we add the request to the response cache and start the computation to
|
||||
// retrieve the requested data.
|
||||
if (response_cache_ &&
|
||||
response_cache_->QueueRequest(request_id, step_id, do_response)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
|
||||
bool is_dead,
|
||||
const Status& status) {
|
||||
if (response_cache_) {
|
||||
// Data is ready. Process all pending requests in the response cache.
|
||||
response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
|
||||
} else {
|
||||
do_response(tensor, is_dead, status);
|
||||
}
|
||||
};
|
||||
|
||||
auto fail = [&rendezvous_done](const Status& status) {
|
||||
rendezvous_done(Tensor(), false, status);
|
||||
};
|
||||
|
||||
Status s = recent_request_ids_.TrackUnique(
|
||||
request_id, "RecvTensor (GrpcWorker)", *request);
|
||||
if (!s.ok()) {
|
||||
fail(s);
|
||||
return;
|
||||
}
|
||||
|
||||
const int64 step_id = request->step_id();
|
||||
const string& key = request->rendezvous_key();
|
||||
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
|
||||
Rendezvous::ParsedKey parsed;
|
||||
@ -461,7 +486,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
s = PrepareRecvTensor(parsed, &src_dev);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
fail(s);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -475,7 +500,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
[step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
|
||||
env_->rendezvous_mgr->RecvLocalAsync(
|
||||
step_id, parsed,
|
||||
[opts, response, done, src_dev, request](
|
||||
[opts, rendezvous_done, src_dev, request](
|
||||
const Status& status, const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args, const Tensor& val,
|
||||
const bool is_dead) {
|
||||
@ -502,25 +527,21 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on an accelerator device. Uses the device_context to
|
||||
// fill the copy on host.
|
||||
StatusCallback copy_ready = [response, done, copy,
|
||||
StatusCallback copy_ready = [rendezvous_done, copy,
|
||||
is_dead](const Status& s) {
|
||||
// The value is now ready to be returned on the wire.
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
|
||||
done(s);
|
||||
rendezvous_done(*copy, is_dead, s);
|
||||
delete copy;
|
||||
};
|
||||
|
||||
send_dev_context->CopyDeviceTensorToCPU(
|
||||
&val, request->rendezvous_key(), src_dev, copy, copy_ready);
|
||||
} else {
|
||||
grpc::EncodeTensorToByteBuffer(is_dead, val, response);
|
||||
done(Status::OK());
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// !s.ok()
|
||||
done(status);
|
||||
}
|
||||
|
||||
rendezvous_done(val, is_dead, status);
|
||||
});
|
||||
}
|
||||
|
||||
@ -537,8 +558,9 @@ namespace {
|
||||
// RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
|
||||
// and remove this function.
|
||||
void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
|
||||
int64 num_bytes, RecvBufResponse* response) {
|
||||
RecvBufResponse* response) {
|
||||
RecvBufRespExtra extra;
|
||||
int64 num_bytes = tensor->TotalBytes();
|
||||
const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
|
||||
while (num_bytes > 0) {
|
||||
int64 bytes =
|
||||
@ -553,20 +575,56 @@ void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
|
||||
|
||||
void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
RecvBufResponse* response, StatusCallback done) {
|
||||
auto do_response = [this, response, done](const Tensor& tensor, bool is_dead,
|
||||
const Status& status) {
|
||||
if (status.ok()) {
|
||||
SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
|
||||
}
|
||||
response->set_send_start_micros(env_->env->NowMicros());
|
||||
response->set_require_ack(response_cache_ != nullptr);
|
||||
done(status);
|
||||
};
|
||||
|
||||
const int64 request_id = request->request_id();
|
||||
const int64 step_id = request->step_id();
|
||||
|
||||
// If response cache is enabled and the response cache already contains the
|
||||
// request, we delegate this retry request to the response cache. Otherwise,
|
||||
// we add the request to the response cache and start the computation to
|
||||
// retrieve the requested data.
|
||||
if (response_cache_ &&
|
||||
response_cache_->QueueRequest(request_id, step_id, do_response)) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto rendezvous_done = [this, request_id, do_response](const Tensor& tensor,
|
||||
const Status& status) {
|
||||
if (response_cache_) {
|
||||
// Data is ready. Process all pending requests in the response cache.
|
||||
response_cache_->OnRequestFinished(request_id, tensor, false, status);
|
||||
} else {
|
||||
do_response(tensor, false, status);
|
||||
}
|
||||
};
|
||||
|
||||
auto fail = [&rendezvous_done](const Status& status) {
|
||||
rendezvous_done(Tensor(), status);
|
||||
};
|
||||
|
||||
// This is a generic, low performance implementation appropriate for grpc.
|
||||
Status s = recent_request_ids_.TrackUnique(request->request_id(),
|
||||
"RecvBuf (GrpcWorker)", *request);
|
||||
Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
|
||||
*request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
fail(s);
|
||||
return;
|
||||
}
|
||||
CollectiveExecutor::Handle ce_handle(
|
||||
env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
|
||||
env_->collective_executor_mgr->FindOrCreate(step_id), true);
|
||||
CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
|
||||
rma->buf_rendezvous()->ConsumeBuf(
|
||||
request->buf_rendezvous_key(),
|
||||
[this, request, response, done](const Status& status,
|
||||
BufRendezvous::Hook* hook) {
|
||||
[this, request, rendezvous_done](const Status& status,
|
||||
BufRendezvous::Hook* hook) {
|
||||
Status s = status;
|
||||
if (s.ok()) {
|
||||
if (!DMAHelper::CanUseDMA(hook->prod_value)) {
|
||||
@ -594,27 +652,17 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
hook->prod_value->shape());
|
||||
hook->prod_ctx->CopyDeviceTensorToCPU(
|
||||
hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
|
||||
[this, num_bytes, response, done, hook,
|
||||
cpu_tensor](const Status& s) {
|
||||
if (s.ok()) {
|
||||
SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor,
|
||||
num_bytes, response);
|
||||
}
|
||||
response->set_send_start_micros(env_->env->NowMicros());
|
||||
done(s);
|
||||
[hook, cpu_tensor, rendezvous_done](const Status& s) {
|
||||
rendezvous_done(*cpu_tensor, s);
|
||||
BufRendezvous::DoneWithHook(hook);
|
||||
delete cpu_tensor;
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
// Tensor is on CPU.
|
||||
SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value,
|
||||
num_bytes, response);
|
||||
}
|
||||
}
|
||||
response->set_send_start_micros(env_->env->NowMicros());
|
||||
done(s);
|
||||
|
||||
rendezvous_done(*hook->prod_value, s);
|
||||
BufRendezvous::DoneWithHook(hook);
|
||||
});
|
||||
}
|
||||
@ -646,8 +694,25 @@ void GrpcWorker::LoggingAsync(const LoggingRequest* request,
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) {
|
||||
if (response_cache_) {
|
||||
// Cleanup any stale response cache entries for this step. This can occur if
|
||||
// a worker crashes before acking a request.
|
||||
response_cache_->CleanEntriesForStep(request->step_id());
|
||||
}
|
||||
Worker::CleanupGraphAsync(request, response, done);
|
||||
}
|
||||
|
||||
WorkerEnv* GrpcWorker::env() { return env_; }
|
||||
|
||||
void GrpcWorker::RemoveCacheEntryForId(int64 request_id) {
|
||||
if (response_cache_) {
|
||||
response_cache_->EraseRequestId(request_id);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
|
||||
const ConfigProto& config) {
|
||||
return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||
#include "tensorflow/core/protobuf/worker.pb.h"
|
||||
|
||||
namespace grpc {
|
||||
class ByteBuffer;
|
||||
@ -33,6 +34,7 @@ class AsyncServiceInterface;
|
||||
class ConfigProto;
|
||||
struct WorkerEnv;
|
||||
struct WorkerSession;
|
||||
class GrpcResponseCache;
|
||||
|
||||
class GrpcWorker : public Worker {
|
||||
public:
|
||||
@ -44,15 +46,24 @@ class GrpcWorker : public Worker {
|
||||
::grpc::ByteBuffer* response,
|
||||
StatusCallback done);
|
||||
|
||||
virtual void LoggingAsync(const LoggingRequest* request,
|
||||
LoggingResponse* response, StatusCallback done);
|
||||
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
RecvBufResponse* response, StatusCallback done);
|
||||
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
|
||||
RecvBufResponse* response, StatusCallback done) override;
|
||||
|
||||
void CleanupGraphAsync(const CleanupGraphRequest* request,
|
||||
CleanupGraphResponse* response,
|
||||
StatusCallback done) override;
|
||||
|
||||
WorkerEnv* env();
|
||||
|
||||
void EnableResponseCache();
|
||||
|
||||
void RemoveCacheEntryForId(int64 request_id);
|
||||
|
||||
private:
|
||||
std::unique_ptr<GrpcResponseCache> response_cache_;
|
||||
const int32 recv_buf_max_chunk_;
|
||||
};
|
||||
|
||||
@ -64,8 +75,14 @@ struct GrpcWorkerServiceOptions {
|
||||
// default queue depth for a method.
|
||||
std::unordered_map<int, int> queue_depth;
|
||||
int num_serving_threads = 8;
|
||||
int64 response_cache_bytes = 0;
|
||||
int64 response_cache_expires_seconds = 0;
|
||||
|
||||
// Setting cache_rpc_response to true will enable sender side caching of
|
||||
// response for RecvTensorAsync and RecvBufAsync to allow receiver to retry
|
||||
// requests . This is only necessary when the network fabric is experiencing a
|
||||
// significant error rate. Without it we'll fail a step on an network error,
|
||||
// while with it we'll be able to complete long steps (like complex
|
||||
// initializations) in the face of some network errors during RecvTensor.
|
||||
bool cache_rpc_response = false;
|
||||
};
|
||||
|
||||
// Returns an implementation of WorkerService rpc service.
|
||||
|
@ -58,6 +58,8 @@ const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
|
||||
return "/tensorflow.WorkerService/CompleteInstance";
|
||||
case GrpcWorkerMethod::kGetStepSequence:
|
||||
return "/tensorflow.WorkerService/GetStepSequence";
|
||||
case GrpcWorkerMethod::kMarkRecvFinished:
|
||||
return "/tensorflow.WorkerService/MarkRecvFinished";
|
||||
}
|
||||
// Shouldn't be reached.
|
||||
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
|
||||
|
@ -85,10 +85,11 @@ enum class GrpcWorkerMethod {
|
||||
kCompleteGroup,
|
||||
kCompleteInstance,
|
||||
kGetStepSequence,
|
||||
kMarkRecvFinished,
|
||||
};
|
||||
|
||||
static const int kGrpcNumWorkerMethods =
|
||||
static_cast<int>(GrpcWorkerMethod::kGetStepSequence) + 1;
|
||||
static_cast<int>(GrpcWorkerMethod::kMarkRecvFinished) + 1;
|
||||
|
||||
const char* GrpcWorkerMethodName(GrpcWorkerMethod id);
|
||||
|
||||
|
@ -246,7 +246,7 @@ bool TensorResponse::ParseFast(Source* source) {
|
||||
case RecvTensorResponse::kIsDeadFieldNumber: {
|
||||
uint32 v;
|
||||
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
|
||||
meta_.set_is_dead((v != 0) ? true : false);
|
||||
meta_.set_is_dead(v != 0);
|
||||
break;
|
||||
}
|
||||
case RecvTensorResponse::kSendStartMicrosFieldNumber: {
|
||||
@ -261,6 +261,12 @@ bool TensorResponse::ParseFast(Source* source) {
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
case RecvTensorResponse::kRequireAckFieldNumber: {
|
||||
uint32 v;
|
||||
if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
|
||||
meta_.set_require_ack(v != 0);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// Unknown tag, so don't handle we can't handle on the fast path
|
||||
return false;
|
||||
|
@ -362,8 +362,20 @@ message RecvTensorResponse {
|
||||
// Optional additional information about how to receive the tensor,
|
||||
// e.g. in the event that `RecvTensorRequest.dma_ok` was true.
|
||||
google.protobuf.Any transport_options = 4;
|
||||
|
||||
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
|
||||
// to ack the message.
|
||||
bool require_ack = 5;
|
||||
}
|
||||
|
||||
// Message for managing the response cache maintained on the sender side.
|
||||
// Currently only used by the gRPC worker service.
|
||||
message MarkRecvFinishedRequest {
|
||||
int64 request_id = 1;
|
||||
}
|
||||
|
||||
message MarkRecvFinishedResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// Logging method request/response messages
|
||||
@ -490,6 +502,10 @@ message RecvBufResponse {
|
||||
google.protobuf.Any transport_options = 4;
|
||||
// Optional, for timeline.
|
||||
int64 send_start_micros = 5;
|
||||
|
||||
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
|
||||
// to ack the message.
|
||||
bool require_ack = 6;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user