Adds tf.Session.make_callable().
				
					
				
			This method returns a Python callable that has the same semantics as `tf.Session.run()`, but can cache some of the work that must be done to map Tensor-like objects to the arguments of the underlying C API function. The initial implementation is optimized for single-`Tensor` and single-`Operation` fetches, and delegates to `tf.Session.run()` for handling feeds. Since most queue runners use a single-`Operation` `run()` call, switch the `tf.train.QueueRunner` implementation to use `make_callable()` Using this new interface can improve the latency of small steps (measurements from my workstation): * The median time to fetch a 4-byte tensor decreases from 99us to 52us (-47us). * The median time to run a trivial op decreases from 80us to 31us (-49us). Change: 152757301
This commit is contained in:
		
							parent
							
								
									d34eec7ec3
								
							
						
					
					
						commit
						c0095c9709
					
				| @ -996,6 +996,91 @@ class BaseSession(SessionInterface): | ||||
|       results = [] | ||||
|     return fetch_handler.build_results(self, results) | ||||
| 
 | ||||
|   def make_callable(self, fetches, feed_list=None): | ||||
|     """Returns a Python callable that runs a particular step. | ||||
| 
 | ||||
|     The returned callable will take `len(feed_list)` arguments whose types | ||||
|     must be compatible feed values for the respective elements of `feed_list`. | ||||
|     For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th | ||||
|     argument to the returned callable must be a numpy ndarray (or something | ||||
|     convertible to an ndarray) with matching element type and shape. See | ||||
|     @{tf.Session.run} for details of the allowable feed key and value types. | ||||
| 
 | ||||
|     The returned callable will have the same return type as | ||||
|     `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`, | ||||
|     the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`, | ||||
|     it will return `None`. | ||||
| 
 | ||||
|     Args: | ||||
|       fetches: A value or list of values to fetch. See @{tf.Session.run} | ||||
|         for details of the allowable fetch types. | ||||
|       feed_list: (Optional.) A list of `feed_dict` keys. See | ||||
|         @{tf.Session.run} for details of the allowable feed key types. | ||||
| 
 | ||||
|     Returns: | ||||
|       A function that when called will execute the step defined by | ||||
|       `feed_list` and `fetches` in this session. | ||||
| 
 | ||||
|     Raises: | ||||
|       TypeError: If `fetches` or `feed_list` cannot be interpreted | ||||
|         as arguments to @{tf.Session.run}. | ||||
|     """ | ||||
|     if feed_list is not None: | ||||
|       if not isinstance(feed_list, (list, tuple)): | ||||
|         raise TypeError('`feed_list` must be a list or tuple.') | ||||
|       # Delegate any non-empty feed lists to the existing `run()` logic. | ||||
|       # TODO(mrry): Refactor the feed handling logic from | ||||
|       # `Session._run()` so that we can convert the feeds to a list of | ||||
|       # strings here. | ||||
|       def _generic_run(*feed_args): | ||||
|         feed_dict = {feed: feed_val | ||||
|                      for feed, feed_val in zip(feed_list, feed_args)} | ||||
|         return self.run(fetches, feed_dict=feed_dict) | ||||
|       return _generic_run | ||||
| 
 | ||||
|     # Ensure any changes to the graph are reflected in the runtime. | ||||
|     # Note that we don't need to do this on subsequent calls to the | ||||
|     # returned object, because the arguments to `fetches` must already be | ||||
|     # in the graph. | ||||
|     self._extend_graph() | ||||
| 
 | ||||
|     # Create a fetch handler to take care of the structure of fetches. | ||||
|     fetch_handler = _FetchHandler(self._graph, fetches, {}) | ||||
|     fetch_list_as_strings = fetch_handler.fetches() | ||||
|     target_list_as_strings = fetch_handler.targets() | ||||
| 
 | ||||
