Make BaseSession.make_callable work with the C API enabled.

This change also switches session_test.py to use the with_c_api decorator
so we get coverage with and without the C API enabled. This way we know
we're not breaking backwards compatibility with the C API enabled.

PiperOrigin-RevId: 178291189
This commit is contained in:
Skye Wanderman-Milne 2017-12-07 14:01:51 -08:00 committed by TensorFlower Gardener
parent b02eae0997
commit 0509f07cc2
2 changed files with 57 additions and 35 deletions

View File

@ -1160,9 +1160,6 @@ class BaseSession(SessionInterface):
TypeError: If `fetches` or `feed_list` cannot be interpreted TypeError: If `fetches` or `feed_list` cannot be interpreted
as arguments to @{tf.Session.run}. as arguments to @{tf.Session.run}.
""" """
assert not self._created_with_new_api, ('session.make_callable() doesn\'t '
'work with C API')
if feed_list is not None: if feed_list is not None:
if not isinstance(feed_list, (list, tuple)): if not isinstance(feed_list, (list, tuple)):
raise TypeError('`feed_list` must be a list or tuple.') raise TypeError('`feed_list` must be a list or tuple.')
@ -1184,12 +1181,18 @@ class BaseSession(SessionInterface):
# Create a fetch handler to take care of the structure of fetches. # Create a fetch handler to take care of the structure of fetches.
fetch_handler = _FetchHandler(self._graph, fetches, {}) fetch_handler = _FetchHandler(self._graph, fetches, {})
fetch_list_as_strings = _name_list(fetch_handler.fetches()) if self._created_with_new_api:
target_list_as_strings = _name_list(fetch_handler.targets()) # pylint: disable=protected-access
fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
target_list = [op._c_op for op in fetch_handler.targets()]
# pylint: enable=protected-access
else:
fetch_list = _name_list(fetch_handler.fetches())
target_list = _name_list(fetch_handler.targets())
def _callable_template_with_options_and_metadata( def _callable_template_with_options_and_metadata(
fetch_list_as_strings, fetch_list,
target_list_as_strings, target_list,
fetch_handler, fetch_handler,
options=None, options=None,
run_metadata=None): run_metadata=None):
@ -1199,9 +1202,14 @@ class BaseSession(SessionInterface):
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
try: try:
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
results = tf_session.TF_Run( if self._created_with_new_api:
self._session, options_ptr, {}, fetch_list_as_strings, results = tf_session.TF_SessionRun_wrapper(
target_list_as_strings, status, run_metadata_ptr) self._session, options_ptr, {}, fetch_list, target_list,
run_metadata_ptr, status)
else:
results = tf_session.TF_Run(
self._session, options_ptr, {}, fetch_list, target_list, status,
run_metadata_ptr)
if fetch_handler: if fetch_handler:
results = fetch_handler.build_results(self, results) results = fetch_handler.build_results(self, results)
else: else:
@ -1218,27 +1226,35 @@ class BaseSession(SessionInterface):
if accept_options: if accept_options:
return functools.partial( return functools.partial(
_callable_template_with_options_and_metadata, fetch_list_as_strings, _callable_template_with_options_and_metadata, fetch_list,
target_list_as_strings, fetch_handler) target_list, fetch_handler)
elif isinstance(fetches, ops.Operation): elif isinstance(fetches, ops.Operation):
# Special case for fetching a single operation, because the # Special case for fetching a single operation, because the
# function will have no return value. # function will have no return value.
assert not fetch_list_as_strings assert not fetch_list
assert len(target_list_as_strings) == 1 assert len(target_list) == 1
def _single_operation_run(): def _single_operation_run():
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_Run(self._session, None, {}, [], if self._created_with_new_api:
target_list_as_strings, status, None) tf_session.TF_SessionRun_wrapper(
self._session, None, {}, [], target_list, None, status)
else:
tf_session.TF_Run(
self._session, None, {}, [], target_list, status, None)
return _single_operation_run return _single_operation_run
elif isinstance(fetches, ops.Tensor): elif isinstance(fetches, ops.Tensor):
# Special case for fetching a single tensor, because the # Special case for fetching a single tensor, because the
# function can return the result of `TF_Run()` directly. # function can return the result of `TF_Run()` directly.
assert len(fetch_list_as_strings) == 1 assert len(fetch_list) == 1
assert not target_list_as_strings assert not target_list
def _single_tensor_run(): def _single_tensor_run():
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
results = tf_session.TF_Run(self._session, None, {}, if self._created_with_new_api:
fetch_list_as_strings, [], status, None) results = tf_session.TF_SessionRun_wrapper(
self._session, None, {}, fetch_list, [], None, status)
else:
results = tf_session.TF_Run(
self._session, None, {}, fetch_list, [], status, None)
return results[0] return results[0]
return _single_tensor_run return _single_tensor_run
else: else:
@ -1246,9 +1262,12 @@ class BaseSession(SessionInterface):
# results for us. # results for us.
def _fetch_handler_run(): def _fetch_handler_run():
with errors.raise_exception_on_not_ok_status() as status: with errors.raise_exception_on_not_ok_status() as status:
results = tf_session.TF_Run(self._session, None, {}, if self._created_with_new_api:
fetch_list_as_strings, results = tf_session.TF_SessionRun_wrapper(
target_list_as_strings, status, None) self._session, None, {}, fetch_list, target_list, None, status)
else:
results = tf_session.TF_Run(
self._session, None, {}, fetch_list, target_list, status, None)
return fetch_handler.build_results(self, results) return fetch_handler.build_results(self, results)
return _fetch_handler_run return _fetch_handler_run

