Adds tf.Session.make_callable()
.
This method returns a Python callable that has the same semantics as `tf.Session.run()`, but can cache some of the work that must be done to map Tensor-like objects to the arguments of the underlying C API function. The initial implementation is optimized for single-`Tensor` and single-`Operation` fetches, and delegates to `tf.Session.run()` for handling feeds. Since most queue runners use a single-`Operation` `run()` call, switch the `tf.train.QueueRunner` implementation to use `make_callable()` Using this new interface can improve the latency of small steps (measurements from my workstation): * The median time to fetch a 4-byte tensor decreases from 99us to 52us (-47us). * The median time to run a trivial op decreases from 80us to 31us (-49us). Change: 152757301
This commit is contained in:
parent
d34eec7ec3
commit
c0095c9709
tensorflow
python
tools/api/golden
@ -996,6 +996,91 @@ class BaseSession(SessionInterface):
|
||||
results = []
|
||||
return fetch_handler.build_results(self, results)
|
||||
|
||||
def make_callable(self, fetches, feed_list=None):
|
||||
"""Returns a Python callable that runs a particular step.
|
||||
|
||||
The returned callable will take `len(feed_list)` arguments whose types
|
||||
must be compatible feed values for the respective elements of `feed_list`.
|
||||
For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
|
||||
argument to the returned callable must be a numpy ndarray (or something
|
||||
convertible to an ndarray) with matching element type and shape. See
|
||||
@{tf.Session.run} for details of the allowable feed key and value types.
|
||||
|
||||
The returned callable will have the same return type as
|
||||
`tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
|
||||
the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
|
||||
it will return `None`.
|
||||
|
||||
Args:
|
||||
fetches: A value or list of values to fetch. See @{tf.Session.run}
|
||||
for details of the allowable fetch types.
|
||||
feed_list: (Optional.) A list of `feed_dict` keys. See
|
||||
@{tf.Session.run} for details of the allowable feed key types.
|
||||
|
||||
Returns:
|
||||
A function that when called will execute the step defined by
|
||||
`feed_list` and `fetches` in this session.
|
||||
|
||||
Raises:
|
||||
TypeError: If `fetches` or `feed_list` cannot be interpreted
|
||||
as arguments to @{tf.Session.run}.
|
||||
"""
|
||||
if feed_list is not None:
|
||||
if not isinstance(feed_list, (list, tuple)):
|
||||
raise TypeError('`feed_list` must be a list or tuple.')
|
||||
# Delegate any non-empty feed lists to the existing `run()` logic.
|
||||
# TODO(mrry): Refactor the feed handling logic from
|
||||
# `Session._run()` so that we can convert the feeds to a list of
|
||||
# strings here.
|
||||
def _generic_run(*feed_args):
|
||||
feed_dict = {feed: feed_val
|
||||
for feed, feed_val in zip(feed_list, feed_args)}
|
||||
return self.run(fetches, feed_dict=feed_dict)
|
||||
return _generic_run
|
||||
|
||||
# Ensure any changes to the graph are reflected in the runtime.
|
||||
# Note that we don't need to do this on subsequent calls to the
|
||||
# returned object, because the arguments to `fetches` must already be
|
||||
# in the graph.
|
||||
self._extend_graph()
|
||||
|
||||
# Create a fetch handler to take care of the structure of fetches.
|
||||
fetch_handler = _FetchHandler(self._graph, fetches, {})
|
||||
fetch_list_as_strings = fetch_handler.fetches()
|
||||
target_list_as_strings = fetch_handler.targets()
|
||||
|
||||
if isinstance(fetches, ops.Operation):
|
||||
# Special case for fetching a single operation, because the
|
||||
# function will have no return value.
|
||||
assert not fetch_list_as_strings
|
||||
assert len(target_list_as_strings) == 1
|
||||
def _single_operation_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
tf_session.TF_Run(self._session, None, {}, [],
|
||||
target_list_as_strings, status, None)
|
||||
return _single_operation_run
|
||||
elif isinstance(fetches, ops.Tensor):
|
||||
# Special case for fetching a single tensor, because the
|
||||
# function can return the result of `TF_Run()` directly.
|
||||
assert len(fetch_list_as_strings) == 1
|
||||
assert not target_list_as_strings
|
||||
def _single_tensor_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
results = tf_session.TF_Run(self._session, None, {},
|
||||
fetch_list_as_strings, [], status, None)
|
||||
return results[0]
|
||||
return _single_tensor_run
|
||||
else:
|
||||
# In all other cases, we must use `fetch_handler` to build the
|
||||
# results for us.
|
||||
def _fetch_handler_run():
|
||||
with errors.raise_exception_on_not_ok_status() as status:
|
||||
results = tf_session.TF_Run(self._session, None, {},
|
||||
fetch_list_as_strings,
|
||||
target_list_as_strings, status, None)
|
||||
return fetch_handler.build_results(self, results)
|
||||
return _fetch_handler_run
|
||||
|
||||
# Captures the name of a node in an error status.
|
||||
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
|
||||
|
||||
|
@ -92,26 +92,127 @@ class SessionBenchmark(test.Benchmark):
|
||||
print("%s %d %f" % (name, size, np.median(times)))
|
||||
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||
|
||||
def _benchmarkFetchPrebuilt(self, name, target, size, iters):
|
||||
"""Runs a microbenchmark to measure the cost of fetching a tensor.
|
||||
|
||||
Reports the median cost of fetching a tensor of `size` * `sizeof(float)`
|
||||
bytes.
|
||||
|
||||
Args:
|
||||
name: A human-readable name for logging the output.
|
||||
target: The session target to use for the benchmark.
|
||||
size: The number of floating-point numbers to be fetched.
|
||||
iters: The number of iterations to perform.
|
||||
"""
|
||||
times = []
|
||||
with ops.Graph().as_default():
|
||||
# Define the tensor to be fetched as a variable, to avoid
|
||||
# constant-folding.
|
||||
v = variables.Variable(random_ops.random_normal([size]))
|
||||
with session.Session(target) as sess:
|
||||
sess.run(v.initializer)
|
||||
runner = sess.make_callable(v)
|
||||
runner() # Warm-up run.
|
||||
for _ in xrange(iters):
|
||||
start_time = time.time()
|
||||
runner()
|
||||
end_time = time.time()
|
||||
times.append(end_time - start_time)
|
||||
print("%s %d %f" % (name, size, np.median(times)))
|
||||
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||
|
||||
def _benchmarkRunOp(self, name, target, iters):
|
||||
"""Runs a microbenchmark to measure the cost of running an op.
|
||||
|
||||
Reports the median cost of running a trivial (Variable) op.
|
||||
|
||||
Args:
|
||||
name: A human-readable name for logging the output.
|
||||
target: The session target to use for the benchmark.
|
||||
iters: The number of iterations to perform.
|
||||
"""
|
||||
times = []
|
||||
with ops.Graph().as_default():
|
||||
# Define the op to be run as a variable, to avoid
|
||||
# constant-folding.
|
||||
v = variables.Variable(random_ops.random_normal([]))
|
||||
with session.Session(target) as sess:
|
||||
sess.run(v.initializer)
|
||||
sess.run(v.op) # Warm-up run.
|
||||
for _ in xrange(iters):
|
||||
start_time = time.time()
|
||||
sess.run(v.op)
|
||||
end_time = time.time()
|
||||
times.append(end_time - start_time)
|
||||
print("%s %f" % (name, np.median(times)))
|
||||
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||
|
||||
def _benchmarkRunOpPrebuilt(self, name, target, iters):
|
||||
"""Runs a microbenchmark to measure the cost of running an op.
|
||||
|
||||
Reports the median cost of running a trivial (Variable) op.
|
||||
|
||||
Args:
|
||||
name: A human-readable name for logging the output.
|
||||
target: The session target to use for the benchmark.
|
||||
iters: The number of iterations to perform.
|
||||
"""
|
||||
times = []
|
||||
with ops.Graph().as_default():
|
||||
# Define the op to be run as a variable, to avoid
|
||||
# constant-folding.
|
||||
v = variables.Variable(random_ops.random_normal([]))
|
||||
with session.Session(target) as sess:
|
||||
sess.run(v.initializer)
|
||||
runner = sess.make_callable(v.op)
|
||||
runner() # Warm-up run.
|
||||
for _ in xrange(iters):
|
||||
start_time = time.time()
|
||||
runner()
|
||||
end_time = time.time()
|
||||
times.append(end_time - start_time)
|
||||
print("%s %f" % (name, np.median(times)))
|
||||
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||
|
||||
def benchmarkGrpcSession(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
self._benchmarkFeed("benchmark_session_feed_grpc_4B", server.target, 1,
|
||||
10000)
|
||||
30000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target, 1
|
||||
<< 20, 100)
|
||||
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target,
|
||||
1 << 20, 25000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkFetch("benchmark_session_fetch_grpc_4B", server.target, 1,
|
||||
20000)
|
||||
40000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target, 1
|
||||
<< 20, 100)
|
||||
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target,
|
||||
1 << 20, 20000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4B",
|
||||
server.target, 1, 50000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4MB",
|
||||
server.target, 1 << 20, 50000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkRunOp("benchmark_session_runop_grpc", server.target, 50000)
|
||||
session.Session.reset(server.target)
|
||||
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_grpc",
|
||||
server.target, 100000)
|
||||
session.Session.reset(server.target)
|
||||
|
||||
def benchmarkDirectSession(self):
|
||||
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 5000)
|
||||
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 200)
|
||||
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 5000)
|
||||
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20, 100)
|
||||
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 80000)
|
||||
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 20000)
|
||||
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 100000)
|
||||
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20,
|
||||
20000)
|
||||
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4B",
|
||||
"", 1, 200000)
|
||||
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4MB",
|
||||
"", 1 << 20, 200000)
|
||||
self._benchmarkRunOp("benchmark_session_runop_direct", "", 200000)
|
||||
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_direct", "",
|
||||
200000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -193,6 +193,12 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(42.0, res)
|
||||
res = sess.run(a.op) # An op, not a tensor.
|
||||
self.assertEqual(None, res)
|
||||
tensor_runner = sess.make_callable(a)
|
||||
res = tensor_runner()
|
||||
self.assertEqual(42.0, res)
|
||||
op_runner = sess.make_callable(a.op)
|
||||
res = op_runner()
|
||||
self.assertEqual(None, res)
|
||||
|
||||
def testFetchSingletonByName(self):
|
||||
with session.Session() as sess:
|
||||
@ -211,12 +217,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
assign = v.assign([63.0])
|
||||
res = sess.run([a, b, c, a.name, assign.op])
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual(42.0, res[0])
|
||||
self.assertEqual(None, res[1])
|
||||
self.assertEqual(44.0, res[2])
|
||||
self.assertEqual(42.0, res[3])
|
||||
self.assertEqual(None, res[4])
|
||||
self.assertEqual(63.0, sess.run(v))
|
||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
|
||||
res = list_runner()
|
||||
self.assertTrue(isinstance(res, list))
|
||||
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||
|
||||
def testFetchTuple(self):
|
||||
with session.Session() as sess:
|
||||
@ -225,10 +230,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
c = constant_op.constant(44.0)
|
||||
res = sess.run((a, b, c, a.name))
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertEqual(42.0, res[0])
|
||||
self.assertEqual(None, res[1])
|
||||
self.assertEqual(44.0, res[2])
|
||||
self.assertEqual(42.0, res[3])
|
||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||
tuple_runner = sess.make_callable((a, b, c, a.name))
|
||||
res = tuple_runner()
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||
|
||||
def testFetchNamedTuple(self):
|
||||
# pylint: disable=invalid-name
|
||||
@ -243,6 +249,12 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(42.0, res.a)
|
||||
self.assertEqual(None, res.b)
|
||||
self.assertEqual(44.0, res.c)
|
||||
namedtuple_runner = sess.make_callable(ABC(a, b, c))
|
||||
res = namedtuple_runner()
|
||||
self.assertTrue(isinstance(res, ABC))
|
||||
self.assertEqual(42.0, res.a)
|
||||
self.assertEqual(None, res.b)
|
||||
self.assertEqual(44.0, res.c)
|
||||
|
||||
def testFetchDict(self):
|
||||
with session.Session() as sess:
|
||||
@ -1181,6 +1193,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllEqual(np_array, out_v)
|
||||
self.assertAllEqual(np_array, feed_v)
|
||||
|
||||
feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
|
||||
out_v, feed_v = feed_fetch_runner(np_array)
|
||||
self.assertAllEqual(np_array, out_v)
|
||||
self.assertAllEqual(np_array, feed_v)
|
||||
|
||||
def testFeedError(self):
|
||||
with session.Session() as sess:
|
||||
feed_t = array_ops.placeholder(dtype=dtypes.float32)
|
||||
|
@ -227,11 +227,14 @@ class QueueRunner(object):
|
||||
"""
|
||||
decremented = False
|
||||
try:
|
||||
# Make a cached callable from the `enqueue_op` to decrease the
|
||||
# Python overhead in the queue-runner loop.
|
||||
enqueue_callable = sess.make_callable(enqueue_op)
|
||||
while True:
|
||||
if coord and coord.should_stop():
|
||||
break
|
||||
try:
|
||||
sess.run(enqueue_op)
|
||||
enqueue_callable()
|
||||
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
|
||||
# This exception indicates that a queue was closed.
|
||||
with self._lock:
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "close"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_callable"
|
||||
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "partial_run"
|
||||
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -28,6 +28,10 @@ tf_class {
|
||||
name: "close"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_callable"
|
||||
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "partial_run"
|
||||
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user