Distinguish between duplicate feed/fetch and unspecified feed/fetch errors.

Change: 154606429
This commit is contained in:
Suharsh Sivakumar 2017-04-28 17:44:46 -08:00 committed by TensorFlower Gardener
parent ad3c84b58b
commit dae9329b0a
5 changed files with 149 additions and 31 deletions

View File

@ -720,16 +720,21 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument(
"The feed ", input.first,
" has already been fed or was not specified in partial_run_setup.");
" was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The feed ", input.first,
" has already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
for (const auto& output : output_names) {
auto it = run_state->pending_outputs.find(output);
if (it == run_state->pending_outputs.end()) {
return errors::InvalidArgument(
"The fetch ", output, " was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The fetch ", output,
" has already been fetched or was not "
"specified in partial_run_setup.");
" has already been fetched.");
}
}
}
@ -764,14 +769,15 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
<< run_state->status;
}
}
for (const auto& it : inputs) {
run_state->pending_inputs.erase(it.first);
for (const auto& input : inputs) {
auto it = run_state->pending_inputs.find(input.first);
it->second = true;
}
for (const auto& name : output_names) {
run_state->pending_outputs.erase(name);
auto it = run_state->pending_outputs.find(name);
it->second = true;
}
done = (run_state->pending_inputs.size() == 0 &&
run_state->pending_outputs.size() == 0);
done = run_state->PendingDone();
}
if (done) {
WaitForNotification(run_state, cancellation_manager_,
@ -900,11 +906,13 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
{
mutex_lock l(executor_lock_);
for (const string& feed : run_state->pending_inputs) {
TensorId id(ParseTensorName(feed));
for (const auto& input : run_state->pending_inputs) {
// Skip if the feed has already been fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
auto it = name_to_node->find(id.first);
if (it == name_to_node->end()) {
return errors::NotFound("Feed ", feed, ": not found");
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
}
@ -1351,10 +1359,10 @@ DirectSession::RunState::RunState(
}) {
// Initially all the feeds and fetches are pending.
for (auto& name : pending_input_names) {
pending_inputs.emplace(name);
pending_inputs[name] = false;
}
for (auto& name : pending_output_names) {
pending_outputs.emplace(name);
pending_outputs[name] = false;
}
}
@ -1372,6 +1380,16 @@ DirectSession::RunState::~RunState() {
}
}
bool DirectSession::RunState::PendingDone() const {
for (const auto& it : pending_inputs) {
if (!it.second) return false;
}
for (const auto& it : pending_outputs) {
if (!it.second) return false;
}
return true;
}
void DirectSession::WaitForNotification(RunState* run_state,
CancellationManager* cm,
int64 timeout_in_ms) {

View File

@ -151,8 +151,8 @@ class DirectSession : public Session {
IntraProcessRendezvous* rendez = nullptr;
std::unique_ptr<StepStatsCollector> collector;
Notification executors_done;
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
std::unordered_map<string, bool> pending_inputs; // true if fed
std::unordered_map<string, bool> pending_outputs; // true if fetched
TensorStore tensor_store;
ScopedStepContainer step_container;
@ -162,6 +162,9 @@ class DirectSession : public Session {
const std::vector<string>& pending_output_names, int64 step_id,
const std::vector<Device*>* devices);
// Returns true if all pending inputs and outputs have been completed.
bool PendingDone() const;
~RunState();
};

View File

@ -803,11 +803,13 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
SimpleGraphExecutionState* execution_state) {
// Build the set of pending feeds that we haven't seen.
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
for (const string& feed : run_state->pending_inputs) {
TensorId id(ParseTensorName(feed));
for (const auto& input : run_state->pending_inputs) {
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
auto it = name_to_node_.find(id.first);
if (it == name_to_node_.end()) {
return errors::NotFound("Feed ", feed, ": not found");
return errors::NotFound("Feed ", input.first, ": not found");
}
pending_feeds.insert(id);
}
@ -1247,11 +1249,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
// Make sure that this is a new set of feeds that are still pending.
for (size_t i = 0; i < req.num_feeds(); ++i) {
auto it = run_state->pending_inputs.find(req.feed_name(i));
const string& feed = req.feed_name(i);
auto it = run_state->pending_inputs.find(feed);
if (it == run_state->pending_inputs.end()) {
return errors::InvalidArgument(
"The feed ", req.feed_name(i),
" has already been fed or was not specified in partial_run_setup.");
"The feed ", feed, " was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The feed ", feed,
" has already been fed.");
}
}
// Check that this is a new set of fetches that are still pending.
@ -1259,9 +1264,11 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
const string& fetch = req.fetch_name(i);
auto it = run_state->pending_outputs.find(fetch);
if (it == run_state->pending_outputs.end()) {
return errors::InvalidArgument(
"The fetch ", fetch, " was not specified in partial_run_setup.");
} else if (it->second) {
return errors::InvalidArgument("The fetch ", fetch,
" had already been fetched or was not "
"specified in partial_run_setup.");
" has already been fetched.");
}
}
@ -1274,13 +1281,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
// Determine if this partial run satisfies all the pending inputs and ouputs.
for (size_t i = 0; i < req.num_feeds(); ++i) {
run_state->pending_inputs.erase(req.feed_name(i));
auto it = run_state->pending_inputs.find(req.feed_name(i));
it->second = true;
}
for (size_t i = 0; i < req.num_fetches(); ++i) {
run_state->pending_outputs.erase(req.fetch_name(i));
auto it = run_state->pending_outputs.find(req.fetch_name(i));
it->second = true;
}
bool is_last_partial_run =
(run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
bool is_last_partial_run = run_state->PendingDone();
Status s = run_state->rcg->RunPartitions(
env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
@ -1418,10 +1426,10 @@ MasterSession::RunState::RunState(const std::vector<string>& input_names,
: rcg(rcg), step_id(step_id), count(count) {
// Initially all the feeds and fetches are pending.
for (auto& name : input_names) {
pending_inputs.emplace(name);
pending_inputs[name] = false;
}
for (auto& name : output_names) {
pending_outputs.emplace(name);
pending_outputs[name] = false;
}
}
@ -1429,4 +1437,14 @@ MasterSession::RunState::~RunState() {
if (rcg) rcg->Unref();
}
bool MasterSession::RunState::PendingDone() const {
for (const auto& it : pending_inputs) {
if (!it.second) return false;
}
for (const auto& it : pending_outputs) {
if (!it.second) return false;
}
return true;
}
} // end namespace tensorflow

View File

@ -141,8 +141,8 @@ class MasterSession : public core::RefCounted {
};
struct RunState {
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
std::unordered_map<string, bool> pending_inputs; // true if fed
std::unordered_map<string, bool> pending_outputs; // true if fetched
ReffedClientGraph* rcg = nullptr;
uint64 step_id;
int64 count = 0;
@ -154,6 +154,8 @@ class MasterSession : public core::RefCounted {
const std::vector<string>& output_names, ReffedClientGraph* rcg,
const uint64 step_id, const int64 count);
bool PendingDone() const;
~RunState();
};
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_

View File

@ -1431,6 +1431,55 @@ class SessionTest(test_util.TensorFlowTestCase):
'You must feed a value for placeholder'):
sess.partial_run(handle, fetches[0])
def runTestPartialRunUnspecifiedFeed(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
h = sess.partial_run_setup([r1], [a, b])
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'was not specified in partial_run_setup.$'):
sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3})
def runTestPartialRunUnspecifiedFetch(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(a, c)
h = sess.partial_run_setup([r1], [a, b, c])
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'was not specified in partial_run_setup.$'):
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
def runTestPartialRunAlreadyFed(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(a, c)
h = sess.partial_run_setup([r1, r2], [a, b, c])
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'has already been fed.$'):
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
def runTestPartialRunAlreadyFetched(self, sess):
a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(a, c)
h = sess.partial_run_setup([r1, r2], [a, b, c])
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
'has already been fetched.$'):
sess.partial_run(h, r1, feed_dict={c: 3})
def testInvalidPartialRunSetup(self):
sess = session.Session()
x = array_ops.placeholder(dtypes.float32, shape=[])
@ -1457,6 +1506,18 @@ class SessionTest(test_util.TensorFlowTestCase):
def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
self.runTestPartialRunMissingPlaceholderFeedException(session.Session())
def testPartialRunUnspecifiedFeedDirect(self):
self.runTestPartialRunUnspecifiedFeed(session.Session())
def testPartialRunUnspecifiedFetchDirect(self):
self.runTestPartialRunUnspecifiedFetch(session.Session())
def testPartialRunAlreadyFedDirect(self):
self.runTestPartialRunAlreadyFed(session.Session())
def testPartialRunAlreadyFetchedDirect(self):
self.runTestPartialRunAlreadyFetched(session.Session())
def testPartialRunDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRun(session.Session(server.target))
@ -1482,6 +1543,22 @@ class SessionTest(test_util.TensorFlowTestCase):
self.runTestPartialRunMissingPlaceholderFeedException(
session.Session(server.target))
def testPartialRunUnspecifiedFeedDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRunUnspecifiedFeed(session.Session(server.target))
def testPartialRunUnspecifiedFetchDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRunUnspecifiedFetch(session.Session(server.target))
def testPartialRunAlreadyFedDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRunAlreadyFed(session.Session(server.target))
def testPartialRunAlreadyFetchedDist(self):
server = server_lib.Server.create_local_server()
self.runTestPartialRunAlreadyFetched(session.Session(server.target))
def testFeedDictKeyException(self):
with session.Session() as sess:
a = constant_op.constant(1.0, dtypes.float32, name='a')