Add support for bidirectional streaming RPCs to TF gRPC infrastructure
This change: - Adds an EagerService::StreamingEnqueue RPC method that takes/returns the same protos as EagerService::Enqueue, but is a bidirectional streaming method. The server returns one response message for each request message. Most of the code depends on this assumption. - Adds low-level gRPC logic to handle streaming RPCs on top of grpc::CompletionQueue. gRPC has a nicer callback based async API, but it is experimental and we would need to change a fair amount of existing CompletionQueue based code. These low-level classes are fairly generic, but hard-code some options that other users might need to change, if/when there are other users. This CL does not include changes to eager runtime to actually use the streaming RPCs. They are in the next CL. PiperOrigin-RevId: 253327204
This commit is contained in:
parent
19c72dd784
commit
1ae09760d1
@ -42,6 +42,21 @@ class EagerClient {
|
||||
CLIENT_METHOD(SendTensor);
|
||||
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
// Feeds `request` into the request stream of EagerService::StreamingEnqueue.
|
||||
// `response` will be filled with the response for this `request`. The
|
||||
// 1-to-1 correspondence between requests and responses is a property
|
||||
// of the current service implementation. When the response is received,
|
||||
// `done` is invoked with the current status of the StreamingEnqueue call.
|
||||
// The status can contain an error because of an earlier request in the
|
||||
// current streaming call.
|
||||
// The client initiates a streaming call the first time StreamingEnqueueAsync
|
||||
// is invoked and keeps it open until some error condition.
|
||||
// Similarly to the methods above, the request can be deleted as soon as
|
||||
// StreamingEnqueueAsync returns.
|
||||
virtual void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||
EnqueueResponse* response,
|
||||
StatusCallback done) = 0;
|
||||
};
|
||||
|
||||
// Simple wrapper class that can be used to retrieve EagerClients.
|
||||
|
@ -64,15 +64,17 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "grpc_state",
|
||||
srcs = [],
|
||||
srcs = ["grpc_state.cc"],
|
||||
hdrs = ["grpc_state.h"],
|
||||
deps = [
|
||||
":grpc_client_cq_tag",
|
||||
":grpc_util",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,15 +47,49 @@ class GrpcEagerClient : public EagerClient {
|
||||
CLIENT_METHOD(Enqueue);
|
||||
CLIENT_METHOD(WaitQueueDone);
|
||||
CLIENT_METHOD(KeepAlive);
|
||||
CLIENT_METHOD(CloseContext);
|
||||
CLIENT_METHOD(RegisterFunction);
|
||||
CLIENT_METHOD(SendTensor);
|
||||
|
||||
#undef CLIENT_METHOD
|
||||
|
||||
void CloseContextAsync(const CloseContextRequest* request,
|
||||
CloseContextResponse* response,
|
||||
StatusCallback done) override {
|
||||
new RPCState<protobuf::Message>(
|
||||
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
|
||||
response, std::move(done), nullptr, nullptr);
|
||||
|
||||
if (enqueue_dispatchers_.find(request->context_id()) !=
|
||||
enqueue_dispatchers_.end()) {
|
||||
enqueue_dispatchers_.erase(request->context_id());
|
||||
} else {
|
||||
LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
|
||||
<< " does not seems to exist.";
|
||||
}
|
||||
}
|
||||
|
||||
void StreamingEnqueueAsync(const EnqueueRequest* request,
|
||||
EnqueueResponse* response,
|
||||
StatusCallback done) override {
|
||||
auto it = enqueue_dispatchers_.find(request->context_id());
|
||||
if (enqueue_dispatchers_.find(request->context_id()) ==
|
||||
enqueue_dispatchers_.end()) {
|
||||
auto it_and_bool = enqueue_dispatchers_.emplace(
|
||||
std::piecewise_construct,
|
||||
std::forward_as_tuple(request->context_id()),
|
||||
std::forward_as_tuple(
|
||||
&stub_, cq_, "/tensorflow.eager.EagerService/StreamingEnqueue"));
|
||||
it = it_and_bool.first;
|
||||
}
|
||||
|
||||
it->second.SendNextRequest(*request, response, std::move(done));
|
||||
}
|
||||
|
||||
private:
|
||||
::grpc::GenericStub stub_;
|
||||
::grpc::CompletionQueue* cq_;
|
||||
std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
|
||||
enqueue_dispatchers_;
|
||||
};
|
||||
|
||||
class GrpcEagerClientCache : public EagerClientCache {
|
||||
|
@ -50,6 +50,14 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
|
||||
ENQUEUE_REQUEST(SendTensor);
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
// Request a StreamingEnqueue call.
|
||||
ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
|
||||
grpc::EagerService::AsyncService,
|
||||
EnqueueRequest, EnqueueResponse>::
|
||||
EnqueueRequest(&service_, cq_.get(),
|
||||
&grpc::EagerService::AsyncService::RequestStreamingEnqueue,
|
||||
&GrpcEagerServiceImpl::StreamingEnqueueHandler);
|
||||
|
||||
void* tag; // Matches the operation started against this cq_.
|
||||
bool ok;
|
||||
|
||||
@ -58,8 +66,8 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
|
||||
// The queue is shutting down.
|
||||
break;
|
||||
}
|
||||
UntypedCall<GrpcEagerServiceImpl>::Tag* callback_tag =
|
||||
static_cast<UntypedCall<GrpcEagerServiceImpl>::Tag*>(tag);
|
||||
GrpcCallTag<GrpcEagerServiceImpl>* callback_tag =
|
||||
static_cast<GrpcCallTag<GrpcEagerServiceImpl>*>(tag);
|
||||
|
||||
if (callback_tag) {
|
||||
callback_tag->OnCompleted(this, ok);
|
||||
|
@ -34,6 +34,11 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using StreamingCall =
|
||||
ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
|
||||
grpc::EagerService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
|
||||
GrpcEagerServiceImpl(const WorkerEnv* env,
|
||||
::grpc::ServerBuilder* server_builder);
|
||||
@ -64,6 +69,35 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
|
||||
HANDLER(SendTensor);
|
||||
#undef HANDLER
|
||||
|
||||
// Called when a new request has been received as part of a StreamingEnqueue
|
||||
// call.
|
||||
// StreamingEnqueueHandler gets the request from the `call` and fills the
|
||||
// response (also found in `call`) by invoking the local EagerServiceImpl.
|
||||
// The local EagerServiceImpl is invoked in this thread instead of using a
|
||||
// thread-pool as is done for all other methods above. We do this to preserve
|
||||
// request order. The local service can parallelize based on context_id in
|
||||
// request if necessary. Remote contexts are created in async mode by default,
|
||||
// so the local service impl just puts the request on eager executor queue.
|
||||
void StreamingEnqueueHandler(
|
||||
StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
|
||||
Status status =
|
||||
local_impl_.Enqueue(&call->request(), call->mutable_response());
|
||||
|
||||
if (status.ok()) {
|
||||
VLOG(1) << "local_impl_.Enqueue completed successfully";
|
||||
call->SendResponse();
|
||||
} else {
|
||||
VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString()
|
||||
<< " on request " << call->request().DebugString();
|
||||
call->Finish(ToGrpcStatus(status));
|
||||
}
|
||||
|
||||
// We do not tell gRPC to accept a new StreamingEnqueue request because this
|
||||
// method can be called multiple times for a given streaming call.
|
||||
// The StreamingCall does this per call instead, after a call has been
|
||||
// opened.
|
||||
}
|
||||
|
||||
const WorkerEnv* const env_; // Not owned.
|
||||
EagerServiceImpl local_impl_;
|
||||
|
||||
|
@ -16,13 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
||||
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
#include "grpcpp/impl/codegen/service_type.h"
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -70,6 +69,16 @@ namespace tensorflow {
|
||||
//
|
||||
// 4. When the response has been sent, the tag is returned from
|
||||
// `cq_->Next()`, and the call object is deleted.
|
||||
//
|
||||
|
||||
template <class Service>
|
||||
class GrpcCallTag {
|
||||
public:
|
||||
virtual ~GrpcCallTag() {}
|
||||
|
||||
// Calls the callback associated with this tag.
|
||||
virtual void OnCompleted(Service* service, bool ok) = 0;
|
||||
};
|
||||
|
||||
// Represents a pending request with unknown message types.
|
||||
template <class Service>
|
||||
@ -98,7 +107,7 @@ class UntypedCall : public core::RefCounted {
|
||||
// Associates a tag in a `::grpc::CompletionQueue` with a callback
|
||||
// for an incoming RPC. An active Tag owns a reference on the corresponding
|
||||
// Call object.
|
||||
class Tag {
|
||||
class Tag : public GrpcCallTag<Service> {
|
||||
public:
|
||||
// One enum value per supported callback.
|
||||
enum Callback { kRequestReceived, kResponseSent, kCancelled };
|
||||
@ -108,7 +117,7 @@ class UntypedCall : public core::RefCounted {
|
||||
// Calls the callback associated with this tag.
|
||||
//
|
||||
// The callback takes ownership of `this->call_`.
|
||||
void OnCompleted(Service* service, bool ok) {
|
||||
void OnCompleted(Service* service, bool ok) override {
|
||||
switch (callback_) {
|
||||
case kRequestReceived:
|
||||
call_->RequestReceived(service, ok);
|
||||
@ -263,6 +272,242 @@ class Call : public UntypedCall<Service> {
|
||||
std::function<void()> cancel_callback_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Lifetime of a server-side bidirectional streaming call:
|
||||
// - The call is created in the static EnqueueRequest method. It transfers
|
||||
// ownership to the kCallOpen tag pushed onto the completion queue.
|
||||
// - If kCallOpen completes successfully, a read is requested and the
|
||||
// kRequestReceived tag takes ownership of the call. If kCallOpen fails,
|
||||
// e.g. server is shutdown, no further requests are pushed and the call is
|
||||
// destroyed (at the end of Tag::OnCompleted).
|
||||
// - When the first request is received, we Ref() the call and invoke the
|
||||
// handler method thereby transferring ownership to the handler method.
|
||||
// The handler is responsible for calling SendResponse() or Finish() on this
|
||||
// call.
|
||||
// - If the handler calls Finish(), e.g. the request was invalid, Finish()
|
||||
// transfers ownership from the handler to the kServerFinished tag that
|
||||
// it pushes on the completion queue. The ownership is transferred because
|
||||
// the ref count is not incremented before putting the tag on the queue.
|
||||
// - If the handler calls SendResponse(), SendResponse() transfers ownership
|
||||
// to the kResponseSent tag.
|
||||
// - When kResponseSent completes, we request a new read, which owns the call
|
||||
// now.
|
||||
// - When the next request is received, it is handled the same way as the first
|
||||
// request.
|
||||
//
|
||||
// Because we request a read only after the write is sent, we can safely reuse
|
||||
// the same request and response messages for the whole call.
|
||||
template <class Service>
|
||||
class ServerUntypedBidirectionalStreamingCall : public core::RefCounted {
|
||||
public:
|
||||
virtual void RequestReceived(Service* service) = 0;
|
||||
|
||||
// Enqueues a request on the completion queue to read the next request.
|
||||
virtual void CallOpen() = 0;
|
||||
|
||||
virtual void RequestRead() = 0;
|
||||
|
||||
// Associates a tag in a `::grpc::CompletionQueue` with a callback.
|
||||
// An active Tag owns a reference on the corresponding Call object.
|
||||
class Tag : public GrpcCallTag<Service> {
|
||||
public:
|
||||
// One enum value per supported callback.
|
||||
enum class TagType {
|
||||
kCallOpen,
|
||||
kRequestReceived,
|
||||
kResponseSent,
|
||||
kServerFinished,
|
||||
};
|
||||
|
||||
Tag(ServerUntypedBidirectionalStreamingCall* call, TagType cb)
|
||||
: call_(call), callback_(cb) {}
|
||||
|
||||
// Calls the callback associated with this tag and Unrefs this->call_.
|
||||
void OnCompleted(Service* service, bool ok) override {
|
||||
switch (callback_) {
|
||||
case TagType::kCallOpen:
|
||||
// Non-ok value indicates that the server has been shutdown before we
|
||||
// received a message for this call type. We do nothing to let this
|
||||
// call object be destroyed and avoid enqueuing request for another
|
||||
// call.
|
||||
if (ok) {
|
||||
call_->CallOpen();
|
||||
}
|
||||
break;
|
||||
case TagType::kRequestReceived:
|
||||
// Non-ok value from completion queue here means that we will not
|
||||
// receive any more messages from the client, e.g. the client called
|
||||
// WritesDone. There is nothing we need to do in this case. The call
|
||||
// will be Unref'ed and deleted. If the client wants to open a new
|
||||
// call, we have already enqueued a request for a new call in CallOpen
|
||||
// above.
|
||||
if (ok) {
|
||||
call_->RequestReceived(service);
|
||||
}
|
||||
break;
|
||||
case TagType::kResponseSent:
|
||||
if (ok) {
|
||||
// The obvious place to request a read would be at the end of
|
||||
// RequestReceived(). Unfortunately, this can result in multiple
|
||||
// outstanding write requests in the completion queue. This is
|
||||
// currently not supported by gRPC, which requires at most one
|
||||
// outstanding write request in the completion queue.
|
||||
// Requesting a read here, in ResponseSent, works because at
|
||||
// this point, the completion queue has no write requests
|
||||
// (kResponseSent happens when a write completes).
|
||||
// This might be synchronizing the processing more than strictly
|
||||
// necessary, but is probably fine because, AFAICT from gRPC docs,
|
||||
// the write request completes as soon as it can be written to
|
||||
// outgoing buffer.
|
||||
call_->RequestRead();
|
||||
}
|
||||
// ok == false means that the response is not going on the wire
|
||||
// because the call is already dead (i.e., canceled, deadline
|
||||
// expired, other side dropped the channel, etc). Since the call is
|
||||
// dead, there is nothing for us to do, we just let the call be
|
||||
// deleted.
|
||||
break;
|
||||
case TagType::kServerFinished:
|
||||
// Whether our finish request is successful or not (whether it went
|
||||
// on the wire towards the client), there is nothing for us to do.
|
||||
// In the current implementation, there can be no read or write
|
||||
// requests in the completion queue (see the comment in kResponseSent)
|
||||
// above. Even if there were pending requests, they would complete
|
||||
// with a non-ok status, we would not do anything, and let the call be
|
||||
// deleted.
|
||||
break;
|
||||
}
|
||||
call_->Unref(); // Ref acquired when tag was handed to grpc.
|
||||
}
|
||||
|
||||
private:
|
||||
ServerUntypedBidirectionalStreamingCall* const
|
||||
call_; // `this` owns one reference.
|
||||
TagType callback_;
|
||||
};
|
||||
};
|
||||
|
||||
// Represents a pending call with known request and response message
|
||||
// types, and a known request-handling method.
|
||||
// Common usage pattern is to have a single thread waiting on events from
|
||||
// completion queue and calling Tag::OnCompleted(), which invokes methods
|
||||
// on this.
|
||||
// This implementation assumes that the server will generate a single response
|
||||
// message for each request message. More precisely, this class expects that
|
||||
// each time it invokes handle_request_function_, the service implementation
|
||||
// will either call SendResponse or Finish exactly once.
|
||||
// Not thread-safe.
|
||||
template <class Service, class GrpcService, class RequestMessage,
|
||||
class ResponseMessage>
|
||||
class ServerBidirectionalStreamingCall
|
||||
: public ServerUntypedBidirectionalStreamingCall<Service> {
|
||||
public:
|
||||
// Represents the generic signature of a generated
|
||||
// `GrpcService::RequestFoo()` method, where `Foo` is the name of an
|
||||
// RPC method.
|
||||
using EnqueueFunction = void (GrpcService::*)(
|
||||
::grpc::ServerContext*,
|
||||
::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage>*,
|
||||
::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
|
||||
|
||||
// Represents the generic signature of a `Service::HandleFoo()`
|
||||
// method, where `Foo` is the name of an RPC method.
|
||||
using HandleRequestFunction = void (Service::*)(
|
||||
ServerBidirectionalStreamingCall<Service, GrpcService, RequestMessage,
|
||||
ResponseMessage>*);
|
||||
|
||||
ServerBidirectionalStreamingCall(
|
||||
HandleRequestFunction handle_request_function, GrpcService* grpc_service,
|
||||
::grpc::ServerCompletionQueue* cq, EnqueueFunction enqueue_function)
|
||||
: handle_request_function_(handle_request_function),
|
||||
stream_(&ctx_),
|
||||
grpc_service_(grpc_service),
|
||||
cq_(cq),
|
||||
enqueue_function_(enqueue_function) {}
|
||||
|
||||
void CallOpen() override {
|
||||
// Let gRPC know that we can accept another call.
|
||||
ServerBidirectionalStreamingCall<
|
||||
Service, GrpcService, RequestMessage,
|
||||
ResponseMessage>::EnqueueRequest(grpc_service_, cq_, enqueue_function_,
|
||||
handle_request_function_);
|
||||
RequestRead();
|
||||
}
|
||||
|
||||
void RequestRead() override {
|
||||
this->Ref();
|
||||
request_.Clear();
|
||||
stream_.Read(&request_, &request_received_tag_);
|
||||
}
|
||||
|
||||
void RequestReceived(Service* service) override {
|
||||
this->Ref();
|
||||
// Request handling should result in a call to SendResponse or Finish.
|
||||
(service->*handle_request_function_)(this);
|
||||
}
|
||||
|
||||
void SendResponse() {
|
||||
// Transferring ownership of this to the response_sent_tag_.
|
||||
stream_.Write(response_, &response_sent_tag_);
|
||||
// stream_.Write does not save references to response_. We are free to muck
|
||||
// around with it as soon as Write returns.
|
||||
// We clear the response_ to prepare it for the next response.
|
||||
response_.Clear();
|
||||
}
|
||||
|
||||
void Finish(::grpc::Status status) {
|
||||
// Transferring ownership of this to the server_finished_tag_.
|
||||
stream_.Finish(status, &server_finished_tag_);
|
||||
}
|
||||
|
||||
// Enqueues a new request for the given service on the given
|
||||
// completion queue, using the given `enqueue_function`.
|
||||
//
|
||||
// The request will be handled by the given `handle_request_function`.
|
||||
static void EnqueueRequest(GrpcService* grpc_service,
|
||||
::grpc::ServerCompletionQueue* cq,
|
||||
EnqueueFunction enqueue_function,
|
||||
HandleRequestFunction handle_request_function) {
|
||||
auto call =
|
||||
new ServerBidirectionalStreamingCall<Service, GrpcService,
|
||||
RequestMessage, ResponseMessage>(
|
||||
handle_request_function, grpc_service, cq, enqueue_function);
|
||||
|
||||
// Initial ref for call handed to grpc; released in Tag callback.
|
||||
(grpc_service->*enqueue_function)(&call->ctx_, &call->stream_, cq, cq,
|
||||
&call->call_open_tag_);
|
||||
}
|
||||
|
||||
const RequestMessage& request() const { return request_; }
|
||||
ResponseMessage* mutable_response() { return &response_; }
|
||||
|
||||
private:
|
||||
// Request and response messages are reused for each request/response exchange
|
||||
// between the client and the server.
|
||||
RequestMessage request_;
|
||||
ResponseMessage response_;
|
||||
::grpc::ServerContext ctx_;
|
||||
|
||||
HandleRequestFunction handle_request_function_;
|
||||
::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage> stream_;
|
||||
|
||||
// Used as void* completion markers from grpc to indicate different
|
||||
// events of interest for a ServerBidirectionalStreamingCall.
|
||||
typedef typename ServerUntypedBidirectionalStreamingCall<Service>::Tag Tag;
|
||||
// At most one tag of each kind may be given to gRPC at any one time.
|
||||
// Beyond semantic sanity, this is needed to ensure proper ref counting
|
||||
// of this call object.
|
||||
Tag call_open_tag_{this, Tag::TagType::kCallOpen};
|
||||
Tag request_received_tag_{this, Tag::TagType::kRequestReceived};
|
||||
Tag response_sent_tag_{this, Tag::TagType::kResponseSent};
|
||||
Tag server_finished_tag_{this, Tag::TagType::kServerFinished};
|
||||
|
||||
// These fields are used only to spawn another instance of this to accept
|
||||
// more streaming calls.
|
||||
GrpcService* grpc_service_;
|
||||
::grpc::ServerCompletionQueue* cq_;
|
||||
EnqueueFunction enqueue_function_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
|
||||
|
||||
#include "grpcpp/grpcpp.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -32,7 +31,7 @@ class GrpcClientCQTag {
|
||||
virtual ~GrpcClientCQTag() {}
|
||||
|
||||
// OnCompleted is invoked when the RPC has finished.
|
||||
// Implementations of OnCompleted must delete *this.
|
||||
// Implementations of OnCompleted can delete *this.
|
||||
virtual void OnCompleted(bool ok) = 0;
|
||||
|
||||
private:
|
||||
|
211
tensorflow/core/distributed_runtime/rpc/grpc_state.cc
Normal file
211
tensorflow/core/distributed_runtime/rpc/grpc_state.cc
Normal file
@ -0,0 +1,211 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
|
||||
|
||||
#include "absl/strings/str_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type) {
|
||||
switch (tag_type) {
|
||||
case UntypedStreamingRPCState::Tag::TagType::kCallStarted:
|
||||
return "kCallStarted";
|
||||
case UntypedStreamingRPCState::Tag::TagType::kRequestWriteCompleted:
|
||||
return "kRequestWriteCompleted";
|
||||
case UntypedStreamingRPCState::Tag::TagType::kResponseReadCommpleted:
|
||||
return "kResponseReadCommpleted";
|
||||
}
|
||||
}
|
||||
|
||||
UntypedStreamingRPCState::Tag::Tag(UntypedStreamingRPCState* streaming_state,
|
||||
Tag::TagType type)
|
||||
: streaming_state_(streaming_state), type_(type) {}
|
||||
|
||||
void UntypedStreamingRPCState::Tag::OnCompleted(bool ok) {
|
||||
switch (type_) {
|
||||
case TagType::kCallStarted:
|
||||
streaming_state_->CallStarted(ok);
|
||||
break;
|
||||
case TagType::kRequestWriteCompleted:
|
||||
streaming_state_->RequestWriteCompleted(ok);
|
||||
break;
|
||||
case TagType::kResponseReadCommpleted:
|
||||
streaming_state_->ResponseReadCompleted(ok);
|
||||
break;
|
||||
}
|
||||
streaming_state_->Unref(); // Ref acquired when tag was handed to grpc.
|
||||
}
|
||||
|
||||
void Exchange::Complete(Status status) {
|
||||
if (status.ok()) {
|
||||
if (!GrpcMaybeParseProto(&response_buf_, response_)) {
|
||||
status.Update(errors::Internal("could not parse rpc response"));
|
||||
}
|
||||
}
|
||||
cb_(status);
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Exchange::State& state) {
|
||||
os << ToString(state);
|
||||
return os;
|
||||
}
|
||||
|
||||
const char* ToString(Exchange::State state) {
|
||||
switch (state) {
|
||||
case Exchange::State::kExchangeCreated:
|
||||
return "ExchangeCreated";
|
||||
case Exchange::State::kRequestWriteIssued:
|
||||
return "RequestWriteIssued";
|
||||
case Exchange::State::kRequestWriteCompleted:
|
||||
return "RequestWriteCompleted";
|
||||
case Exchange::State::kResponseReadIssued:
|
||||
return "ResponseReadIssued";
|
||||
}
|
||||
}
|
||||
|
||||
string Exchange::DebugString() const {
|
||||
return absl::StrFormat("%p@%s", this, ToString(state_));
|
||||
}
|
||||
|
||||
void ExchangeQueue::Emplace(const ::grpc::ByteBuffer& request_buf,
|
||||
protobuf::Message* response, StatusCallback cb) {
|
||||
exchanges_.emplace(exchanges_.end(), request_buf, response, std::move(cb));
|
||||
}
|
||||
|
||||
Exchange* ExchangeQueue::GetReadyForRequestWriting() {
|
||||
CheckInvariants();
|
||||
if (!call_started_) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(iga): Optimize to avoid linear search.
|
||||
for (Exchange& e : exchanges_) {
|
||||
if (e.state() == Exchange::State::kExchangeCreated) {
|
||||
return &e;
|
||||
} else if (e.state() == Exchange::State::kRequestWriteIssued) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Exchange* ExchangeQueue::GetReadyForResponseReading() {
|
||||
CheckInvariants();
|
||||
if (!call_started_) {
|
||||
// We should never ask for response reading when call has not
|
||||
// been started, but it does not hurt to defensively check here anyway.
|
||||
return nullptr;
|
||||
}
|
||||
if (exchanges_.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
Exchange& e = exchanges_[0];
|
||||
if (e.state() == Exchange::State::kRequestWriteCompleted) {
|
||||
return &e;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ExchangeQueue::MarkRequestWriteCompleted() {
|
||||
CheckInvariants();
|
||||
// TODO(iga): Optimize to avoid linear search.
|
||||
for (Exchange& e : exchanges_) {
|
||||
if (e.state() == Exchange::State::kRequestWriteIssued) {
|
||||
e.MarkRequestWriteCompleted();
|
||||
}
|
||||
}
|
||||
CheckInvariants();
|
||||
}
|
||||
|
||||
Exchange& ExchangeQueue::GetFront() {
|
||||
CheckInvariants();
|
||||
return exchanges_.front();
|
||||
}
|
||||
|
||||
void ExchangeQueue::PopFront() {
|
||||
CheckInvariants();
|
||||
exchanges_.pop_front();
|
||||
}
|
||||
|
||||
string ExchangeQueue::DebugString() const {
|
||||
return absl::StrJoin(exchanges_, ", ", [](string* out, const Exchange& e) {
|
||||
out->append(e.DebugString());
|
||||
});
|
||||
}
|
||||
|
||||
void ExchangeQueue::Swap(ExchangeQueue* other) {
|
||||
exchanges_.swap(other->exchanges_);
|
||||
std::swap(call_started_, other->call_started_);
|
||||
}
|
||||
|
||||
void ExchangeQueue::CompleteAll(Status status) {
|
||||
for (Exchange& exchange : exchanges_) {
|
||||
exchange.Complete(status);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::set<std::pair<Exchange::State, Exchange::State>>*
|
||||
GetPossibleTransitions() {
|
||||
std::set<std::pair<Exchange::State, Exchange::State>>* s =
|
||||
new std::set<std::pair<Exchange::State, Exchange::State>>();
|
||||
// Regular state transitions
|
||||
s->emplace(Exchange::State::kExchangeCreated,
|
||||
Exchange::State::kRequestWriteIssued);
|
||||
s->emplace(Exchange::State::kRequestWriteIssued,
|
||||
Exchange::State::kRequestWriteCompleted);
|
||||
s->emplace(Exchange::State::kRequestWriteCompleted,
|
||||
Exchange::State::kResponseReadIssued);
|
||||
// Self transitions. Possible when several exchanges can be in
|
||||
// the same state.
|
||||
s->emplace(Exchange::State::kExchangeCreated,
|
||||
Exchange::State::kExchangeCreated);
|
||||
s->emplace(Exchange::State::kRequestWriteCompleted,
|
||||
Exchange::State::kRequestWriteCompleted);
|
||||
// Skip transitions. Possible when there are no exchanges in a
|
||||
// certain state.
|
||||
s->emplace(Exchange::State::kExchangeCreated,
|
||||
Exchange::State::kRequestWriteCompleted);
|
||||
s->emplace(Exchange::State::kExchangeCreated,
|
||||
Exchange::State::kResponseReadIssued);
|
||||
s->emplace(Exchange::State::kRequestWriteIssued,
|
||||
Exchange::State::kResponseReadIssued);
|
||||
return s;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void ExchangeQueue::CheckInvariants() {
|
||||
static std::set<std::pair<Exchange::State, Exchange::State>>*
|
||||
possible_transitions = GetPossibleTransitions();
|
||||
|
||||
if (!VLOG_IS_ON(5)) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 1; i < exchanges_.size(); ++i) {
|
||||
const Exchange& e0 = exchanges_[i - 1];
|
||||
const Exchange& e1 = exchanges_[i];
|
||||
// The first exchange in the pair is the one that arrived later and is
|
||||
// behind in processing.
|
||||
auto p = std::make_pair(e1.state(), e0.state());
|
||||
if (possible_transitions->find(p) == possible_transitions->end()) {
|
||||
LOG(FATAL)
|
||||
<< "Found an impossible state transition in the exchange queue: "
|
||||
<< p.first << " -> " << p.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
|
||||
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
|
||||
#include "grpcpp/generic/generic_stub.h"
|
||||
@ -24,9 +25,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -178,6 +181,423 @@ class RPCState : public GrpcClientCQTag {
|
||||
bool fail_fast_;
|
||||
};
|
||||
|
||||
// Represents state associated with one streaming RPC call.
|
||||
// Similarly to above, we extract the methods of StreamingRPCState that don't
|
||||
// need to be templated into this abstract class.
|
||||
// Currently, *StreamingRPCState does not support client closing the call as
|
||||
// there is no use case for it - current clients keep the streaming call open
|
||||
// as long as possible. If/when the need arises, support can be added
|
||||
// by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
|
||||
// TagType::kClientFinished and handling the completion in a new callback.
|
||||
class UntypedStreamingRPCState : public core::RefCounted {
|
||||
public:
|
||||
virtual void CallStarted(bool ok) = 0;
|
||||
virtual void RequestWriteCompleted(bool ok) = 0;
|
||||
virtual void ResponseReadCompleted(bool ok) = 0;
|
||||
|
||||
virtual string DebugString() const = 0;
|
||||
|
||||
class Tag : public GrpcClientCQTag {
|
||||
public:
|
||||
// One enum value per supported callback.
|
||||
enum class TagType {
|
||||
kCallStarted,
|
||||
kRequestWriteCompleted,
|
||||
kResponseReadCommpleted,
|
||||
};
|
||||
|
||||
Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
|
||||
|
||||
// Calls the callback associated with this tag and Unrefs
|
||||
// `this->streaming_state_`.
|
||||
void OnCompleted(bool ok) override;
|
||||
|
||||
private:
|
||||
// OnCompleted() consumes on reference each time it is called.
|
||||
UntypedStreamingRPCState* const streaming_state_;
|
||||
const Tag::TagType type_;
|
||||
};
|
||||
};
|
||||
|
||||
const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
|
||||
|
||||
// Represents a single request/response exchange between client and the server.
|
||||
// A single streaming call contains a sequence of exchanges. Besides the
|
||||
// messages, exchange contains:
|
||||
// - the user callback to invoke when exchange completes (response is received
|
||||
// or an error occurs).
|
||||
// - The current state of the exchange.
|
||||
class Exchange {
|
||||
public:
|
||||
enum class State {
|
||||
kExchangeCreated,
|
||||
kRequestWriteIssued,
|
||||
kRequestWriteCompleted,
|
||||
kResponseReadIssued,
|
||||
};
|
||||
|
||||
Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
|
||||
StatusCallback cb)
|
||||
: state_(State::kExchangeCreated),
|
||||
request_buf_(request_buf),
|
||||
response_(response),
|
||||
cb_(std::move(cb)) {}
|
||||
|
||||
const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
|
||||
::grpc::ByteBuffer* response_buf() { return &response_buf_; }
|
||||
|
||||
void MarkRequestWriteIssued() {
|
||||
DCHECK(state_ == State::kExchangeCreated);
|
||||
state_ = State::kRequestWriteIssued;
|
||||
}
|
||||
void MarkRequestWriteCompleted() {
|
||||
DCHECK(state_ == State::kRequestWriteIssued);
|
||||
state_ = State::kRequestWriteCompleted;
|
||||
}
|
||||
void MarkResponseReadIssued() {
|
||||
DCHECK(state_ == State::kRequestWriteCompleted);
|
||||
state_ = State::kResponseReadIssued;
|
||||
}
|
||||
|
||||
// If `status` is success, completes this exchange by parsing the
|
||||
// response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
|
||||
// callback with `status`.
|
||||
void Complete(Status status);
|
||||
|
||||
const State& state() const { return state_; }
|
||||
|
||||
string DebugString() const;
|
||||
|
||||
private:
|
||||
State state_;
|
||||
::grpc::ByteBuffer request_buf_;
|
||||
::grpc::ByteBuffer response_buf_;
|
||||
protobuf::Message* response_;
|
||||
StatusCallback cb_;
|
||||
};
|
||||
|
||||
const char* ToString(Exchange::State s);
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
|
||||
|
||||
// Represents a queue of exchanges.
|
||||
// When a client sends a new request a new exchange is created and added to the
|
||||
// end of the queue. Completed exchanges are popped from the front of the queue.
|
||||
// An explicit exchange queue is needed to brdige the client, which can send new
|
||||
// requests at any time, with gRPC infrastructure, which can handle a single
|
||||
// read and a single write request at a time.
|
||||
//
|
||||
// As the exchange progresses (request sending initiated, request sending
|
||||
// completed, response reading initiated) the queue helps to make sure that the
|
||||
// right operation is issued on the right exchange at the right time.
|
||||
//
|
||||
// To satisfy gRPC constraints, the states of exchanges must be as follows
|
||||
// starting from the front of the queue:
|
||||
// - 0 or 1 exchange in kResponseReadIssued state
|
||||
// - 0 or more exchanges in kRequestWriteCompleted state
|
||||
// - 0 or 1 exchange in kRequestWriteIssued state
|
||||
// - 0 or more exchanges in kExchangeCreated state
|
||||
//
|
||||
// Thread-compatible.
|
||||
class ExchangeQueue {
|
||||
public:
|
||||
// Creates a new exchange and adds it to the end of the queue.
|
||||
void Emplace(const ::grpc::ByteBuffer& request_buf,
|
||||
protobuf::Message* response, StatusCallback cb);
|
||||
|
||||
// Returns an exchange for which we can initiated request writing, if any.
|
||||
// Returns nullptr if there is no such exchange.
|
||||
Exchange* GetReadyForRequestWriting();
|
||||
|
||||
// Returns an exchange for which we can initiate response reading, if any.
|
||||
// Returns nullptr if there is no such exchange.
|
||||
Exchange* GetReadyForResponseReading();
|
||||
|
||||
// Changes the state of the exchange that is current in kRequestWriteIssued
|
||||
// state to kRequestWriteCompleted state.
|
||||
// REQUIRES: There is an exhange in kRequestWriteIssued state.
|
||||
void MarkRequestWriteCompleted();
|
||||
|
||||
// Returns the exchange at the front of the queue.
|
||||
// REQUIRES: ExchangeQueue is not empty.
|
||||
Exchange& GetFront();
|
||||
|
||||
// Removes the exchange at the front of the queue.
|
||||
// REQUIRES: ExchangeQueue is not empty.
|
||||
void PopFront();
|
||||
|
||||
// Returns a string containing addresses and states of all exchanges in this
|
||||
// queue.
|
||||
string DebugString() const;
|
||||
|
||||
// Swaps the contents of this and `other`.
|
||||
void Swap(ExchangeQueue* other);
|
||||
|
||||
// Completes all exchanges in this with `status`.
|
||||
void CompleteAll(Status status);
|
||||
|
||||
void CallStarted() { call_started_ = true; }
|
||||
|
||||
private:
|
||||
// Does nothing by default. Turn on VLOG(5) to enable.
|
||||
// Checks that this ExchangeQueue is in a valid state.
|
||||
// Kills the process if not.
|
||||
void CheckInvariants();
|
||||
|
||||
// We can't process any exchanges until the call has started.
|
||||
bool call_started_ = false;
|
||||
|
||||
// std::queue is based on std::deque by default. std::deque provides
|
||||
// fairly strong iterator stability.
|
||||
std::deque<Exchange> exchanges_;
|
||||
}; // namespace tensorflow
|
||||
|
||||
// Represents state associated with one streaming RPC call.
|
||||
// Thread-safe
|
||||
template <class Response>
|
||||
class StreamingRPCState : public UntypedStreamingRPCState {
|
||||
public:
|
||||
// Default behavior is to set fail_fast = False and handle timeouts
|
||||
// manually.
|
||||
StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
|
||||
const std::shared_ptr<::grpc::ClientContext>& context)
|
||||
: context_(context), call_(std::move(call)), call_done_(false) {
|
||||
Ref();
|
||||
call_->StartCall(&call_started_tag_);
|
||||
}
|
||||
|
||||
// Attempts to send the next request. `done` is invoked when
|
||||
// `response` has been filled with the data from the server, or if there
|
||||
// is an error. `done` can be invoked before SendNextRequest returns.
|
||||
// Return `true` if the call is alive and the `done` callback has or
|
||||
// will be invoked. If the call is dead, returns `false`. `done` callback
|
||||
// will not be invoked in this case.
|
||||
// REQUIRES: The call has been started, i.e. WaitForCallStarted() has
|
||||
// returned.
|
||||
bool SendNextRequest(const protobuf::Message& request, Response* response,
|
||||
const StatusCallback& done) {
|
||||
::grpc::ByteBuffer request_buf;
|
||||
::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
|
||||
if (!s.ok()) {
|
||||
Status status = FromGrpcStatus(s);
|
||||
LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
|
||||
<< status.ToString();
|
||||
done(status);
|
||||
return true;
|
||||
}
|
||||
|
||||
mutex_lock l(mu_);
|
||||
if (call_done_) {
|
||||
// `done` is not invoked intentionally.
|
||||
return false;
|
||||
}
|
||||
exchanges_.Emplace(request_buf, response, done);
|
||||
MaybeIssueRequestWriteLocked();
|
||||
return true;
|
||||
}
|
||||
|
||||
void CallStarted(bool ok) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!ok) {
|
||||
call_done_ = true;
|
||||
return;
|
||||
}
|
||||
exchanges_.CallStarted();
|
||||
// Now that the call has started, we can write our first request, if any.
|
||||
MaybeIssueRequestWriteLocked();
|
||||
}
|
||||
|
||||
void RequestWriteCompleted(bool ok) override {
|
||||
mu_.lock();
|
||||
if (call_done_) {
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
if (!ok) {
|
||||
// unlocks mu_
|
||||
MarkDoneAndCompleteExchanges();
|
||||
return;
|
||||
}
|
||||
|
||||
exchanges_.MarkRequestWriteCompleted();
|
||||
MaybeIssueResponseReadLocked();
|
||||
MaybeIssueRequestWriteLocked();
|
||||
mu_.unlock();
|
||||
}
|
||||
|
||||
void ResponseReadCompleted(bool ok) override {
|
||||
mu_.lock();
|
||||
if (call_done_) {
|
||||
mu_.unlock();
|
||||
return;
|
||||
}
|
||||
if (!ok) {
|
||||
// unlocks mu_
|
||||
MarkDoneAndCompleteExchanges();
|
||||
return;
|
||||
}
|
||||
|
||||
// Complete the exchange without holding the lock because user's
|
||||
// callback can call back into this RPC code resulting in a deadlock.
|
||||
// No other thread can pop this exchange while we release the lock because
|
||||
// this is the only method that pops exchanges and it is called from a
|
||||
// single thread that waits on completion queue events.
|
||||
Exchange* e;
|
||||
e = &exchanges_.GetFront();
|
||||
mu_.unlock();
|
||||
|
||||
e->Complete(Status::OK());
|
||||
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
exchanges_.PopFront();
|
||||
MaybeIssueResponseReadLocked();
|
||||
}
|
||||
}
|
||||
|
||||
string DebugString() const override {
|
||||
mutex_lock l(mu_);
|
||||
return exchanges_.DebugString();
|
||||
}
|
||||
|
||||
private:
|
||||
void MarkDoneAndCompleteExchanges() EXCLUSIVE_LOCKS_REQUIRED(mu_)
|
||||
UNLOCK_FUNCTION(mu_) {
|
||||
call_done_ = true;
|
||||
Status status = errors::Unknown("gRPC streaming call has ended: ",
|
||||
context_->debug_error_string());
|
||||
// Swap the exchanges_ into a temporary ExchangeQueue so that we can
|
||||
// complete all exchanges without holding mu_ in case user callback
|
||||
// reach back into this. This should be impossible now, but safer for
|
||||
// the future.
|
||||
ExchangeQueue queue;
|
||||
exchanges_.Swap(&queue);
|
||||
mu_.unlock();
|
||||
queue.CompleteAll(status);
|
||||
}
|
||||
|
||||
void MaybeIssueRequestWriteLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
Exchange* exchange = exchanges_.GetReadyForRequestWriting();
|
||||
if (exchange == nullptr) {
|
||||
// There are no queued exchanges, there is already an outstanding write,
|
||||
// or there are no just created exchanges.
|
||||
return;
|
||||
}
|
||||
exchange->MarkRequestWriteIssued();
|
||||
Ref();
|
||||
call_->Write(exchange->request_buf(), &request_write_completed_tag_);
|
||||
}
|
||||
|
||||
void MaybeIssueResponseReadLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
Exchange* exchange = exchanges_.GetReadyForResponseReading();
|
||||
if (exchange == nullptr) {
|
||||
return;
|
||||
}
|
||||
exchange->MarkResponseReadIssued();
|
||||
Ref();
|
||||
call_->Read(exchange->response_buf(), &response_read_completed_tag_);
|
||||
}
|
||||
|
||||
// Holds state for a single request/response exchange between the client
|
||||
// and the server.
|
||||
typedef typename UntypedStreamingRPCState::Tag Tag;
|
||||
|
||||
// Order of context_ and call_ is important because context_ must outlive
|
||||
// call_.
|
||||
const std::shared_ptr<const ::grpc::ClientContext> context_;
|
||||
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
|
||||
|
||||
mutable mutex mu_;
|
||||
ExchangeQueue exchanges_ GUARDED_BY(mu_);
|
||||
bool call_done_ GUARDED_BY(mu_);
|
||||
|
||||
// We can get away with having single instances of these tags per
|
||||
// StreamingRPCState because we make sure (as gRPC requires) that
|
||||
// there is at most one outstanding Read and at most one outstanding Write
|
||||
// in the completion queue.
|
||||
// Tags are immutable. No need to guard them.
|
||||
Tag call_started_tag_{this, Tag::TagType::kCallStarted};
|
||||
Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
|
||||
Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCommpleted};
|
||||
};
|
||||
|
||||
// Creates streaming calls and dispatches requests to them.
|
||||
// In the common case, the client would create a StreamingRPCDispatcher for
|
||||
// each bidirectional streaming RPC it might want to make. The first time, it
|
||||
// calls SendNextRequest, a streaming call is initiated and the request is
|
||||
// sent within this call. Initiation of the call blocks the client. If there are
|
||||
// no errors, subsequent calls to SendNextRequest would use the already active
|
||||
// call. If there was an error, the call object will be destroyed after all
|
||||
// the callbacks for outstanding requests have been invoked. The next call to
|
||||
// SendNextRequest will initiate a new call.
|
||||
//
|
||||
// Callbacks that are part of the same call, are invoked in the order they were
|
||||
// provided, but callbacks across calls (a failed and a new one) can be invoked
|
||||
// in any order.
|
||||
//
|
||||
// Thread-safe.
|
||||
template <class Response>
|
||||
class StreamingRPCDispatcher {
|
||||
public:
|
||||
StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
|
||||
const ::grpc::string& method)
|
||||
: stub_(stub), cq_(cq), method_(method) {}
|
||||
|
||||
// Attempts to send the next request. If there is no active streaming call,
|
||||
// starts one and sends the request on top of it. `done` is invoked when
|
||||
// `response` has been filled with the data from the server, or if there
|
||||
// is an error. `done` can be invoked before SendNextRequest returns.
|
||||
void SendNextRequest(const protobuf::Message& request, Response* response,
|
||||
StatusCallback done) {
|
||||
mutex_lock l(mu_);
|
||||
if (state_ == nullptr) {
|
||||
CreateStreamingState();
|
||||
}
|
||||
|
||||
bool is_call_alive = state_->SendNextRequest(request, response, done);
|
||||
if (is_call_alive) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The attempt to send failed because the call was dead, create a new
|
||||
// call and try again. When the call is dead SendNextRequest does not call
|
||||
// `done`.
|
||||
CreateStreamingState();
|
||||
|
||||
is_call_alive = state_->SendNextRequest(request, response, done);
|
||||
if (!is_call_alive) {
|
||||
// Consider retrying to create and start a call few more times.
|
||||
done(errors::Unknown("gRPC call failed right after it was created"));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// ClientContext cannot be reused across calls.
|
||||
context_ = std::make_shared<::grpc::ClientContext>();
|
||||
// Don't immediately fail StartCall if the channel is not ready. Wait for
|
||||
// the channel to become ready.
|
||||
context_->set_wait_for_ready(true);
|
||||
|
||||
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
|
||||
std::move(stub_->PrepareCall(context_.get(), method_, cq_));
|
||||
|
||||
state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
|
||||
}
|
||||
|
||||
mutable mutex mu_;
|
||||
|
||||
// Both are thread-safe
|
||||
::grpc::GenericStub* const stub_;
|
||||
::grpc::CompletionQueue* const cq_;
|
||||
|
||||
// Does not need synchronization since it is constant.
|
||||
const ::grpc::string method_;
|
||||
|
||||
std::shared_ptr<::grpc::ClientContext> context_ GUARDED_BY(mu_);
|
||||
core::RefCountPtr<StreamingRPCState<Response>> state_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
|
||||
|
@ -176,6 +176,14 @@ service EagerService {
|
||||
// future calls to Enqueue.
|
||||
rpc Enqueue(EnqueueRequest) returns (EnqueueResponse);
|
||||
|
||||
// A streaming version of Enqueue.
|
||||
// Current server implementation sends one response per received request.
|
||||
// The benefit for using a streaming version is that subsequent requests
|
||||
// can be sent without waiting for a response to the previous request. This
|
||||
// synchronization is required in the regular Enqueue call because gRPC does
|
||||
// not guarantee to preserve request order.
|
||||
rpc StreamingEnqueue(stream EnqueueRequest) returns (stream EnqueueResponse);
|
||||
|
||||
// Takes a set of op IDs and waits until those ops are done. Returns any error
|
||||
// in the stream so far.
|
||||
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
|
||||
|
Loading…
x
Reference in New Issue
Block a user