Adds tf.Session.make_callable().

This method returns a Python callable that has the same semantics as
`tf.Session.run()`, but can cache some of the work that must be done
to map Tensor-like objects to the arguments of the underlying C API
function.

The initial implementation is optimized for single-`Tensor` and
single-`Operation` fetches, and delegates to `tf.Session.run()` for
handling feeds. Since most queue runners use a single-`Operation`
`run()` call, switch the `tf.train.QueueRunner` implementation to use
`make_callable()`

Using this new interface can improve the latency of small steps (measurements from my workstation):

* The median time to fetch a 4-byte tensor decreases from 99us to 52us (-47us).
* The median time to run a trivial op decreases from 80us to 31us (-49us).
Change: 152757301
This commit is contained in:
Derek Murray 2017-04-10 16:27:56 -08:00 committed by TensorFlower Gardener
parent d34eec7ec3
commit c0095c9709
6 changed files with 235 additions and 21 deletions

View File

@ -996,6 +996,91 @@ class BaseSession(SessionInterface):
results = []
return fetch_handler.build_results(self, results)
def make_callable(self, fetches, feed_list=None):
"""Returns a Python callable that runs a particular step.
The returned callable will take `len(feed_list)` arguments whose types
must be compatible feed values for the respective elements of `feed_list`.
For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
argument to the returned callable must be a numpy ndarray (or something
convertible to an ndarray) with matching element type and shape. See
@{tf.Session.run} for details of the allowable feed key and value types.
The returned callable will have the same return type as
`tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
it will return `None`.
Args:
fetches: A value or list of values to fetch. See @{tf.Session.run}
for details of the allowable fetch types.
feed_list: (Optional.) A list of `feed_dict` keys. See
@{tf.Session.run} for details of the allowable feed key types.
Returns:
A function that when called will execute the step defined by
`feed_list` and `fetches` in this session.
Raises:
TypeError: If `fetches` or `feed_list` cannot be interpreted
as arguments to @{tf.Session.run}.
"""
if feed_list is not None:
if not isinstance(feed_list, (list, tuple)):
raise TypeError('`feed_list` must be a list or tuple.')
# Delegate any non-empty feed lists to the existing `run()` logic.
# TODO(mrry): Refactor the feed handling logic from
# `Session._run()` so that we can convert the feeds to a list of
# strings here.
def _generic_run(*feed_args):
feed_dict = {feed: feed_val
for feed, feed_val in zip(feed_list, feed_args)}
return self.run(fetches, feed_dict=feed_dict)
return _generic_run
# Ensure any changes to the graph are reflected in the runtime.
# Note that we don't need to do this on subsequent calls to the
# returned object, because the arguments to `fetches` must already be
# in the graph.
self._extend_graph()
# Create a fetch handler to take care of the structure of fetches.
fetch_handler = _FetchHandler(self._graph, fetches, {})
fetch_list_as_strings = fetch_handler.fetches()
target_list_as_strings = fetch_handler.targets()
if isinstance(fetches, ops.Operation):
# Special case for fetching a single operation, because the
# function will have no return value.
assert not fetch_list_as_strings
assert len(target_list_as_strings) == 1
def _single_operation_run():
with errors.raise_exception_on_not_ok_status() as status:
tf_session.TF_Run(self._session, None, {}, [],
target_list_as_strings, status, None)
return _single_operation_run
elif isinstance(fetches, ops.Tensor):
# Special case for fetching a single tensor, because the
# function can return the result of `TF_Run()` directly.
assert len(fetch_list_as_strings) == 1
assert not target_list_as_strings
def _single_tensor_run():
with errors.raise_exception_on_not_ok_status() as status:
results = tf_session.TF_Run(self._session, None, {},
fetch_list_as_strings, [], status, None)
return results[0]
return _single_tensor_run
else:
# In all other cases, we must use `fetch_handler` to build the
# results for us.
def _fetch_handler_run():
with errors.raise_exception_on_not_ok_status() as status:
results = tf_session.TF_Run(self._session, None, {},
fetch_list_as_strings,
target_list_as_strings, status, None)
return fetch_handler.build_results(self, results)
return _fetch_handler_run
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')

View File

