Add support for bidirectional streaming RPCs to TF gRPC infrastructure

This change:
 - Adds an EagerService::StreamingEnqueue RPC method that takes/returns the same
   protos as EagerService::Enqueue, but is a bidirectional streaming method. The
   server returns one response message for each request message. Most of the
   code depends on this assumption.
 - Adds low-level gRPC logic to handle streaming RPCs on top of
   grpc::CompletionQueue. gRPC has a nicer callback based async API, but it is
   experimental and we would need to change a fair amount of existing
   CompletionQueue based code. These low-level classes are fairly generic, but
   hard-code some options that other users might need to change, if/when there
   are other users.

This CL does not include changes to eager runtime to actually use the streaming
RPCs. They are in the next CL.

PiperOrigin-RevId: 253327204
This commit is contained in:
Igor Ganichev 2019-06-14 17:34:15 -07:00 committed by TensorFlower Gardener
parent 19c72dd784
commit 1ae09760d1
10 changed files with 988 additions and 12 deletions

View File

@ -42,6 +42,21 @@ class EagerClient {
CLIENT_METHOD(SendTensor);
#undef CLIENT_METHOD
// Feeds `request` into the request stream of EagerService::StreamingEnqueue.
// `response` will be filled with the response for this `request`. The
// 1-to-1 correspondence between requests and responses is a property
// of the current service implementation. When the response is received,
// `done` is invoked with the current status of the StreamingEnqueue call.
// The status can contain an error because of an earlier request in the
// current streaming call.
// The client initiates a streaming call the first time StreamingEnqueueAsync
// is invoked and keeps it open until some error condition.
// Similarly to the methods above, the request can be deleted as soon as
// StreamingEnqueueAsync returns.
virtual void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) = 0;
};
// Simple wrapper class that can be used to retrieve EagerClients.

View File

