Bug fix: Need to distinguish run and partial run when caching executors in session.

Change: 121896355
This commit is contained in:
Yuan Yu 2016-05-09 16:22:23 -08:00 committed by TensorFlower Gardener
parent 9efc4097cf
commit 6d2623d03c
2 changed files with 12 additions and 1 deletions

View File

@ -661,7 +661,8 @@ Status DirectSession::GetOrCreateExecutors(
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/",
str_util::Join(tn_sorted, ","));
str_util::Join(tn_sorted, ","), "/",
run_state_args->is_partial_run);
// Set the handle.
{

View File

@ -971,6 +971,16 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(steps, len(res))
self.assertEqual(2.0, res[-1])
def testRunAndPartialRun(self):
with session.Session() as sess:
a = constant_op.constant(2.0, dtypes.float32)
b = a * 2
c = b * 3
r1 = sess.run([b, c])
h = sess.partial_run_setup([b, c], [])
r2 = sess.partial_run(h, [b, c])
self.assertEqual(r1, r2)
def testFeedDictKeyException(self):
with session.Session() as sess:
a = constant_op.constant(1.0, dtypes.float32, name='a')