Make BaseSession.make_callable work with the C API enabled.
This change also switches session_test.py to use the with_c_api decorator so we get coverage with and without the C API enabled. This way we know we're not breaking backwards compatibility with the C API enabled. PiperOrigin-RevId: 178291189
This commit is contained in:
parent
b02eae0997
commit
0509f07cc2
@ -1160,9 +1160,6 @@ class BaseSession(SessionInterface):
|
||||
TypeError: If `fetches` or `feed_list` cannot be interpreted
|
||||
as arguments to @{tf.Session.run}.
|
||||
"""
|
||||
assert not self._created_with_new_api, ('session.make_callable() doesn\'t '
|
||||
'work with C API')
|
||||
|
||||
if feed_list is not None:
|
||||
if not isinstance(feed_list, (list, tuple)):
|
||||
raise TypeError('`feed_list` must be a list or tuple.')
|
||||
@ -1184,12 +1181,18 @@ class BaseSession(SessionInterface):
|
||||
|
||||
# Create a fetch handler to take care of the structure of fetches.
|
||||
fetch_handler = _FetchHandler(self._graph, fetches, {})
|
||||
fetch_list_as_strings = _name_list(fetch_handler.fetches())
|
||||
target_list_as_strings = _name_list(fetch_handler.targets())
|
||||
if self._created_with_new_api:
|
||||
# pylint: disable=protected-access
|
||||
fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
|
||||
target_list = [op._c_op for op in fetch_handler.targets()]
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
fetch_list = _name_list(fetch_handler.fetches())
|
||||
target_list = _name_list(fetch_handler.targets())
|
||||
|
||||
def _callable_template_with_options_and_metadata(
|
||||
fetch_list_as_strings,
|
||||
target_list_as_strings,
|
||||
fetch_list,
|
||||
target_list,
|
||||
fetch_handler,
|
||||
options=None,
|
||||
run_metadata=None):
|
||||
@ -1199,9 +1202,14 @@ class BaseSession(SessionInterface):
|
||||
run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
|
||||
try:
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
if self._created_with_new_api:
|
||||
results = tf_session.TF_SessionRun_wrapper(
|
||||
self._session, options_ptr, {}, fetch_list, target_list,
|
||||
run_metadata_ptr, status)
|
||||
else:
|
||||
results = tf_session.TF_Run(
|
||||
self._session, options_ptr, {}, fetch_list_as_strings,
|
||||
target_list_as_strings, status, run_metadata_ptr)
|
||||
self._session, options_ptr, {}, fetch_list, target_list, status,
|
||||
run_metadata_ptr)
|
||||
if fetch_handler:
|
||||
results = fetch_handler.build_results(self, results)
|
||||
else:
|
||||
@ -1218,27 +1226,35 @@ class BaseSession(SessionInterface):
|
||||
|
||||
if accept_options:
|
||||
return functools.partial(
|
||||
_callable_template_with_options_and_metadata, fetch_list_as_strings,
|
||||
target_list_as_strings, fetch_handler)
|
||||
_callable_template_with_options_and_metadata, fetch_list,
|
||||
target_list, fetch_handler)
|
||||
elif isinstance(fetches, ops.Operation):
|
||||
# Special case for fetching a single operation, because the
|
||||
# function will have no return value.
|
||||
assert not fetch_list_as_strings
|
||||
assert len(target_list_as_strings) == 1
|
||||
assert not fetch_list
|
||||
assert len(target_list) == 1
|
||||
def _single_operation_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_Run(self._session, None, {}, [],
|
||||
target_list_as_strings, status, None)
|
||||
if self._created_with_new_api:
|
||||
tf_session.TF_SessionRun_wrapper(
|
||||
self._session, None, {}, [], target_list, None, status)
|
||||
else:
|
||||
tf_session.TF_Run(
|
||||
self._session, None, {}, [], target_list, status, None)
|
||||
return _single_operation_run
|
||||
elif isinstance(fetches, ops.Tensor):
|
||||
# Special case for fetching a single tensor, because the
|
||||
# function can return the result of `TF_Run()` directly.
|
||||
assert len(fetch_list_as_strings) == 1
|
||||
assert not target_list_as_strings
|
||||
assert len(fetch_list) == 1
|
||||
assert not target_list
|
||||
def _single_tensor_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
results = tf_session.TF_Run(self._session, None, {},
|
||||
fetch_list_as_strings, [], status, None)
|
||||
if self._created_with_new_api:
|
||||
results = tf_session.TF_SessionRun_wrapper(
|
||||
self._session, None, {}, fetch_list, [], None, status)
|
||||
else:
|
||||
results = tf_session.TF_Run(
|
||||
self._session, None, {}, fetch_list, [], status, None)
|
||||
return results[0]
|
||||
return _single_tensor_run
|
||||
else:
|
||||
@ -1246,9 +1262,12 @@ class BaseSession(SessionInterface):
|
||||
# results for us.
|
||||
def _fetch_handler_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
results = tf_session.TF_Run(self._session, None, {},
|
||||
fetch_list_as_strings,
|
||||
target_list_as_strings, status, None)
|
||||
if self._created_with_new_api:
|
||||
results = tf_session.TF_SessionRun_wrapper(
|
||||
self._session, None, {}, fetch_list, target_list, None, status)
|
||||
else:
|
||||
results = tf_session.TF_Run(
|
||||
self._session, None, {}, fetch_list, target_list, status, None)
|
||||
return fetch_handler.build_results(self, results)
|
||||
return _fetch_handler_run
|
||||
|
||||
|
@ -57,13 +57,13 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
ops._USE_C_API = True
|
||||
|
||||
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
|
||||
# don't have C++ op registrations on which to attach C++ shape fns.
|
||||
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testUseExistingGraph(self):
|
||||
@ -165,8 +165,9 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
# Run with a bogus handle.
|
||||
s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
|
||||
|
||||
@test_util.disable_c_api # No shape registration for 'ConstructionFails'
|
||||
def testOpConstructionErrorPayload(self):
|
||||
if ops._USE_C_API: return # No shape registration for 'ConstructionFails'
|
||||
|
||||
with session.Session():
|
||||
failing_op = ops.get_default_graph().create_op(
|
||||
'ConstructionFails', [], [], name='f')
|
||||
@ -208,7 +209,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
s.run({'a': a, 'b': None})
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testFetchSingleton(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(42.0)
|
||||
@ -231,7 +231,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
res = sess.run(a.op) # An op, not a tensor.
|
||||
self.assertEqual(None, res)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testFetchList(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(42.0)
|
||||
@ -247,7 +246,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testFetchTuple(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(42.0)
|
||||
@ -261,7 +259,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testFetchNamedTuple(self):
|
||||
# pylint: disable=invalid-name
|
||||
ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
|
||||
@ -1178,7 +1175,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
|
||||
self.assertAllEqual(a2_val, [[1.0, 1.0]])
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testFeedAndFetch(self):
|
||||
with session.Session() as sess:
|
||||
for dtype in [dtypes.float16,
|
||||
@ -1225,7 +1221,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(np_array, out_v)
|
||||
self.assertAllEqual(np_array, feed_v)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testMakeCallableOnTensorWithRunOptions(self):
|
||||
with session.Session() as sess:
|
||||
a = constant_op.constant(42.0)
|
||||
@ -1238,7 +1233,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(42.0, res)
|
||||
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testMakeCallableOnOperationWithRunOptions(self):
|
||||
with session.Session() as sess:
|
||||
a = variables.Variable(42.0)
|
||||
@ -1253,7 +1247,6 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(43.0, sess.run(a))
|
||||
self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
|
||||
|
||||
@test_util.disable_c_api # session.make_callable() doesn't work with C API
|
||||
def testMakeCallableWithFeedListAndRunOptions(self):
|
||||
with session.Session() as sess:
|
||||
ph = array_ops.placeholder(dtypes.float32)
|
||||
@ -1460,9 +1453,10 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertTrue(run_metadata.HasField('step_stats'))
|
||||
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
|
||||
|
||||
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
|
||||
@test_util.disable_c_api
|
||||
def testFeedShapeCompatibility(self):
|
||||
# TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
|
||||
if ops._USE_C_API: return
|
||||
|
||||
with session.Session() as sess:
|
||||
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
|
||||
new_shape = constant_op.constant([2, 2])
|
||||
@ -1746,6 +1740,15 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class GraphMutationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._original_use_c_api_value = ops._USE_C_API
|
||||
ops._USE_C_API = True
|
||||
super(GraphMutationTest, self).setUp()
|
||||
|
||||
def tearDown(self):
|
||||
ops._USE_C_API = self._original_use_c_api_value
|
||||
super(GraphMutationTest, self).tearDown()
|
||||
|
||||
def testUpdateInputAfterRunning(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
a = constant_op.constant(1.0)
|
||||
|
Loading…
Reference in New Issue
Block a user