@ -64,15 +64,17 @@ cc_library(
cc_library(
name = "grpc_state",
srcs = [],
srcs = ["grpc_state.cc"],
hdrs = ["grpc_state.h"],
deps = [
":grpc_client_cq_tag",
":grpc_util",
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:call_options",
"//tensorflow/core/distributed_runtime:tensor_coding",
"@com_google_absl//absl/strings:str_format",
],
)

View File

@ -47,15 +47,49 @@ class GrpcEagerClient : public EagerClient {
CLIENT_METHOD(Enqueue);
CLIENT_METHOD(WaitQueueDone);
CLIENT_METHOD(KeepAlive);
CLIENT_METHOD(CloseContext);
CLIENT_METHOD(RegisterFunction);
CLIENT_METHOD(SendTensor);
#undef CLIENT_METHOD
void CloseContextAsync(const CloseContextRequest* request,
CloseContextResponse* response,
StatusCallback done) override {
new RPCState<protobuf::Message>(
&stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
response, std::move(done), nullptr, nullptr);
if (enqueue_dispatchers_.find(request->context_id()) !=
enqueue_dispatchers_.end()) {
enqueue_dispatchers_.erase(request->context_id());
} else {
LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
<< " does not seems to exist.";
}
}
void StreamingEnqueueAsync(const EnqueueRequest* request,
EnqueueResponse* response,
StatusCallback done) override {
auto it = enqueue_dispatchers_.find(request->context_id());
if (enqueue_dispatchers_.find(request->context_id()) ==
enqueue_dispatchers_.end()) {
auto it_and_bool = enqueue_dispatchers_.emplace(
std::piecewise_construct,
std::forward_as_tuple(request->context_id()),
std::forward_as_tuple(
&stub_, cq_, "/tensorflow.eager.EagerService/StreamingEnqueue"));
it = it_and_bool.first;
}
it->second.SendNextRequest(*request, response, std::move(done));
}
private:
::grpc::GenericStub stub_;
::grpc::CompletionQueue* cq_;
std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
enqueue_dispatchers_;
};
class GrpcEagerClientCache : public EagerClientCache {

View File

@ -50,6 +50,14 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
ENQUEUE_REQUEST(SendTensor);
#undef ENQUEUE_REQUEST
// Request a StreamingEnqueue call.
ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
grpc::EagerService::AsyncService,
EnqueueRequest, EnqueueResponse>::
EnqueueRequest(&service_, cq_.get(),
&grpc::EagerService::AsyncService::RequestStreamingEnqueue,
&GrpcEagerServiceImpl::StreamingEnqueueHandler);
void* tag; // Matches the operation started against this cq_.
bool ok;
@ -58,8 +66,8 @@ void GrpcEagerServiceImpl::HandleRPCsLoop() {
// The queue is shutting down.
break;
}
UntypedCall<GrpcEagerServiceImpl>::Tag* callback_tag =
static_cast<UntypedCall<GrpcEagerServiceImpl>::Tag*>(tag);
GrpcCallTag<GrpcEagerServiceImpl>* callback_tag =
static_cast<GrpcCallTag<GrpcEagerServiceImpl>*>(tag);
if (callback_tag) {
callback_tag->OnCompleted(this, ok);

View File

@ -34,6 +34,11 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
template <class RequestMessage, class ResponseMessage>
using EagerCall = Call<GrpcEagerServiceImpl, grpc::EagerService::AsyncService,
RequestMessage, ResponseMessage>;
template <class RequestMessage, class ResponseMessage>
using StreamingCall =
ServerBidirectionalStreamingCall<GrpcEagerServiceImpl,
grpc::EagerService::AsyncService,
RequestMessage, ResponseMessage>;
GrpcEagerServiceImpl(const WorkerEnv* env,
::grpc::ServerBuilder* server_builder);
@ -64,6 +69,35 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface {
HANDLER(SendTensor);
#undef HANDLER
// Called when a new request has been received as part of a StreamingEnqueue
// call.
// StreamingEnqueueHandler gets the request from the `call` and fills the
// response (also found in `call`) by invoking the local EagerServiceImpl.
// The local EagerServiceImpl is invoked in this thread instead of using a
// thread-pool as is done for all other methods above. We do this to preserve
// request order. The local service can parallelize based on context_id in
// request if necessary. Remote contexts are created in async mode by default,
// so the local service impl just puts the request on eager executor queue.
void StreamingEnqueueHandler(
StreamingCall<EnqueueRequest, EnqueueResponse>* call) {
Status status =
local_impl_.Enqueue(&call->request(), call->mutable_response());
if (status.ok()) {
VLOG(1) << "local_impl_.Enqueue completed successfully";
call->SendResponse();
} else {
VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString()
<< " on request " << call->request().DebugString();
call->Finish(ToGrpcStatus(status));
}
// We do not tell gRPC to accept a new StreamingEnqueue request because this
// method can be called multiple times for a given streaming call.
// The StreamingCall does this per call instead, after a call has been
// opened.
}
const WorkerEnv* const env_; // Not owned.
EagerServiceImpl local_impl_;

View File

@ -16,13 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "grpcpp/grpcpp.h"
#include "grpcpp/impl/codegen/service_type.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
@ -70,6 +69,16 @@ namespace tensorflow {
//
// 4. When the response has been sent, the tag is returned from
// `cq_->Next()`, and the call object is deleted.
//
template <class Service>
class GrpcCallTag {
public:
virtual ~GrpcCallTag() {}
// Calls the callback associated with this tag.
virtual void OnCompleted(Service* service, bool ok) = 0;
};
// Represents a pending request with unknown message types.
template <class Service>
@ -98,7 +107,7 @@ class UntypedCall : public core::RefCounted {
// Associates a tag in a `::grpc::CompletionQueue` with a callback
// for an incoming RPC. An active Tag owns a reference on the corresponding
// Call object.
class Tag {
class Tag : public GrpcCallTag<Service> {
public:
// One enum value per supported callback.
enum Callback { kRequestReceived, kResponseSent, kCancelled };
@ -108,7 +117,7 @@ class UntypedCall : public core::RefCounted {
// Calls the callback associated with this tag.
//
// The callback takes ownership of `this->call_`.
void OnCompleted(Service* service, bool ok) {
void OnCompleted(Service* service, bool ok) override {
switch (callback_) {
case kRequestReceived:
call_->RequestReceived(service, ok);
@ -263,6 +272,242 @@ class Call : public UntypedCall<Service> {
std::function<void()> cancel_callback_ GUARDED_BY(mu_);
};
// Lifetime of a server-side bidirectional streaming call:
// - The call is created in the static EnqueueRequest method. It transfers
// ownership to the kCallOpen tag pushed onto the completion queue.
// - If kCallOpen completes successfully, a read is requested and the
// kRequestReceived tag takes ownership of the call. If kCallOpen fails,
// e.g. server is shutdown, no further requests are pushed and the call is
// destroyed (at the end of Tag::OnCompleted).
// - When the first request is received, we Ref() the call and invoke the
// handler method thereby transferring ownership to the handler method.
// The handler is responsible for calling SendResponse() or Finish() on this
// call.
// - If the handler calls Finish(), e.g. the request was invalid, Finish()
// transfers ownership from the handler to the kServerFinished tag that
// it pushes on the completion queue. The ownership is transferred because
// the ref count is not incremented before putting the tag on the queue.
// - If the handler calls SendResponse(), SendResponse() transfers ownership
// to the kResponseSent tag.
// - When kResponseSent completes, we request a new read, which owns the call
// now.
// - When the next request is received, it is handled the same way as the first
// request.
//
// Because we request a read only after the write is sent, we can safely reuse
// the same request and response messages for the whole call.
template <class Service>
class ServerUntypedBidirectionalStreamingCall : public core::RefCounted {
public:
virtual void RequestReceived(Service* service) = 0;
// Enqueues a request on the completion queue to read the next request.
virtual void CallOpen() = 0;
virtual void RequestRead() = 0;
// Associates a tag in a `::grpc::CompletionQueue` with a callback.
// An active Tag owns a reference on the corresponding Call object.
class Tag : public GrpcCallTag<Service> {
public:
// One enum value per supported callback.
enum class TagType {
kCallOpen,
kRequestReceived,
kResponseSent,
kServerFinished,
};
Tag(ServerUntypedBidirectionalStreamingCall* call, TagType cb)
: call_(call), callback_(cb) {}
// Calls the callback associated with this tag and Unrefs this->call_.
void OnCompleted(Service* service, bool ok) override {
switch (callback_) {
case TagType::kCallOpen:
// Non-ok value indicates that the server has been shutdown before we
// received a message for this call type. We do nothing to let this
// call object be destroyed and avoid enqueuing request for another
// call.
if (ok) {
call_->CallOpen();
}
break;
case TagType::kRequestReceived:
// Non-ok value from completion queue here means that we will not
// receive any more messages from the client, e.g. the client called
// WritesDone. There is nothing we need to do in this case. The call
// will be Unref'ed and deleted. If the client wants to open a new
// call, we have already enqueued a request for a new call in CallOpen
// above.
if (ok) {
call_->RequestReceived(service);
}
break;
case TagType::kResponseSent:
if (ok) {
// The obvious place to request a read would be at the end of
// RequestReceived(). Unfortunately, this can result in multiple
// outstanding write requests in the completion queue. This is
// currently not supported by gRPC, which requires at most one
// outstanding write request in the completion queue.
// Requesting a read here, in ResponseSent, works because at
// this point, the completion queue has no write requests
// (kResponseSent happens when a write completes).
// This might be synchronizing the processing more than strictly
// necessary, but is probably fine because, AFAICT from gRPC docs,
// the write request completes as soon as it can be written to
// outgoing buffer.
call_->RequestRead();
}
// ok == false means that the response is not going on the wire
// because the call is already dead (i.e., canceled, deadline
// expired, other side dropped the channel, etc). Since the call is
// dead, there is nothing for us to do, we just let the call be
// deleted.
break;
case TagType::kServerFinished:
// Whether our finish request is successful or not (whether it went
// on the wire towards the client), there is nothing for us to do.
// In the current implementation, there can be no read or write
// requests in the completion queue (see the comment in kResponseSent)
// above. Even if there were pending requests, they would complete
// with a non-ok status, we would not do anything, and let the call be
// deleted.
break;
}
call_->Unref(); // Ref acquired when tag was handed to grpc.
}
private:
ServerUntypedBidirectionalStreamingCall* const
call_; // `this` owns one reference.
TagType callback_;
};
};
// Represents a pending call with known request and response message
// types, and a known request-handling method.
// Common usage pattern is to have a single thread waiting on events from
// completion queue and calling Tag::OnCompleted(), which invokes methods
// on this.
// This implementation assumes that the server will generate a single response
// message for each request message. More precisely, this class expects that
// each time it invokes handle_request_function_, the service implementation
// will either call SendResponse or Finish exactly once.
// Not thread-safe.
template <class Service, class GrpcService, class RequestMessage,
class ResponseMessage>
class ServerBidirectionalStreamingCall
: public ServerUntypedBidirectionalStreamingCall<Service> {
public:
// Represents the generic signature of a generated
// `GrpcService::RequestFoo()` method, where `Foo` is the name of an
// RPC method.
using EnqueueFunction = void (GrpcService::*)(
::grpc::ServerContext*,
::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage>*,
::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
// Represents the generic signature of a `Service::HandleFoo()`
// method, where `Foo` is the name of an RPC method.
using HandleRequestFunction = void (Service::*)(
ServerBidirectionalStreamingCall<Service, GrpcService, RequestMessage,
ResponseMessage>*);
ServerBidirectionalStreamingCall(
HandleRequestFunction handle_request_function, GrpcService* grpc_service,
::grpc::ServerCompletionQueue* cq, EnqueueFunction enqueue_function)
: handle_request_function_(handle_request_function),
stream_(&ctx_),
grpc_service_(grpc_service),
cq_(cq),
enqueue_function_(enqueue_function) {}
void CallOpen() override {
// Let gRPC know that we can accept another call.
ServerBidirectionalStreamingCall<
Service, GrpcService, RequestMessage,
ResponseMessage>::EnqueueRequest(grpc_service_, cq_, enqueue_function_,
handle_request_function_);
RequestRead();
}
void RequestRead() override {
this->Ref();
request_.Clear();
stream_.Read(&request_, &request_received_tag_);
}
void RequestReceived(Service* service) override {
this->Ref();
// Request handling should result in a call to SendResponse or Finish.
(service->*handle_request_function_)(this);
}
void SendResponse() {
// Transferring ownership of this to the response_sent_tag_.
stream_.Write(response_, &response_sent_tag_);
// stream_.Write does not save references to response_. We are free to muck
// around with it as soon as Write returns.
// We clear the response_ to prepare it for the next response.
response_.Clear();
}
void Finish(::grpc::Status status) {
// Transferring ownership of this to the server_finished_tag_.
stream_.Finish(status, &server_finished_tag_);
}
// Enqueues a new request for the given service on the given
// completion queue, using the given `enqueue_function`.
//
// The request will be handled by the given `handle_request_function`.
static void EnqueueRequest(GrpcService* grpc_service,
::grpc::ServerCompletionQueue* cq,
EnqueueFunction enqueue_function,
HandleRequestFunction handle_request_function) {
auto call =
new ServerBidirectionalStreamingCall<Service, GrpcService,
RequestMessage, ResponseMessage>(
handle_request_function, grpc_service, cq, enqueue_function);
// Initial ref for call handed to grpc; released in Tag callback.
(grpc_service->*enqueue_function)(&call->ctx_, &call->stream_, cq, cq,
&call->call_open_tag_);
}
const RequestMessage& request() const { return request_; }
ResponseMessage* mutable_response() { return &response_; }
private:
// Request and response messages are reused for each request/response exchange
// between the client and the server.
RequestMessage request_;
ResponseMessage response_;
::grpc::ServerContext ctx_;
HandleRequestFunction handle_request_function_;
::grpc::ServerAsyncReaderWriter<ResponseMessage, RequestMessage> stream_;
// Used as void* completion markers from grpc to indicate different
// events of interest for a ServerBidirectionalStreamingCall.
typedef typename ServerUntypedBidirectionalStreamingCall<Service>::Tag Tag;
// At most one tag of each kind may be given to gRPC at any one time.
// Beyond semantic sanity, this is needed to ensure proper ref counting
// of this call object.
Tag call_open_tag_{this, Tag::TagType::kCallOpen};
Tag request_received_tag_{this, Tag::TagType::kRequestReceived};
Tag response_sent_tag_{this, Tag::TagType::kResponseSent};
Tag server_finished_tag_{this, Tag::TagType::kServerFinished};
// These fields are used only to spawn another instance of this to accept
// more streaming calls.
GrpcService* grpc_service_;
::grpc::ServerCompletionQueue* cq_;
EnqueueFunction enqueue_function_;
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
#include "grpcpp/grpcpp.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
@ -32,7 +31,7 @@ class GrpcClientCQTag {
virtual ~GrpcClientCQTag() {}
// OnCompleted is invoked when the RPC has finished.
// Implementations of OnCompleted must delete *this.
// Implementations of OnCompleted can delete *this.
virtual void OnCompleted(bool ok) = 0;
private:

View File

@ -0,0 +1,211 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
#include "absl/strings/str_format.h"
namespace tensorflow {
const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type) {
switch (tag_type) {
case UntypedStreamingRPCState::Tag::TagType::kCallStarted:
return "kCallStarted";
case UntypedStreamingRPCState::Tag::TagType::kRequestWriteCompleted:
return "kRequestWriteCompleted";
case UntypedStreamingRPCState::Tag::TagType::kResponseReadCommpleted:
return "kResponseReadCommpleted";
}
}
UntypedStreamingRPCState::Tag::Tag(UntypedStreamingRPCState* streaming_state,
Tag::TagType type)
: streaming_state_(streaming_state), type_(type) {}
void UntypedStreamingRPCState::Tag::OnCompleted(bool ok) {
switch (type_) {
case TagType::kCallStarted:
streaming_state_->CallStarted(ok);
break;
case TagType::kRequestWriteCompleted:
streaming_state_->RequestWriteCompleted(ok);
break;
case TagType::kResponseReadCommpleted:
streaming_state_->ResponseReadCompleted(ok);
break;
}
streaming_state_->Unref(); // Ref acquired when tag was handed to grpc.
}
void Exchange::Complete(Status status) {
if (status.ok()) {
if (!GrpcMaybeParseProto(&response_buf_, response_)) {
status.Update(errors::Internal("could not parse rpc response"));
}
}
cb_(status);
}
std::ostream& operator<<(std::ostream& os, const Exchange::State& state) {
os << ToString(state);
return os;
}
const char* ToString(Exchange::State state) {
switch (state) {
case Exchange::State::kExchangeCreated:
return "ExchangeCreated";
case Exchange::State::kRequestWriteIssued:
return "RequestWriteIssued";
case Exchange::State::kRequestWriteCompleted:
return "RequestWriteCompleted";
case Exchange::State::kResponseReadIssued:
return "ResponseReadIssued";
}
}
string Exchange::DebugString() const {
return absl::StrFormat("%p@%s", this, ToString(state_));
}
void ExchangeQueue::Emplace(const ::grpc::ByteBuffer& request_buf,
protobuf::Message* response, StatusCallback cb) {
exchanges_.emplace(exchanges_.end(), request_buf, response, std::move(cb));
}
Exchange* ExchangeQueue::GetReadyForRequestWriting() {
CheckInvariants();
if (!call_started_) {
return nullptr;
}
// TODO(iga): Optimize to avoid linear search.
for (Exchange& e : exchanges_) {
if (e.state() == Exchange::State::kExchangeCreated) {
return &e;
} else if (e.state() == Exchange::State::kRequestWriteIssued) {
return nullptr;
}
}
return nullptr;
}
Exchange* ExchangeQueue::GetReadyForResponseReading() {
CheckInvariants();
if (!call_started_) {
// We should never ask for response reading when call has not
// been started, but it does not hurt to defensively check here anyway.
return nullptr;
}
if (exchanges_.empty()) {
return nullptr;
}
Exchange& e = exchanges_[0];
if (e.state() == Exchange::State::kRequestWriteCompleted) {
return &e;
}
return nullptr;
}
void ExchangeQueue::MarkRequestWriteCompleted() {
CheckInvariants();
// TODO(iga): Optimize to avoid linear search.
for (Exchange& e : exchanges_) {
if (e.state() == Exchange::State::kRequestWriteIssued) {
e.MarkRequestWriteCompleted();
}
}
CheckInvariants();
}
Exchange& ExchangeQueue::GetFront() {
CheckInvariants();
return exchanges_.front();
}
void ExchangeQueue::PopFront() {
CheckInvariants();
exchanges_.pop_front();
}
string ExchangeQueue::DebugString() const {
return absl::StrJoin(exchanges_, ", ", [](string* out, const Exchange& e) {
out->append(e.DebugString());
});
}
void ExchangeQueue::Swap(ExchangeQueue* other) {
exchanges_.swap(other->exchanges_);
std::swap(call_started_, other->call_started_);
}
void ExchangeQueue::CompleteAll(Status status) {
for (Exchange& exchange : exchanges_) {
exchange.Complete(status);
}
}
namespace {
std::set<std::pair<Exchange::State, Exchange::State>>*
GetPossibleTransitions() {
std::set<std::pair<Exchange::State, Exchange::State>>* s =
new std::set<std::pair<Exchange::State, Exchange::State>>();
// Regular state transitions
s->emplace(Exchange::State::kExchangeCreated,
Exchange::State::kRequestWriteIssued);
s->emplace(Exchange::State::kRequestWriteIssued,
Exchange::State::kRequestWriteCompleted);
s->emplace(Exchange::State::kRequestWriteCompleted,
Exchange::State::kResponseReadIssued);
// Self transitions. Possible when several exchanges can be in
// the same state.
s->emplace(Exchange::State::kExchangeCreated,
Exchange::State::kExchangeCreated);
s->emplace(Exchange::State::kRequestWriteCompleted,
Exchange::State::kRequestWriteCompleted);
// Skip transitions. Possible when there are no exchanges in a
// certain state.
s->emplace(Exchange::State::kExchangeCreated,
Exchange::State::kRequestWriteCompleted);
s->emplace(Exchange::State::kExchangeCreated,
Exchange::State::kResponseReadIssued);
s->emplace(Exchange::State::kRequestWriteIssued,
Exchange::State::kResponseReadIssued);
return s;
}
} // namespace
void ExchangeQueue::CheckInvariants() {
static std::set<std::pair<Exchange::State, Exchange::State>>*
possible_transitions = GetPossibleTransitions();
if (!VLOG_IS_ON(5)) {
return;
}
for (int i = 1; i < exchanges_.size(); ++i) {
const Exchange& e0 = exchanges_[i - 1];
const Exchange& e1 = exchanges_[i];
// The first exchange in the pair is the one that arrived later and is
// behind in processing.
auto p = std::make_pair(e1.state(), e0.state());
if (possible_transitions->find(p) == possible_transitions->end()) {
LOG(FATAL)
<< "Found an impossible state transition in the exchange queue: "
<< p.first << " -> " << p.second;
}
}
}
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
#include <queue>
#include <utility>
#include "grpcpp/generic/generic_stub.h"
@ -24,9 +25,11 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
namespace tensorflow {
@ -178,6 +181,423 @@ class RPCState : public GrpcClientCQTag {
bool fail_fast_;
};
// Represents state associated with one streaming RPC call.
// Similarly to above, we extract the methods of StreamingRPCState that don't
// need to be templated into this abstract class.
// Currently, *StreamingRPCState does not support client closing the call as
// there is no use case for it - current clients keep the streaming call open
// as long as possible. If/when the need arises, support can be added
// by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
// TagType::kClientFinished and handling the completion in a new callback.
class UntypedStreamingRPCState : public core::RefCounted {
public:
virtual void CallStarted(bool ok) = 0;
virtual void RequestWriteCompleted(bool ok) = 0;
virtual void ResponseReadCompleted(bool ok) = 0;
virtual string DebugString() const = 0;
class Tag : public GrpcClientCQTag {
public:
// One enum value per supported callback.
enum class TagType {
kCallStarted,
kRequestWriteCompleted,
kResponseReadCommpleted,
};
Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
// Calls the callback associated with this tag and Unrefs
// `this->streaming_state_`.
void OnCompleted(bool ok) override;
private:
// OnCompleted() consumes on reference each time it is called.
UntypedStreamingRPCState* const streaming_state_;
const Tag::TagType type_;
};
};
const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
// Represents a single request/response exchange between client and the server.
// A single streaming call contains a sequence of exchanges. Besides the
// messages, exchange contains:
// - the user callback to invoke when exchange completes (response is received
// or an error occurs).
// - The current state of the exchange.
class Exchange {
public:
enum class State {
kExchangeCreated,
kRequestWriteIssued,
kRequestWriteCompleted,
kResponseReadIssued,
};
Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
StatusCallback cb)
: state_(State::kExchangeCreated),
request_buf_(request_buf),
response_(response),
cb_(std::move(cb)) {}
const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
::grpc::ByteBuffer* response_buf() { return &response_buf_; }
void MarkRequestWriteIssued() {
DCHECK(state_ == State::kExchangeCreated);
state_ = State::kRequestWriteIssued;
}
void MarkRequestWriteCompleted() {
DCHECK(state_ == State::kRequestWriteIssued);
state_ = State::kRequestWriteCompleted;
}
void MarkResponseReadIssued() {
DCHECK(state_ == State::kRequestWriteCompleted);
state_ = State::kResponseReadIssued;
}
// If `status` is success, completes this exchange by parsing the
// response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
// callback with `status`.
void Complete(Status status);
const State& state() const { return state_; }
string DebugString() const;
private:
State state_;
::grpc::ByteBuffer request_buf_;
::grpc::ByteBuffer response_buf_;
protobuf::Message* response_;
StatusCallback cb_;
};
const char* ToString(Exchange::State s);
std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
// Represents a queue of exchanges.
// When a client sends a new request a new exchange is created and added to the
// end of the queue. Completed exchanges are popped from the front of the queue.
// An explicit exchange queue is needed to brdige the client, which can send new
// requests at any time, with gRPC infrastructure, which can handle a single
// read and a single write request at a time.
//
// As the exchange progresses (request sending initiated, request sending
// completed, response reading initiated) the queue helps to make sure that the
// right operation is issued on the right exchange at the right time.
//
// To satisfy gRPC constraints, the states of exchanges must be as follows
// starting from the front of the queue:
// - 0 or 1 exchange in kResponseReadIssued state
// - 0 or more exchanges in kRequestWriteCompleted state
// - 0 or 1 exchange in kRequestWriteIssued state
// - 0 or more exchanges in kExchangeCreated state
//
// Thread-compatible.
class ExchangeQueue {
public:
// Creates a new exchange and adds it to the end of the queue.
void Emplace(const ::grpc::ByteBuffer& request_buf,
protobuf::Message* response, StatusCallback cb);
// Returns an exchange for which we can initiated request writing, if any.
// Returns nullptr if there is no such exchange.
Exchange* GetReadyForRequestWriting();
// Returns an exchange for which we can initiate response reading, if any.
// Returns nullptr if there is no such exchange.
Exchange* GetReadyForResponseReading();
// Changes the state of the exchange that is current in kRequestWriteIssued
// state to kRequestWriteCompleted state.
// REQUIRES: There is an exhange in kRequestWriteIssued state.
void MarkRequestWriteCompleted();
// Returns the exchange at the front of the queue.
// REQUIRES: ExchangeQueue is not empty.
Exchange& GetFront();
// Removes the exchange at the front of the queue.
// REQUIRES: ExchangeQueue is not empty.
void PopFront();
// Returns a string containing addresses and states of all exchanges in this
// queue.
string DebugString() const;
// Swaps the contents of this and `other`.
void Swap(ExchangeQueue* other);
// Completes all exchanges in this with `status`.
void CompleteAll(Status status);
void CallStarted() { call_started_ = true; }
private:
// Does nothing by default. Turn on VLOG(5) to enable.
// Checks that this ExchangeQueue is in a valid state.
// Kills the process if not.
void CheckInvariants();
// We can't process any exchanges until the call has started.
bool call_started_ = false;
// std::queue is based on std::deque by default. std::deque provides
// fairly strong iterator stability.
std::deque<Exchange> exchanges_;
}; // namespace tensorflow
// Represents state associated with one streaming RPC call.
// Thread-safe
template <class Response>
class StreamingRPCState : public UntypedStreamingRPCState {
public:
// Default behavior is to set fail_fast = False and handle timeouts
// manually.
StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
const std::shared_ptr<::grpc::ClientContext>& context)
: context_(context), call_(std::move(call)), call_done_(false) {
Ref();
call_->StartCall(&call_started_tag_);
}
// Attempts to send the next request. `done` is invoked when
// `response` has been filled with the data from the server, or if there
// is an error. `done` can be invoked before SendNextRequest returns.
// Return `true` if the call is alive and the `done` callback has or
// will be invoked. If the call is dead, returns `false`. `done` callback
// will not be invoked in this case.
// REQUIRES: The call has been started, i.e. WaitForCallStarted() has
// returned.
bool SendNextRequest(const protobuf::Message& request, Response* response,
const StatusCallback& done) {
::grpc::ByteBuffer request_buf;
::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
if (!s.ok()) {
Status status = FromGrpcStatus(s);
LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
<< status.ToString();
done(status);
return true;
}
mutex_lock l(mu_);
if (call_done_) {
// `done` is not invoked intentionally.
return false;
}
exchanges_.Emplace(request_buf, response, done);
MaybeIssueRequestWriteLocked();
return true;
}
void CallStarted(bool ok) override {
mutex_lock l(mu_);
if (!ok) {
call_done_ = true;
return;
}
exchanges_.CallStarted();
// Now that the call has started, we can write our first request, if any.
MaybeIssueRequestWriteLocked();
}
void RequestWriteCompleted(bool ok) override {
mu_.lock();
if (call_done_) {
mu_.unlock();
return;
}
if (!ok) {
// unlocks mu_
MarkDoneAndCompleteExchanges();
return;
}
exchanges_.MarkRequestWriteCompleted();
MaybeIssueResponseReadLocked();
MaybeIssueRequestWriteLocked();
mu_.unlock();
}
void ResponseReadCompleted(bool ok) override {
mu_.lock();
if (call_done_) {
mu_.unlock();
return;
}
if (!ok) {
// unlocks mu_
MarkDoneAndCompleteExchanges();
return;
}
// Complete the exchange without holding the lock because user's
// callback can call back into this RPC code resulting in a deadlock.
// No other thread can pop this exchange while we release the lock because
// this is the only method that pops exchanges and it is called from a
// single thread that waits on completion queue events.
Exchange* e;
e = &exchanges_.GetFront();
mu_.unlock();
e->Complete(Status::OK());
{
mutex_lock l(mu_);
exchanges_.PopFront();
MaybeIssueResponseReadLocked();
}
}
string DebugString() const override {
mutex_lock l(mu_);
return exchanges_.DebugString();
}
private:
void MarkDoneAndCompleteExchanges() EXCLUSIVE_LOCKS_REQUIRED(mu_)
UNLOCK_FUNCTION(mu_) {
call_done_ = true;
Status status = errors::Unknown("gRPC streaming call has ended: ",
context_->debug_error_string());
// Swap the exchanges_ into a temporary ExchangeQueue so that we can
// complete all exchanges without holding mu_ in case user callback
// reach back into this. This should be impossible now, but safer for
// the future.
ExchangeQueue queue;
exchanges_.Swap(&queue);
mu_.unlock();
queue.CompleteAll(status);
}
void MaybeIssueRequestWriteLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Exchange* exchange = exchanges_.GetReadyForRequestWriting();
if (exchange == nullptr) {
// There are no queued exchanges, there is already an outstanding write,
// or there are no just created exchanges.
return;
}
exchange->MarkRequestWriteIssued();
Ref();
call_->Write(exchange->request_buf(), &request_write_completed_tag_);
}
void MaybeIssueResponseReadLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
Exchange* exchange = exchanges_.GetReadyForResponseReading();
if (exchange == nullptr) {
return;
}
exchange->MarkResponseReadIssued();
Ref();
call_->Read(exchange->response_buf(), &response_read_completed_tag_);
}
// Holds state for a single request/response exchange between the client
// and the server.
typedef typename UntypedStreamingRPCState::Tag Tag;
// Order of context_ and call_ is important because context_ must outlive
// call_.
const std::shared_ptr<const ::grpc::ClientContext> context_;
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
mutable mutex mu_;
ExchangeQueue exchanges_ GUARDED_BY(mu_);
bool call_done_ GUARDED_BY(mu_);
// We can get away with having single instances of these tags per
// StreamingRPCState because we make sure (as gRPC requires) that
// there is at most one outstanding Read and at most one outstanding Write
// in the completion queue.
// Tags are immutable. No need to guard them.
Tag call_started_tag_{this, Tag::TagType::kCallStarted};
Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCommpleted};
};
// Creates streaming calls and dispatches requests to them.
// In the common case, the client would create a StreamingRPCDispatcher for
// each bidirectional streaming RPC it might want to make. The first time, it
// calls SendNextRequest, a streaming call is initiated and the request is
// sent within this call. Initiation of the call blocks the client. If there are
// no errors, subsequent calls to SendNextRequest would use the already active
// call. If there was an error, the call object will be destroyed after all
// the callbacks for outstanding requests have been invoked. The next call to
// SendNextRequest will initiate a new call.
//
// Callbacks that are part of the same call, are invoked in the order they were
// provided, but callbacks across calls (a failed and a new one) can be invoked
// in any order.
//
// Thread-safe.
template <class Response>
class StreamingRPCDispatcher {
public:
StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
const ::grpc::string& method)
: stub_(stub), cq_(cq), method_(method) {}
// Attempts to send the next request. If there is no active streaming call,
// starts one and sends the request on top of it. `done` is invoked when
// `response` has been filled with the data from the server, or if there
// is an error. `done` can be invoked before SendNextRequest returns.
void SendNextRequest(const protobuf::Message& request, Response* response,
StatusCallback done) {
mutex_lock l(mu_);
if (state_ == nullptr) {
CreateStreamingState();
}
bool is_call_alive = state_->SendNextRequest(request, response, done);
if (is_call_alive) {
return;
}
// The attempt to send failed because the call was dead, create a new
// call and try again. When the call is dead SendNextRequest does not call
// `done`.
CreateStreamingState();
is_call_alive = state_->SendNextRequest(request, response, done);
if (!is_call_alive) {
// Consider retrying to create and start a call few more times.
done(errors::Unknown("gRPC call failed right after it was created"));
}
}
private:
void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
// ClientContext cannot be reused across calls.
context_ = std::make_shared<::grpc::ClientContext>();
// Don't immediately fail StartCall if the channel is not ready. Wait for
// the channel to become ready.
context_->set_wait_for_ready(true);
std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
std::move(stub_->PrepareCall(context_.get(), method_, cq_));
state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
}
mutable mutex mu_;
// Both are thread-safe
::grpc::GenericStub* const stub_;
::grpc::CompletionQueue* const cq_;
// Does not need synchronization since it is constant.
const ::grpc::string method_;
std::shared_ptr<::grpc::ClientContext> context_ GUARDED_BY(mu_);
core::RefCountPtr<StreamingRPCState<Response>> state_ GUARDED_BY(mu_);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_

View File

@ -176,6 +176,14 @@ service EagerService {
// future calls to Enqueue.
rpc Enqueue(EnqueueRequest) returns (EnqueueResponse);
// A streaming version of Enqueue.
// Current server implementation sends one response per received request.
// The benefit for using a streaming version is that subsequent requests
// can be sent without waiting for a response to the previous request. This
// synchronization is required in the regular Enqueue call because gRPC does
// not guarantee to preserve request order.
rpc StreamingEnqueue(stream EnqueueRequest) returns (stream EnqueueResponse);
// Takes a set of op IDs and waits until those ops are done. Returns any error
// in the stream so far.
rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);