|     if isinstance(fetches, ops.Operation): | ||||
|       # Special case for fetching a single operation, because the | ||||
|       # function will have no return value. | ||||
|       assert not fetch_list_as_strings | ||||
|       assert len(target_list_as_strings) == 1 | ||||
|       def _single_operation_run(): | ||||
|         with errors.raise_exception_on_not_ok_status() as status: | ||||
|           tf_session.TF_Run(self._session, None, {}, [], | ||||
|                             target_list_as_strings, status, None) | ||||
|       return _single_operation_run | ||||
|     elif isinstance(fetches, ops.Tensor): | ||||
|       # Special case for fetching a single tensor, because the | ||||
|       # function can return the result of `TF_Run()` directly. | ||||
|       assert len(fetch_list_as_strings) == 1 | ||||
|       assert not target_list_as_strings | ||||
|       def _single_tensor_run(): | ||||
|         with errors.raise_exception_on_not_ok_status() as status: | ||||
|           results = tf_session.TF_Run(self._session, None, {}, | ||||
|                                       fetch_list_as_strings, [], status, None) | ||||
|         return results[0] | ||||
|       return _single_tensor_run | ||||
|     else: | ||||
|       # In all other cases, we must use `fetch_handler` to build the | ||||
|       # results for us. | ||||
|       def _fetch_handler_run(): | ||||
|         with errors.raise_exception_on_not_ok_status() as status: | ||||
|           results = tf_session.TF_Run(self._session, None, {}, | ||||
|                                       fetch_list_as_strings, | ||||
|                                       target_list_as_strings, status, None) | ||||
|         return fetch_handler.build_results(self, results) | ||||
|       return _fetch_handler_run | ||||
| 
 | ||||
|   # Captures the name of a node in an error status. | ||||
|   _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') | ||||
| 
 | ||||
|  | ||||
| @ -92,26 +92,127 @@ class SessionBenchmark(test.Benchmark): | ||||
|     print("%s %d %f" % (name, size, np.median(times))) | ||||
|     self.report_benchmark(iters=1, wall_time=np.median(times), name=name) | ||||
| 
 | ||||
|   def _benchmarkFetchPrebuilt(self, name, target, size, iters): | ||||
|     """Runs a microbenchmark to measure the cost of fetching a tensor. | ||||
| 
 | ||||
|     Reports the median cost of fetching a tensor of `size` * `sizeof(float)` | ||||
|     bytes. | ||||
| 
 | ||||
|     Args: | ||||
|       name: A human-readable name for logging the output. | ||||
|       target: The session target to use for the benchmark. | ||||
|       size: The number of floating-point numbers to be fetched. | ||||
|       iters: The number of iterations to perform. | ||||
|     """ | ||||
|     times = [] | ||||
|     with ops.Graph().as_default(): | ||||
|       # Define the tensor to be fetched as a variable, to avoid | ||||
|       # constant-folding. | ||||
|       v = variables.Variable(random_ops.random_normal([size])) | ||||
|       with session.Session(target) as sess: | ||||
|         sess.run(v.initializer) | ||||
|         runner = sess.make_callable(v) | ||||
|         runner()  # Warm-up run. | ||||
|         for _ in xrange(iters): | ||||
|           start_time = time.time() | ||||
|           runner() | ||||
|           end_time = time.time() | ||||
|           times.append(end_time - start_time) | ||||
|     print("%s %d %f" % (name, size, np.median(times))) | ||||
|     self.report_benchmark(iters=1, wall_time=np.median(times), name=name) | ||||
| 
 | ||||
|   def _benchmarkRunOp(self, name, target, iters): | ||||
|     """Runs a microbenchmark to measure the cost of running an op. | ||||
| 
 | ||||
|     Reports the median cost of running a trivial (Variable) op. | ||||
| 
 | ||||
|     Args: | ||||
|       name: A human-readable name for logging the output. | ||||
|       target: The session target to use for the benchmark. | ||||
|       iters: The number of iterations to perform. | ||||
|     """ | ||||
|     times = [] | ||||
|     with ops.Graph().as_default(): | ||||
|       # Define the op to be run as a variable, to avoid | ||||
|       # constant-folding. | ||||
|       v = variables.Variable(random_ops.random_normal([])) | ||||
|       with session.Session(target) as sess: | ||||
|         sess.run(v.initializer) | ||||
|         sess.run(v.op)  # Warm-up run. | ||||
|         for _ in xrange(iters): | ||||
|           start_time = time.time() | ||||
|           sess.run(v.op) | ||||
|           end_time = time.time() | ||||
|           times.append(end_time - start_time) | ||||
|     print("%s %f" % (name, np.median(times))) | ||||
|     self.report_benchmark(iters=1, wall_time=np.median(times), name=name) | ||||
| 
 | ||||
