Implement duplicate checking on Master methods

PiperOrigin-RevId: 232927103
This commit is contained in:
Noah Eisen 2019-02-07 12:51:13 -08:00 committed by TensorFlower Gardener
parent 8773ca70dc
commit 243118bbbf
11 changed files with 117 additions and 18 deletions

View File

@ -17,7 +17,6 @@ filegroup(
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
load("//tensorflow:tensorflow.bzl", "tf_copts")
# For platform specific build config
@ -298,6 +297,7 @@ cc_library(
deps = [
":call_options",
":message_wrappers",
":request_id",
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
],
@ -311,6 +311,7 @@ cc_library(
":call_options",
":master_env",
":master_session",
":recent_request_ids",
":remote_device",
":worker_cache",
":worker_interface",
@ -765,6 +766,7 @@ cc_library(
srcs = ["recent_request_ids.cc"],
hdrs = ["recent_request_ids.h"],
deps = [
":message_wrappers",
"//tensorflow/core:lib",
"//tensorflow/core:worker_proto_cc",
],

View File

@ -65,7 +65,8 @@ Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env),
last_1000_steps_(1000),
step_count_(0),
session_gc_seconds_(session_gc_seconds) {
session_gc_seconds_(session_gc_seconds),
recent_request_ids_(10000) {
// Right now, a master service must be co-located with a device.
// Otherwise, fetches do not work.
CHECK(!env->local_devices.empty());
@ -510,6 +511,12 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp, MyClosure done) {
Status s = recent_request_ids_.TrackUnique(req->request_id(),
"PartialRunSetup (Master)", *req);
if (!s.ok()) {
done(s);
return;
}
auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
@ -525,6 +532,12 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
MutableRunStepResponseWrapper* resp, MyClosure done) {
Status s = recent_request_ids_.TrackUnique(req->request_id(),
"RunStep (Master)", req);
if (!s.ok()) {
done(s);
return;
}
auto start_time = env_->env->NowMicros();
auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
@ -664,6 +677,12 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
void Master::MakeCallable(const MakeCallableRequest* req,
MakeCallableResponse* resp, MyClosure done) {
Status s = recent_request_ids_.TrackUnique(req->request_id(),
"MakeCallable (Master)", *req);
if (!s.ok()) {
done(s);
return;
}
auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
@ -681,6 +700,12 @@ void Master::MakeCallable(const MakeCallableRequest* req,
void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
RunCallableResponse* resp, MyClosure done) {
Status s = recent_request_ids_.TrackUnique(req->request_id(),
"RunCallable (Master)", *req);
if (!s.ok()) {
done(s);
return;
}
auto session = FindMasterSession(req->session_handle());
if (session == nullptr) {
done(errors::Aborted("Session ", req->session_handle(), " is not found."));

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/macros.h"
@ -95,6 +96,9 @@ class Master {
// closed automatically.
const double session_gc_seconds_;
// Used to track ids for incoming requests so we can detect duplicates.
RecentRequestIds recent_request_ids_;
// Call CleanupAll on all workers.
void CleanupWorkers(const ResetRequest& reset);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/master.pb.h"
@ -66,7 +67,9 @@ class MasterInterface {
// The message returned from this method must only be used in a
// `RunStep()` call on the same `MasterInterface` instance.
virtual MutableRunStepRequestWrapper* CreateRunStepRequest() {
return new MutableProtoRunStepRequest;
MutableProtoRunStepRequest* ret = new MutableProtoRunStepRequest;
ret->request_.set_request_id(GetUniqueRequestId());
return ret;
}
// Returns a response object for use in calls to

View File

@ -97,6 +97,10 @@ bool InMemoryRunStepRequest::store_errors_in_response_body() const {
return store_errors_in_response_body_;
}
int64 InMemoryRunStepRequest::request_id() const {
return 0; // no need to track request id for local version.
}
void InMemoryRunStepRequest::set_store_errors_in_response_body(
bool store_errors) {
store_errors_in_response_body_ = store_errors;
@ -210,6 +214,10 @@ void MutableProtoRunStepRequest::set_store_errors_in_response_body(
request_.set_store_errors_in_response_body(store_errors);
}
int64 MutableProtoRunStepRequest::request_id() const {
return request_.request_id();
}
string MutableProtoRunStepRequest::DebugString() const {
return request_.DebugString();
}
@ -272,6 +280,8 @@ bool ProtoRunStepRequest::store_errors_in_response_body() const {
return request_->store_errors_in_response_body();
}
int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
string ProtoRunStepRequest::DebugString() const {
return request_->DebugString();
}

View File

@ -87,6 +87,8 @@ class RunStepRequestWrapper {
// truncate long metadata messages.
virtual bool store_errors_in_response_body() const = 0;
virtual int64 request_id() const = 0;
// Returns a human-readable representation of this message for debugging.
virtual string DebugString() const = 0;
@ -127,6 +129,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
string DebugString() const override;
const RunStepRequest& ToProto() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
// MutableRunStepRequestWrapper methods.
void set_session_handle(const string& handle) override;
@ -177,6 +180,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
string DebugString() const override;
const RunStepRequest& ToProto() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
// MutableRunStepRequestWrapper methods.
void set_session_handle(const string& handle) override;
@ -189,6 +193,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
private:
RunStepRequest request_;
friend class MasterInterface;
};
// Wrapper for immutable RunStep requests that use a non-owned
@ -216,6 +221,7 @@ class ProtoRunStepRequest : public RunStepRequestWrapper {
string DebugString() const override;
const RunStepRequest& ToProto() const override;
bool store_errors_in_response_body() const override;
int64 request_id() const override;
private:
const RunStepRequest* const request_; // Not owned.

View File

@ -28,12 +28,10 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
set_.reserve(num_tracked_request_ids);
}
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const protobuf::Message& request) {
bool RecentRequestIds::Insert(int64 request_id) {
if (request_id == 0) {
// For backwards compatibility, allow all requests with request_id 0.
return Status::OK();
return true;
}
mutex_lock l(mu_);
@ -43,9 +41,7 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
// request_id's age in the circular_buffer_ if it's tracked again. Strict
// LRU is not useful here because returning this error will close the
// current Session.
return errors::Aborted("The same ", method_name,
" request was received twice. ",
request.ShortDebugString());
return false;
}
// Remove the oldest request_id from the set_. circular_buffer_ is
@ -54,7 +50,30 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
set_.erase(circular_buffer_[next_index_]);
circular_buffer_[next_index_] = request_id;
next_index_ = (next_index_ + 1) % circular_buffer_.size();
return Status::OK();
return true;
}
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const protobuf::Message& request) {
if (Insert(request_id)) {
return Status::OK();
} else {
return errors::Aborted("The same ", method_name,
" request was received twice. ",
request.ShortDebugString());
}
}
Status RecentRequestIds::TrackUnique(int64 request_id,
const string& method_name,
const RunStepRequestWrapper* wrapper) {
if (Insert(request_id)) {
return Status::OK();
} else {
return errors::Aborted("The same ", method_name,
" request was received twice. ",
wrapper->ToProto().ShortDebugString());
}
}
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
@ -58,8 +59,13 @@ class RecentRequestIds {
// ShortDebugString are added to returned errors.
Status TrackUnique(int64 request_id, const string& method_name,
const protobuf::Message& request);
// Overloaded versions of the above function for wrapped protos.
Status TrackUnique(int64 request_id, const string& method_name,
const RunStepRequestWrapper* wrapper);
private:
bool Insert(int64 request_id);
mutex mu_;
// next_index_ indexes into circular_buffer_, and points to the next storage
// space to use. When the buffer is full, next_index_ points at the oldest

View File

@ -408,6 +408,7 @@ cc_library(
"//tensorflow/core/distributed_runtime:local_master",
"//tensorflow/core/distributed_runtime:master_interface",
"//tensorflow/core/distributed_runtime:message_wrappers",
"//tensorflow/core/distributed_runtime:request_id",
],
alwayslink = 1,
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/local_master.h"
#include "tensorflow/core/distributed_runtime/master_interface.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -312,6 +313,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
for (const string& target : target_nodes) {
req.add_target(target);
}
req.set_request_id(GetUniqueRequestId());
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
*handle = resp.partial_run_handle();
@ -408,6 +410,7 @@ Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
MakeCallableRequest req;
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
*req.mutable_options() = callable_options;
req.set_request_id(GetUniqueRequestId());
MakeCallableResponse resp;
CallOptions call_options;
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
@ -423,6 +426,7 @@ Status GrpcSession::RunCallable(CallableHandle handle,
RunCallableRequest req;
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
req.set_handle(handle);
req.set_request_id(GetUniqueRequestId());
for (const Tensor& feed : feed_tensors) {
feed.AsProtoTensorContent(req.mutable_feed()->Add());
}

View File

@ -16,11 +16,13 @@ limitations under the License.
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "DistributedRuntimeProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.distruntime";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
// add go_package externally with copybara
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
@ -138,6 +140,11 @@ message RunStepRequest {
// response body. This is a workaround since the RPC subsystem may
// truncate long metadata messages.
bool store_errors_in_response_body = 7;
// Unique identifier for this request. Every RunStepRequest must
// have a unique request_id, and retried RunStepRequest must have
// the same request_id. If request_id is zero, retry detection is disabled.
int64 request_id = 8;
}
message RunStepResponse {
@ -183,6 +190,11 @@ message PartialRunSetupRequest {
// Target Nodes. A list of node names. The named nodes will be run in future
// steps, but their outputs will not be fetched.
repeated string target = 4;
// Unique identifier for this request. Every PartialRunSetupRequest must
// have a unique request_id, and retried PartialRunSetupRequest must have
// the same request_id. If request_id is zero, retry detection is disabled.
int64 request_id = 5;
}
message PartialRunSetupResponse {
@ -204,8 +216,7 @@ message CloseSessionRequest {
string session_handle = 1;
}
message CloseSessionResponse {
}
message CloseSessionResponse {}
// Reset() allows misbehaving or slow sessions to be aborted and closed, and
// causes their resources eventually to be released. Reset() does not wait
@ -237,8 +248,7 @@ message ResetRequest {
repeated string device_filters = 2;
}
message ResetResponse {
}
message ResetResponse {}
////////////////////////////////////////////////////////////////////////////////
//
@ -279,6 +289,11 @@ message MakeCallableRequest {
// Options that define the behavior of the created callable.
CallableOptions options = 2;
// Unique identifier for this request. Every MakeCallableRequest must
// have a unique request_id, and retried MakeCallableRequest must have
// the same request_id. If request_id is zero, retry detection is disabled.
int64 request_id = 3;
}
message MakeCallableResponse {
@ -303,6 +318,11 @@ message RunCallableRequest {
// Values of the tensors passed as arguments to the callable, in the order
// defined in the CallableOptions.feed field passed to MakeCallable.
repeated TensorProto feed = 3;
// Unique identifier for this request. Every RunCallableRequest must
// have a unique request_id, and retried RunCallableRequest must have
// the same request_id. If request_id is zero, retry detection is disabled.
int64 request_id = 4;
}
message RunCallableResponse {
@ -330,5 +350,4 @@ message ReleaseCallableRequest {
int64 handle = 2;
}
message ReleaseCallableResponse {
}
message ReleaseCallableResponse {}