View File

@ -57,13 +57,13 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.util import compat from tensorflow.python.util import compat
ops._USE_C_API = True
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they # NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns. # don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@test_util.with_c_api
class SessionTest(test_util.TensorFlowTestCase): class SessionTest(test_util.TensorFlowTestCase):
def testUseExistingGraph(self): def testUseExistingGraph(self):
@ -165,8 +165,9 @@ class SessionTest(test_util.TensorFlowTestCase):
# Run with a bogus handle. # Run with a bogus handle.
s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
@test_util.disable_c_api # No shape registration for 'ConstructionFails'
def testOpConstructionErrorPayload(self): def testOpConstructionErrorPayload(self):
if ops._USE_C_API: return # No shape registration for 'ConstructionFails'
with session.Session(): with session.Session():
failing_op = ops.get_default_graph().create_op( failing_op = ops.get_default_graph().create_op(
'ConstructionFails', [], [], name='f') 'ConstructionFails', [], [], name='f')
@ -208,7 +209,6 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
s.run({'a': a, 'b': None}) s.run({'a': a, 'b': None})
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testFetchSingleton(self): def testFetchSingleton(self):
with session.Session() as sess: with session.Session() as sess:
a = constant_op.constant(42.0) a = constant_op.constant(42.0)
@ -231,7 +231,6 @@ class SessionTest(test_util.TensorFlowTestCase):
res = sess.run(a.op) # An op, not a tensor. res = sess.run(a.op) # An op, not a tensor.
self.assertEqual(None, res) self.assertEqual(None, res)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testFetchList(self): def testFetchList(self):
with session.Session() as sess: with session.Session() as sess:
a = constant_op.constant(42.0) a = constant_op.constant(42.0)
@ -247,7 +246,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(res, list)) self.assertTrue(isinstance(res, list))
self.assertEqual([42.0, None, 44.0, 42.0, None], res) self.assertEqual([42.0, None, 44.0, 42.0, None], res)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testFetchTuple(self): def testFetchTuple(self):
with session.Session() as sess: with session.Session() as sess:
a = constant_op.constant(42.0) a = constant_op.constant(42.0)
@ -261,7 +259,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(res, tuple)) self.assertTrue(isinstance(res, tuple))
self.assertEqual((42.0, None, 44.0, 42.0), res) self.assertEqual((42.0, None, 44.0, 42.0), res)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testFetchNamedTuple(self): def testFetchNamedTuple(self):
# pylint: disable=invalid-name # pylint: disable=invalid-name
ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
@ -1178,7 +1175,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
self.assertAllEqual(a2_val, [[1.0, 1.0]]) self.assertAllEqual(a2_val, [[1.0, 1.0]])
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testFeedAndFetch(self): def testFeedAndFetch(self):
with session.Session() as sess: with session.Session() as sess:
for dtype in [dtypes.float16, for dtype in [dtypes.float16,
@ -1225,7 +1221,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(np_array, out_v) self.assertAllEqual(np_array, out_v)
self.assertAllEqual(np_array, feed_v) self.assertAllEqual(np_array, feed_v)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testMakeCallableOnTensorWithRunOptions(self): def testMakeCallableOnTensorWithRunOptions(self):
with session.Session() as sess: with session.Session() as sess:
a = constant_op.constant(42.0) a = constant_op.constant(42.0)
@ -1238,7 +1233,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(42.0, res) self.assertEqual(42.0, res)
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testMakeCallableOnOperationWithRunOptions(self): def testMakeCallableOnOperationWithRunOptions(self):
with session.Session() as sess: with session.Session() as sess:
a = variables.Variable(42.0) a = variables.Variable(42.0)
@ -1253,7 +1247,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(43.0, sess.run(a)) self.assertEqual(43.0, sess.run(a))
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
@test_util.disable_c_api # session.make_callable() doesn't work with C API
def testMakeCallableWithFeedListAndRunOptions(self): def testMakeCallableWithFeedListAndRunOptions(self):
with session.Session() as sess: with session.Session() as sess:
ph = array_ops.placeholder(dtypes.float32) ph = array_ops.placeholder(dtypes.float32)
@ -1460,9 +1453,10 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertTrue(run_metadata.HasField('step_stats')) self.assertTrue(run_metadata.HasField('step_stats'))
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
@test_util.disable_c_api
def testFeedShapeCompatibility(self): def testFeedShapeCompatibility(self):
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
if ops._USE_C_API: return
with session.Session() as sess: with session.Session() as sess:
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
new_shape = constant_op.constant([2, 2]) new_shape = constant_op.constant([2, 2])
@ -1746,6 +1740,15 @@ class SessionTest(test_util.TensorFlowTestCase):
class GraphMutationTest(test_util.TensorFlowTestCase): class GraphMutationTest(test_util.TensorFlowTestCase):
def setUp(self):
self._original_use_c_api_value = ops._USE_C_API
ops._USE_C_API = True
super(GraphMutationTest, self).setUp()
def tearDown(self):
ops._USE_C_API = self._original_use_c_api_value
super(GraphMutationTest, self).tearDown()
def testUpdateInputAfterRunning(self): def testUpdateInputAfterRunning(self):
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
a = constant_op.constant(1.0) a = constant_op.constant(1.0)