|   def _benchmarkRunOpPrebuilt(self, name, target, iters): | ||||
|     """Runs a microbenchmark to measure the cost of running an op. | ||||
| 
 | ||||
|     Reports the median cost of running a trivial (Variable) op. | ||||
| 
 | ||||
|     Args: | ||||
|       name: A human-readable name for logging the output. | ||||
|       target: The session target to use for the benchmark. | ||||
|       iters: The number of iterations to perform. | ||||
|     """ | ||||
|     times = [] | ||||
|     with ops.Graph().as_default(): | ||||
|       # Define the op to be run as a variable, to avoid | ||||
|       # constant-folding. | ||||
|       v = variables.Variable(random_ops.random_normal([])) | ||||
|       with session.Session(target) as sess: | ||||
|         sess.run(v.initializer) | ||||
|         runner = sess.make_callable(v.op) | ||||
|         runner()  # Warm-up run. | ||||
|         for _ in xrange(iters): | ||||
|           start_time = time.time() | ||||
|           runner() | ||||
|           end_time = time.time() | ||||
|           times.append(end_time - start_time) | ||||
|     print("%s %f" % (name, np.median(times))) | ||||
|     self.report_benchmark(iters=1, wall_time=np.median(times), name=name) | ||||
| 
 | ||||
|   def benchmarkGrpcSession(self): | ||||
|     server = server_lib.Server.create_local_server() | ||||
|     self._benchmarkFeed("benchmark_session_feed_grpc_4B", server.target, 1, | ||||
|                         10000) | ||||
|                         30000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target, 1 | ||||
|                         << 20, 100) | ||||
|     self._benchmarkFeed("benchmark_session_feed_grpc_4MB", server.target, | ||||
|                         1 << 20, 25000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_grpc_4B", server.target, 1, | ||||
|                          20000) | ||||
|                          40000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target, 1 | ||||
|                          << 20, 100) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_grpc_4MB", server.target, | ||||
|                          1 << 20, 20000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4B", | ||||
|                                  server.target, 1, 50000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_grpc_4MB", | ||||
|                                  server.target, 1 << 20, 50000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkRunOp("benchmark_session_runop_grpc", server.target, 50000) | ||||
|     session.Session.reset(server.target) | ||||
|     self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_grpc", | ||||
|                                  server.target, 100000) | ||||
|     session.Session.reset(server.target) | ||||
| 
 | ||||
|   def benchmarkDirectSession(self): | ||||
|     self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 5000) | ||||
|     self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 200) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 5000) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20, 100) | ||||
|     self._benchmarkFeed("benchmark_session_feed_direct_4B", "", 1, 80000) | ||||
|     self._benchmarkFeed("benchmark_session_feed_direct_4MB", "", 1 << 20, 20000) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_direct_4B", "", 1, 100000) | ||||
|     self._benchmarkFetch("benchmark_session_fetch_direct_4MB", "", 1 << 20, | ||||
|                          20000) | ||||
|     self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4B", | ||||
|                                  "", 1, 200000) | ||||
|     self._benchmarkFetchPrebuilt("benchmark_session_fetchprebuilt_direct_4MB", | ||||
|                                  "", 1 << 20, 200000) | ||||
|     self._benchmarkRunOp("benchmark_session_runop_direct", "", 200000) | ||||
|     self._benchmarkRunOpPrebuilt("benchmark_session_runopprebuilt_direct", "", | ||||
|                                  200000) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|  | ||||
| @ -193,6 +193,12 @@ class SessionTest(test_util.TensorFlowTestCase): | ||||
|       self.assertEqual(42.0, res) | ||||
|       res = sess.run(a.op)  # An op, not a tensor. | ||||
|       self.assertEqual(None, res) | ||||
|       tensor_runner = sess.make_callable(a) | ||||
|       res = tensor_runner() | ||||
|       self.assertEqual(42.0, res) | ||||
|       op_runner = sess.make_callable(a.op) | ||||
|       res = op_runner() | ||||
|       self.assertEqual(None, res) | ||||
| 
 | ||||
|   def testFetchSingletonByName(self): | ||||
|     with session.Session() as sess: | ||||
| @ -211,12 +217,11 @@ class SessionTest(test_util.TensorFlowTestCase): | ||||
|       assign = v.assign([63.0]) | ||||
|       res = sess.run([a, b, c, a.name, assign.op]) | ||||
|       self.assertTrue(isinstance(res, list)) | ||||
|       self.assertEqual(42.0, res[0]) | ||||
|       self.assertEqual(None, res[1]) | ||||
|       self.assertEqual(44.0, res[2]) | ||||
|       self.assertEqual(42.0, res[3]) | ||||
|       self.assertEqual(None, res[4]) | ||||
|       self.assertEqual(63.0, sess.run(v)) | ||||
|       self.assertEqual([42.0, None, 44.0, 42.0, None], res) | ||||
|       list_runner = sess.make_callable([a, b, c, a.name, assign.op]) | ||||
|       res = list_runner() | ||||
|       self.assertTrue(isinstance(res, list)) | ||||
|       self.assertEqual([42.0, None, 44.0, 42.0, None], res) | ||||
| 
 | ||||
