Detect and report an error when a WorkerService.RunGraph request is delivered more than once.

Currently, we track request IDs for WorkerService.RecvTensor, because if a duplicate message is delivered for that method the system will deadlock. However, a duplicate WorkerService.RunGraph request can also trigger deadlock. Consider the following simple example:

1. Master M divides a simple graph across workers W_0 and W_1, where W_0 sends a tensor to W_1.
2. M calls W_0.RunGraph and W_1.RunGraph.
3. W_1 calls W_0.RecvTensor to receive the tensor, and receives a response.
4. W_0.RunGraph returns successfully.
5. M retries the call to W_1.RunGraph.
6. W_1 again calls W_0.RecvTensor to receive the tensor, and never receives a response because the sender has completed.

To work around this problem, we extend the duplicate request ID tracking to WorkerService.RunGraph.

PiperOrigin-RevId: 239327645
This commit is contained in:
Derek Murray 2019-03-19 20:48:05 -07:00 committed by TensorFlower Gardener
parent 144729f9ee
commit cd174ffdb7
12 changed files with 94 additions and 37 deletions

View File

@ -205,6 +205,7 @@ tf_cuda_library(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:device_tracer",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:recent_request_ids",
],
)
@ -334,6 +335,7 @@ cc_library(
":call_options",
":master_env",
":message_wrappers",
":request_id",
":scheduler",
":worker_cache",
":worker_interface",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/profile_handler.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/scheduler.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@ -633,6 +634,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
c->req->set_store_errors_in_response_body(true);
c->req->set_request_id(GetUniqueRequestId());
// If any feeds are provided, send the feed values together
// in the RunGraph request.
// In the partial case, we only want to include feeds provided in the req.

View File

@ -392,6 +392,12 @@ void InMemoryRunGraphRequest::set_store_errors_in_response_body(
store_errors_in_response_body_ = store_errors;
}
int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
void InMemoryRunGraphRequest::set_request_id(int64 request_id) {
request_id_ = request_id;
}
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
@ -412,6 +418,9 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
proto_version_->set_is_partial(is_partial());
proto_version_->set_is_last_partial_run(is_last_partial_run());
}
proto_version_->set_store_errors_in_response_body(
store_errors_in_response_body_);
proto_version_->set_request_id(request_id_);
return *proto_version_;
}
@ -532,6 +541,14 @@ void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
request_.set_store_errors_in_response_body(store_errors);
}
int64 MutableProtoRunGraphRequest::request_id() const {
return request_.request_id();
}
void MutableProtoRunGraphRequest::set_request_id(int64 request_id) {
request_.set_request_id(request_id);
}
const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
return request_;
}
@ -589,6 +606,10 @@ bool ProtoRunGraphRequest::store_errors_in_response_body() const {
return request_->store_errors_in_response_body();
}
int64 ProtoRunGraphRequest::request_id() const {
return request_->request_id();
}
const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
return *request_;
}

View File

@ -87,6 +87,9 @@ class RunStepRequestWrapper {
// truncate long metadata messages.
virtual bool store_errors_in_response_body() const = 0;
// Unique identifier for this request. Every RunGraphRequest must have a
// unique request_id, and retried RunGraphRequests must have the same
// request_id. If request_id is zero, retry detection is disabled.
virtual int64 request_id() const = 0;
// Returns a human-readable representation of this message for debugging.
@ -292,6 +295,8 @@ class RunGraphRequestWrapper {
// truncate long metadata messages.
virtual bool store_errors_in_response_body() const = 0;
virtual int64 request_id() const = 0;
// Returns the wrapped data as a protocol buffer message.
virtual const RunGraphRequest& ToProto() const = 0;
};
@ -320,6 +325,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
virtual void set_is_partial(bool is_partial) = 0;
virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
virtual void set_store_errors_in_response_body(bool store_errors) = 0;
virtual void set_request_id(int64 request_id) = 0;
};
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
@ -339,6 +345,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
bool is_last_partial_run() const override;
const RunGraphRequest& ToProto() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
@ -356,6 +363,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
void set_store_errors_in_response_body(bool store_errors) override;
void set_request_id(int64 request_id) override;
private:
string session_handle_;
@ -368,6 +376,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
bool is_partial_ = false;
bool is_last_partial_run_ = false;
bool store_errors_in_response_body_ = false;
int64 request_id_ = 0;
// Holds a cached and owned representation of the proto
// representation of this request, if needed, so that `ToProto()`
@ -395,6 +404,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
bool is_partial() const override;
bool is_last_partial_run() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
@ -413,6 +423,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
void set_store_errors_in_response_body(bool store_errors) override;
void set_request_id(int64 request_id) override;
private:
RunGraphRequest request_;
@ -436,6 +447,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
bool is_partial() const override;
bool is_last_partial_run() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
const RunGraphRequest& ToProto() const override;
private:

View File

@ -64,16 +64,5 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
request.ShortDebugString());
}
}
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const RunStepRequestWrapper* wrapper) {
if (Insert(request_id)) {
return Status::OK();
} else {
return errors::Aborted("The same ", method_name,
" request was received twice. ",
wrapper->ToProto().ShortDebugString());
}
}
} // namespace tensorflow

View File

