Implement duplicate checking on Master methods
PiperOrigin-RevId: 232927103
This commit is contained in:
parent
8773ca70dc
commit
243118bbbf
@ -17,7 +17,6 @@ filegroup(
|
|||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_cuda_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
load("//tensorflow:tensorflow.bzl", "tf_copts")
|
||||||
|
|
||||||
# For platform specific build config
|
# For platform specific build config
|
||||||
@ -298,6 +297,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":call_options",
|
":call_options",
|
||||||
":message_wrappers",
|
":message_wrappers",
|
||||||
|
":request_id",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
],
|
],
|
||||||
@ -311,6 +311,7 @@ cc_library(
|
|||||||
":call_options",
|
":call_options",
|
||||||
":master_env",
|
":master_env",
|
||||||
":master_session",
|
":master_session",
|
||||||
|
":recent_request_ids",
|
||||||
":remote_device",
|
":remote_device",
|
||||||
":worker_cache",
|
":worker_cache",
|
||||||
":worker_interface",
|
":worker_interface",
|
||||||
@ -765,6 +766,7 @@ cc_library(
|
|||||||
srcs = ["recent_request_ids.cc"],
|
srcs = ["recent_request_ids.cc"],
|
||||||
hdrs = ["recent_request_ids.h"],
|
hdrs = ["recent_request_ids.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":message_wrappers",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
],
|
],
|
||||||
|
@ -65,7 +65,8 @@ Master::Master(MasterEnv* env, double session_gc_seconds)
|
|||||||
: env_(env),
|
: env_(env),
|
||||||
last_1000_steps_(1000),
|
last_1000_steps_(1000),
|
||||||
step_count_(0),
|
step_count_(0),
|
||||||
session_gc_seconds_(session_gc_seconds) {
|
session_gc_seconds_(session_gc_seconds),
|
||||||
|
recent_request_ids_(10000) {
|
||||||
// Right now, a master service must be co-located with a device.
|
// Right now, a master service must be co-located with a device.
|
||||||
// Otherwise, fetches do not work.
|
// Otherwise, fetches do not work.
|
||||||
CHECK(!env->local_devices.empty());
|
CHECK(!env->local_devices.empty());
|
||||||
@ -510,6 +511,12 @@ void Master::ExtendSession(const ExtendSessionRequest* req,
|
|||||||
|
|
||||||
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
||||||
PartialRunSetupResponse* resp, MyClosure done) {
|
PartialRunSetupResponse* resp, MyClosure done) {
|
||||||
|
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||||
|
"PartialRunSetup (Master)", *req);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto session = FindMasterSession(req->session_handle());
|
auto session = FindMasterSession(req->session_handle());
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
@ -525,6 +532,12 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req,
|
|||||||
|
|
||||||
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
|
void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
|
||||||
MutableRunStepResponseWrapper* resp, MyClosure done) {
|
MutableRunStepResponseWrapper* resp, MyClosure done) {
|
||||||
|
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||||
|
"RunStep (Master)", req);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto start_time = env_->env->NowMicros();
|
auto start_time = env_->env->NowMicros();
|
||||||
auto session = FindMasterSession(req->session_handle());
|
auto session = FindMasterSession(req->session_handle());
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
@ -664,6 +677,12 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
|
|||||||
|
|
||||||
void Master::MakeCallable(const MakeCallableRequest* req,
|
void Master::MakeCallable(const MakeCallableRequest* req,
|
||||||
MakeCallableResponse* resp, MyClosure done) {
|
MakeCallableResponse* resp, MyClosure done) {
|
||||||
|
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||||
|
"MakeCallable (Master)", *req);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto session = FindMasterSession(req->session_handle());
|
auto session = FindMasterSession(req->session_handle());
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
@ -681,6 +700,12 @@ void Master::MakeCallable(const MakeCallableRequest* req,
|
|||||||
|
|
||||||
void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
|
void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
|
||||||
RunCallableResponse* resp, MyClosure done) {
|
RunCallableResponse* resp, MyClosure done) {
|
||||||
|
Status s = recent_request_ids_.TrackUnique(req->request_id(),
|
||||||
|
"RunCallable (Master)", *req);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
auto session = FindMasterSession(req->session_handle());
|
auto session = FindMasterSession(req->session_handle());
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
done(errors::Aborted("Session ", req->session_handle(), " is not found."));
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/master_env.h"
|
#include "tensorflow/core/distributed_runtime/master_env.h"
|
||||||
#include "tensorflow/core/distributed_runtime/master_session.h"
|
#include "tensorflow/core/distributed_runtime/master_session.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
|
||||||
#include "tensorflow/core/lib/core/notification.h"
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -95,6 +96,9 @@ class Master {
|
|||||||
// closed automatically.
|
// closed automatically.
|
||||||
const double session_gc_seconds_;
|
const double session_gc_seconds_;
|
||||||
|
|
||||||
|
// Used to track ids for incoming requests so we can detect duplicates.
|
||||||
|
RecentRequestIds recent_request_ids_;
|
||||||
|
|
||||||
// Call CleanupAll on all workers.
|
// Call CleanupAll on all workers.
|
||||||
void CleanupWorkers(const ResetRequest& reset);
|
void CleanupWorkers(const ResetRequest& reset);
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/protobuf/master.pb.h"
|
#include "tensorflow/core/protobuf/master.pb.h"
|
||||||
@ -66,7 +67,9 @@ class MasterInterface {
|
|||||||
// The message returned from this method must only be used in a
|
// The message returned from this method must only be used in a
|
||||||
// `RunStep()` call on the same `MasterInterface` instance.
|
// `RunStep()` call on the same `MasterInterface` instance.
|
||||||
virtual MutableRunStepRequestWrapper* CreateRunStepRequest() {
|
virtual MutableRunStepRequestWrapper* CreateRunStepRequest() {
|
||||||
return new MutableProtoRunStepRequest;
|
MutableProtoRunStepRequest* ret = new MutableProtoRunStepRequest;
|
||||||
|
ret->request_.set_request_id(GetUniqueRequestId());
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a response object for use in calls to
|
// Returns a response object for use in calls to
|
||||||
|
@ -97,6 +97,10 @@ bool InMemoryRunStepRequest::store_errors_in_response_body() const {
|
|||||||
return store_errors_in_response_body_;
|
return store_errors_in_response_body_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 InMemoryRunStepRequest::request_id() const {
|
||||||
|
return 0; // no need to track request id for local version.
|
||||||
|
}
|
||||||
|
|
||||||
void InMemoryRunStepRequest::set_store_errors_in_response_body(
|
void InMemoryRunStepRequest::set_store_errors_in_response_body(
|
||||||
bool store_errors) {
|
bool store_errors) {
|
||||||
store_errors_in_response_body_ = store_errors;
|
store_errors_in_response_body_ = store_errors;
|
||||||
@ -210,6 +214,10 @@ void MutableProtoRunStepRequest::set_store_errors_in_response_body(
|
|||||||
request_.set_store_errors_in_response_body(store_errors);
|
request_.set_store_errors_in_response_body(store_errors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 MutableProtoRunStepRequest::request_id() const {
|
||||||
|
return request_.request_id();
|
||||||
|
}
|
||||||
|
|
||||||
string MutableProtoRunStepRequest::DebugString() const {
|
string MutableProtoRunStepRequest::DebugString() const {
|
||||||
return request_.DebugString();
|
return request_.DebugString();
|
||||||
}
|
}
|
||||||
@ -272,6 +280,8 @@ bool ProtoRunStepRequest::store_errors_in_response_body() const {
|
|||||||
return request_->store_errors_in_response_body();
|
return request_->store_errors_in_response_body();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 ProtoRunStepRequest::request_id() const { return request_->request_id(); }
|
||||||
|
|
||||||
string ProtoRunStepRequest::DebugString() const {
|
string ProtoRunStepRequest::DebugString() const {
|
||||||
return request_->DebugString();
|
return request_->DebugString();
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,8 @@ class RunStepRequestWrapper {
|
|||||||
// truncate long metadata messages.
|
// truncate long metadata messages.
|
||||||
virtual bool store_errors_in_response_body() const = 0;
|
virtual bool store_errors_in_response_body() const = 0;
|
||||||
|
|
||||||
|
virtual int64 request_id() const = 0;
|
||||||
|
|
||||||
// Returns a human-readable representation of this message for debugging.
|
// Returns a human-readable representation of this message for debugging.
|
||||||
virtual string DebugString() const = 0;
|
virtual string DebugString() const = 0;
|
||||||
|
|
||||||
@ -127,6 +129,7 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper {
|
|||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
const RunStepRequest& ToProto() const override;
|
const RunStepRequest& ToProto() const override;
|
||||||
bool store_errors_in_response_body() const override;
|
bool store_errors_in_response_body() const override;
|
||||||
|
int64 request_id() const override;
|
||||||
|
|
||||||
// MutableRunStepRequestWrapper methods.
|
// MutableRunStepRequestWrapper methods.
|
||||||
void set_session_handle(const string& handle) override;
|
void set_session_handle(const string& handle) override;
|
||||||
@ -177,6 +180,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
|
|||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
const RunStepRequest& ToProto() const override;
|
const RunStepRequest& ToProto() const override;
|
||||||
bool store_errors_in_response_body() const override;
|
bool store_errors_in_response_body() const override;
|
||||||
|
int64 request_id() const override;
|
||||||
|
|
||||||
// MutableRunStepRequestWrapper methods.
|
// MutableRunStepRequestWrapper methods.
|
||||||
void set_session_handle(const string& handle) override;
|
void set_session_handle(const string& handle) override;
|
||||||
@ -189,6 +193,7 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
RunStepRequest request_;
|
RunStepRequest request_;
|
||||||
|
friend class MasterInterface;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wrapper for immutable RunStep requests that use a non-owned
|
// Wrapper for immutable RunStep requests that use a non-owned
|
||||||
@ -216,6 +221,7 @@ class ProtoRunStepRequest : public RunStepRequestWrapper {
|
|||||||
string DebugString() const override;
|
string DebugString() const override;
|
||||||
const RunStepRequest& ToProto() const override;
|
const RunStepRequest& ToProto() const override;
|
||||||
bool store_errors_in_response_body() const override;
|
bool store_errors_in_response_body() const override;
|
||||||
|
int64 request_id() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const RunStepRequest* const request_; // Not owned.
|
const RunStepRequest* const request_; // Not owned.
|
||||||
|
@ -28,12 +28,10 @@ RecentRequestIds::RecentRequestIds(int num_tracked_request_ids)
|
|||||||
set_.reserve(num_tracked_request_ids);
|
set_.reserve(num_tracked_request_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RecentRequestIds::TrackUnique(int64 request_id,
|
bool RecentRequestIds::Insert(int64 request_id) {
|
||||||
const string& method_name,
|
|
||||||
const protobuf::Message& request) {
|
|
||||||
if (request_id == 0) {
|
if (request_id == 0) {
|
||||||
// For backwards compatibility, allow all requests with request_id 0.
|
// For backwards compatibility, allow all requests with request_id 0.
|
||||||
return Status::OK();
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -43,9 +41,7 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
|||||||
// request_id's age in the circular_buffer_ if it's tracked again. Strict
|
// request_id's age in the circular_buffer_ if it's tracked again. Strict
|
||||||
// LRU is not useful here because returning this error will close the
|
// LRU is not useful here because returning this error will close the
|
||||||
// current Session.
|
// current Session.
|
||||||
return errors::Aborted("The same ", method_name,
|
return false;
|
||||||
" request was received twice. ",
|
|
||||||
request.ShortDebugString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the oldest request_id from the set_. circular_buffer_ is
|
// Remove the oldest request_id from the set_. circular_buffer_ is
|
||||||
@ -54,7 +50,30 @@ Status RecentRequestIds::TrackUnique(int64 request_id,
|
|||||||
set_.erase(circular_buffer_[next_index_]);
|
set_.erase(circular_buffer_[next_index_]);
|
||||||
circular_buffer_[next_index_] = request_id;
|
circular_buffer_[next_index_] = request_id;
|
||||||
next_index_ = (next_index_ + 1) % circular_buffer_.size();
|
next_index_ = (next_index_ + 1) % circular_buffer_.size();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||||
|
const string& method_name,
|
||||||
|
const protobuf::Message& request) {
|
||||||
|
if (Insert(request_id)) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::Aborted("The same ", method_name,
|
||||||
|
" request was received twice. ",
|
||||||
|
request.ShortDebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Status RecentRequestIds::TrackUnique(int64 request_id,
|
||||||
|
const string& method_name,
|
||||||
|
const RunStepRequestWrapper* wrapper) {
|
||||||
|
if (Insert(request_id)) {
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
return errors::Aborted("The same ", method_name,
|
||||||
|
" request was received twice. ",
|
||||||
|
wrapper->ToProto().ShortDebugString());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/protobuf.h"
|
#include "tensorflow/core/platform/protobuf.h"
|
||||||
@ -58,8 +59,13 @@ class RecentRequestIds {
|
|||||||
// ShortDebugString are added to returned errors.
|
// ShortDebugString are added to returned errors.
|
||||||
Status TrackUnique(int64 request_id, const string& method_name,
|
Status TrackUnique(int64 request_id, const string& method_name,
|
||||||
const protobuf::Message& request);
|
const protobuf::Message& request);
|
||||||
|
// Overloaded versions of the above function for wrapped protos.
|
||||||
|
Status TrackUnique(int64 request_id, const string& method_name,
|
||||||
|
const RunStepRequestWrapper* wrapper);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
bool Insert(int64 request_id);
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// next_index_ indexes into circular_buffer_, and points to the next storage
|
// next_index_ indexes into circular_buffer_, and points to the next storage
|
||||||
// space to use. When the buffer is full, next_index_ points at the oldest
|
// space to use. When the buffer is full, next_index_ points at the oldest
|
||||||
|
@ -408,6 +408,7 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime:local_master",
|
"//tensorflow/core/distributed_runtime:local_master",
|
||||||
"//tensorflow/core/distributed_runtime:master_interface",
|
"//tensorflow/core/distributed_runtime:master_interface",
|
||||||
"//tensorflow/core/distributed_runtime:message_wrappers",
|
"//tensorflow/core/distributed_runtime:message_wrappers",
|
||||||
|
"//tensorflow/core/distributed_runtime:request_id",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||||
#include "tensorflow/core/distributed_runtime/local_master.h"
|
#include "tensorflow/core/distributed_runtime/local_master.h"
|
||||||
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
#include "tensorflow/core/distributed_runtime/master_interface.h"
|
||||||
|
#include "tensorflow/core/distributed_runtime/request_id.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
@ -312,6 +313,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
|
|||||||
for (const string& target : target_nodes) {
|
for (const string& target : target_nodes) {
|
||||||
req.add_target(target);
|
req.add_target(target);
|
||||||
}
|
}
|
||||||
|
req.set_request_id(GetUniqueRequestId());
|
||||||
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
||||||
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
|
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
|
||||||
*handle = resp.partial_run_handle();
|
*handle = resp.partial_run_handle();
|
||||||
@ -408,6 +410,7 @@ Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
|
|||||||
MakeCallableRequest req;
|
MakeCallableRequest req;
|
||||||
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
||||||
*req.mutable_options() = callable_options;
|
*req.mutable_options() = callable_options;
|
||||||
|
req.set_request_id(GetUniqueRequestId());
|
||||||
MakeCallableResponse resp;
|
MakeCallableResponse resp;
|
||||||
CallOptions call_options;
|
CallOptions call_options;
|
||||||
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
||||||
@ -423,6 +426,7 @@ Status GrpcSession::RunCallable(CallableHandle handle,
|
|||||||
RunCallableRequest req;
|
RunCallableRequest req;
|
||||||
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
|
||||||
req.set_handle(handle);
|
req.set_handle(handle);
|
||||||
|
req.set_request_id(GetUniqueRequestId());
|
||||||
for (const Tensor& feed : feed_tensors) {
|
for (const Tensor& feed : feed_tensors) {
|
||||||
feed.AsProtoTensorContent(req.mutable_feed()->Add());
|
feed.AsProtoTensorContent(req.mutable_feed()->Add());
|
||||||
}
|
}
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package tensorflow;
|
package tensorflow;
|
||||||
|
|
||||||
option cc_enable_arenas = true;
|
option cc_enable_arenas = true;
|
||||||
option java_outer_classname = "DistributedRuntimeProtos";
|
option java_outer_classname = "DistributedRuntimeProtos";
|
||||||
option java_multiple_files = true;
|
option java_multiple_files = true;
|
||||||
option java_package = "org.tensorflow.distruntime";
|
option java_package = "org.tensorflow.distruntime";
|
||||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf";
|
|
||||||
|
// add go_package externally with copybara
|
||||||
import "tensorflow/core/framework/device_attributes.proto";
|
import "tensorflow/core/framework/device_attributes.proto";
|
||||||
import "tensorflow/core/framework/graph.proto";
|
import "tensorflow/core/framework/graph.proto";
|
||||||
import "tensorflow/core/framework/tensor.proto";
|
import "tensorflow/core/framework/tensor.proto";
|
||||||
@ -138,6 +140,11 @@ message RunStepRequest {
|
|||||||
// response body. This is a workaround since the RPC subsystem may
|
// response body. This is a workaround since the RPC subsystem may
|
||||||
// truncate long metadata messages.
|
// truncate long metadata messages.
|
||||||
bool store_errors_in_response_body = 7;
|
bool store_errors_in_response_body = 7;
|
||||||
|
|
||||||
|
// Unique identifier for this request. Every RunStepRequest must
|
||||||
|
// have a unique request_id, and retried RunStepRequest must have
|
||||||
|
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||||
|
int64 request_id = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RunStepResponse {
|
message RunStepResponse {
|
||||||
@ -183,6 +190,11 @@ message PartialRunSetupRequest {
|
|||||||
// Target Nodes. A list of node names. The named nodes will be run in future
|
// Target Nodes. A list of node names. The named nodes will be run in future
|
||||||
// steps, but their outputs will not be fetched.
|
// steps, but their outputs will not be fetched.
|
||||||
repeated string target = 4;
|
repeated string target = 4;
|
||||||
|
|
||||||
|
// Unique identifier for this request. Every PartialRunSetupRequest must
|
||||||
|
// have a unique request_id, and retried PartialRunSetupRequest must have
|
||||||
|
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||||
|
int64 request_id = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message PartialRunSetupResponse {
|
message PartialRunSetupResponse {
|
||||||
@ -204,8 +216,7 @@ message CloseSessionRequest {
|
|||||||
string session_handle = 1;
|
string session_handle = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message CloseSessionResponse {
|
message CloseSessionResponse {}
|
||||||
}
|
|
||||||
|
|
||||||
// Reset() allows misbehaving or slow sessions to be aborted and closed, and
|
// Reset() allows misbehaving or slow sessions to be aborted and closed, and
|
||||||
// causes their resources eventually to be released. Reset() does not wait
|
// causes their resources eventually to be released. Reset() does not wait
|
||||||
@ -237,8 +248,7 @@ message ResetRequest {
|
|||||||
repeated string device_filters = 2;
|
repeated string device_filters = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ResetResponse {
|
message ResetResponse {}
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
//
|
//
|
||||||
@ -279,6 +289,11 @@ message MakeCallableRequest {
|
|||||||
|
|
||||||
// Options that define the behavior of the created callable.
|
// Options that define the behavior of the created callable.
|
||||||
CallableOptions options = 2;
|
CallableOptions options = 2;
|
||||||
|
|
||||||
|
// Unique identifier for this request. Every MakeCallableRequest must
|
||||||
|
// have a unique request_id, and retried MakeCallableRequest must have
|
||||||
|
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||||
|
int64 request_id = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message MakeCallableResponse {
|
message MakeCallableResponse {
|
||||||
@ -303,6 +318,11 @@ message RunCallableRequest {
|
|||||||
// Values of the tensors passed as arguments to the callable, in the order
|
// Values of the tensors passed as arguments to the callable, in the order
|
||||||
// defined in the CallableOptions.feed field passed to MakeCallable.
|
// defined in the CallableOptions.feed field passed to MakeCallable.
|
||||||
repeated TensorProto feed = 3;
|
repeated TensorProto feed = 3;
|
||||||
|
|
||||||
|
// Unique identifier for this request. Every RunCallableRequest must
|
||||||
|
// have a unique request_id, and retried RunCallableRequest must have
|
||||||
|
// the same request_id. If request_id is zero, retry detection is disabled.
|
||||||
|
int64 request_id = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message RunCallableResponse {
|
message RunCallableResponse {
|
||||||
@ -330,5 +350,4 @@ message ReleaseCallableRequest {
|
|||||||
int64 handle = 2;
|
int64 handle = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ReleaseCallableResponse {
|
message ReleaseCallableResponse {}
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user