@ -92,26 +92,127 @@ class SessionBenchmark(test.Benchmark):
print("%s %d %f" % (name, size, np.median(times)))
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
def _benchmarkFetchPrebuilt(self, name, target, size, iters):
"""Runs a microbenchmark to measure the cost of fetching a tensor.
Reports the median cost of fetching a tensor of `size` * `sizeof(float)`
bytes.
Args:
name: A human-readable name for logging the output.
target: The session target to use for the benchmark.
size: The number of floating-point numbers to be fetched.
iters: The number of iterations to perform.
"""
times = []
with ops.Graph().as_default():
# Define the tensor to be fetched as a variable, to avoid
# constant-folding.
v = variables.Variable(random_ops.random_normal([size]))
with session.Session(target) as sess:
sess.run(v.initializer)
runner = sess.make_callable(v)
runner() # Warm-up run.
for _ in xrange(iters):
start_time = time.time()
runner()
end_time = time.time()
times.append(end_time - start_time)
print("%s %d %f" % (name, size, np.median(times)))
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
def _benchmarkRunOp(self, name, target, iters):
"""Runs a microbenchmark to measure the cost of running an op.
Reports the median cost of running a trivial (Variable) op.
Args:
name: A human-readable name for logging the output.
target: The session target to use for the benchmark.
iters: The number of iterations to perform.
"""
times = []
with ops.Graph().as_default():
# Define the op to be run as a variable, to avoid
# constant-folding.
v = variables.Variable(random_ops.random_normal([]))
with session.Session(target) as sess:
sess.run(v.initializer)
sess.run(v.op) # Warm-up run.
for _ in xrange(iters):
start_time = time.time()
sess.run(v.op)
end_time = time.time()
times.append(end_time - start_time)
print("%s %f" % (name, np.median(times)))
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
def _benchmarkRunOpPrebuilt(self, name, target, iters):
"""Runs a microbenchmark to measure the cost of running an op.
Reports the median cost of running a trivial (Variable) op.
Args:
name: A human-readable name for logging the output.
target: The session target to use for the benchmark.
iters: The number of iterations to perform.
"""
times = []
with ops.Graph().as_default():
# Define the op to be run as a variable, to avoid
# constant-folding.
v = variables.Variable(random_ops.random_normal([]))
with session.Session(target) as sess:
sess.run(v.initializer)
runner = sess.make_callable(v.op)
runner() # Warm-up run.
for _ in xrange(iters):
start_time = time.time()
runner()
end_time = time.time()
times.append(end_time - start_time)
print("%s %f" % (name, np.median(times)))
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
def benchmarkGrpcSession(self):
server = server_lib.Server.create_local_server()
self._benchmarkFeed("benchmark_session_feed_grpc_4B", server.target, 1,
10000)
30000)
session.Session.reset(server.target)
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target, 1
<< 20, 100)
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target,
1 << 20, 25000)
session.Session.reset(server.target)
self._benchmarkFetch("benchmark_session_fetch_grpc_4B", server.target, 1,
20000)
40000)
session.Session.reset(server.target)
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target, 1
<< 20, 100)
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target,
1 << 20, 20000)
session.Session.reset(server.target)
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4B",
server.target, 1, 50000)
session.Session.reset(server.target)
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4MB",
server.target, 1 << 20, 50000)
session.Session.reset(server.target)
self._benchmarkRunOp("benchmark_session_runop_grpc", server.target, 50000)
session.Session.reset(server.target)
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_grpc",
server.target, 100000)
session.Session.reset(server.target)
def benchmarkDirectSession(self):
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 5000)
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 200)
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 5000)
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20, 100)
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 80000)
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 20000)
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 100000)
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20,
20000)
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4B",
"", 1, 200000)
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4MB",
"", 1 << 20, 200000)
self._benchmarkRunOp("benchmark_session_runop_direct", "", 200000)
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_direct", "",
200000)
if __name__ == "__main__":

View File

@ -193,6 +193,12 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(42.0, res)
res = sess.run(a.op) # An op, not a tensor.
self.assertEqual(None, res)
tensor_runner = sess.make_callable(a)
res = tensor_runner()
self.assertEqual(42.0, res)
op_runner = sess.make_callable(a.op)
res = op_runner()
self.assertEqual(None, res)
def testFetchSingletonByName(self):
with session.Session() as sess:
@ -211,12 +217,11 @@ class SessionTest(test_util.TensorFlowTestCase):
assign = v.assign([63.0])
res = sess.run([a, b, c, a.name, assign.op])
self.assertTrue(isinstance(res, list))
self.assertEqual(42.0, res[0])
self.assertEqual(None, res[1])
self.assertEqual(44.0, res[2])
self.assertEqual(42.0, res[3])
self.assertEqual(None, res[4])
self.assertEqual(63.0, sess.run(v))
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
res = list_runner()
self.assertTrue(isinstance(res, list))
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
def testFetchTuple(self):
with session.Session() as sess:
@ -225,10 +230,11 @@ class SessionTest(test_util.TensorFlowTestCase):
c = constant_op.constant(44.0)
res = sess.run((a, b, c, a.name))
self.assertTrue(isinstance(res, tuple))
self.assertEqual(42.0, res[0])
self.assertEqual(None, res[1])
self.assertEqual(44.0, res[2])
self.assertEqual(42.0, res[3])
self.assertEqual((42.0, None, 44.0, 42.0), res)
tuple_runner = sess.make_callable((a, b, c, a.name))
res = tuple_runner()
self.assertTrue(isinstance(res, tuple))
self.assertEqual((42.0, None, 44.0, 42.0), res)
def testFetchNamedTuple(self):
# pylint: disable=invalid-name
@ -243,6 +249,12 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(42.0, res.a)
self.assertEqual(None, res.b)
self.assertEqual(44.0, res.c)
namedtuple_runner = sess.make_callable(ABC(a, b, c))
res = namedtuple_runner()
self.assertTrue(isinstance(res, ABC))
self.assertEqual(42.0, res.a)
self.assertEqual(None, res.b)
self.assertEqual(44.0, res.c)
def testFetchDict(self):
with session.Session() as sess:
@ -1181,6 +1193,11 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(np_array, out_v)
self.assertAllEqual(np_array, feed_v)
feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
out_v, feed_v = feed_fetch_runner(np_array)
self.assertAllEqual(np_array, out_v)
self.assertAllEqual(np_array, feed_v)
def testFeedError(self):
with session.Session() as sess:
feed_t = array_ops.placeholder(dtype=dtypes.float32)

View File

@ -227,11 +227,14 @@ class QueueRunner(object):
"""
decremented = False
try:
# Make a cached callable from the `enqueue_op` to decrease the
# Python overhead in the queue-runner loop.
enqueue_callable = sess.make_callable(enqueue_op)
while True:
if coord and coord.should_stop():
break
try:
sess.run(enqueue_op)
enqueue_callable()
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
# This exception indicates that a queue was closed.
with self._lock:

View File

@ -28,6 +28,10 @@ tf_class {
name: "close"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_callable"
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "partial_run"
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -28,6 +28,10 @@ tf_class {
name: "close"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_callable"
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "partial_run"
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "