Bug fix: Need to distinguish run and partial run when caching executors in session.
Change: 121896355
This commit is contained in:
parent
9efc4097cf
commit
6d2623d03c
@ -661,7 +661,8 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
|
||||
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
|
||||
str_util::Join(outputs_sorted, ","), "/",
|
||||
str_util::Join(tn_sorted, ","));
|
||||
str_util::Join(tn_sorted, ","), "/",
|
||||
run_state_args->is_partial_run);
|
||||
|
||||
// Set the handle.
|
||||
{
|
||||
|
@ -971,6 +971,16 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(steps, len(res))
|
||||
self.assertEqual(2.0, res[-1])
|
||||
|
||||
def testRunAndPartialRun(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(2.0, dtypes.float32)
|
||||
b = a * 2
|
||||
c = b * 3
|
||||
r1 = sess.run([b, c])
|
||||
h = sess.partial_run_setup([b, c], [])
|
||||
r2 = sess.partial_run(h, [b, c])
|
||||
self.assertEqual(r1, r2)
|
||||
|
||||
def testFeedDictKeyException(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(1.0, dtypes.float32, name='a')
|
||||
|
Loading…
Reference in New Issue
Block a user