Bug fix: Need to distinguish run and partial run when caching executors in session.
Change: 121896355
This commit is contained in:
parent
9efc4097cf
commit
6d2623d03c
tensorflow
@ -661,7 +661,8 @@ Status DirectSession::GetOrCreateExecutors(
|
|||||||
|
|
||||||
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
|
const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->",
|
||||||
str_util::Join(outputs_sorted, ","), "/",
|
str_util::Join(outputs_sorted, ","), "/",
|
||||||
str_util::Join(tn_sorted, ","));
|
str_util::Join(tn_sorted, ","), "/",
|
||||||
|
run_state_args->is_partial_run);
|
||||||
|
|
||||||
// Set the handle.
|
// Set the handle.
|
||||||
{
|
{
|
||||||
|
@ -971,6 +971,16 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(steps, len(res))
|
self.assertEqual(steps, len(res))
|
||||||
self.assertEqual(2.0, res[-1])
|
self.assertEqual(2.0, res[-1])
|
||||||
|
|
||||||
|
def testRunAndPartialRun(self):
|
||||||
|
with session.Session() as sess:
|
||||||
|
a = constant_op.constant(2.0, dtypes.float32)
|
||||||
|
b = a * 2
|
||||||
|
c = b * 3
|
||||||
|
r1 = sess.run([b, c])
|
||||||
|
h = sess.partial_run_setup([b, c], [])
|
||||||
|
r2 = sess.partial_run(h, [b, c])
|
||||||
|
self.assertEqual(r1, r2)
|
||||||
|
|
||||||
def testFeedDictKeyException(self):
|
def testFeedDictKeyException(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = constant_op.constant(1.0, dtypes.float32, name='a')
|
a = constant_op.constant(1.0, dtypes.float32, name='a')
|
||||||
|
Loading…
Reference in New Issue
Block a user