@ -59,9 +59,10 @@ class RecentRequestIds {
// ShortDebugString are added to returned errors.
Status TrackUnique(int64 request_id, const string& method_name,
const protobuf::Message& request);
// Overloaded versions of the above function for wrapped protos.
// Overloaded version of the above function for wrapped protos.
template <typename RequestWrapper>
Status TrackUnique(int64 request_id, const string& method_name,
const RunStepRequestWrapper* wrapper);
const RequestWrapper* wrapper);
private:
bool Insert(int64 request_id);
@ -75,6 +76,21 @@ class RecentRequestIds {
std::unordered_set<int64> set_ GUARDED_BY(mu_);
};
// Implementation details
template <typename RequestWrapper>
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const RequestWrapper* wrapper) {
if (Insert(request_id)) {
return Status::OK();
} else {
return errors::Aborted("The same ", method_name,
" request was received twice. ",
wrapper->ToProto().ShortDebugString());
}
}
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_

View File

@ -189,7 +189,6 @@ tf_cuda_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:graph_mgr",
"//tensorflow/core/distributed_runtime:recent_request_ids",
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"//tensorflow/core/distributed_runtime:worker",
"//tensorflow/core/distributed_runtime:worker_cache",

View File

@ -432,7 +432,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
: Worker(worker_env),
recent_request_ids_(100000),
recv_buf_max_chunk_(
config.experimental().recv_buf_max_chunk() > 0
? config.experimental().recv_buf_max_chunk()

View File

@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker.h"
@ -54,7 +53,6 @@ class GrpcWorker : public Worker {
WorkerEnv* env();
private:
RecentRequestIds recent_request_ids_;
const int32 recv_buf_max_chunk_;
};

View File

@ -28,7 +28,7 @@ limitations under the License.
namespace tensorflow {
Worker::Worker(WorkerEnv* env) : env_(env) {}
Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {}
void Worker::GetStatusAsync(const GetStatusRequest* request,
GetStatusResponse* response, StatusCallback done) {
@ -156,8 +156,14 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
StatusCallback done) {
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
Status s = recent_request_ids_.TrackUnique(request->request_id(),
"RunGraph (Worker)", request);
if (!s.ok()) {
done(s);
return;
}
std::shared_ptr<WorkerSession> session;
Status s;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);
@ -266,9 +272,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const int64 step_id = request->step_id();
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
std::shared_ptr<WorkerSession> session;
Status s = recent_request_ids_.TrackUnique(
request->request_id(), "PartialRunGraph (Worker)", request);
if (!s.ok()) {
done(s);
return;
}
Status s;
std::shared_ptr<WorkerSession> session;
if (request->create_worker_session_called()) {
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
&session);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/partial_run_mgr.h"
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@ -109,6 +110,7 @@ class Worker : public WorkerInterface {
protected:
WorkerEnv* const env_; // Not owned.
RecentRequestIds recent_request_ids_;
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev);

View File

@ -16,16 +16,18 @@ limitations under the License.
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "WorkerProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
// add go_package externally with copybara
import "google/protobuf/any.proto";
import "tensorflow/core/framework/cost_graph.proto";
import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/step_stats.proto";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
@ -41,8 +43,7 @@ import "tensorflow/core/protobuf/tensorflow_server.proto";
//
////////////////////////////////////////////////////////////////////////////////
message GetStatusRequest {
}
message GetStatusRequest {}
message GetStatusResponse {
repeated DeviceAttributes device_attributes = 1;
@ -68,8 +69,7 @@ message CreateWorkerSessionRequest {
bool isolate_session_state = 3;
}
message CreateWorkerSessionResponse {
}
message CreateWorkerSessionResponse {}
////////////////////////////////////////////////////////////////////////////////
//
@ -84,8 +84,7 @@ message DeleteWorkerSessionRequest {
string session_handle = 1;
}
message DeleteWorkerSessionResponse {
}
message DeleteWorkerSessionResponse {}
////////////////////////////////////////////////////////////////////////////////
//
@ -186,8 +185,7 @@ message CleanupAllRequest {
repeated string container = 1;
}
message CleanupAllResponse {
}
message CleanupAllResponse {}
////////////////////////////////////////////////////////////////////////////////
//
@ -207,7 +205,7 @@ message ExecutorOpts {
bool record_timeline = 3;
bool record_partition_graphs = 4;
bool report_tensor_allocations_upon_oom = 5;
};
}
message RunGraphRequest {
// session_handle is the master-generated unique id for this session.
@ -253,7 +251,17 @@ message RunGraphRequest {
// truncate long metadata messages.
bool store_errors_in_response_body = 9;
// Next: 11
// Unique identifier for this request. Every RunGraphRequest must have a
// unique request_id, and retried RunGraphRequests must have the same
// request_id. If request_id is zero, retry detection is disabled.
//
// Retried RunGraphRequests are problematic because they may issue a
// RecvTensor that will have no corresponding sender and will wait forever.
// Workers use request_ids to reject retried RunGraph requests instead of
// waiting forever.
int64 request_id = 11;
// Next: 12
}
message RunGraphResponse {
@ -295,8 +303,7 @@ message CleanupGraphRequest {
int64 step_id = 1;
}
message CleanupGraphResponse {
}
message CleanupGraphResponse {}
////////////////////////////////////////////////////////////////////////////////
//
@ -424,8 +431,7 @@ message TracingRequest {
TraceOpts options = 1;
}
message TracingResponse {
}
message TracingResponse {}
////////////////////////////////////////////////////////////////////////////////
//