diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index 869fe1496ea..7bfe34b0c95 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -27,7 +27,9 @@ namespace eager { GrpcEagerServiceImpl::GrpcEagerServiceImpl( const WorkerEnv* env, ::grpc::ServerBuilder* server_builder) - : env_(env), local_impl_(env) { + : env_(env), + local_impl_(env), + enqueue_streaming_thread_(env_->env, "enqueue_streaming_thread", 1) { server_builder->RegisterService(&service_); cq_ = server_builder->AddCompletionQueue(); } diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 0d979cd99cd..ae9477049ab 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -77,38 +77,42 @@ class GrpcEagerServiceImpl : public AsyncServiceInterface { // call. // StreamingEnqueueHandler gets the request from the `call` and fills the // response (also found in `call`) by invoking the local EagerServiceImpl. - // The local EagerServiceImpl is invoked in this thread instead of using a - // thread-pool as is done for all other methods above. We do this to preserve - // request order. The local service can parallelize based on context_id in - // request if necessary. Remote contexts are created in async mode by default, - // so the local service impl just puts the request on eager executor queue. + // The local EagerServiceImpl is invoked in a single-threaded thread pool. We + // do this to preserve request order. The local service can parallelize based + // on context_id in request if necessary. Remote contexts are created in async + // mode by default, so the local service impl just puts the request on eager + // executor queue. void StreamingEnqueueHandler( StreamingCall* call) { - // NOTE(fishx): Use the address of StreamingCall as the stream_id since we - // reuse the same StreamingCall for multiple requests in the same streaming - // connection. - Status status = - local_impl_.Enqueue(&call->request(), call->mutable_response(), - reinterpret_cast(static_cast(call))); + enqueue_streaming_thread_.Schedule([this, call]() { + // NOTE(fishx): Use the address of StreamingCall as the stream_id since we + // reuse the same StreamingCall for multiple requests in the same + // streaming connection. + Status status = local_impl_.Enqueue( + &call->request(), call->mutable_response(), + reinterpret_cast(static_cast(call))); - if (status.ok()) { - VLOG(1) << "local_impl_.Enqueue completed successfully"; - call->SendResponse(); - } else { - VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString() - << " on request " << call->request().DebugString(); - call->Finish(ToGrpcStatus(status)); - } + if (status.ok()) { + VLOG(1) << "local_impl_.Enqueue completed successfully"; + call->SendResponse(); + } else { + VLOG(1) << "local_impl_.Enqueue failed with " << status.ToString() + << " on request " << call->request().DebugString(); + call->Finish(ToGrpcStatus(status)); + } - // We do not tell gRPC to accept a new StreamingEnqueue request because this - // method can be called multiple times for a given streaming call. - // The StreamingCall does this per call instead, after a call has been - // opened. + // We do not tell gRPC to accept a new StreamingEnqueue request because + // this method can be called multiple times for a given streaming call. + // The StreamingCall does this per call instead, after a call has been + // opened. + }); } const WorkerEnv* const env_; // Not owned. EagerServiceImpl local_impl_; + // A single-threaded thread pool to handle streaming enqueue rpc request. + thread::ThreadPool enqueue_streaming_thread_; std::unique_ptr<::grpc::Alarm> shutdown_alarm_; std::unique_ptr<::grpc::ServerCompletionQueue> cq_;