Make the TensorFlow debugger plugin respond with health pills.

The /health_pills handler now responds with a JSON-ified mapping between node name and a list of health pill events for a specified run. Towards that end, added an optional run POST  parameter to that route.
Change: 146767197
This commit is contained in:
A. Unique TensorFlower 2017-02-07 03:03:33 -08:00 committed by TensorFlower Gardener
parent 764507ea71
commit 1bed3d798e
5 changed files with 193 additions and 18 deletions

View File

@ -693,7 +693,9 @@ class EventAccumulator(object):
output_slot: The output slot for this health pill.
elements: An ND array of 12 floats. The elements of the health pill.
"""
# Key by the node name for fast retrieval of health pills by node name.
# Key by the node name for fast retrieval of health pills by node name. The
# array is cast to a list so that it is JSON-able. The debugger data plugin
# serves a JSON response.
self._health_pills.AddItem(
node_name,
HealthPillEvent(
@ -701,7 +703,7 @@ class EventAccumulator(object):
step=step,
node_name=node_name,
output_slot=output_slot,
value=elements))
value=list(elements)))
def _Purge(self, event, by_tags):
"""Purge all events that have occurred after the given event.step.

View File

@ -251,7 +251,7 @@ class EventMultiplexer(object):
return accumulator.Scalars(tag)
def HealthPills(self, run, node_name):
"""Retrieve the scalar events associated with a run and node name.
"""Retrieve the health pill events associated with a run and node name.
Args:
run: A string name of the run for which health pills are retrieved.

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import json
from werkzeug import wrappers
@ -32,9 +33,15 @@ PLUGIN_PREFIX_ROUTE = 'debugger'
# HTTP routes.
_HEALTH_PILLS_ROUTE = '/health_pills'
# The POST key value of the HEALTH_PILLS_ROUTE for a JSON list of node names.
# The POST key of HEALTH_PILLS_ROUTE for a JSON list of node names.
_NODE_NAMES_POST_KEY = 'node_names'
# The POST key of HEALTH_PILLS_ROUTE for the run to retrieve health pills for.
_RUN_POST_KEY = 'run'
# The default run to retrieve health pills for.
_DEFAULT_RUN = '.'
class DebuggerPlugin(base_plugin.TBPlugin):
"""TensorFlow Debugger plugin. Receives requests for debugger-related data.
@ -43,6 +50,14 @@ class DebuggerPlugin(base_plugin.TBPlugin):
values.
"""
def __init__(self, event_multiplexer):
"""Constructs a plugin for serving TensorFlow debugger data.
Args:
event_multiplexer: Organizes data from events files.
"""
self._event_multiplexer = event_multiplexer
def get_plugin_apps(self, unused_run_paths, unused_logdir):
"""Obtains a mapping between routes and handlers.
@ -62,10 +77,6 @@ class DebuggerPlugin(base_plugin.TBPlugin):
def _serve_health_pills_handler(self, request):
"""A (wrapped) werkzeug handler for serving health pills.
NOTE(chizeng): This handler is currently not useful and does not behave as
expected. It currently merely responds with the provided list of op names
(instead of the health pills for those ops). Very soon, that will change.
We defer to another method for actually performing the main logic because
the @wrappers.Request.application decorator makes this logic hard to access
in tests.
@ -82,14 +93,26 @@ class DebuggerPlugin(base_plugin.TBPlugin):
"""Responds with health pills.
Accepts POST requests and responds with health pills. Specifically, the
handler expects a "node_names" POST data key. The value of that key should
be a JSON-ified list of node names for which the client would like to
request health pills. This data is sent via POST instead of GET because URL
length is limited.
handler expects a required "node_names" and an optional "run" POST data key.
The value of the "node_names" key should be a JSON-ified list of node names
for which the client would like to request health pills. The value of the
"run" key (which defaults to ".") should be the run to retrieve health pills
for. This data is sent via POST (not GET) because URL length is limited.
This handler responds with a JSON-ified object mapping from node names to a
list of HealthPillEvents. Node names for which there are no health pills to
be found are excluded from the mapping.
list of health pill event objects, each of which has these properties.
{
'wall_time': float,
'step': int,
'node_name': string,
'output_slot': int,
# A list of 12 floats that summarizes the elements of the tensor.
'value': float[],
}
Node names for which there are no health pills to be found are excluded from
the mapping.
Args:
request: The request issued by the client for health pills.
@ -125,5 +148,20 @@ class DebuggerPlugin(base_plugin.TBPlugin):
'%s is not a JSON list of node names:', jsonified_node_names)
return wrappers.Response(status=400)
# TODO(chizeng): Actually respond with the health pills per node name.
return http_util.Respond(request, node_names, mimetype='application/json')
mapping = collections.defaultdict(list)
run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN)
for node_name in node_names:
try:
pill_events = self._event_multiplexer.HealthPills(run, node_name)
for pill_event in pill_events:
mapping[node_name].append({
'wall_time': pill_event[0],
'step': pill_event[1],
'node_name': pill_event[2],
'output_slot': pill_event[3],
'value': pill_event[4],
})
except KeyError:
logging.info('No health pills found for node %s.', node_name)
return http_util.Respond(request, mapping, 'application/json')

View File

@ -22,6 +22,7 @@ import collections
import json
from tensorflow.python.platform import test
from tensorflow.python.summary import event_accumulator
from tensorflow.tensorboard.plugins.debugger import plugin as debugger_plugin
@ -42,14 +43,148 @@ class FakeRequest(object):
self.method = method
self.form = post_data
# http_util.Respond requires a headers property.
self.headers = {}
class FakeEventMultiplexer(object):
"""A fake event multiplexer we can populate with custom health pills."""
def __init__(self, run_to_node_name_to_health_pills):
"""Constructs a fake event multiplexer.
Args:
run_to_node_name_to_health_pills: A dict mapping run to a dict mapping
node name to a list of health pills.
"""
self._run_to_node_name_to_health_pills = run_to_node_name_to_health_pills
def HealthPills(self, run, node_name):
"""Retrieve the health pill events associated with a run and node name.
Args:
run: A string name of the run for which health pills are retrieved.
node_name: A string name of the node for which health pills are retrieved.
Raises:
KeyError: If the run is not found, or the node name is not available for
the given run.
Returns:
An array of strings (that substitute for
event_accumulator.HealthPillEvents) that represent health pills.
"""
return self._run_to_node_name_to_health_pills[run][node_name]
class DebuggerPluginTest(test.TestCase):
def setUp(self):
self.debugger_plugin = debugger_plugin.DebuggerPlugin()
self.fake_event_multiplexer = FakeEventMultiplexer({
'.': {
'layers/Matmul': [
event_accumulator.HealthPillEvent(
wall_time=42,
step=2,
node_name='layers/Matmul',
output_slot=0,
value=[1, 2, 3]),
event_accumulator.HealthPillEvent(
wall_time=43,
step=3,
node_name='layers/Matmul',
output_slot=1,
value=[4, 5, 6]),
],
'logits/Add': [
event_accumulator.HealthPillEvent(
wall_time=1337,
step=7,
node_name='logits/Add',
output_slot=0,
value=[7, 8, 9]),
event_accumulator.HealthPillEvent(
wall_time=1338,
step=8,
node_name='logits/Add',
output_slot=0,
value=[10, 11, 12]),
],
},
'run_foo': {
'layers/Variable': [
event_accumulator.HealthPillEvent(
wall_time=4242,
step=42,
node_name='layers/Variable',
output_slot=0,
value=[13, 14, 15]),
],
},
})
self.debugger_plugin = debugger_plugin.DebuggerPlugin(
self.fake_event_multiplexer)
self.unused_run_paths = {}
self.unused_logdir = '/logdir'
def _DeserializeResponse(self, byte_content):
"""Deserializes byte content that is a JSON encoding.
Args:
byte_content: The byte content of a JSON response.
Returns:
The deserialized python object.
"""
return json.loads(byte_content.decode('utf-8'))
def testRequestHealthPillsForRunFoo(self):
"""Tests that the plugin produces health pills for a specified run."""
request = FakeRequest('POST', {
'node_names': json.dumps(['layers/Variable', 'unavailable_node']),
'run': 'run_foo',
})
response = self.debugger_plugin._serve_health_pills_helper(request)
self.assertEqual(200, response.status_code)
self.assertDictEqual({
'layers/Variable': [{
'wall_time': 4242,
'step': 42,
'node_name': 'layers/Variable',
'output_slot': 0,
'value': [13, 14, 15],
}],
}, self._DeserializeResponse(response.get_data()))
def testRequestHealthPillsForDefaultRun(self):
"""Tests that the plugin produces health pills for the default '.' run."""
# Do not provide a 'run' parameter in POST data.
request = FakeRequest('POST', {
'node_names': json.dumps(['logits/Add', 'unavailable_node']),
})
response = self.debugger_plugin._serve_health_pills_helper(request)
self.assertEqual(200, response.status_code)
# The health pills for 'layers/Matmul' should not be included since the
# request excluded that node name.
self.assertDictEqual({
'logits/Add': [
{
'wall_time': 1337,
'step': 7,
'node_name': 'logits/Add',
'output_slot': 0,
'value': [7, 8, 9],
},
{
'wall_time': 1338,
'step': 8,
'node_name': 'logits/Add',
'output_slot': 0,
'value': [10, 11, 12],
},
],
}, self._DeserializeResponse(response.get_data()))
def testHealthPillsRouteProvided(self):
"""Tests that the plugin offers the route for requesting health pills."""
apps = self.debugger_plugin.get_plugin_apps(self.unused_run_paths,

View File

@ -122,7 +122,7 @@ class Server(object):
purge_orphaned_data=FLAGS.purge_orphaned_data)
plugins = {
debugger_plugin.PLUGIN_PREFIX_ROUTE:
debugger_plugin.DebuggerPlugin(),
debugger_plugin.DebuggerPlugin(multiplexer),
projector_plugin.PLUGIN_PREFIX_ROUTE:
projector_plugin.ProjectorPlugin(),
}