Make BaseSession.make_callable work with the C API enabled.
This change also switches session_test.py to use the with_c_api decorator so we get coverage with and without the C API enabled. This way we know we're not breaking backwards compatibility with the C API enabled. PiperOrigin-RevId: 178291189
This commit is contained in:
parent
b02eae0997
commit
0509f07cc2
@ -1160,9 +1160,6 @@ class BaseSession(SessionInterface):
|
|||||||
TypeError: If `fetches` or `feed_list` cannot be interpreted
|
TypeError: If `fetches` or `feed_list` cannot be interpreted
|
||||||
as arguments to @{tf.Session.run}.
|
as arguments to @{tf.Session.run}.
|
||||||
"""
|
"""
|
||||||
assert not self._created_with_new_api, ('session.make_callable() doesn\'t '
|
|
||||||
'work with C API')
|
|
||||||
|
|
||||||
if feed_list is not None:
|
if feed_list is not None:
|
||||||
if not isinstance(feed_list, (list, tuple)):
|
if not isinstance(feed_list, (list, tuple)):
|
||||||
raise TypeError('`feed_list` must be a list or tuple.')
|
raise TypeError('`feed_list` must be a list or tuple.')
|
||||||
@ -1184,12 +1181,18 @@ class BaseSession(SessionInterface):
|
|||||||
|
|
||||||
# Create a fetch handler to take care of the structure of fetches.
|
# Create a fetch handler to take care of the structure of fetches.
|
||||||
fetch_handler = _FetchHandler(self._graph, fetches, {})
|
fetch_handler = _FetchHandler(self._graph, fetches, {})
|
||||||
fetch_list_as_strings = _name_list(fetch_handler.fetches())
|
if self._created_with_new_api:
|
||||||
target_list_as_strings = _name_list(fetch_handler.targets())
|
# pylint: disable=protected-access
|
||||||
|
fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
|
||||||
|
target_list = [op._c_op for op in fetch_handler.targets()]
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
else:
|
||||||
|
fetch_list = _name_list(fetch_handler.fetches())
|
||||||
|
target_list = _name_list(fetch_handler.targets())
|
||||||
|
|
||||||
def _callable_template_with_options_and_metadata(
|
def _callable_template_with_options_and_metadata(
|
||||||
fetch_list_as_strings,
|
fetch_list,
|
||||||
target_list_as_strings,
|
target_list,
|
||||||
fetch_handler,
|
fetch_handler,
|
||||||
options=None,
|
options=None,
|
||||||
run_metadata=None):
|
run_metadata=None):
|
||||||
@ -1199,9 +1202,14 @@ class BaseSession(SessionInterface):
|
|||||||
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
|
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
|
||||||
try:
|
try:
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
results = tf_session.TF_Run(
|
if self._created_with_new_api:
|
||||||
self._session, options_ptr, {}, fetch_list_as_strings,
|
results = tf_session.TF_SessionRun_wrapper(
|
||||||
target_list_as_strings, status, run_metadata_ptr)
|
self._session, options_ptr, {}, fetch_list, target_list,
|
||||||
|
run_metadata_ptr, status)
|
||||||
|
else:
|
||||||
|
results = tf_session.TF_Run(
|
||||||
|
self._session, options_ptr, {}, fetch_list, target_list, status,
|
||||||
|
run_metadata_ptr)
|
||||||
if fetch_handler:
|
if fetch_handler:
|
||||||
results = fetch_handler.build_results(self, results)
|
results = fetch_handler.build_results(self, results)
|
||||||
else:
|
else:
|
||||||
@ -1218,27 +1226,35 @@ class BaseSession(SessionInterface):
|
|||||||
|
|
||||||
if accept_options:
|
if accept_options:
|
||||||
return functools.partial(
|
return functools.partial(
|
||||||
_callable_template_with_options_and_metadata, fetch_list_as_strings,
|
_callable_template_with_options_and_metadata, fetch_list,
|
||||||
target_list_as_strings, fetch_handler)
|
target_list, fetch_handler)
|
||||||
elif isinstance(fetches, ops.Operation):
|
elif isinstance(fetches, ops.Operation):
|
||||||
# Special case for fetching a single operation, because the
|
# Special case for fetching a single operation, because the
|
||||||
# function will have no return value.
|
# function will have no return value.
|
||||||
assert not fetch_list_as_strings
|
assert not fetch_list
|
||||||
assert len(target_list_as_strings) == 1
|
assert len(target_list) == 1
|
||||||
def _single_operation_run():
|
def _single_operation_run():
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
tf_session.TF_Run(self._session, None, {}, [],
|
if self._created_with_new_api:
|
||||||
target_list_as_strings, status, None)
|
tf_session.TF_SessionRun_wrapper(
|
||||||
|
self._session, None, {}, [], target_list, None, status)
|
||||||
|
else:
|
||||||
|
tf_session.TF_Run(
|
||||||
|
self._session, None, {}, [], target_list, status, None)
|
||||||
return _single_operation_run
|
return _single_operation_run
|
||||||
elif isinstance(fetches, ops.Tensor):
|
elif isinstance(fetches, ops.Tensor):
|
||||||
# Special case for fetching a single tensor, because the
|
# Special case for fetching a single tensor, because the
|
||||||
# function can return the result of `TF_Run()` directly.
|
# function can return the result of `TF_Run()` directly.
|
||||||
assert len(fetch_list_as_strings) == 1
|
assert len(fetch_list) == 1
|
||||||
assert not target_list_as_strings
|
assert not target_list
|
||||||
def _single_tensor_run():
|
def _single_tensor_run():
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
results = tf_session.TF_Run(self._session, None, {},
|
if self._created_with_new_api:
|
||||||
fetch_list_as_strings, [], status, None)
|
results = tf_session.TF_SessionRun_wrapper(
|
||||||
|
self._session, None, {}, fetch_list, [], None, status)
|
||||||
|
else:
|
||||||
|
results = tf_session.TF_Run(
|
||||||
|
self._session, None, {}, fetch_list, [], status, None)
|
||||||
return results[0]
|
return results[0]
|
||||||
return _single_tensor_run
|
return _single_tensor_run
|
||||||
else:
|
else:
|
||||||
@ -1246,9 +1262,12 @@ class BaseSession(SessionInterface):
|
|||||||
# results for us.
|
# results for us.
|
||||||
def _fetch_handler_run():
|
def _fetch_handler_run():
|
||||||
with errors.raise_exception_on_not_ok_status() as status:
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
results = tf_session.TF_Run(self._session, None, {},
|
if self._created_with_new_api:
|
||||||
fetch_list_as_strings,
|
results = tf_session.TF_SessionRun_wrapper(
|
||||||
target_list_as_strings, status, None)
|
self._session, None, {}, fetch_list, target_list, None, status)
|
||||||
|
else:
|
||||||
|
results = tf_session.TF_Run(
|
||||||
|
self._session, None, {}, fetch_list, target_list, status, None)
|
||||||
return fetch_handler.build_results(self, results)
|
return fetch_handler.build_results(self, results)
|
||||||
return _fetch_handler_run
|
return _fetch_handler_run
|
||||||
|
|
||||||
|
@ -57,13 +57,13 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
|
||||||
ops._USE_C_API = True
|
|
||||||
|
|
||||||
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
|
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
|
||||||
# don't have C++ op registrations on which to attach C++ shape fns.
|
# don't have C++ op registrations on which to attach C++ shape fns.
|
||||||
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
|
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.with_c_api
|
||||||
class SessionTest(test_util.TensorFlowTestCase):
|
class SessionTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
def testUseExistingGraph(self):
|
def testUseExistingGraph(self):
|
||||||
@ -165,8 +165,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
# Run with a bogus handle.
|
# Run with a bogus handle.
|
||||||
s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
|
s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
|
||||||
|
|
||||||
@test_util.disable_c_api # No shape registration for 'ConstructionFails'
|
|
||||||
def testOpConstructionErrorPayload(self):
|
def testOpConstructionErrorPayload(self):
|
||||||
|
if ops._USE_C_API: return # No shape registration for 'ConstructionFails'
|
||||||
|
|
||||||
with session.Session():
|
with session.Session():
|
||||||
failing_op = ops.get_default_graph().create_op(
|
failing_op = ops.get_default_graph().create_op(
|
||||||
'ConstructionFails', [], [], name='f')
|
'ConstructionFails', [], [], name='f')
|
||||||
@ -208,7 +209,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
s.run({'a': a, 'b': None})
|
s.run({'a': a, 'b': None})
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testFetchSingleton(self):
|
def testFetchSingleton(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = constant_op.constant(42.0)
|
a = constant_op.constant(42.0)
|
||||||
@ -231,7 +231,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
res = sess.run(a.op) # An op, not a tensor.
|
res = sess.run(a.op) # An op, not a tensor.
|
||||||
self.assertEqual(None, res)
|
self.assertEqual(None, res)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testFetchList(self):
|
def testFetchList(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = constant_op.constant(42.0)
|
a = constant_op.constant(42.0)
|
||||||
@ -247,7 +246,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(isinstance(res, list))
|
self.assertTrue(isinstance(res, list))
|
||||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testFetchTuple(self):
|
def testFetchTuple(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = constant_op.constant(42.0)
|
a = constant_op.constant(42.0)
|
||||||
@ -261,7 +259,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(isinstance(res, tuple))
|
self.assertTrue(isinstance(res, tuple))
|
||||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testFetchNamedTuple(self):
|
def testFetchNamedTuple(self):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
||||||
@ -1178,7 +1175,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
|
self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
|
||||||
self.assertAllEqual(a2_val, [[1.0, 1.0]])
|
self.assertAllEqual(a2_val, [[1.0, 1.0]])
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testFeedAndFetch(self):
|
def testFeedAndFetch(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
for dtype in [dtypes.float16,
|
for dtype in [dtypes.float16,
|
||||||
@ -1225,7 +1221,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(np_array, out_v)
|
self.assertAllEqual(np_array, out_v)
|
||||||
self.assertAllEqual(np_array, feed_v)
|
self.assertAllEqual(np_array, feed_v)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testMakeCallableOnTensorWithRunOptions(self):
|
def testMakeCallableOnTensorWithRunOptions(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = constant_op.constant(42.0)
|
a = constant_op.constant(42.0)
|
||||||
@ -1238,7 +1233,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(42.0, res)
|
self.assertEqual(42.0, res)
|
||||||
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testMakeCallableOnOperationWithRunOptions(self):
|
def testMakeCallableOnOperationWithRunOptions(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
a = variables.Variable(42.0)
|
a = variables.Variable(42.0)
|
||||||
@ -1253,7 +1247,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(43.0, sess.run(a))
|
self.assertEqual(43.0, sess.run(a))
|
||||||
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
||||||
|
|
||||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
|
||||||
def testMakeCallableWithFeedListAndRunOptions(self):
|
def testMakeCallableWithFeedListAndRunOptions(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
ph = array_ops.placeholder(dtypes.float32)
|
ph = array_ops.placeholder(dtypes.float32)
|
||||||
@ -1460,9 +1453,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||||
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
|
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
|
||||||
|
|
||||||
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
|
|
||||||
@test_util.disable_c_api
|
|
||||||
def testFeedShapeCompatibility(self):
|
def testFeedShapeCompatibility(self):
|
||||||
|
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
|
||||||
|
if ops._USE_C_API: return
|
||||||
|
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
|
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
|
||||||
new_shape = constant_op.constant([2, 2])
|
new_shape = constant_op.constant([2, 2])
|
||||||
@ -1746,6 +1740,15 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
|
|
||||||
class GraphMutationTest(test_util.TensorFlowTestCase):
|
class GraphMutationTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._original_use_c_api_value = ops._USE_C_API
|
||||||
|
ops._USE_C_API = True
|
||||||
|
super(GraphMutationTest, self).setUp()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
ops._USE_C_API = self._original_use_c_api_value
|
||||||
|
super(GraphMutationTest, self).tearDown()
|
||||||
|
|
||||||
def testUpdateInputAfterRunning(self):
|
def testUpdateInputAfterRunning(self):
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
a = constant_op.constant(1.0)
|
a = constant_op.constant(1.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user