Partial run support for GRPC runtime.
Tests for distributed partial run added in session_test.py. Change: 142477604
This commit is contained in:
parent
35ea93ba20
commit
a02510a260
@ -1314,9 +1314,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
|
||||
LOG(ERROR) << "Cleanup partition error: " << s;
|
||||
}
|
||||
rcg->Unref();
|
||||
mutex_lock l(mu_);
|
||||
partial_runs_.erase(prun_handle);
|
||||
});
|
||||
mutex_lock l(mu_);
|
||||
partial_runs_.erase(prun_handle);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
@ -104,6 +104,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
ENQUEUE_REQUEST(CreateSession, true);
|
||||
ENQUEUE_REQUEST(ExtendSession, false);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
ENQUEUE_REQUEST(PartialRunSetup, false);
|
||||
ENQUEUE_REQUEST(RunStep, true);
|
||||
}
|
||||
ENQUEUE_REQUEST(CloseSession, false);
|
||||
@ -158,6 +159,16 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
ENQUEUE_REQUEST(ExtendSession, false);
|
||||
}
|
||||
|
||||
// RPC handler for setting up a partial run call.
|
||||
void PartialRunSetupHandler(
|
||||
MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
|
||||
master_impl_->PartialRunSetup(&call->request, &call->response,
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(PartialRunSetup, false);
|
||||
}
|
||||
|
||||
// RPC handler for running one step in a session.
|
||||
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
|
||||
CallOptions* call_opts = new CallOptions;
|
||||
|
@ -31,6 +31,7 @@ namespace grpc {
|
||||
static const char* grpcMasterService_method_names[] = {
|
||||
"/tensorflow.MasterService/CreateSession",
|
||||
"/tensorflow.MasterService/ExtendSession",
|
||||
"/tensorflow.MasterService/PartialRunSetup",
|
||||
"/tensorflow.MasterService/RunStep",
|
||||
"/tensorflow.MasterService/CloseSession",
|
||||
"/tensorflow.MasterService/ListDevices",
|
||||
@ -51,13 +52,15 @@ MasterService::Stub::Stub(
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_ExtendSession_(grpcMasterService_method_names[1],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_RunStep_(grpcMasterService_method_names[2],
|
||||
rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_RunStep_(grpcMasterService_method_names[3],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_CloseSession_(grpcMasterService_method_names[3],
|
||||
rpcmethod_CloseSession_(grpcMasterService_method_names[4],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_ListDevices_(grpcMasterService_method_names[4],
|
||||
rpcmethod_ListDevices_(grpcMasterService_method_names[5],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel),
|
||||
rpcmethod_Reset_(grpcMasterService_method_names[5],
|
||||
rpcmethod_Reset_(grpcMasterService_method_names[6],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel) {}
|
||||
|
||||
::grpc::Status MasterService::Stub::CreateSession(
|
||||
@ -74,6 +77,13 @@ MasterService::Stub::Stub(
|
||||
context, request, response);
|
||||
}
|
||||
|
||||
::grpc::Status MasterService::Stub::PartialRunSetup(
|
||||
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
|
||||
PartialRunSetupResponse* response) {
|
||||
return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_PartialRunSetup_,
|
||||
context, request, response);
|
||||
}
|
||||
|
||||
::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context,
|
||||
const RunStepRequest& request,
|
||||
RunStepResponse* response) {
|
||||
@ -103,7 +113,7 @@ MasterService::Stub::Stub(
|
||||
}
|
||||
|
||||
MasterService::AsyncService::AsyncService() {
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
AddMethod(new ::grpc::RpcServiceMethod(grpcMasterService_method_names[i],
|
||||
::grpc::RpcMethod::NORMAL_RPC,
|
||||
nullptr));
|
||||
|
@ -64,6 +64,9 @@ class MasterService GRPC_FINAL {
|
||||
virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context,
|
||||
const ExtendSessionRequest& request,
|
||||
ExtendSessionResponse* response) = 0;
|
||||
virtual ::grpc::Status PartialRunSetup(
|
||||
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
|
||||
PartialRunSetupResponse* response) = 0;
|
||||
virtual ::grpc::Status RunStep(::grpc::ClientContext* context,
|
||||
const RunStepRequest& request,
|
||||
RunStepResponse* response) = 0;
|
||||
@ -86,6 +89,9 @@ class MasterService GRPC_FINAL {
|
||||
::grpc::Status ExtendSession(::grpc::ClientContext* context,
|
||||
const ExtendSessionRequest& request,
|
||||
ExtendSessionResponse* response) GRPC_OVERRIDE;
|
||||
::grpc::Status PartialRunSetup(
|
||||
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
|
||||
PartialRunSetupResponse* response) GRPC_OVERRIDE;
|
||||
::grpc::Status RunStep(::grpc::ClientContext* context,
|
||||
const RunStepRequest& request,
|
||||
RunStepResponse* response) GRPC_OVERRIDE;
|
||||
@ -103,6 +109,7 @@ class MasterService GRPC_FINAL {
|
||||
std::shared_ptr< ::grpc::ChannelInterface> channel_;
|
||||
const ::grpc::RpcMethod rpcmethod_CreateSession_;
|
||||
const ::grpc::RpcMethod rpcmethod_ExtendSession_;
|
||||
const ::grpc::RpcMethod rpcmethod_PartialRunSetup_;
|
||||
const ::grpc::RpcMethod rpcmethod_RunStep_;
|
||||
const ::grpc::RpcMethod rpcmethod_CloseSession_;
|
||||
const ::grpc::RpcMethod rpcmethod_ListDevices_;
|
||||
@ -132,12 +139,20 @@ class MasterService GRPC_FINAL {
|
||||
::grpc::Service::RequestAsyncUnary(1, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
void RequestPartialRunSetup(
|
||||
::grpc::ServerContext* context, PartialRunSetupRequest* request,
|
||||
::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(2, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
void RequestRunStep(
|
||||
::grpc::ServerContext* context, RunStepRequest* request,
|
||||
::grpc::ServerAsyncResponseWriter<RunStepResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(2, context, request, response,
|
||||
::grpc::Service::RequestAsyncUnary(3, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
void RequestCloseSession(
|
||||
@ -145,7 +160,7 @@ class MasterService GRPC_FINAL {
|
||||
::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(3, context, request, response,
|
||||
::grpc::Service::RequestAsyncUnary(4, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
void RequestListDevices(
|
||||
@ -153,7 +168,7 @@ class MasterService GRPC_FINAL {
|
||||
::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(4, context, request, response,
|
||||
::grpc::Service::RequestAsyncUnary(5, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
void RequestReset(
|
||||
@ -161,7 +176,7 @@ class MasterService GRPC_FINAL {
|
||||
::grpc::ServerAsyncResponseWriter<ResetResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(5, context, request, response,
|
||||
::grpc::Service::RequestAsyncUnary(6, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
};
|
||||
|
@ -52,6 +52,15 @@ class GrpcRemoteMaster : public MasterInterface {
|
||||
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status PartialRunSetup(CallOptions* call_options,
|
||||
const PartialRunSetupRequest* request,
|
||||
PartialRunSetupResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
SetDeadline(&ctx, call_options->GetTimeout());
|
||||
return FromGrpcStatus(stub_->PartialRunSetup(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status RunStep(CallOptions* call_options, const RunStepRequest* request,
|
||||
RunStepResponse* response) override {
|
||||
::grpc::ClientContext ctx;
|
||||
|
@ -162,18 +162,22 @@ Status GrpcSession::Extend(const RunOptions& run_options,
|
||||
return ExtendImpl(&call_options, graph);
|
||||
}
|
||||
|
||||
Status GrpcSession::Run(const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
Status GrpcSession::RunHelper(
|
||||
const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata, const string& prun_handle) {
|
||||
// Convert to proto
|
||||
RunStepRequest req;
|
||||
RunStepResponse resp;
|
||||
|
||||
*req.mutable_options() = run_options;
|
||||
|
||||
if (!prun_handle.empty()) {
|
||||
req.set_partial_run_handle(prun_handle);
|
||||
}
|
||||
|
||||
for (const auto& it : inputs) {
|
||||
Tensor input_tensor = it.second;
|
||||
auto feed = req.add_feed();
|
||||
@ -225,6 +229,16 @@ Status GrpcSession::Run(const RunOptions& run_options,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcSession::Run(const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs,
|
||||
RunMetadata* run_metadata) {
|
||||
return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
|
||||
outputs, run_metadata, /* prun_handle */ "");
|
||||
}
|
||||
|
||||
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
@ -252,14 +266,41 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
string* handle) {
|
||||
return errors::Internal("Partial run is not supported for remote session.");
|
||||
// Convert to proto
|
||||
PartialRunSetupRequest req;
|
||||
PartialRunSetupResponse resp;
|
||||
CallOptions call_options;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (handle_.empty()) {
|
||||
return errors::InvalidArgument("A session is not created yet....");
|
||||
}
|
||||
|
||||
req.set_session_handle(handle_);
|
||||
}
|
||||
for (const string& feed : input_names) {
|
||||
req.add_feed(feed);
|
||||
}
|
||||
for (const string& fetch : output_names) {
|
||||
req.add_fetch(fetch);
|
||||
}
|
||||
for (const string& target : target_nodes) {
|
||||
req.add_target(target);
|
||||
}
|
||||
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
|
||||
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
|
||||
*handle = resp.partial_run_handle();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcSession::PRun(const string& handle,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
std::vector<Tensor>* outputs) {
|
||||
return errors::Internal("Partial run is not supported for remote session.");
|
||||
RunOptions run_options;
|
||||
run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
|
||||
return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
|
||||
/* run_metadata */ nullptr, handle);
|
||||
}
|
||||
|
||||
Status GrpcSession::Close() {
|
||||
|
@ -110,6 +110,13 @@ class GrpcSession : public Session {
|
||||
// The current version of the graph.
|
||||
int64 current_graph_version_ GUARDED_BY(mu_);
|
||||
|
||||
Status RunHelper(const RunOptions& run_options,
|
||||
const std::vector<std::pair<string, Tensor> >& inputs,
|
||||
const std::vector<string>& output_tensor_names,
|
||||
const std::vector<string>& target_node_names,
|
||||
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
|
||||
const string& prun_handle);
|
||||
|
||||
Status RunProto(CallOptions* call_options, RunStepRequest* req,
|
||||
RunStepResponse* resp);
|
||||
|
||||
|
@ -163,6 +163,47 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
mutex mu_;
|
||||
CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
|
||||
struct PartialRunState {
|
||||
CancellationManager* cancellation_manager;
|
||||
Notification executor_done;
|
||||
|
||||
explicit PartialRunState(CancellationManager* cm)
|
||||
: cancellation_manager(cm) {}
|
||||
};
|
||||
struct PairHash {
|
||||
std::size_t operator()(std::pair<string, int> const& p) const {
|
||||
return Hash64Combine(std::hash<string>()(p.first),
|
||||
std::hash<int>()(p.second));
|
||||
}
|
||||
};
|
||||
std::unordered_map<std::pair<string, int>, std::unique_ptr<PartialRunState>,
|
||||
PairHash>
|
||||
partial_runs_ GUARDED_BY(mu_);
|
||||
|
||||
PartialRunState* FindPartialRun(const string& graph_handle, int step_id) {
|
||||
std::pair<string, int> k(graph_handle, step_id);
|
||||
PartialRunState* prun_state = nullptr;
|
||||
mutex_lock l(mu_);
|
||||
auto it = partial_runs_.find(k);
|
||||
if (it != partial_runs_.end()) {
|
||||
prun_state = it->second.get();
|
||||
}
|
||||
return prun_state;
|
||||
}
|
||||
|
||||
void InsertPartialRunLocked(const string& graph_handle, int step_id,
|
||||
PartialRunState* partial_run_state)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::pair<string, int> k(graph_handle, step_id);
|
||||
partial_runs_.emplace(
|
||||
std::make_pair(k, std::unique_ptr<PartialRunState>(partial_run_state)));
|
||||
}
|
||||
|
||||
void RemovePartialRun(const string& graph_handle, int step_id) {
|
||||
std::pair<string, int> k(graph_handle, step_id);
|
||||
mutex_lock l(mu_);
|
||||
partial_runs_.erase(partial_runs_.find(k));
|
||||
}
|
||||
|
||||
mutex shutdown_mu_;
|
||||
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
||||
@ -225,7 +266,11 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
}
|
||||
|
||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
||||
if (call->request.is_partial()) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoPartialRunGraph(call); });
|
||||
} else {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
||||
}
|
||||
ENQUEUE_REQUEST(RunGraph, true);
|
||||
}
|
||||
|
||||
@ -294,10 +339,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
|
||||
Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
|
||||
GraphMgr::NamedTensors* out) {
|
||||
if (req.is_partial()) {
|
||||
return errors::Unimplemented(
|
||||
"Partial run not implemented for GRPC worker service");
|
||||
}
|
||||
if (req.send_size() > 0) {
|
||||
// TODO(zhifengc): Let the caller decide on which device to
|
||||
// allocate the tensor.
|
||||
@ -386,6 +427,110 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
});
|
||||
}
|
||||
|
||||
// TODO(suharshs): Add stats collection support to partial run.
|
||||
void DoPartialRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
const int64 step_id = call->request.step_id();
|
||||
const string& graph_handle = call->request.graph_handle();
|
||||
TRACEPRINTF("PartialRunGraph: %lld", step_id);
|
||||
GraphMgr::NamedTensors in;
|
||||
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
|
||||
Status s = PrepareRunGraph(call->request, &in, out);
|
||||
auto finish = [this, call, out](const Status& s) {
|
||||
delete out;
|
||||
call->ClearCancelCallback();
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
};
|
||||
if (!s.ok()) {
|
||||
finish(s);
|
||||
return;
|
||||
}
|
||||
|
||||
PartialRunState* partial_run_state = FindPartialRun(graph_handle, step_id);
|
||||
|
||||
CancellationManager* cm = nullptr;
|
||||
// If this is a new partial run call we need to create a new cancellation
|
||||
// manager.
|
||||
// Otherwise we use the cancellation manager stored in the found partial
|
||||
// run state.
|
||||
if (partial_run_state == nullptr) {
|
||||
cm = new CancellationManager;
|
||||
} else {
|
||||
cm = partial_run_state->cancellation_manager;
|
||||
}
|
||||
|
||||
// Before we start doing anything, we set the RPC cancellation.
|
||||
call->SetCancelCallback([this, cm, step_id]() {
|
||||
cm->StartCancel();
|
||||
AbortStep(step_id);
|
||||
});
|
||||
|
||||
// If this is a new partial run request, the request will need to start the
|
||||
// executors.
|
||||
if (partial_run_state == nullptr) {
|
||||
CancellationToken token;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Insert the new partial run into the partial_runs_ map.
|
||||
partial_run_state = new PartialRunState(cm);
|
||||
InsertPartialRunLocked(graph_handle, step_id, partial_run_state);
|
||||
token = cancellation_manager_->get_cancellation_token();
|
||||
cancellation_manager_->RegisterCallback(token,
|
||||
[cm]() { cm->StartCancel(); });
|
||||
}
|
||||
env_->graph_mgr->ExecuteAsync(
|
||||
graph_handle, step_id, call->request.exec_opts(),
|
||||
nullptr /* collector */, nullptr /* cost_graph */, cm, in,
|
||||
[this, step_id, graph_handle, token, partial_run_state](Status s) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
cancellation_manager_->DeregisterCallback(token);
|
||||
}
|
||||
partial_run_state->executor_done.Notify();
|
||||
// TODO(suharshs): Propagate the status once we keep state for
|
||||
// each partial run call.
|
||||
});
|
||||
} else {
|
||||
// Send the partial run's new inputs.
|
||||
s = env_->graph_mgr->SendInputs(step_id, in);
|
||||
if (!s.ok()) {
|
||||
finish(s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Receive the partial run's outputs.
|
||||
s = env_->graph_mgr->RecvOutputs(step_id, out);
|
||||
if (!s.ok()) {
|
||||
finish(s);
|
||||
return;
|
||||
}
|
||||
|
||||
// Construct and return the resp.
|
||||
for (const auto& p : *out) {
|
||||
const string& key = p.first;
|
||||
const Tensor& val = p.second;
|
||||
auto* recv = call->response.add_recv();
|
||||
recv->set_key(key);
|
||||
// TODO(zhifengc): Deal with gpu -> cpu copy.
|
||||
TensorProto* proto = recv->mutable_val();
|
||||
val.AsProtoField(proto);
|
||||
}
|
||||
|
||||
// If this is the last partial run request we must also wait for the entire
|
||||
// graph execution to be completed.
|
||||
if (call->request.is_last_partial_run()) {
|
||||
partial_run_state->executor_done.WaitForNotification();
|
||||
RemovePartialRun(graph_handle, step_id);
|
||||
// Before deleting the cancellation manager on the final call, ensure
|
||||
// that we clear the RPC cancel callback, which has a reference to the
|
||||
// cancellation manager.
|
||||
call->ClearCancelCallback();
|
||||
delete cm;
|
||||
}
|
||||
|
||||
finish(s);
|
||||
}
|
||||
|
||||
// Helper for RecvTensor. Validates "key" and returns the source
|
||||
// device in "*src_dev".
|
||||
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
|
||||
|
@ -2463,6 +2463,7 @@ py_test(
|
||||
":platform_test",
|
||||
":session",
|
||||
":state_ops",
|
||||
":training",
|
||||
":util",
|
||||
":variables",
|
||||
],
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
@ -1322,91 +1323,121 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
sess_2.run(c_1.op)
|
||||
self.assertEqual(2.0, sess_2.run(c_2))
|
||||
|
||||
def testPartialRun(self):
|
||||
with session.Session() as sess:
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
def runTestPartialRun(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 17
|
||||
res = sess.partial_run(h, r2, feed_dict={c: temp})
|
||||
self.assertEqual(153, res)
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 17
|
||||
res = sess.partial_run(h, r2, feed_dict={c: temp})
|
||||
self.assertEqual(153, res)
|
||||
|
||||
# Call again on the same graph.
|
||||
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 18
|
||||
res = sess.partial_run(h2, r2, feed_dict={c: temp})
|
||||
self.assertEqual(162, res)
|
||||
# Call again on the same graph.
|
||||
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 18
|
||||
res = sess.partial_run(h2, r2, feed_dict={c: temp})
|
||||
self.assertEqual(162, res)
|
||||
|
||||
def testPartialRunIncomplete(self):
|
||||
with session.Session() as sess:
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
def runTestPartialRunIncomplete(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
|
||||
def testConcurrentPartialRun(self):
|
||||
with session.Session() as sess:
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
def runTestConcurrentPartialRun(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.mul(r1, c)
|
||||
|
||||
h1 = sess.partial_run_setup([r1], [a, b, c])
|
||||
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 19
|
||||
res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
|
||||
self.assertEqual(66, res)
|
||||
res = sess.partial_run(h2, r2, feed_dict={c: 7})
|
||||
self.assertEqual(462, res)
|
||||
h1 = sess.partial_run_setup([r1], [a, b, c])
|
||||
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
|
||||
self.assertEqual(3, res)
|
||||
temp = res * 19
|
||||
res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
|
||||
self.assertEqual(66, res)
|
||||
res = sess.partial_run(h2, r2, feed_dict={c: 7})
|
||||
self.assertEqual(462, res)
|
||||
|
||||
def testManyPartialRun(self):
|
||||
with session.Session() as sess:
|
||||
steps = 200
|
||||
inputs = []
|
||||
outputs = []
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
for i in xrange(steps):
|
||||
inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
|
||||
a = math_ops.mul(a, inputs[i])
|
||||
outputs.append(a)
|
||||
def runTestManyPartialRun(self, sess):
|
||||
steps = 200
|
||||
inputs = []
|
||||
outputs = []
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
for i in xrange(steps):
|
||||
inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
|
||||
a = math_ops.mul(a, inputs[i])
|
||||
outputs.append(a)
|
||||
|
||||
h = sess.partial_run_setup(outputs, inputs)
|
||||
for i in xrange(steps):
|
||||
res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
|
||||
self.assertEqual(2.0, res)
|
||||
h = sess.partial_run_setup(outputs, inputs)
|
||||
for i in xrange(steps):
|
||||
res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
|
||||
self.assertEqual(2.0, res)
|
||||
|
||||
feed_dict = {}
|
||||
for i in xrange(steps):
|
||||
feed_dict[inputs[i]] = 1.0
|
||||
res = sess.run(outputs, feed_dict)
|
||||
self.assertEqual(steps, len(res))
|
||||
self.assertEqual(2.0, res[-1])
|
||||
feed_dict = {}
|
||||
for i in xrange(steps):
|
||||
feed_dict[inputs[i]] = 1.0
|
||||
res = sess.run(outputs, feed_dict)
|
||||
self.assertEqual(steps, len(res))
|
||||
self.assertEqual(2.0, res[-1])
|
||||
|
||||
def testRunAndPartialRun(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
b = a * 2
|
||||
c = b * 3
|
||||
r1 = sess.run([b, c])
|
||||
h = sess.partial_run_setup([b, c], [])
|
||||
r2 = sess.partial_run(h, [b, c])
|
||||
self.assertEqual(r1, r2)
|
||||
def runTestRunAndPartialRun(self, sess):
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
b = a * 2
|
||||
c = b * 3
|
||||
r1 = sess.run([b, c])
|
||||
h = sess.partial_run_setup([b, c], [])
|
||||
r2 = sess.partial_run(h, [b, c])
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def testPartialRunDirect(self):
|
||||
self.runTestPartialRun(session.Session())
|
||||
|
||||
def testPartialRunIncompleteDirect(self):
|
||||
self.runTestPartialRunIncomplete(session.Session())
|
||||
|
||||
def testConcurrentPartialRunDirect(self):
|
||||
self.runTestConcurrentPartialRun(session.Session())
|
||||
|
||||
def testManyPartialRunDirect(self):
|
||||
self.runTestManyPartialRun(session.Session())
|
||||
|
||||
def testRunAndPartialRunDirect(self):
|
||||
self.runTestRunAndPartialRun(session.Session())
|
||||
|
||||
def testPartialRunDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRun(session.Session(server.target))
|
||||
|
||||
def testPartialRunIncompleteDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRunIncomplete(session.Session(server.target))
|
||||
|
||||
def testConcurrentPartialRunDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestConcurrentPartialRun(session.Session(server.target))
|
||||
|
||||
def testManyPartialRunDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestManyPartialRun(session.Session(server.target))
|
||||
|
||||
def testRunAndPartialRunDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestRunAndPartialRun(session.Session(server.target))
|
||||
|
||||
def testFeedDictKeyException(self):
|
||||
with session.Session() as sess:
|
||||
|
Loading…
x
Reference in New Issue
Block a user