Implement duplicate checking on Master methods
PiperOrigin-RevId: 232927103
This commit is contained in:
parent
8773ca70dc
commit
243118bbbf
@ -17,7 +17,6 @@ filegroup(
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||
|
||||
# For platform specific build config
|
||||
@ -298,6 +297,7 @@ cc_library(
|
||||
deps = [
|
||||
":call_options",
|
||||
":message_wrappers",
|
||||
":request_id",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
],
|
||||
@ -311,6 +311,7 @@ cc_library(
|
||||
":call_options",
|
||||
":master_env",
|
||||
":master_session",
|
||||
":recent_request_ids",
|
||||
":remote_device",
|
||||
":worker_cache",
|
||||
":worker_interface",
|
||||
@ -765,6 +766,7 @@ cc_library(
|
||||
srcs = ["recent_request_ids.cc"],
|
||||
hdrs = ["recent_request_ids.h"],
|
||||
deps = [
|
||||
":message_wrappers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
],
|
||||
|
@ -65,7 +65,8 @@ Master::Master(MasterEnv* env, double session_gc_seconds)
|
||||
: env_(env),
|
||||
last_1000_steps_(1000),
|
||||
step_count_(0),
|
||||
session_gc_seconds_(session_gc_seconds) {
|
||||
session_gc_seconds_(session_gc_seconds),
|
||||
recent_request_ids_(10000) {
|
||||
// Right now, a master service must be co-located with a device.
|
||||
// Otherwise, fetches do not work.
|
||||
CHECK(!env->local_devices.empty());
|
||||
@ -510,6 +511,12 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
||||
|
||||
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
PartialRunSetupResponse* resp, MyClosure done) {
|
||||
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||
"PartialRunSetup (Master)", *req);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
auto session = FindMasterSession(req->session_handle());
|
||||
if (session == nullptr) {
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
@ -525,6 +532,12 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||
|
||||
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
|
||||
MutableRunStepResponseWrapper* resp, MyClosure done) {
|
||||
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||
"RunStep (Master)", req);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
auto start_time = env_->env->NowMicros();
|
||||
auto session = FindMasterSession(req->session_handle());
|
||||
if (session == nullptr) {
|
||||
@ -664,6 +677,12 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
|
||||
|
||||
void Master::MakeCallable(const MakeCallableRequest* req,
|
||||
MakeCallableResponse* resp, MyClosure done) {
|
||||
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||
"MakeCallable (Master)", *req);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
auto session = FindMasterSession(req->session_handle());
|
||||
if (session == nullptr) {
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
@ -681,6 +700,12 @@ void Master::MakeCallable(const MakeCallableRequest* req,
|
||||
|
||||
void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
|
||||
RunCallableResponse* resp, MyClosure done) {
|
||||
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||
"RunCallable (Master)", *req);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
auto session = FindMasterSession(req->session_handle());
|
||||
if (session == nullptr) {
|
||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
||||
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -95,6 +96,9 @@ class Master {
|
||||
// closed automatically.
|
||||
const double session_gc_seconds_;
|
||||
|
||||
// Used to track ids for incoming requests so we can detect duplicates.
|
||||
RecentRequestIds recent_request_ids_;
|
||||
|
||||
// Call CleanupAll on all workers.
|
||||
void CleanupWorkers(const ResetRequest& reset);
|
||||
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/master.pb.h"
|
||||
@ -66,7 +67,9 @@ class MasterInterface {
|
||||
// The message returned from this method must only be used in a
|
||||
// `RunStep()` call on the same `MasterInterface` instance.
|
||||
virtual MutableRunStepRequestWrapper* CreateRunStepRequest() {
|
||||
return new MutableProtoRunStepRequest;
|
||||
MutableProtoRunStepRequest* ret = new MutableProtoRunStepRequest;
|
||||
ret->request_.set_request_id(GetUniqueRequestId());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Returns a response object for use in calls to
|
||||
|
@ -97,6 +97,10 @@ bool InMemoryRunStepRequest::store_errors_in_response_body() const {
|
||||
return store_errors_in_response_body_;
|
||||
}
|
||||
|
||||
int64 InMemoryRunStepRequest::request_id() const {
|
||||
return 0; // no need to track request id for local version.
|
||||
}
|
||||
|
||||
void InMemoryRunStepRequest::set_store_errors_in_response_body(
|
||||
bool store_errors) {
|
||||
store_errors_in_response_body_ = store_errors;
|
||||
@ -210,6 +214,10 @@ void MutableProtoRunStepRequest::set_store_errors_in_response_body(
|
||||
request_.set_store_errors_in_response_body(store_errors);
|
||||
}
|
||||
|
||||
int64 MutableProtoRunStepRequest::request_id() const {
|
||||
return request_.request_id();
|
||||
}
|
||||
|
||||
string MutableProtoRunStepRequest::DebugString() const {
|
||||
return request_.DebugString();
|
||||
}
|
||||
@ -272,6 +280,8 @@ bool ProtoRunStepRequest::store_errors_in_response_body() const {
|
||||
return request_->store_errors_in_response_body();
|
||||
}
|
||||
|
||||
int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
|
||||
|
||||
string ProtoRunStepRequest::DebugString() const {
|
||||
return request_->DebugString();
|
||||
}
|
||||
|
@ -87,6 +87,8 @@ class RunStepRequestWrapper {
|
||||
// truncate long metadata messages.
|
||||
virtual bool store_errors_in_response_body() const = 0;
|
||||
|
||||
virtual int64 request_id() const = 0;
|
||||
|
||||
// Returns a human-readable representation of this message for debugging.
|
||||
virtual string DebugString() const = 0;
|
||||
|
||||
@ -127,6 +129,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
|
||||
string DebugString() const override;
|
||||
const RunStepRequest& ToProto() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
|
||||
// MutableRunStepRequestWrapper methods.
|
||||
void set_session_handle(const string& handle) override;
|
||||
@ -177,6 +180,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
|
||||
string DebugString() const override;
|
||||
const RunStepRequest& ToProto() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
|
||||
// MutableRunStepRequestWrapper methods.
|
||||
void set_session_handle(const string& handle) override;
|
||||
@ -189,6 +193,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
|
||||
|
||||
private:
|
||||
RunStepRequest request_;
|
||||
friend class MasterInterface;
|
||||
};
|
||||
|
||||
// Wrapper for immutable RunStep requests that use a non-owned
|
||||
@ -216,6 +221,7 @@ class ProtoRunStepRequest : public RunStepRequestWrapper {
|
||||
string DebugString() const override;
|
||||
const RunStepRequest& ToProto() const override;
|
||||
bool store_errors_in_response_body() const override;
|
||||
int64 request_id() const override;
|
||||
|
||||
private:
|
||||
const RunStepRequest* const request_; // Not owned.
|
||||
|
@ -28,12 +28,10 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
|
||||
set_.reserve(num_tracked_request_ids);
|
||||
}
|
||||
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const protobuf::Message& request) {
|
||||
bool RecentRequestIds::Insert(int64 request_id) {
|
||||
if (request_id == 0) {
|
||||
// For backwards compatibility, allow all requests with request_id 0.
|
||||
return Status::OK();
|
||||
return true;
|
||||
}
|
||||
|
||||
mutex_lock l(mu_);
|
||||
@ -43,9 +41,7 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
// request_id's age in the circular_buffer_ if it's tracked again. Strict
|
||||
// LRU is not useful here because returning this error will close the
|
||||
// current Session.
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
request.ShortDebugString());
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the oldest request_id from the set_. circular_buffer_ is
|
||||
@ -54,7 +50,30 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
set_.erase(circular_buffer_[next_index_]);
|
||||
circular_buffer_[next_index_] = request_id;
|
||||
next_index_ = (next_index_ + 1) % circular_buffer_.size();
|
||||
return Status::OK();
|
||||
return true;
|
||||
}
|
||||
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const protobuf::Message& request) {
|
||||
if (Insert(request_id)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
request.ShortDebugString());
|
||||
}
|
||||
}
|
||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||
const string& method_name,
|
||||
const RunStepRequestWrapper* wrapper) {
|
||||
if (Insert(request_id)) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return errors::Aborted("The same ", method_name,
|
||||
" request was received twice. ",
|
||||
wrapper->ToProto().ShortDebugString());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -58,8 +59,13 @@ class RecentRequestIds {
|
||||
// ShortDebugString are added to returned errors.
|
||||
Status TrackUnique(int64 request_id, const string& method_name,
|
||||
const protobuf::Message& request);
|
||||
// Overloaded versions of the above function for wrapped protos.
|
||||
Status TrackUnique(int64 request_id, const string& method_name,
|
||||
const RunStepRequestWrapper* wrapper);
|
||||
|
||||
private:
|
||||
bool Insert(int64 request_id);
|
||||
|
||||
mutex mu_;
|
||||
// next_index_ indexes into circular_buffer_, and points to the next storage
|
||||
// space to use. When the buffer is full, next_index_ points at the oldest
|
||||
|
@ -408,6 +408,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:local_master",
|
||||
"//tensorflow/core/distributed_runtime:master_interface",
|
||||
"//tensorflow/core/distributed_runtime:message_wrappers",
|
||||
"//tensorflow/core/distributed_runtime:request_id",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/local_master.h"
|
||||
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -312,6 +313,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
|
||||
for (const string& target : target_nodes) {
|
||||
req.add_target(target);
|
||||
}
|
||||
req.set_request_id(GetUniqueRequestId());
|
||||
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
||||
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
|
||||
*handle = resp.partial_run_handle();
|
||||
@ -408,6 +410,7 @@ Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
|
||||
MakeCallableRequest req;
|
||||
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
||||
*req.mutable_options() = callable_options;
|
||||
req.set_request_id(GetUniqueRequestId());
|
||||
MakeCallableResponse resp;
|
||||
CallOptions call_options;
|
||||
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
||||
@ -423,6 +426,7 @@ Status GrpcSession::RunCallable(CallableHandle handle,
|
||||
RunCallableRequest req;
|
||||
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
||||
req.set_handle(handle);
|
||||
req.set_request_id(GetUniqueRequestId());
|
||||
for (const Tensor& feed : feed_tensors) {
|
||||
feed.AsProtoTensorContent(req.mutable_feed()->Add());
|
||||
}
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
option java_outer_classname = "DistributedRuntimeProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.distruntime";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
|
||||
|
||||
// add go_package externally with copybara
|
||||
import "tensorflow/core/framework/device_attributes.proto";
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
@ -138,6 +140,11 @@ message RunStepRequest {
|
||||
// response body. This is a workaround since the RPC subsystem may
|
||||
// truncate long metadata messages.
|
||||
bool store_errors_in_response_body = 7;
|
||||
|
||||
// Unique identifier for this request. Every RunStepRequest must
|
||||
// have a unique request_id, and retried RunStepRequest must have
|
||||
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||
int64 request_id = 8;
|
||||
}
|
||||
|
||||
message RunStepResponse {
|
||||
@ -183,6 +190,11 @@ message PartialRunSetupRequest {
|
||||
// Target Nodes. A list of node names. The named nodes will be run in future
|
||||
// steps, but their outputs will not be fetched.
|
||||
repeated string target = 4;
|
||||
|
||||
// Unique identifier for this request. Every PartialRunSetupRequest must
|
||||
// have a unique request_id, and retried PartialRunSetupRequest must have
|
||||
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||
int64 request_id = 5;
|
||||
}
|
||||
|
||||
message PartialRunSetupResponse {
|
||||
@ -204,8 +216,7 @@ message CloseSessionRequest {
|
||||
string session_handle = 1;
|
||||
}
|
||||
|
||||
message CloseSessionResponse {
|
||||
}
|
||||
message CloseSessionResponse {}
|
||||
|
||||
// Reset() allows misbehaving or slow sessions to be aborted and closed, and
|
||||
// causes their resources eventually to be released. Reset() does not wait
|
||||
@ -237,8 +248,7 @@ message ResetRequest {
|
||||
repeated string device_filters = 2;
|
||||
}
|
||||
|
||||
message ResetResponse {
|
||||
}
|
||||
message ResetResponse {}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -279,6 +289,11 @@ message MakeCallableRequest {
|
||||
|
||||
// Options that define the behavior of the created callable.
|
||||
CallableOptions options = 2;
|
||||
|
||||
// Unique identifier for this request. Every MakeCallableRequest must
|
||||
// have a unique request_id, and retried MakeCallableRequest must have
|
||||
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||
int64 request_id = 3;
|
||||
}
|
||||
|
||||
message MakeCallableResponse {
|
||||
@ -303,6 +318,11 @@ message RunCallableRequest {
|
||||
// Values of the tensors passed as arguments to the callable, in the order
|
||||
// defined in the CallableOptions.feed field passed to MakeCallable.
|
||||
repeated TensorProto feed = 3;
|
||||
|
||||
// Unique identifier for this request. Every RunCallableRequest must
|
||||
// have a unique request_id, and retried RunCallableRequest must have
|
||||
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||
int64 request_id = 4;
|
||||
}
|
||||
|
||||
message RunCallableResponse {
|
||||
@ -330,5 +350,4 @@ message ReleaseCallableRequest {
|
||||
int64 handle = 2;
|
||||
}
|
||||
|
||||
message ReleaseCallableResponse {
|
||||
}
|
||||
message ReleaseCallableResponse {}
|
||||
|
Loading…
Reference in New Issue
Block a user