diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 759c36ad72e..017bef99ce8 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1160,9 +1160,6 @@ class BaseSession(SessionInterface): TypeError: If `fetches` or `feed_list` cannot be interpreted as arguments to @{tf.Session.run}. """ - assert not self._created_with_new_api, ('session.make_callable() doesn\'t ' - 'work with C API') - if feed_list is not None: if not isinstance(feed_list, (list, tuple)): raise TypeError('`feed_list` must be a list or tuple.') @@ -1184,12 +1181,18 @@ class BaseSession(SessionInterface): # Create a fetch handler to take care of the structure of fetches. fetch_handler = _FetchHandler(self._graph, fetches, {}) - fetch_list_as_strings = _name_list(fetch_handler.fetches()) - target_list_as_strings = _name_list(fetch_handler.targets()) + if self._created_with_new_api: + # pylint: disable=protected-access + fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()] + target_list = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access + else: + fetch_list = _name_list(fetch_handler.fetches()) + target_list = _name_list(fetch_handler.targets()) def _callable_template_with_options_and_metadata( - fetch_list_as_strings, - target_list_as_strings, + fetch_list, + target_list, fetch_handler, options=None, run_metadata=None): @@ -1199,9 +1202,14 @@ class BaseSession(SessionInterface): run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None try: with errors.raise_exception_on_not_ok_status() as status: - results = tf_session.TF_Run( - self._session, options_ptr, {}, fetch_list_as_strings, - target_list_as_strings, status, run_metadata_ptr) + if self._created_with_new_api: + results = tf_session.TF_SessionRun_wrapper( + self._session, options_ptr, {}, fetch_list, target_list, + run_metadata_ptr, status) + else: + results = tf_session.TF_Run( + self._session, options_ptr, {}, fetch_list, target_list, status, + run_metadata_ptr) if fetch_handler: results = fetch_handler.build_results(self, results) else: @@ -1218,27 +1226,35 @@ class BaseSession(SessionInterface): if accept_options: return functools.partial( - _callable_template_with_options_and_metadata, fetch_list_as_strings, - target_list_as_strings, fetch_handler) + _callable_template_with_options_and_metadata, fetch_list, + target_list, fetch_handler) elif isinstance(fetches, ops.Operation): # Special case for fetching a single operation, because the # function will have no return value. - assert not fetch_list_as_strings - assert len(target_list_as_strings) == 1 + assert not fetch_list + assert len(target_list) == 1 def _single_operation_run(): with errors.raise_exception_on_not_ok_status() as status: - tf_session.TF_Run(self._session, None, {}, [], - target_list_as_strings, status, None) + if self._created_with_new_api: + tf_session.TF_SessionRun_wrapper( + self._session, None, {}, [], target_list, None, status) + else: + tf_session.TF_Run( + self._session, None, {}, [], target_list, status, None) return _single_operation_run elif isinstance(fetches, ops.Tensor): # Special case for fetching a single tensor, because the # function can return the result of `TF_Run()` directly. - assert len(fetch_list_as_strings) == 1 - assert not target_list_as_strings + assert len(fetch_list) == 1 + assert not target_list def _single_tensor_run(): with errors.raise_exception_on_not_ok_status() as status: - results = tf_session.TF_Run(self._session, None, {}, - fetch_list_as_strings, [], status, None) + if self._created_with_new_api: + results = tf_session.TF_SessionRun_wrapper( + self._session, None, {}, fetch_list, [], None, status) + else: + results = tf_session.TF_Run( + self._session, None, {}, fetch_list, [], status, None) return results[0] return _single_tensor_run else: @@ -1246,9 +1262,12 @@ class BaseSession(SessionInterface): # results for us. def _fetch_handler_run(): with errors.raise_exception_on_not_ok_status() as status: - results = tf_session.TF_Run(self._session, None, {}, - fetch_list_as_strings, - target_list_as_strings, status, None) + if self._created_with_new_api: + results = tf_session.TF_SessionRun_wrapper( + self._session, None, {}, fetch_list, target_list, None, status) + else: + results = tf_session.TF_Run( + self._session, None, {}, fetch_list, target_list, status, None) return fetch_handler.build_results(self, results) return _fetch_handler_run diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e4545d287b7..3da03a7b0fb 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -57,13 +57,13 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib from tensorflow.python.util import compat -ops._USE_C_API = True # NOTE(mrry): Dummy shape registration for ops used in the tests, since they # don't have C++ op registrations on which to attach C++ shape fns. ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) +@test_util.with_c_api class SessionTest(test_util.TensorFlowTestCase): def testUseExistingGraph(self): @@ -165,8 +165,9 @@ class SessionTest(test_util.TensorFlowTestCase): # Run with a bogus handle. s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) - @test_util.disable_c_api # No shape registration for 'ConstructionFails' def testOpConstructionErrorPayload(self): + if ops._USE_C_API: return # No shape registration for 'ConstructionFails' + with session.Session(): failing_op = ops.get_default_graph().create_op( 'ConstructionFails', [], [], name='f') @@ -208,7 +209,6 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaises(TypeError): s.run({'a': a, 'b': None}) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testFetchSingleton(self): with session.Session() as sess: a = constant_op.constant(42.0) @@ -231,7 +231,6 @@ class SessionTest(test_util.TensorFlowTestCase): res = sess.run(a.op) # An op, not a tensor. self.assertEqual(None, res) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testFetchList(self): with session.Session() as sess: a = constant_op.constant(42.0) @@ -247,7 +246,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(res, list)) self.assertEqual([42.0, None, 44.0, 42.0, None], res) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testFetchTuple(self): with session.Session() as sess: a = constant_op.constant(42.0) @@ -261,7 +259,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(res, tuple)) self.assertEqual((42.0, None, 44.0, 42.0), res) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testFetchNamedTuple(self): # pylint: disable=invalid-name ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) @@ -1178,7 +1175,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) self.assertAllEqual(a2_val, [[1.0, 1.0]]) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testFeedAndFetch(self): with session.Session() as sess: for dtype in [dtypes.float16, @@ -1225,7 +1221,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(np_array, out_v) self.assertAllEqual(np_array, feed_v) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testMakeCallableOnTensorWithRunOptions(self): with session.Session() as sess: a = constant_op.constant(42.0) @@ -1238,7 +1233,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(42.0, res) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testMakeCallableOnOperationWithRunOptions(self): with session.Session() as sess: a = variables.Variable(42.0) @@ -1253,7 +1247,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(43.0, sess.run(a)) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) - @test_util.disable_c_api # session.make_callable() doesn't work with C API def testMakeCallableWithFeedListAndRunOptions(self): with session.Session() as sess: ph = array_ops.placeholder(dtypes.float32) @@ -1460,9 +1453,10 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertTrue(run_metadata.HasField('step_stats')) self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) - # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. - @test_util.disable_c_api def testFeedShapeCompatibility(self): + # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. + if ops._USE_C_API: return + with session.Session() as sess: some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) new_shape = constant_op.constant([2, 2]) @@ -1746,6 +1740,15 @@ class SessionTest(test_util.TensorFlowTestCase): class GraphMutationTest(test_util.TensorFlowTestCase): + def setUp(self): + self._original_use_c_api_value = ops._USE_C_API + ops._USE_C_API = True + super(GraphMutationTest, self).setUp() + + def tearDown(self): + ops._USE_C_API = self._original_use_c_api_value + super(GraphMutationTest, self).tearDown() + def testUpdateInputAfterRunning(self): with ops.Graph().as_default() as g: a = constant_op.constant(1.0)