Fix performance issues with the Embedding Projector plugin.

- Setup data handlers in the class c-tor instead of on every GET request.
- Use only 1 instance of the plugin, instead of creating a new instance on every GET Request.
- Memorize readers, and other data, and share it across the multiple threads of TB server. Previously # of checkpoint readers = # of GET requests, ups.
- Checkpoints can be big (order of 1G). Reading it takes a while. Only start reading checkpoints if someone explicitly opens the embeddings tab. Previously the checkpoint was read on TB startup, which slows down TensorBoard.
Change: 138187941
This commit is contained in:
Dan Smilkov 2016-11-04 06:34:23 -08:00 committed by TensorFlower Gardener
parent d150ca80da
commit 59daeca315
4 changed files with 111 additions and 108 deletions

View File

@ -24,13 +24,11 @@ from __future__ import division
from __future__ import print_function
import csv
import functools
import imghdr
import json
import mimetypes
import os
import six
import six as _six
from six import StringIO
from six.moves import BaseHTTPServer
from six.moves import urllib
@ -101,8 +99,41 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
def __init__(self, multiplexer, logdir, *args):
self._multiplexer = multiplexer
self._logdir = logdir
self._setup_data_handlers()
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args)
def _setup_data_handlers(self):
self.data_handlers = {
DATA_PREFIX + LOGDIR_ROUTE: self._serve_logdir,
DATA_PREFIX + SCALARS_ROUTE: self._serve_scalars,
DATA_PREFIX + GRAPH_ROUTE: self._serve_graph,
DATA_PREFIX + RUN_METADATA_ROUTE: self._serve_run_metadata,
DATA_PREFIX + HISTOGRAMS_ROUTE: self._serve_histograms,
DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE:
self._serve_compressed_histograms,
DATA_PREFIX + IMAGES_ROUTE: self._serve_images,
DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: self._serve_image,
DATA_PREFIX + AUDIO_ROUTE: self._serve_audio,
DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE: self._serve_individual_audio,
DATA_PREFIX + RUNS_ROUTE: self._serve_runs,
'/app.js': self._serve_js
}
# Serve the routes from the registered plugins using their name as the route
# prefix. For example if plugin z has two routes /a and /b, they will be
# served as /data/plugin/z/a and /data/plugin/z/b.
for name in REGISTERED_PLUGINS:
try:
plugin = REGISTERED_PLUGINS[name]
plugin_handlers = plugin.get_plugin_handlers(
self._multiplexer.RunPaths(), self._logdir)
except Exception as e: # pylint: disable=broad-except
logging.warning('Plugin %s failed. Exception: %s', name, str(e))
continue
for route, handler in plugin_handlers.items():
path = DATA_PREFIX + PLUGIN_PREFIX + '/' + name + route
self.data_handlers[path] = functools.partial(handler, self)
def respond(self, *args, **kwargs):
"""Delegates to http.Respond."""
http.Respond(self, *args, **kwargs)
@ -436,6 +467,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
def _serve_static_file(self, path):
"""Serves the static file located at the given path.
Args:
path: The path of the static file, relative to the tensorboard/ directory.
"""
@ -483,39 +515,6 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
if clean_path.endswith('/'):
clean_path = clean_path[:-1]
data_handlers = {
DATA_PREFIX + LOGDIR_ROUTE: self._serve_logdir,
DATA_PREFIX + SCALARS_ROUTE: self._serve_scalars,
DATA_PREFIX + GRAPH_ROUTE: self._serve_graph,
DATA_PREFIX + RUN_METADATA_ROUTE: self._serve_run_metadata,
DATA_PREFIX + HISTOGRAMS_ROUTE: self._serve_histograms,
DATA_PREFIX + COMPRESSED_HISTOGRAMS_ROUTE:
self._serve_compressed_histograms,
DATA_PREFIX + IMAGES_ROUTE: self._serve_images,
DATA_PREFIX + INDIVIDUAL_IMAGE_ROUTE: self._serve_image,
DATA_PREFIX + AUDIO_ROUTE: self._serve_audio,
DATA_PREFIX + INDIVIDUAL_AUDIO_ROUTE: self._serve_individual_audio,
DATA_PREFIX + RUNS_ROUTE: self._serve_runs,
'/app.js': self._serve_js
}
# Serve the routes from the registered plugins using their name as the route
# prefix. For example if plugin z has two routes /a and /b, they will be
# served as /data/plugin/z/a and /data/plugin/z/b.
for name in REGISTERED_PLUGINS:
try:
plugin = REGISTERED_PLUGINS[name]()
# Initialize the plugin by passing the main http handler.
plugin.initialize(self)
plugin_handlers = plugin.get_plugin_handlers(
self._multiplexer.RunPaths(), self._logdir)
except Exception as e: # pylint: disable=broad-except
logging.warning('Plugin %s failed. Exception: %s', name, str(e))
continue
for route, handler in six.iteritems(plugin_handlers):
path = DATA_PREFIX + PLUGIN_PREFIX + '/' + name + route
data_handlers[path] = handler
query_params = urlparse.parse_qs(parsed_url.query)
# parse_qs returns a list of values for each key; we're only interested in
# the first.
@ -528,8 +527,8 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return
query_params[key] = query_params[key][0]
if clean_path in data_handlers:
data_handlers[clean_path](query_params)
if clean_path in self.data_handlers:
self.data_handlers[clean_path](query_params)
elif clean_path in TAB_ROUTES:
self._serve_index(query_params)
else:

