From 6d2623d03c4f4ad25169140fbcd0183a8188be51 Mon Sep 17 00:00:00 2001 From: Yuan Yu Date: Mon, 9 May 2016 16:22:23 -0800 Subject: [PATCH] Bug fix: Need to distinguish run and partial run when caching executors in session. Change: 121896355 --- tensorflow/core/common_runtime/direct_session.cc | 3 ++- tensorflow/python/client/session_test.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 05ca2e87026..0c3ddaf2f7b 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -661,7 +661,8 @@ Status DirectSession::GetOrCreateExecutors( const string key = strings::StrCat(str_util::Join(inputs_sorted, ","), "->", str_util::Join(outputs_sorted, ","), "/", - str_util::Join(tn_sorted, ",")); + str_util::Join(tn_sorted, ","), "/", + run_state_args->is_partial_run); // Set the handle. { diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 59c9afca43c..d60883464d0 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -971,6 +971,16 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(steps, len(res)) self.assertEqual(2.0, res[-1]) + def testRunAndPartialRun(self): + with session.Session() as sess: + a = constant_op.constant(2.0, dtypes.float32) + b = a * 2 + c = b * 3 + r1 = sess.run([b, c]) + h = sess.partial_run_setup([b, c], []) + r2 = sess.partial_run(h, [b, c]) + self.assertEqual(r1, r2) + def testFeedDictKeyException(self): with session.Session() as sess: a = constant_op.constant(1.0, dtypes.float32, name='a')