Adds tf.Session.make_callable()
.
This method returns a Python callable that has the same semantics as `tf.Session.run()`, but can cache some of the work that must be done to map Tensor-like objects to the arguments of the underlying C API function. The initial implementation is optimized for single-`Tensor` and single-`Operation` fetches, and delegates to `tf.Session.run()` for handling feeds. Since most queue runners use a single-`Operation` `run()` call, switch the `tf.train.QueueRunner` implementation to use `make_callable()` Using this new interface can improve the latency of small steps (measurements from my workstation): * The median time to fetch a 4-byte tensor decreases from 99us to 52us (-47us). * The median time to run a trivial op decreases from 80us to 31us (-49us). Change: 152757301
This commit is contained in:
parent
d34eec7ec3
commit
c0095c9709
@ -996,6 +996,91 @@ class BaseSession(SessionInterface):
|
|||||||
results = []
|
results = []
|
||||||
return fetch_handler.build_results(self, results)
|
return fetch_handler.build_results(self, results)
|
||||||
|
|
||||||
|
def make_callable(self, fetches, feed_list=None):
|
||||||
|
"""Returns a Python callable that runs a particular step.
|
||||||
|
|
||||||
|
The returned callable will take `len(feed_list)` arguments whose types
|
||||||
|
must be compatible feed values for the respective elements of `feed_list`.
|
||||||
|
For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
|
||||||
|
argument to the returned callable must be a numpy ndarray (or something
|
||||||
|
convertible to an ndarray) with matching element type and shape. See
|
||||||
|
@{tf.Session.run} for details of the allowable feed key and value types.
|
||||||
|
|
||||||
|
The returned callable will have the same return type as
|
||||||
|
`tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
|
||||||
|
the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
|
||||||
|
it will return `None`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fetches: A value or list of values to fetch. See @{tf.Session.run}
|
||||||
|
for details of the allowable fetch types.
|
||||||
|
feed_list: (Optional.) A list of `feed_dict` keys. See
|
||||||
|
@{tf.Session.run} for details of the allowable feed key types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that when called will execute the step defined by
|
||||||
|
`feed_list` and `fetches` in this session.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `fetches` or `feed_list` cannot be interpreted
|
||||||
|
as arguments to @{tf.Session.run}.
|
||||||
|
"""
|
||||||
|
if feed_list is not None:
|
||||||
|
if not isinstance(feed_list, (list, tuple)):
|
||||||
|
raise TypeError('`feed_list` must be a list or tuple.')
|
||||||
|
# Delegate any non-empty feed lists to the existing `run()` logic.
|
||||||
|
# TODO(mrry): Refactor the feed handling logic from
|
||||||
|
# `Session._run()` so that we can convert the feeds to a list of
|
||||||
|
# strings here.
|
||||||
|
def _generic_run(*feed_args):
|
||||||
|
feed_dict = {feed: feed_val
|
||||||
|
for feed, feed_val in zip(feed_list, feed_args)}
|
||||||
|
return self.run(fetches, feed_dict=feed_dict)
|
||||||
|
return _generic_run
|
||||||
|
|
||||||
|
# Ensure any changes to the graph are reflected in the runtime.
|
||||||
|
# Note that we don't need to do this on subsequent calls to the
|
||||||
|
# returned object, because the arguments to `fetches` must already be
|
||||||
|
# in the graph.
|
||||||
|
self._extend_graph()
|
||||||
|
|
||||||
|
# Create a fetch handler to take care of the structure of fetches.
|
||||||
|
fetch_handler = _FetchHandler(self._graph, fetches, {})
|
||||||
|
fetch_list_as_strings = fetch_handler.fetches()
|
||||||
|
target_list_as_strings = fetch_handler.targets()
|
||||||
|
|
||||||
|
if isinstance(fetches, ops.Operation):
|
||||||
|
# Special case for fetching a single operation, because the
|
||||||
|
# function will have no return value.
|
||||||
|
assert not fetch_list_as_strings
|
||||||
|
assert len(target_list_as_strings) == 1
|
||||||
|
def _single_operation_run():
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
tf_session.TF_Run(self._session, None, {}, [],
|
||||||
|
target_list_as_strings, status, None)
|
||||||
|
return _single_operation_run
|
||||||
|
elif isinstance(fetches, ops.Tensor):
|
||||||
|
# Special case for fetching a single tensor, because the
|
||||||
|
# function can return the result of `TF_Run()` directly.
|
||||||
|
assert len(fetch_list_as_strings) == 1
|
||||||
|
assert not target_list_as_strings
|
||||||
|
def _single_tensor_run():
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
results = tf_session.TF_Run(self._session, None, {},
|
||||||
|
fetch_list_as_strings, [], status, None)
|
||||||
|
return results[0]
|
||||||
|
return _single_tensor_run
|
||||||
|
else:
|
||||||
|
# In all other cases, we must use `fetch_handler` to build the
|
||||||
|
# results for us.
|
||||||
|
def _fetch_handler_run():
|
||||||
|
with errors.raise_exception_on_not_ok_status() as status:
|
||||||
|
results = tf_session.TF_Run(self._session, None, {},
|
||||||
|
fetch_list_as_strings,
|
||||||
|
target_list_as_strings, status, None)
|
||||||
|
return fetch_handler.build_results(self, results)
|
||||||
|
return _fetch_handler_run
|
||||||
|
|
||||||
# Captures the name of a node in an error status.
|
# Captures the name of a node in an error status.
|
||||||
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
|
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
|
||||||
|
|
||||||
|
@ -92,26 +92,127 @@ class SessionBenchmark(test.Benchmark):
|
|||||||
print("%s %d %f" % (name, size, np.median(times)))
|
print("%s %d %f" % (name, size, np.median(times)))
|
||||||
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||||
|
|
||||||
|
def _benchmarkFetchPrebuilt(self, name, target, size, iters):
|
||||||
|
"""Runs a microbenchmark to measure the cost of fetching a tensor.
|
||||||
|
|
||||||
|
Reports the median cost of fetching a tensor of `size` * `sizeof(float)`
|
||||||
|
bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: A human-readable name for logging the output.
|
||||||
|
target: The session target to use for the benchmark.
|
||||||
|
size: The number of floating-point numbers to be fetched.
|
||||||
|
iters: The number of iterations to perform.
|
||||||
|
"""
|
||||||
|
times = []
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
# Define the tensor to be fetched as a variable, to avoid
|
||||||
|
# constant-folding.
|
||||||
|
v = variables.Variable(random_ops.random_normal([size]))
|
||||||
|
with session.Session(target) as sess:
|
||||||
|
sess.run(v.initializer)
|
||||||
|
runner = sess.make_callable(v)
|
||||||
|
runner() # Warm-up run.
|
||||||
|
for _ in xrange(iters):
|
||||||
|
start_time = time.time()
|
||||||
|
runner()
|
||||||
|
end_time = time.time()
|
||||||
|
times.append(end_time - start_time)
|
||||||
|
print("%s %d %f" % (name, size, np.median(times)))
|
||||||
|
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||||
|
|
||||||
|
def _benchmarkRunOp(self, name, target, iters):
|
||||||
|
"""Runs a microbenchmark to measure the cost of running an op.
|
||||||
|
|
||||||
|
Reports the median cost of running a trivial (Variable) op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: A human-readable name for logging the output.
|
||||||
|
target: The session target to use for the benchmark.
|
||||||
|
iters: The number of iterations to perform.
|
||||||
|
"""
|
||||||
|
times = []
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
# Define the op to be run as a variable, to avoid
|
||||||
|
# constant-folding.
|
||||||
|
v = variables.Variable(random_ops.random_normal([]))
|
||||||
|
with session.Session(target) as sess:
|
||||||
|
sess.run(v.initializer)
|
||||||
|
sess.run(v.op) # Warm-up run.
|
||||||
|
for _ in xrange(iters):
|
||||||
|
start_time = time.time()
|
||||||
|
sess.run(v.op)
|
||||||
|
end_time = time.time()
|
||||||
|
times.append(end_time - start_time)
|
||||||
|
print("%s %f" % (name, np.median(times)))
|
||||||
|
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||||
|
|
||||||
|
def _benchmarkRunOpPrebuilt(self, name, target, iters):
|
||||||
|
"""Runs a microbenchmark to measure the cost of running an op.
|
||||||
|
|
||||||
|
Reports the median cost of running a trivial (Variable) op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: A human-readable name for logging the output.
|
||||||
|
target: The session target to use for the benchmark.
|
||||||
|
iters: The number of iterations to perform.
|
||||||
|
"""
|
||||||
|
times = []
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
# Define the op to be run as a variable, to avoid
|
||||||
|
# constant-folding.
|
||||||
|
v = variables.Variable(random_ops.random_normal([]))
|
||||||
|
with session.Session(target) as sess:
|
||||||
|
sess.run(v.initializer)
|
||||||
|
runner = sess.make_callable(v.op)
|
||||||
|
runner() # Warm-up run.
|
||||||
|
for _ in xrange(iters):
|
||||||
|
start_time = time.time()
|
||||||
|
runner()
|
||||||
|
end_time = time.time()
|
||||||
|
times.append(end_time - start_time)
|
||||||
|
print("%s %f" % (name, np.median(times)))
|
||||||
|
self.report_benchmark(iters=1, wall_time=np.median(times), name=name)
|
||||||
|
|
||||||
def benchmarkGrpcSession(self):
|
def benchmarkGrpcSession(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
self._benchmarkFeed("benchmark_session_feed_grpc_4B", server.target, 1,
|
self._benchmarkFeed("benchmark_session_feed_grpc_4B", server.target, 1,
|
||||||
10000)
|
30000)
|
||||||
session.Session.reset(server.target)
|
session.Session.reset(server.target)
|
||||||
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target, 1
|
self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target,
|
||||||
<< 20, 100)
|
1 << 20, 25000)
|
||||||
session.Session.reset(server.target)
|
session.Session.reset(server.target)
|
||||||
self._benchmarkFetch("benchmark_session_fetch_grpc_4B", server.target, 1,
|
self._benchmarkFetch("benchmark_session_fetch_grpc_4B", server.target, 1,
|
||||||
20000)
|
40000)
|
||||||
session.Session.reset(server.target)
|
session.Session.reset(server.target)
|
||||||
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target, 1
|
self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target,
|
||||||
<< 20, 100)
|
1 << 20, 20000)
|
||||||
|
session.Session.reset(server.target)
|
||||||
|
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4B",
|
||||||
|
server.target, 1, 50000)
|
||||||
|
session.Session.reset(server.target)
|
||||||
|
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4MB",
|
||||||
|
server.target, 1 << 20, 50000)
|
||||||
|
session.Session.reset(server.target)
|
||||||
|
self._benchmarkRunOp("benchmark_session_runop_grpc", server.target, 50000)
|
||||||
|
session.Session.reset(server.target)
|
||||||
|
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_grpc",
|
||||||
|
server.target, 100000)
|
||||||
session.Session.reset(server.target)
|
session.Session.reset(server.target)
|
||||||
|
|
||||||
def benchmarkDirectSession(self):
|
def benchmarkDirectSession(self):
|
||||||
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 5000)
|
self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 80000)
|
||||||
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 200)
|
self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 20000)
|
||||||
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 5000)
|
self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 100000)
|
||||||
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20, 100)
|
self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20,
|
||||||
|
20000)
|
||||||
|
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4B",
|
||||||
|
"", 1, 200000)
|
||||||
|
self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4MB",
|
||||||
|
"", 1 << 20, 200000)
|
||||||
|
self._benchmarkRunOp("benchmark_session_runop_direct", "", 200000)
|
||||||
|
self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_direct", "",
|
||||||
|
200000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -193,6 +193,12 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(42.0, res)
|
self.assertEqual(42.0, res)
|
||||||
res = sess.run(a.op) # An op, not a tensor.
|
res = sess.run(a.op) # An op, not a tensor.
|
||||||
self.assertEqual(None, res)
|
self.assertEqual(None, res)
|
||||||
|
tensor_runner = sess.make_callable(a)
|
||||||
|
res = tensor_runner()
|
||||||
|
self.assertEqual(42.0, res)
|
||||||
|
op_runner = sess.make_callable(a.op)
|
||||||
|
res = op_runner()
|
||||||
|
self.assertEqual(None, res)
|
||||||
|
|
||||||
def testFetchSingletonByName(self):
|
def testFetchSingletonByName(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
@ -211,12 +217,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
assign = v.assign([63.0])
|
assign = v.assign([63.0])
|
||||||
res = sess.run([a, b, c, a.name, assign.op])
|
res = sess.run([a, b, c, a.name, assign.op])
|
||||||
self.assertTrue(isinstance(res, list))
|
self.assertTrue(isinstance(res, list))
|
||||||
self.assertEqual(42.0, res[0])
|
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||||
self.assertEqual(None, res[1])
|
list_runner = sess.make_callable([a, b, c, a.name, assign.op])
|
||||||
self.assertEqual(44.0, res[2])
|
res = list_runner()
|
||||||
self.assertEqual(42.0, res[3])
|
self.assertTrue(isinstance(res, list))
|
||||||
self.assertEqual(None, res[4])
|
self.assertEqual([42.0, None, 44.0, 42.0, None], res)
|
||||||
self.assertEqual(63.0, sess.run(v))
|
|
||||||
|
|
||||||
def testFetchTuple(self):
|
def testFetchTuple(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
@ -225,10 +230,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
c = constant_op.constant(44.0)
|
c = constant_op.constant(44.0)
|
||||||
res = sess.run((a, b, c, a.name))
|
res = sess.run((a, b, c, a.name))
|
||||||
self.assertTrue(isinstance(res, tuple))
|
self.assertTrue(isinstance(res, tuple))
|
||||||
self.assertEqual(42.0, res[0])
|
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||||
self.assertEqual(None, res[1])
|
tuple_runner = sess.make_callable((a, b, c, a.name))
|
||||||
self.assertEqual(44.0, res[2])
|
res = tuple_runner()
|
||||||
self.assertEqual(42.0, res[3])
|
self.assertTrue(isinstance(res, tuple))
|
||||||
|
self.assertEqual((42.0, None, 44.0, 42.0), res)
|
||||||
|
|
||||||
def testFetchNamedTuple(self):
|
def testFetchNamedTuple(self):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
@ -243,6 +249,12 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual(42.0, res.a)
|
self.assertEqual(42.0, res.a)
|
||||||
self.assertEqual(None, res.b)
|
self.assertEqual(None, res.b)
|
||||||
self.assertEqual(44.0, res.c)
|
self.assertEqual(44.0, res.c)
|
||||||
|
namedtuple_runner = sess.make_callable(ABC(a, b, c))
|
||||||
|
res = namedtuple_runner()
|
||||||
|
self.assertTrue(isinstance(res, ABC))
|
||||||
|
self.assertEqual(42.0, res.a)
|
||||||
|
self.assertEqual(None, res.b)
|
||||||
|
self.assertEqual(44.0, res.c)
|
||||||
|
|
||||||
def testFetchDict(self):
|
def testFetchDict(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
@ -1181,6 +1193,11 @@ class SessionTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertAllEqual(np_array, out_v)
|
self.assertAllEqual(np_array, out_v)
|
||||||
self.assertAllEqual(np_array, feed_v)
|
self.assertAllEqual(np_array, feed_v)
|
||||||
|
|
||||||
|
feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
|
||||||
|
out_v, feed_v = feed_fetch_runner(np_array)
|
||||||
|
self.assertAllEqual(np_array, out_v)
|
||||||
|
self.assertAllEqual(np_array, feed_v)
|
||||||
|
|
||||||
def testFeedError(self):
|
def testFeedError(self):
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
feed_t = array_ops.placeholder(dtype=dtypes.float32)
|
feed_t = array_ops.placeholder(dtype=dtypes.float32)
|
||||||
|
@ -227,11 +227,14 @@ class QueueRunner(object):
|
|||||||
"""
|
"""
|
||||||
decremented = False
|
decremented = False
|
||||||
try:
|
try:
|
||||||
|
# Make a cached callable from the `enqueue_op` to decrease the
|
||||||
|
# Python overhead in the queue-runner loop.
|
||||||
|
enqueue_callable = sess.make_callable(enqueue_op)
|
||||||
while True:
|
while True:
|
||||||
if coord and coord.should_stop():
|
if coord and coord.should_stop():
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
sess.run(enqueue_op)
|
enqueue_callable()
|
||||||
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
|
except self._queue_closed_exception_types: # pylint: disable=catching-non-exception
|
||||||
# This exception indicates that a queue was closed.
|
# This exception indicates that a queue was closed.
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
@ -28,6 +28,10 @@ tf_class {
|
|||||||
name: "close"
|
name: "close"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_callable"
|
||||||
|
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "partial_run"
|
name: "partial_run"
|
||||||
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -28,6 +28,10 @@ tf_class {
|
|||||||
name: "close"
|
name: "close"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_callable"
|
||||||
|
argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "partial_run"
|
name: "partial_run"
|
||||||
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
x
Reference in New Issue
Block a user