View File

@ -19,4 +19,4 @@ from __future__ import print_function
from tensorflow.tensorboard.plugins.projector.plugin import ProjectorPlugin
# Map of registered plugins in TensorBoard.
REGISTERED_PLUGINS = {'projector': ProjectorPlugin}
REGISTERED_PLUGINS = {'projector': ProjectorPlugin()}

View File

@ -30,15 +30,6 @@ class TBPlugin(object):
"""TensorBoard plugin interface. Every plugin must extend from this class."""
__metaclass__ = ABCMeta
def initialize(self, handler):
"""Initializes the plugin.
Args:
handler: The tensorboard http handler that has methods that are used
by plugins such as serving json or gzip response.
"""
self.handler = handler
@abstractmethod
def get_plugin_handlers(self, run_paths, logdir):
"""Returns a set of http handlers that the plugin implements.

View File

@ -64,13 +64,18 @@ def _read_tensor_file(fpath):
class ProjectorPlugin(TBPlugin):
"""Embedding projector."""
def get_plugin_handlers(self, run_paths, logdir):
self.configs, self.config_fpaths = self._read_config_files(run_paths,
logdir)
def __init__(self):
self._handlers = None
self.readers = {}
self._augment_configs_with_checkpoint_info()
self.run_paths = None
self.logdir = None
self._configs = None
self.old_num_run_paths = None
return {
def get_plugin_handlers(self, run_paths, logdir):
self.run_paths = run_paths
self.logdir = logdir
self._handlers = {
RUNS_ROUTE: self._serve_runs,
CONFIG_ROUTE: self._serve_config,
TENSOR_ROUTE: self._serve_tensor,
@ -78,9 +83,26 @@ class ProjectorPlugin(TBPlugin):
BOOKMARKS_ROUTE: self._serve_bookmarks,
SPRITE_IMAGE_ROUTE: self._serve_sprite_image
}
return self._handlers
@property
def configs(self):
"""Returns a map of run paths to `ProjectorConfig` protos."""
if self._run_paths_changed():
self._configs, self.config_fpaths = self._read_config_files(
self.run_paths, self.logdir)
self._augment_configs_with_checkpoint_info()
return self._configs
def _run_paths_changed(self):
num_run_paths = len(list(self.run_paths.keys()))
if num_run_paths != self.old_num_run_paths:
self.old_num_run_paths = num_run_paths
return True
return False
def _augment_configs_with_checkpoint_info(self):
for run, config in self.configs.items():
for run, config in self._configs.items():
# Find the size of the embeddings that are associated with a tensor file.
for embedding in config.embeddings:
if embedding.tensor_path and not embedding.tensor_shape:
@ -111,18 +133,18 @@ class ProjectorPlugin(TBPlugin):
# Remove configs that do not have any valid (2D) tensors.
runs_to_remove = []
for run, config in self.configs.items():
for run, config in self._configs.items():
if not config.embeddings:
runs_to_remove.append(run)
for run in runs_to_remove:
del self.configs[run]
del self._configs[run]
del self.config_fpaths[run]
def _read_config_files(self, run_paths, logdir):
def _read_config_files(self, run_paths, summary_logdir):
# If there are no summary event files, the projector can still work,
# thus treating the `logdir` as the model checkpoint directory.
if not run_paths:
run_paths['.'] = logdir
run_paths['.'] = summary_logdir
configs = {}
config_fpaths = {}
@ -164,7 +186,7 @@ class ProjectorPlugin(TBPlugin):
if run in self.readers:
return self.readers[run]
config = self.configs[run]
config = self._configs[run]
reader = None
if config.model_checkpoint_path:
try:
@ -201,48 +223,45 @@ class ProjectorPlugin(TBPlugin):
return info
return None
def _serve_runs(self, query_params):
def _serve_runs(self, request, query_params):
"""Returns a list of runs that have embeddings."""
self.handler.respond(list(self.configs.keys()), 'application/json')
request.respond(list(self.configs.keys()), 'application/json')
def _serve_config(self, query_params):
def _serve_config(self, request, query_params):
run = query_params.get('run')
if run is None:
self.handler.respond('query parameter "run" is required',
'text/plain', 400)
request.respond('query parameter "run" is required', 'text/plain', 400)
return
if run not in self.configs:
self.handler.respond('Unknown run: %s' % run, 'text/plain', 400)
request.respond('Unknown run: %s' % run, 'text/plain', 400)
return
config = self.configs[run]
self.handler.respond(json_format.MessageToJson(config), 'application/json')
request.respond(json_format.MessageToJson(config), 'application/json')
def _serve_metadata(self, query_params):
def _serve_metadata(self, request, query_params):
run = query_params.get('run')
if run is None:
self.handler.respond('query parameter "run" is required',
'text/plain', 400)
request.respond('query parameter "run" is required', 'text/plain', 400)
return
name = query_params.get('name')
if name is None:
self.handler.respond('query parameter "name" is required',
'text/plain', 400)
request.respond('query parameter "name" is required', 'text/plain', 400)
return
if run not in self.configs:
self.handler.respond('Unknown run: %s' % run, 'text/plain', 400)
request.respond('Unknown run: %s' % run, 'text/plain', 400)
return
config = self.configs[run]
fpath = self._get_metadata_file_for_tensor(name, config)
if not fpath:
self.handler.respond(
request.respond(
'No metadata file found for tensor %s in the config file %s' %
(name, self.config_fpaths[run]), 'text/plain', 400)
return
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
self.handler.respond('%s is not a file' % fpath, 'text/plain', 400)
request.respond('%s is not a file' % fpath, 'text/plain', 400)
return
num_header_rows = 0
@ -256,23 +275,21 @@ class ProjectorPlugin(TBPlugin):
num_header_rows = 1
if len(lines) >= LIMIT_NUM_POINTS + num_header_rows:
break
self.handler.respond(''.join(lines), 'text/plain')
request.respond(''.join(lines), 'text/plain')
def _serve_tensor(self, query_params):
def _serve_tensor(self, request, query_params):
run = query_params.get('run')
if run is None:
self.handler.respond('query parameter "run" is required',
'text/plain', 400)
request.respond('query parameter "run" is required', 'text/plain', 400)
return
name = query_params.get('name')
if name is None:
self.handler.respond('query parameter "name" is required',
'text/plain', 400)
request.respond('query parameter "name" is required', 'text/plain', 400)
return
if run not in self.configs:
self.handler.respond('Unknown run: %s' % run, 'text/plain', 400)
request.respond('Unknown run: %s' % run, 'text/plain', 400)
return
reader = self._get_reader_for_run(run)
@ -282,19 +299,19 @@ class ProjectorPlugin(TBPlugin):
# See if there is a tensor file in the config.
embedding = self._get_embedding(name, config)
if not embedding or not embedding.tensor_path:
self.handler.respond('Tensor %s has no tensor_path in the config' %
name, 'text/plain', 400)
request.respond('Tensor %s has no tensor_path in the config' %
name, 'text/plain', 400)
return
if not file_io.file_exists(embedding.tensor_path):
self.handler.respond('Tensor file %s does not exist' %
embedding.tensor_path, 'text/plain', 400)
request.respond('Tensor file %s does not exist' %
embedding.tensor_path, 'text/plain', 400)
return
tensor = _read_tensor_file(embedding.tensor_path)
else:
if not reader.has_tensor(name):
self.handler.respond('Tensor %s not found in checkpoint dir %s' %
(name, config.model_checkpoint_path),
'text/plain', 400)
request.respond('Tensor %s not found in checkpoint dir %s' %
(name, config.model_checkpoint_path),
'text/plain', 400)
return
tensor = reader.get_tensor(name)
@ -302,75 +319,71 @@ class ProjectorPlugin(TBPlugin):
tensor = tensor[:LIMIT_NUM_POINTS]
# Stream it as TSV.
tsv = '\n'.join(['\t'.join([str(val) for val in row]) for row in tensor])
self.handler.respond(tsv, 'text/tab-separated-values')
request.respond(tsv, 'text/tab-separated-values')
def _serve_bookmarks(self, query_params):
def _serve_bookmarks(self, request, query_params):
run = query_params.get('run')
if not run:
self.handler.respond('query parameter "run" is required', 'text/plain',
400)
request.respond('query parameter "run" is required', 'text/plain', 400)
return
name = query_params.get('name')
if name is None:
self.handler.respond('query parameter "name" is required', 'text/plain',
400)
request.respond('query parameter "name" is required', 'text/plain', 400)
return
if run not in self.configs:
self.handler.respond('Unknown run: %s' % run, 'text/plain', 400)
request.respond('Unknown run: %s' % run, 'text/plain', 400)
return
config = self.configs[run]
fpath = self._get_bookmarks_file_for_tensor(name, config)
if not fpath:
self.handler.respond(
request.respond(
'No bookmarks file found for tensor %s in the config file %s' %
(name, self.config_fpaths[run]), 'text/plain', 400)
return
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
self.handler.respond('%s is not a file' % fpath, 'text/plain', 400)
request.respond('%s is not a file' % fpath, 'text/plain', 400)
return
bookmarks_json = None
with file_io.FileIO(fpath, 'r') as f:
bookmarks_json = f.read()
self.handler.respond(bookmarks_json, 'application/json')
request.respond(bookmarks_json, 'application/json')
def _serve_sprite_image(self, query_params):
def _serve_sprite_image(self, request, query_params):
run = query_params.get('run')
if not run:
self.handler.respond('query parameter "run" is required', 'text/plain',
400)
request.respond('query parameter "run" is required', 'text/plain', 400)
return
name = query_params.get('name')
if name is None:
self.handler.respond('query parameter "name" is required', 'text/plain',
400)
request.respond('query parameter "name" is required', 'text/plain', 400)
return
if run not in self.configs:
self.handler.respond('Unknown run: %s' % run, 'text/plain', 400)
request.respond('Unknown run: %s' % run, 'text/plain', 400)
return
config = self.configs[run]
embedding_info = self._get_embedding(name, config)
if not embedding_info or not embedding_info.sprite.image_path:
self.handler.respond(
request.respond(
'No sprite image file found for tensor %s in the config file %s' %
(name, self.config_fpaths[run]), 'text/plain', 400)
return
fpath = embedding_info.sprite.image_path
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
self.handler.respond('%s does not exist or is directory' % fpath,
'text/plain', 400)
request.respond(
'%s does not exist or is directory' % fpath, 'text/plain', 400)
return
f = file_io.FileIO(fpath, 'r')
encoded_image_string = f.read()
f.close()
image_type = imghdr.what(None, encoded_image_string)
mime_type = _IMGHDR_TO_MIMETYPE.get(image_type, _DEFAULT_IMAGE_MIMETYPE)
self.handler.respond(encoded_image_string, mime_type)
request.respond(encoded_image_string, mime_type)