|   def testFetchTuple(self): | ||||
|     with session.Session() as sess: | ||||
| @ -225,10 +230,11 @@ class SessionTest(test_util.TensorFlowTestCase): | ||||
|       c = constant_op.constant(44.0) | ||||
|       res = sess.run((a, b, c, a.name)) | ||||
|       self.assertTrue(isinstance(res, tuple)) | ||||
|       self.assertEqual(42.0, res[0]) | ||||
|       self.assertEqual(None, res[1]) | ||||
|       self.assertEqual(44.0, res[2]) | ||||
|       self.assertEqual(42.0, res[3]) | ||||
|       self.assertEqual((42.0, None, 44.0, 42.0), res) | ||||
|       tuple_runner = sess.make_callable((a, b, c, a.name)) | ||||
|       res = tuple_runner() | ||||
|       self.assertTrue(isinstance(res, tuple)) | ||||
|       self.assertEqual((42.0, None, 44.0, 42.0), res) | ||||
| 
 | ||||
|   def testFetchNamedTuple(self): | ||||
|     # pylint: disable=invalid-name | ||||
| @ -243,6 +249,12 @@ class SessionTest(test_util.TensorFlowTestCase): | ||||
|       self.assertEqual(42.0, res.a) | ||||
|       self.assertEqual(None, res.b) | ||||
|       self.assertEqual(44.0, res.c) | ||||
|       namedtuple_runner = sess.make_callable(ABC(a, b, c)) | ||||
|       res = namedtuple_runner() | ||||
|       self.assertTrue(isinstance(res, ABC)) | ||||
|       self.assertEqual(42.0, res.a) | ||||
|       self.assertEqual(None, res.b) | ||||
|       self.assertEqual(44.0, res.c) | ||||
| 
 | ||||
|   def testFetchDict(self): | ||||
|     with session.Session() as sess: | ||||
| @ -1181,6 +1193,11 @@ class SessionTest(test_util.TensorFlowTestCase): | ||||
|           self.assertAllEqual(np_array, out_v) | ||||
|           self.assertAllEqual(np_array, feed_v) | ||||
| 
 | ||||
|           feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) | ||||
|           out_v, feed_v = feed_fetch_runner(np_array) | ||||
|           self.assertAllEqual(np_array, out_v) | ||||
|           self.assertAllEqual(np_array, feed_v) | ||||
| 
 | ||||
|   def testFeedError(self): | ||||
|     with session.Session() as sess: | ||||
|       feed_t = array_ops.placeholder(dtype=dtypes.float32) | ||||
|  | ||||
| @ -227,11 +227,14 @@ class QueueRunner(object): | ||||
|     """ | ||||
|     decremented = False | ||||
|     try: | ||||
|       # Make a cached callable from the `enqueue_op` to decrease the | ||||
|       # Python overhead in the queue-runner loop. | ||||
|       enqueue_callable = sess.make_callable(enqueue_op) | ||||
|       while True: | ||||
|         if coord and coord.should_stop(): | ||||
|           break | ||||
|         try: | ||||
|           sess.run(enqueue_op) | ||||
|           enqueue_callable() | ||||
|         except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception | ||||
|           # This exception indicates that a queue was closed. | ||||
|           with self._lock: | ||||
|  | ||||
| @ -28,6 +28,10 @@ tf_class { | ||||
|     name: "close" | ||||
|     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "make_callable" | ||||
|     argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " | ||||
|   } | ||||
|   member_method { | ||||
|     name: "partial_run" | ||||
|     argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " | ||||
|  | ||||
| @ -28,6 +28,10 @@ tf_class { | ||||
|     name: "close" | ||||
|     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" | ||||
|   } | ||||
|   member_method { | ||||
|     name: "make_callable" | ||||
|     argspec: "args=[\'self\', \'fetches\', \'feed_list\'], varargs=None, keywords=None, defaults=[\'None\'], " | ||||
|   } | ||||
|   member_method { | ||||
|     name: "partial_run" | ||||
|     argspec: "args=[\'self\', \'handle\', \'fetches\', \'feed_dict\'], varargs=None, keywords=None, defaults=[\'None\'], " | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user