Partial run support for GRPC runtime.

Tests for distributed partial run added in session_test.py.
Change: 142477604
This commit is contained in:
Suharsh Sivakumar 2016-12-19 13:02:43 -08:00 committed by TensorFlower Gardener
parent 35ea93ba20
commit a02510a260
10 changed files with 369 additions and 99 deletions

View File

@ -1314,9 +1314,9 @@ Status MasterSession::DoPartialRun(CallOptions* opts, const RunStepRequest* req,
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
});
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
}
return s;
}

View File

@ -104,6 +104,7 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(CreateSession, true);
ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(PartialRunSetup, false);
ENQUEUE_REQUEST(RunStep, true);
}
ENQUEUE_REQUEST(CloseSession, false);
@ -158,6 +159,16 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(ExtendSession, false);
}
// RPC handler for setting up a partial run call.
void PartialRunSetupHandler(
MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
master_impl_->PartialRunSetup(&call->request, &call->response,
[call](const Status& status) {
call->SendResponse(ToGrpcStatus(status));
});
ENQUEUE_REQUEST(PartialRunSetup, false);
}
// RPC handler for running one step in a session.
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
CallOptions* call_opts = new CallOptions;

View File

@ -31,6 +31,7 @@ namespace grpc {
static const char* grpcMasterService_method_names[] = {
"/tensorflow.MasterService/CreateSession",
"/tensorflow.MasterService/ExtendSession",
"/tensorflow.MasterService/PartialRunSetup",
"/tensorflow.MasterService/RunStep",
"/tensorflow.MasterService/CloseSession",
"/tensorflow.MasterService/ListDevices",
@ -51,13 +52,15 @@ MasterService::Stub::Stub(
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_ExtendSession_(grpcMasterService_method_names[1],
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_RunStep_(grpcMasterService_method_names[2],
rpcmethod_PartialRunSetup_(grpcMasterService_method_names[2],
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_RunStep_(grpcMasterService_method_names[3],
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_CloseSession_(grpcMasterService_method_names[3],
rpcmethod_CloseSession_(grpcMasterService_method_names[4],
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_ListDevices_(grpcMasterService_method_names[4],
rpcmethod_ListDevices_(grpcMasterService_method_names[5],
::grpc::RpcMethod::NORMAL_RPC, channel),
rpcmethod_Reset_(grpcMasterService_method_names[5],
rpcmethod_Reset_(grpcMasterService_method_names[6],
::grpc::RpcMethod::NORMAL_RPC, channel) {}
::grpc::Status MasterService::Stub::CreateSession(
@ -74,6 +77,13 @@ MasterService::Stub::Stub(
context, request, response);
}
::grpc::Status MasterService::Stub::PartialRunSetup(
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
PartialRunSetupResponse* response) {
return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_PartialRunSetup_,
context, request, response);
}
::grpc::Status MasterService::Stub::RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) {
@ -103,7 +113,7 @@ MasterService::Stub::Stub(
}
MasterService::AsyncService::AsyncService() {
for (int i = 0; i < 6; ++i) {
for (int i = 0; i < 7; ++i) {
AddMethod(new ::grpc::RpcServiceMethod(grpcMasterService_method_names[i],
::grpc::RpcMethod::NORMAL_RPC,
nullptr));

View File

@ -64,6 +64,9 @@ class MasterService GRPC_FINAL {
virtual ::grpc::Status ExtendSession(::grpc::ClientContext* context,
const ExtendSessionRequest& request,
ExtendSessionResponse* response) = 0;
virtual ::grpc::Status PartialRunSetup(
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
PartialRunSetupResponse* response) = 0;
virtual ::grpc::Status RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) = 0;
@ -86,6 +89,9 @@ class MasterService GRPC_FINAL {
::grpc::Status ExtendSession(::grpc::ClientContext* context,
const ExtendSessionRequest& request,
ExtendSessionResponse* response) GRPC_OVERRIDE;
::grpc::Status PartialRunSetup(
::grpc::ClientContext* context, const PartialRunSetupRequest& request,
PartialRunSetupResponse* response) GRPC_OVERRIDE;
::grpc::Status RunStep(::grpc::ClientContext* context,
const RunStepRequest& request,
RunStepResponse* response) GRPC_OVERRIDE;
@ -103,6 +109,7 @@ class MasterService GRPC_FINAL {
std::shared_ptr< ::grpc::ChannelInterface> channel_;
const ::grpc::RpcMethod rpcmethod_CreateSession_;
const ::grpc::RpcMethod rpcmethod_ExtendSession_;
const ::grpc::RpcMethod rpcmethod_PartialRunSetup_;
const ::grpc::RpcMethod rpcmethod_RunStep_;
const ::grpc::RpcMethod rpcmethod_CloseSession_;
const ::grpc::RpcMethod rpcmethod_ListDevices_;
@ -132,12 +139,20 @@ class MasterService GRPC_FINAL {
::grpc::Service::RequestAsyncUnary(1, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestPartialRunSetup(
::grpc::ServerContext* context, PartialRunSetupRequest* request,
::grpc::ServerAsyncResponseWriter<PartialRunSetupResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
::grpc::Service::RequestAsyncUnary(2, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestRunStep(
::grpc::ServerContext* context, RunStepRequest* request,
::grpc::ServerAsyncResponseWriter<RunStepResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
::grpc::Service::RequestAsyncUnary(2, context, request, response,
::grpc::Service::RequestAsyncUnary(3, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestCloseSession(
@ -145,7 +160,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<CloseSessionResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
::grpc::Service::RequestAsyncUnary(3, context, request, response,
::grpc::Service::RequestAsyncUnary(4, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestListDevices(
@ -153,7 +168,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<ListDevicesResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
::grpc::Service::RequestAsyncUnary(4, context, request, response,
::grpc::Service::RequestAsyncUnary(5, context, request, response,
new_call_cq, notification_cq, tag);
}
void RequestReset(
@ -161,7 +176,7 @@ class MasterService GRPC_FINAL {
::grpc::ServerAsyncResponseWriter<ResetResponse>* response,
::grpc::CompletionQueue* new_call_cq,
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
::grpc::Service::RequestAsyncUnary(5, context, request, response,
::grpc::Service::RequestAsyncUnary(6, context, request, response,
new_call_cq, notification_cq, tag);
}
};

View File

@ -52,6 +52,15 @@ class GrpcRemoteMaster : public MasterInterface {
return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
}
Status PartialRunSetup(CallOptions* call_options,
const PartialRunSetupRequest* request,
PartialRunSetupResponse* response) override {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->PartialRunSetup(&ctx, *request, response));
}
Status RunStep(CallOptions* call_options, const RunStepRequest* request,
RunStepResponse* response) override {
::grpc::ClientContext ctx;

View File

@ -162,18 +162,22 @@ Status GrpcSession::Extend(const RunOptions& run_options,
return ExtendImpl(&call_options, graph);
}
Status GrpcSession::Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
Status GrpcSession::RunHelper(
const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
RunMetadata* run_metadata, const string& prun_handle) {
// Convert to proto
RunStepRequest req;
RunStepResponse resp;
*req.mutable_options() = run_options;
if (!prun_handle.empty()) {
req.set_partial_run_handle(prun_handle);
}
for (const auto& it : inputs) {
Tensor input_tensor = it.second;
auto feed = req.add_feed();
@ -225,6 +229,16 @@ Status GrpcSession::Run(const RunOptions& run_options,
return Status::OK();
}
Status GrpcSession::Run(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs,
RunMetadata* run_metadata) {
return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
outputs, run_metadata, /* prun_handle */ "");
}
Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
@ -252,14 +266,41 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
string* handle) {
return errors::Internal("Partial run is not supported for remote session.");
// Convert to proto
PartialRunSetupRequest req;
PartialRunSetupResponse resp;
CallOptions call_options;
{
mutex_lock l(mu_);
if (handle_.empty()) {
return errors::InvalidArgument("A session is not created yet....");
}
req.set_session_handle(handle_);
}
for (const string& feed : input_names) {
req.add_feed(feed);
}
for (const string& fetch : output_names) {
req.add_fetch(fetch);
}
for (const string& target : target_nodes) {
req.add_target(target);
}
call_options.SetTimeout(options_.config.operation_timeout_in_ms());
TF_RETURN_IF_ERROR(master_->PartialRunSetup(&call_options, &req, &resp));
*handle = resp.partial_run_handle();
return Status::OK();
}
Status GrpcSession::PRun(const string& handle,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& output_names,
std::vector<Tensor>* outputs) {
return errors::Internal("Partial run is not supported for remote session.");
RunOptions run_options;
run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms());
return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs,
/* run_metadata */ nullptr, handle);
}
Status GrpcSession::Close() {

View File

@ -110,6 +110,13 @@ class GrpcSession : public Session {
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
Status RunHelper(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
const std::vector<string>& target_node_names,
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
const string& prun_handle);
Status RunProto(CallOptions* call_options, RunStepRequest* req,
RunStepResponse* resp);

View File

@ -163,6 +163,47 @@ class GrpcWorkerService : public AsyncServiceInterface {
mutex mu_;
CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
struct PartialRunState {
CancellationManager* cancellation_manager;
Notification executor_done;
explicit PartialRunState(CancellationManager* cm)
: cancellation_manager(cm) {}
};
struct PairHash {
std::size_t operator()(std::pair<string, int> const& p) const {
return Hash64Combine(std::hash<string>()(p.first),
std::hash<int>()(p.second));
}
};
std::unordered_map<std::pair<string, int>, std::unique_ptr<PartialRunState>,
PairHash>
partial_runs_ GUARDED_BY(mu_);
PartialRunState* FindPartialRun(const string& graph_handle, int step_id) {
std::pair<string, int> k(graph_handle, step_id);
PartialRunState* prun_state = nullptr;
mutex_lock l(mu_);
auto it = partial_runs_.find(k);
if (it != partial_runs_.end()) {
prun_state = it->second.get();
}
return prun_state;
}
void InsertPartialRunLocked(const string& graph_handle, int step_id,
PartialRunState* partial_run_state)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::pair<string, int> k(graph_handle, step_id);
partial_runs_.emplace(
std::make_pair(k, std::unique_ptr<PartialRunState>(partial_run_state)));
}
void RemovePartialRun(const string& graph_handle, int step_id) {
std::pair<string, int> k(graph_handle, step_id);
mutex_lock l(mu_);
partial_runs_.erase(partial_runs_.find(k));
}
mutex shutdown_mu_;
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
@ -225,7 +266,11 @@ class GrpcWorkerService : public AsyncServiceInterface {
}
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
if (call->request.is_partial()) {
env_->compute_pool->Schedule([this, call]() { DoPartialRunGraph(call); });
} else {
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
}
ENQUEUE_REQUEST(RunGraph, true);
}
@ -294,10 +339,6 @@ class GrpcWorkerService : public AsyncServiceInterface {
Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
GraphMgr::NamedTensors* out) {
if (req.is_partial()) {
return errors::Unimplemented(
"Partial run not implemented for GRPC worker service");
}
if (req.send_size() > 0) {
// TODO(zhifengc): Let the caller decide on which device to
// allocate the tensor.
@ -386,6 +427,110 @@ class GrpcWorkerService : public AsyncServiceInterface {
});
}
// TODO(suharshs): Add stats collection support to partial run.
void DoPartialRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
const int64 step_id = call->request.step_id();
const string& graph_handle = call->request.graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(call->request, &in, out);
auto finish = [this, call, out](const Status& s) {
delete out;
call->ClearCancelCallback();
call->SendResponse(ToGrpcStatus(s));
};
if (!s.ok()) {
finish(s);
return;
}
PartialRunState* partial_run_state = FindPartialRun(graph_handle, step_id);
CancellationManager* cm = nullptr;
// If this is a new partial run call we need to create a new cancellation
// manager.
// Otherwise we use the cancellation manager stored in the found partial
// run state.
if (partial_run_state == nullptr) {
cm = new CancellationManager;
} else {
cm = partial_run_state->cancellation_manager;
}
// Before we start doing anything, we set the RPC cancellation.
call->SetCancelCallback([this, cm, step_id]() {
cm->StartCancel();
AbortStep(step_id);
});
// If this is a new partial run request, the request will need to start the
// executors.
if (partial_run_state == nullptr) {
CancellationToken token;
{
mutex_lock l(mu_);
// Insert the new partial run into the partial_runs_ map.
partial_run_state = new PartialRunState(cm);
InsertPartialRunLocked(graph_handle, step_id, partial_run_state);
token = cancellation_manager_->get_cancellation_token();
cancellation_manager_->RegisterCallback(token,
[cm]() { cm->StartCancel(); });
}
env_->graph_mgr->ExecuteAsync(
graph_handle, step_id, call->request.exec_opts(),
nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, step_id, graph_handle, token, partial_run_state](Status s) {
{
mutex_lock l(mu_);
cancellation_manager_->DeregisterCallback(token);
}
partial_run_state->executor_done.Notify();
// TODO(suharshs): Propagate the status once we keep state for
// each partial run call.
});
} else {
// Send the partial run's new inputs.
s = env_->graph_mgr->SendInputs(step_id, in);
if (!s.ok()) {
finish(s);
return;
}
}
// Receive the partial run's outputs.
s = env_->graph_mgr->RecvOutputs(step_id, out);
if (!s.ok()) {
finish(s);
return;
}
// Construct and return the resp.
for (const auto& p : *out) {
const string& key = p.first;
const Tensor& val = p.second;
auto* recv = call->response.add_recv();
recv->set_key(key);
// TODO(zhifengc): Deal with gpu -> cpu copy.
TensorProto* proto = recv->mutable_val();
val.AsProtoField(proto);
}
// If this is the last partial run request we must also wait for the entire
// graph execution to be completed.
if (call->request.is_last_partial_run()) {
partial_run_state->executor_done.WaitForNotification();
RemovePartialRun(graph_handle, step_id);
// Before deleting the cancellation manager on the final call, ensure
// that we clear the RPC cancel callback, which has a reference to the
// cancellation manager.
call->ClearCancelCallback();
delete cm;
}
finish(s);
}
// Helper for RecvTensor. Validates "key" and returns the source
// device in "*src_dev".
Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,

View File

@ -2463,6 +2463,7 @@ py_test(
":platform_test",
":session",
":state_ops",
":training",
":util",
":variables",
],

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
@ -1322,91 +1323,121 @@ class SessionTest(test_util.TensorFlowTestCase):
sess_2.run(c_1.op)
self.assertEqual(2.0, sess_2.run(c_2))
def testPartialRun(self):
with session.Session() as sess:
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
def runTestPartialRun(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 17
res = sess.partial_run(h, r2, feed_dict={c: temp})
self.assertEqual(153, res)
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 17
res = sess.partial_run(h, r2, feed_dict={c: temp})
self.assertEqual(153, res)
# Call again on the same graph.
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 18
res = sess.partial_run(h2, r2, feed_dict={c: temp})
self.assertEqual(162, res)
# Call again on the same graph.
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h2, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 18
res = sess.partial_run(h2, r2, feed_dict={c: temp})
self.assertEqual(162, res)
def testPartialRunIncomplete(self):
with session.Session() as sess:
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
def runTestPartialRunIncomplete(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
def testConcurrentPartialRun(self):
with session.Session() as sess:
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
def runTestConcurrentPartialRun(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.mul(r1, c)
h1 = sess.partial_run_setup([r1], [a, b, c])
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 19
res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
self.assertEqual(66, res)
res = sess.partial_run(h2, r2, feed_dict={c: 7})
self.assertEqual(462, res)
h1 = sess.partial_run_setup([r1], [a, b, c])
h2 = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h1, r1, feed_dict={a: 1, b: 2})
self.assertEqual(3, res)
temp = res * 19
res = sess.partial_run(h2, r1, feed_dict={a: temp, b: 9})
self.assertEqual(66, res)
res = sess.partial_run(h2, r2, feed_dict={c: 7})
self.assertEqual(462, res)
def testManyPartialRun(self):
with session.Session() as sess:
steps = 200
inputs = []
outputs = []
a = constant_op.constant(2.0, dtypes.float32)
for i in xrange(steps):
inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
a = math_ops.mul(a, inputs[i])
outputs.append(a)
def runTestManyPartialRun(self, sess):
steps = 200
inputs = []
outputs = []
a = constant_op.constant(2.0, dtypes.float32)
for i in xrange(steps):
inputs.append(array_ops.placeholder(dtypes.float32, shape=[]))
a = math_ops.mul(a, inputs[i])
outputs.append(a)
h = sess.partial_run_setup(outputs, inputs)
for i in xrange(steps):
res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
self.assertEqual(2.0, res)
h = sess.partial_run_setup(outputs, inputs)
for i in xrange(steps):
res = sess.partial_run(h, outputs[i], feed_dict={inputs[i]: 1.0})
self.assertEqual(2.0, res)
feed_dict = {}
for i in xrange(steps):
feed_dict[inputs[i]] = 1.0
res = sess.run(outputs, feed_dict)
self.assertEqual(steps, len(res))
self.assertEqual(2.0, res[-1])
feed_dict = {}
for i in xrange(steps):
feed_dict[inputs[i]] = 1.0
res = sess.run(outputs, feed_dict)
self.assertEqual(steps, len(res))
self.assertEqual(2.0, res[-1])
def testRunAndPartialRun(self):
with session.Session() as sess:
a = constant_op.constant(2.0, dtypes.float32)
b = a * 2
c = b * 3
r1 = sess.run([b, c])
h = sess.partial_run_setup([b, c], [])
r2 = sess.partial_run(h, [b, c])
self.assertEqual(r1, r2)
def runTestRunAndPartialRun(self, sess):
a = constant_op.constant(2.0, dtypes.float32)
b = a * 2
c = b * 3
r1 = sess.run([b, c])
h = sess.partial_run_setup([b, c], [])
r2 = sess.partial_run(h, [b, c])
self.assertEqual(r1, r2)
def testPartialRunDirect(self):
self.runTestPartialRun(session.Session())
def testPartialRunIncompleteDirect(self):
self.runTestPartialRunIncomplete(session.Session())
def testConcurrentPartialRunDirect(self):
self.runTestConcurrentPartialRun(session.Session())
def testManyPartialRunDirect(self):
self.runTestManyPartialRun(session.Session())
def testRunAndPartialRunDirect(self):
self.runTestRunAndPartialRun(session.Session())
def testPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRun(session.Session(server.target))
def testPartialRunIncompleteDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRunIncomplete(session.Session(server.target))
def testConcurrentPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestConcurrentPartialRun(session.Session(server.target))
def testManyPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestManyPartialRun(session.Session(server.target))
def testRunAndPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestRunAndPartialRun(session.Session(server.target))
def testFeedDictKeyException(self):
with session.Session() as sess: