Detect and report an error when a WorkerService.RunGraph request is delivered more than once.
Currently, we track request IDs for WorkerService.RecvTensor, because if a duplicate message is delivered for that method the system will deadlock. However, a duplicate WorkerService.RunGraph request can also trigger deadlock. Consider the following simple example: 1. Master M divides a simple graph across workers W_0 and W_1, where W_0 sends a tensor to W_1. 2. M calls W_0.RunGraph and W_1.RunGraph. 3. W_1 calls W_0.RecvTensor to receive the tensor, and receives a response. 4. W_0.RunGraph returns successfully. 5. M retries the call to W_1.RunGraph. 6. W_1 again calls W_0.RecvTensor to receive the tensor, and never receives a response because the sender has completed. To work around this problem, we extend the duplicate request ID tracking to WorkerService.RunGraph. PiperOrigin-RevId: 239327645
This commit is contained in:
parent
144729f9ee
commit
cd174ffdb7
@ -205,6 +205,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:device_tracer",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:recent_request_ids",
|
||||
],
|
||||
)
|
||||
|
||||
@ -334,6 +335,7 @@ cc_library(
|
||||
":call_options",
|
||||
":master_env",
|
||||
":message_wrappers",
|
||||
":request_id",
|
||||
":scheduler",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/profile_handler.h"
|
||||
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
|
||||
#include "tensorflow/core/debug/debug_graph_utils.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/scheduler.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
@ -633,6 +634,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
|
||||
c->req->set_step_id(step_id);
|
||||
*c->req->mutable_exec_opts() = exec_opts;
|
||||
c->req->set_store_errors_in_response_body(true);
|
||||
c->req->set_request_id(GetUniqueRequestId());
|
||||
// If any feeds are provided, send the feed values together
|
||||
// in the RunGraph request.
|
||||
// In the partial case, we only want to include feeds provided in the req.
|
||||
|
@ -392,6 +392,12 @@ void InMemoryRunGraphRequest::set_store_errors_in_response_body(
|
||||
store_errors_in_response_body_ = store_errors;
|
||||
}
|
||||
|
||||
int64 InMemoryRunGraphRequest::request_id() const { return request_id_; }
|
||||
|
||||
void InMemoryRunGraphRequest::set_request_id(int64 request_id) {
|
||||
request_id_ = request_id;
|
||||
}
|
||||
|
||||
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
|
||||
if (!proto_version_) {
|
||||
proto_version_.reset(new RunGraphRequest);
|
||||
@ -412,6 +418,9 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
|
||||
proto_version_->set_is_partial(is_partial());
|
||||
proto_version_->set_is_last_partial_run(is_last_partial_run());
|
||||
}
|
||||
proto_version_->set_store_errors_in_response_body(
|
||||
store_errors_in_response_body_);
|
||||
proto_version_->set_request_id(request_id_);
|
||||
return *proto_version_;
|
||||
}
|
||||
|
||||
@ -532,6 +541,14 @@ void MutableProtoRunGraphRequest::set_store_errors_in_response_body(
|
||||
request_.set_store_errors_in_response_body(store_errors);
|
||||
}
|
||||
|
||||
int64 MutableProtoRunGraphRequest::request_id() const {
|
||||
return request_.request_id();
|
||||
}
|
||||
|
||||
void MutableProtoRunGraphRequest::set_request_id(int64 request_id) {
|
||||
request_.set_request_id(request_id);
|
||||
}
|
||||
|
||||
const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
|
||||
return request_;
|
||||
}
|
||||
@ -589,6 +606,10 @@ bool ProtoRunGraphRequest::store_errors_in_response_body() const {
|
||||
return request_->store_errors_in_response_body();
|
||||
}
|
||||
|
||||
int64 ProtoRunGraphRequest::request_id() const {
|
||||
return request_->request_id();
|
||||
}
|
||||
|
||||
const RunGraphRequest& ProtoRunGraphRequest::ToProto() const {
|
||||
return *request_;
|
||||
}
|
||||
|
@ -87,6 +87,9 @@ class RunStepRequestWrapper {
|
||||
// truncate long metadata messages.
|
||||
virtual bool store_errors_in_response_body() const = 0;
|
||||
|
||||
// Unique identifier for this request. Every RunGraphRequest must have a
|
||||
// unique request_id, and retried RunGraphRequests must have the same
|
||||
// request_id. If request_id is zero, retry detection is disabled.
|
||||
virtual int64 request_id() const = 0;
|
||||
|
||||
// Returns a human-readable representation of this message for debugging.
|
||||
@ -292,6 +295,8 @@ class RunGraphRequestWrapper {
|
||||
// truncate long metadata messages.
|
||||
virtual bool store_errors_in_response_body() const = 0;
|
||||
|
||||
virtual int64 request_id() const = 0;
|
||||
|
||||
// Returns the wrapped data as a protocol buffer message.
|
||||
virtual const RunGraphRequest& ToProto() const = 0;
|
||||
};
|
||||
@ -320,6 +325,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
|
||||
virtual void set_is_partial(bool is_partial) = 0;
|
||||
virtual void set_is_last_partial_run(bool is_last_partial_run) = 0;
|
||||
virtual void set_store_errors_in_response_body(bool store_errors) = 0;
|
||||
virtual void set_request_id(int64 request_id) = 0;
|
||||
};
|
||||
|
||||
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
@ -339,6 +345,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
bool is_last_partial_run() const override;
|
||||
const RunGraphRequest& ToProto() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
|
||||
// MutableRunGraphRequestWrapper methods.
|
||||
void set_session_handle(const string& handle) override;
|
||||
@ -356,6 +363,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
void set_is_partial(bool is_partial) override;
|
||||
void set_is_last_partial_run(bool is_last_partial_run) override;
|
||||
void set_store_errors_in_response_body(bool store_errors) override;
|
||||
void set_request_id(int64 request_id) override;
|
||||
|
||||
private:
|
||||
string session_handle_;
|
||||
@ -368,6 +376,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
bool is_partial_ = false;
|
||||
bool is_last_partial_run_ = false;
|
||||
bool store_errors_in_response_body_ = false;
|
||||
int64 request_id_ = 0;
|
||||
|
||||
// Holds a cached and owned representation of the proto
|
||||
// representation of this request, if needed, so that `ToProto()`
|
||||
@ -395,6 +404,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
bool is_partial() const override;
|
||||
bool is_last_partial_run() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
const RunGraphRequest& ToProto() const override;
|
||||
|
||||
// MutableRunGraphRequestWrapper methods.
|
||||
@ -413,6 +423,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
|
||||
void set_is_partial(bool is_partial) override;
|
||||
void set_is_last_partial_run(bool is_last_partial_run) override;
|
||||
void set_store_errors_in_response_body(bool store_errors) override;
|
||||
void set_request_id(int64 request_id) override;
|
||||
|
||||
private:
|
||||
RunGraphRequest request_;
|
||||
@ -436,6 +447,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
|
||||
bool is_partial() const override;
|
||||
bool is_last_partial_run() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
const RunGraphRequest& ToProto() const override;
|
||||
|
||||
private:
|
||||
|
@ -64,16 +64,5 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
request.ShortDebugString());
|
||||
}
|
||||
}
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const RunStepRequestWrapper* wrapper) {
|
||||
if (Insert(request_id)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
wrapper->ToProto().ShortDebugString());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -59,9 +59,10 @@ class RecentRequestIds {
|
||||
// ShortDebugString are added to returned errors.
|
||||
Status TrackUnique(int64 request_id, const string& method_name,
|
||||
const protobuf::Message& request);
|
||||
// Overloaded versions of the above function for wrapped protos.
|
||||
// Overloaded version of the above function for wrapped protos.
|
||||
template <typename RequestWrapper>
|
||||
Status TrackUnique(int64 request_id, const string& method_name,
|
||||
const RunStepRequestWrapper* wrapper);
|
||||
const RequestWrapper* wrapper);
|
||||
|
||||
private:
|
||||
bool Insert(int64 request_id);
|
||||
@ -75,6 +76,21 @@ class RecentRequestIds {
|
||||
std::unordered_set<int64> set_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
// Implementation details
|
||||
|
||||
template <typename RequestWrapper>
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const RequestWrapper* wrapper) {
|
||||
if (Insert(request_id)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
wrapper->ToProto().ShortDebugString());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RECENT_REQUEST_IDS_H_
|
||||
|
@ -189,7 +189,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:graph_mgr",
|
||||
"//tensorflow/core/distributed_runtime:recent_request_ids",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:worker",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
|
@ -432,7 +432,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
|
||||
: Worker(worker_env),
|
||||
recent_request_ids_(100000),
|
||||
recv_buf_max_chunk_(
|
||||
config.experimental().recv_buf_max_chunk() > 0
|
||||
? config.experimental().recv_buf_max_chunk()
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker.h"
|
||||
@ -54,7 +53,6 @@ class GrpcWorker : public Worker {
|
||||
WorkerEnv* env();
|
||||
|
||||
private:
|
||||
RecentRequestIds recent_request_ids_;
|
||||
const int32 recv_buf_max_chunk_;
|
||||
};
|
||||
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Worker::Worker(WorkerEnv* env) : env_(env) {}
|
||||
Worker::Worker(WorkerEnv* env) : env_(env), recent_request_ids_(100000) {}
|
||||
|
||||
void Worker::GetStatusAsync(const GetStatusRequest* request,
|
||||
GetStatusResponse* response, StatusCallback done) {
|
||||
@ -156,8 +156,14 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
|
||||
StatusCallback done) {
|
||||
const int64 step_id = request->step_id();
|
||||
TRACEPRINTF("RunGraph: %lld", step_id);
|
||||
Status s = recent_request_ids_.TrackUnique(request->request_id(),
|
||||
"RunGraph (Worker)", request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
Status s;
|
||||
if (request->create_worker_session_called()) {
|
||||
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
|
||||
&session);
|
||||
@ -266,9 +272,14 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
|
||||
const int64 step_id = request->step_id();
|
||||
const string& graph_handle = request->graph_handle();
|
||||
TRACEPRINTF("PartialRunGraph: %lld", step_id);
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
Status s = recent_request_ids_.TrackUnique(
|
||||
request->request_id(), "PartialRunGraph (Worker)", request);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
|
||||
Status s;
|
||||
std::shared_ptr<WorkerSession> session;
|
||||
if (request->create_worker_session_called()) {
|
||||
s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
|
||||
&session);
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/partial_run_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
|
||||
@ -109,6 +110,7 @@ class Worker : public WorkerInterface {
|
||||
|
||||
protected:
|
||||
WorkerEnv* const env_; // Not owned.
|
||||
RecentRequestIds recent_request_ids_;
|
||||
|
||||
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||
Device** src_dev);
|
||||
|
@ -16,16 +16,18 @@ limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "WorkerProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
|
||||
|
||||
// add go_package externally with copybara
|
||||
import "google/protobuf/any.proto";
|
||||
import "tensorflow/core/framework/cost_graph.proto";
|
||||
import "tensorflow/core/framework/step_stats.proto";
|
||||
import "tensorflow/core/framework/device_attributes.proto";
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/step_stats.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
@ -41,8 +43,7 @@ import "tensorflow/core/protobuf/tensorflow_server.proto";
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message GetStatusRequest {
|
||||
}
|
||||
message GetStatusRequest {}
|
||||
|
||||
message GetStatusResponse {
|
||||
repeated DeviceAttributes device_attributes = 1;
|
||||
@ -68,8 +69,7 @@ message CreateWorkerSessionRequest {
|
||||
bool isolate_session_state = 3;
|
||||
}
|
||||
|
||||
message CreateWorkerSessionResponse {
|
||||
}
|
||||
message CreateWorkerSessionResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -84,8 +84,7 @@ message DeleteWorkerSessionRequest {
|
||||
string session_handle = 1;
|
||||
}
|
||||
|
||||
message DeleteWorkerSessionResponse {
|
||||
}
|
||||
message DeleteWorkerSessionResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -186,8 +185,7 @@ message CleanupAllRequest {
|
||||
repeated string container = 1;
|
||||
}
|
||||
|
||||
message CleanupAllResponse {
|
||||
}
|
||||
message CleanupAllResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -207,7 +205,7 @@ message ExecutorOpts {
|
||||
bool record_timeline = 3;
|
||||
bool record_partition_graphs = 4;
|
||||
bool report_tensor_allocations_upon_oom = 5;
|
||||
};
|
||||
}
|
||||
|
||||
message RunGraphRequest {
|
||||
// session_handle is the master-generated unique id for this session.
|
||||
@ -253,7 +251,17 @@ message RunGraphRequest {
|
||||
// truncate long metadata messages.
|
||||
bool store_errors_in_response_body = 9;
|
||||
|
||||
// Next: 11
|
||||
// Unique identifier for this request. Every RunGraphRequest must have a
|
||||
// unique request_id, and retried RunGraphRequests must have the same
|
||||
// request_id. If request_id is zero, retry detection is disabled.
|
||||
//
|
||||
// Retried RunGraphRequests are problematic because they may issue a
|
||||
// RecvTensor that will have no corresponding sender and will wait forever.
|
||||
// Workers use request_ids to reject retried RunGraph requests instead of
|
||||
// waiting forever.
|
||||
int64 request_id = 11;
|
||||
|
||||
// Next: 12
|
||||
}
|
||||
|
||||
message RunGraphResponse {
|
||||
@ -295,8 +303,7 @@ message CleanupGraphRequest {
|
||||
int64 step_id = 1;
|
||||
}
|
||||
|
||||
message CleanupGraphResponse {
|
||||
}
|
||||
message CleanupGraphResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -424,8 +431,7 @@ message TracingRequest {
|
||||
TraceOpts options = 1;
|
||||
}
|
||||
|
||||
message TracingResponse {
|
||||
}
|
||||
message TracingResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
|
Loading…
Reference in New Issue
Block a user