Distinguish between duplicate feed/fetch and unspecified feed/fetch errors.
Change: 154606429
This commit is contained in:
parent
ad3c84b58b
commit
dae9329b0a
@ -720,16 +720,21 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
if (it == run_state->pending_inputs.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"The feed ", input.first,
|
||||
" has already been fed or was not specified in partial_run_setup.");
|
||||
" was not specified in partial_run_setup.");
|
||||
} else if (it->second) {
|
||||
return errors::InvalidArgument("The feed ", input.first,
|
||||
" has already been fed.");
|
||||
}
|
||||
}
|
||||
// Check that this is a new set of fetches that are still pending.
|
||||
for (const auto& output : output_names) {
|
||||
auto it = run_state->pending_outputs.find(output);
|
||||
if (it == run_state->pending_outputs.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"The fetch ", output, " was not specified in partial_run_setup.");
|
||||
} else if (it->second) {
|
||||
return errors::InvalidArgument("The fetch ", output,
|
||||
" has already been fetched or was not "
|
||||
"specified in partial_run_setup.");
|
||||
" has already been fetched.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -764,14 +769,15 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
<< run_state->status;
|
||||
}
|
||||
}
|
||||
for (const auto& it : inputs) {
|
||||
run_state->pending_inputs.erase(it.first);
|
||||
for (const auto& input : inputs) {
|
||||
auto it = run_state->pending_inputs.find(input.first);
|
||||
it->second = true;
|
||||
}
|
||||
for (const auto& name : output_names) {
|
||||
run_state->pending_outputs.erase(name);
|
||||
auto it = run_state->pending_outputs.find(name);
|
||||
it->second = true;
|
||||
}
|
||||
done = (run_state->pending_inputs.size() == 0 &&
|
||||
run_state->pending_outputs.size() == 0);
|
||||
done = run_state->PendingDone();
|
||||
}
|
||||
if (done) {
|
||||
WaitForNotification(run_state, cancellation_manager_,
|
||||
@ -900,11 +906,13 @@ Status DirectSession::CheckFetch(const NamedTensorList& feeds,
|
||||
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
|
||||
{
|
||||
mutex_lock l(executor_lock_);
|
||||
for (const string& feed : run_state->pending_inputs) {
|
||||
TensorId id(ParseTensorName(feed));
|
||||
for (const auto& input : run_state->pending_inputs) {
|
||||
// Skip if the feed has already been fed.
|
||||
if (input.second) continue;
|
||||
TensorId id(ParseTensorName(input.first));
|
||||
auto it = name_to_node->find(id.first);
|
||||
if (it == name_to_node->end()) {
|
||||
return errors::NotFound("Feed ", feed, ": not found");
|
||||
return errors::NotFound("Feed ", input.first, ": not found");
|
||||
}
|
||||
pending_feeds.insert(id);
|
||||
}
|
||||
@ -1351,10 +1359,10 @@ DirectSession::RunState::RunState(
|
||||
}) {
|
||||
// Initially all the feeds and fetches are pending.
|
||||
for (auto& name : pending_input_names) {
|
||||
pending_inputs.emplace(name);
|
||||
pending_inputs[name] = false;
|
||||
}
|
||||
for (auto& name : pending_output_names) {
|
||||
pending_outputs.emplace(name);
|
||||
pending_outputs[name] = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1372,6 +1380,16 @@ DirectSession::RunState::~RunState() {
|
||||
}
|
||||
}
|
||||
|
||||
bool DirectSession::RunState::PendingDone() const {
|
||||
for (const auto& it : pending_inputs) {
|
||||
if (!it.second) return false;
|
||||
}
|
||||
for (const auto& it : pending_outputs) {
|
||||
if (!it.second) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void DirectSession::WaitForNotification(RunState* run_state,
|
||||
CancellationManager* cm,
|
||||
int64 timeout_in_ms) {
|
||||
|
@ -151,8 +151,8 @@ class DirectSession : public Session {
|
||||
IntraProcessRendezvous* rendez = nullptr;
|
||||
std::unique_ptr<StepStatsCollector> collector;
|
||||
Notification executors_done;
|
||||
std::unordered_set<string> pending_inputs;
|
||||
std::unordered_set<string> pending_outputs;
|
||||
std::unordered_map<string, bool> pending_inputs; // true if fed
|
||||
std::unordered_map<string, bool> pending_outputs; // true if fetched
|
||||
TensorStore tensor_store;
|
||||
ScopedStepContainer step_container;
|
||||
|
||||
@ -162,6 +162,9 @@ class DirectSession : public Session {
|
||||
const std::vector<string>& pending_output_names, int64 step_id,
|
||||
const std::vector<Device*>* devices);
|
||||
|
||||
// Returns true if all pending inputs and outputs have been completed.
|
||||
bool PendingDone() const;
|
||||
|
||||
~RunState();
|
||||
};
|
||||
|
||||
|
@ -803,11 +803,13 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
|
||||
SimpleGraphExecutionState* execution_state) {
|
||||
// Build the set of pending feeds that we haven't seen.
|
||||
std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
|
||||
for (const string& feed : run_state->pending_inputs) {
|
||||
TensorId id(ParseTensorName(feed));
|
||||
for (const auto& input : run_state->pending_inputs) {
|
||||
// Skip if already fed.
|
||||
if (input.second) continue;
|
||||
TensorId id(ParseTensorName(input.first));
|
||||
auto it = name_to_node_.find(id.first);
|
||||
if (it == name_to_node_.end()) {
|
||||
return errors::NotFound("Feed ", feed, ": not found");
|
||||
return errors::NotFound("Feed ", input.first, ": not found");
|
||||
}
|
||||
pending_feeds.insert(id);
|
||||
}
|
||||
@ -1247,11 +1249,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
|
||||
// Make sure that this is a new set of feeds that are still pending.
|
||||
for (size_t i = 0; i < req.num_feeds(); ++i) {
|
||||
auto it = run_state->pending_inputs.find(req.feed_name(i));
|
||||
const string& feed = req.feed_name(i);
|
||||
auto it = run_state->pending_inputs.find(feed);
|
||||
if (it == run_state->pending_inputs.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"The feed ", req.feed_name(i),
|
||||
" has already been fed or was not specified in partial_run_setup.");
|
||||
"The feed ", feed, " was not specified in partial_run_setup.");
|
||||
} else if (it->second) {
|
||||
return errors::InvalidArgument("The feed ", feed,
|
||||
" has already been fed.");
|
||||
}
|
||||
}
|
||||
// Check that this is a new set of fetches that are still pending.
|
||||
@ -1259,9 +1264,11 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
const string& fetch = req.fetch_name(i);
|
||||
auto it = run_state->pending_outputs.find(fetch);
|
||||
if (it == run_state->pending_outputs.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"The fetch ", fetch, " was not specified in partial_run_setup.");
|
||||
} else if (it->second) {
|
||||
return errors::InvalidArgument("The fetch ", fetch,
|
||||
" had already been fetched or was not "
|
||||
"specified in partial_run_setup.");
|
||||
" has already been fetched.");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1274,13 +1281,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
|
||||
|
||||
// Determine if this partial run satisfies all the pending inputs and ouputs.
|
||||
for (size_t i = 0; i < req.num_feeds(); ++i) {
|
||||
run_state->pending_inputs.erase(req.feed_name(i));
|
||||
auto it = run_state->pending_inputs.find(req.feed_name(i));
|
||||
it->second = true;
|
||||
}
|
||||
for (size_t i = 0; i < req.num_fetches(); ++i) {
|
||||
run_state->pending_outputs.erase(req.fetch_name(i));
|
||||
auto it = run_state->pending_outputs.find(req.fetch_name(i));
|
||||
it->second = true;
|
||||
}
|
||||
bool is_last_partial_run =
|
||||
(run_state->pending_inputs.empty() && run_state->pending_outputs.empty());
|
||||
bool is_last_partial_run = run_state->PendingDone();
|
||||
|
||||
Status s = run_state->rcg->RunPartitions(
|
||||
env_, run_state->step_id, run_state->count, &run_state->pss, opts, req,
|
||||
@ -1418,10 +1426,10 @@ MasterSession::RunState::RunState(const std::vector<string>& input_names,
|
||||
: rcg(rcg), step_id(step_id), count(count) {
|
||||
// Initially all the feeds and fetches are pending.
|
||||
for (auto& name : input_names) {
|
||||
pending_inputs.emplace(name);
|
||||
pending_inputs[name] = false;
|
||||
}
|
||||
for (auto& name : output_names) {
|
||||
pending_outputs.emplace(name);
|
||||
pending_outputs[name] = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1429,4 +1437,14 @@ MasterSession::RunState::~RunState() {
|
||||
if (rcg) rcg->Unref();
|
||||
}
|
||||
|
||||
bool MasterSession::RunState::PendingDone() const {
|
||||
for (const auto& it : pending_inputs) {
|
||||
if (!it.second) return false;
|
||||
}
|
||||
for (const auto& it : pending_outputs) {
|
||||
if (!it.second) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -141,8 +141,8 @@ class MasterSession : public core::RefCounted {
|
||||
};
|
||||
|
||||
struct RunState {
|
||||
std::unordered_set<string> pending_inputs;
|
||||
std::unordered_set<string> pending_outputs;
|
||||
std::unordered_map<string, bool> pending_inputs; // true if fed
|
||||
std::unordered_map<string, bool> pending_outputs; // true if fetched
|
||||
ReffedClientGraph* rcg = nullptr;
|
||||
uint64 step_id;
|
||||
int64 count = 0;
|
||||
@ -154,6 +154,8 @@ class MasterSession : public core::RefCounted {
|
||||
const std::vector<string>& output_names, ReffedClientGraph* rcg,
|
||||
const uint64 step_id, const int64 count);
|
||||
|
||||
bool PendingDone() const;
|
||||
|
||||
~RunState();
|
||||
};
|
||||
std::unordered_map<string, std::unique_ptr<RunState>> partial_runs_
|
||||
|
@ -1431,6 +1431,55 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
'You must feed a value for placeholder'):
|
||||
sess.partial_run(handle, fetches[0])
|
||||
|
||||
def runTestPartialRunUnspecifiedFeed(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
|
||||
h = sess.partial_run_setup([r1], [a, b])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3})
|
||||
|
||||
def runTestPartialRunUnspecifiedFetch(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.multiply(a, c)
|
||||
|
||||
h = sess.partial_run_setup([r1], [a, b, c])
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'was not specified in partial_run_setup.$'):
|
||||
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
|
||||
|
||||
def runTestPartialRunAlreadyFed(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.multiply(a, c)
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'has already been fed.$'):
|
||||
sess.partial_run(h, r2, feed_dict={a: 1, c: 3})
|
||||
|
||||
def runTestPartialRunAlreadyFetched(self, sess):
|
||||
a = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
b = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
c = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
r1 = math_ops.add(a, b)
|
||||
r2 = math_ops.multiply(a, c)
|
||||
|
||||
h = sess.partial_run_setup([r1, r2], [a, b, c])
|
||||
sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
'has already been fetched.$'):
|
||||
sess.partial_run(h, r1, feed_dict={c: 3})
|
||||
|
||||
def testInvalidPartialRunSetup(self):
|
||||
sess = session.Session()
|
||||
x = array_ops.placeholder(dtypes.float32, shape=[])
|
||||
@ -1457,6 +1506,18 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
def testPartialRunMissingPlaceholderFeedExceptionDirect(self):
|
||||
self.runTestPartialRunMissingPlaceholderFeedException(session.Session())
|
||||
|
||||
def testPartialRunUnspecifiedFeedDirect(self):
|
||||
self.runTestPartialRunUnspecifiedFeed(session.Session())
|
||||
|
||||
def testPartialRunUnspecifiedFetchDirect(self):
|
||||
self.runTestPartialRunUnspecifiedFetch(session.Session())
|
||||
|
||||
def testPartialRunAlreadyFedDirect(self):
|
||||
self.runTestPartialRunAlreadyFed(session.Session())
|
||||
|
||||
def testPartialRunAlreadyFetchedDirect(self):
|
||||
self.runTestPartialRunAlreadyFetched(session.Session())
|
||||
|
||||
def testPartialRunDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRun(session.Session(server.target))
|
||||
@ -1482,6 +1543,22 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.runTestPartialRunMissingPlaceholderFeedException(
|
||||
session.Session(server.target))
|
||||
|
||||
def testPartialRunUnspecifiedFeedDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRunUnspecifiedFeed(session.Session(server.target))
|
||||
|
||||
def testPartialRunUnspecifiedFetchDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRunUnspecifiedFetch(session.Session(server.target))
|
||||
|
||||
def testPartialRunAlreadyFedDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRunAlreadyFed(session.Session(server.target))
|
||||
|
||||
def testPartialRunAlreadyFetchedDist(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self.runTestPartialRunAlreadyFetched(session.Session(server.target))
|
||||
|
||||
def testFeedDictKeyException(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(1.0, dtypes.float32, name='a')
|
||||
|
Loading…
x
Reference in New Issue
Block a user