Make the TensorFlow debugger plugin respond with health pills.
The /health_pills handler now responds with a JSON-ified mapping between node name and a list of health pill events for a specified run. Towards that end, added an optional run POST parameter to that route. Change: 146767197
This commit is contained in:
parent
764507ea71
commit
1bed3d798e
tensorflow
python/summary
tensorboard
@ -693,7 +693,9 @@ class EventAccumulator(object):
|
||||
output_slot: The output slot for this health pill.
|
||||
elements: An ND array of 12 floats. The elements of the health pill.
|
||||
"""
|
||||
# Key by the node name for fast retrieval of health pills by node name.
|
||||
# Key by the node name for fast retrieval of health pills by node name. The
|
||||
# array is cast to a list so that it is JSON-able. The debugger data plugin
|
||||
# serves a JSON response.
|
||||
self._health_pills.AddItem(
|
||||
node_name,
|
||||
HealthPillEvent(
|
||||
@ -701,7 +703,7 @@ class EventAccumulator(object):
|
||||
step=step,
|
||||
node_name=node_name,
|
||||
output_slot=output_slot,
|
||||
value=elements))
|
||||
value=list(elements)))
|
||||
|
||||
def _Purge(self, event, by_tags):
|
||||
"""Purge all events that have occurred after the given event.step.
|
||||
|
@ -251,7 +251,7 @@ class EventMultiplexer(object):
|
||||
return accumulator.Scalars(tag)
|
||||
|
||||
def HealthPills(self, run, node_name):
|
||||
"""Retrieve the scalar events associated with a run and node name.
|
||||
"""Retrieve the health pill events associated with a run and node name.
|
||||
|
||||
Args:
|
||||
run: A string name of the run for which health pills are retrieved.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
|
||||
from werkzeug import wrappers
|
||||
@ -32,9 +33,15 @@ PLUGIN_PREFIX_ROUTE = 'debugger'
|
||||
# HTTP routes.
|
||||
_HEALTH_PILLS_ROUTE = '/health_pills'
|
||||
|
||||
# The POST key value of the HEALTH_PILLS_ROUTE for a JSON list of node names.
|
||||
# The POST key of HEALTH_PILLS_ROUTE for a JSON list of node names.
|
||||
_NODE_NAMES_POST_KEY = 'node_names'
|
||||
|
||||
# The POST key of HEALTH_PILLS_ROUTE for the run to retrieve health pills for.
|
||||
_RUN_POST_KEY = 'run'
|
||||
|
||||
# The default run to retrieve health pills for.
|
||||
_DEFAULT_RUN = '.'
|
||||
|
||||
|
||||
class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
"""TensorFlow Debugger plugin. Receives requests for debugger-related data.
|
||||
@ -43,6 +50,14 @@ class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(self, event_multiplexer):
|
||||
"""Constructs a plugin for serving TensorFlow debugger data.
|
||||
|
||||
Args:
|
||||
event_multiplexer: Organizes data from events files.
|
||||
"""
|
||||
self._event_multiplexer = event_multiplexer
|
||||
|
||||
def get_plugin_apps(self, unused_run_paths, unused_logdir):
|
||||
"""Obtains a mapping between routes and handlers.
|
||||
|
||||
@ -62,10 +77,6 @@ class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
def _serve_health_pills_handler(self, request):
|
||||
"""A (wrapped) werkzeug handler for serving health pills.
|
||||
|
||||
NOTE(chizeng): This handler is currently not useful and does not behave as
|
||||
expected. It currently merely responds with the provided list of op names
|
||||
(instead of the health pills for those ops). Very soon, that will change.
|
||||
|
||||
We defer to another method for actually performing the main logic because
|
||||
the @wrappers.Request.application decorator makes this logic hard to access
|
||||
in tests.
|
||||
@ -82,14 +93,26 @@ class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
"""Responds with health pills.
|
||||
|
||||
Accepts POST requests and responds with health pills. Specifically, the
|
||||
handler expects a "node_names" POST data key. The value of that key should
|
||||
be a JSON-ified list of node names for which the client would like to
|
||||
request health pills. This data is sent via POST instead of GET because URL
|
||||
length is limited.
|
||||
handler expects a required "node_names" and an optional "run" POST data key.
|
||||
The value of the "node_names" key should be a JSON-ified list of node names
|
||||
for which the client would like to request health pills. The value of the
|
||||
"run" key (which defaults to ".") should be the run to retrieve health pills
|
||||
for. This data is sent via POST (not GET) because URL length is limited.
|
||||
|
||||
This handler responds with a JSON-ified object mapping from node names to a
|
||||
list of HealthPillEvents. Node names for which there are no health pills to
|
||||
be found are excluded from the mapping.
|
||||
list of health pill event objects, each of which has these properties.
|
||||
|
||||
{
|
||||
'wall_time': float,
|
||||
'step': int,
|
||||
'node_name': string,
|
||||
'output_slot': int,
|
||||
# A list of 12 floats that summarizes the elements of the tensor.
|
||||
'value': float[],
|
||||
}
|
||||
|
||||
Node names for which there are no health pills to be found are excluded from
|
||||
the mapping.
|
||||
|
||||
Args:
|
||||
request: The request issued by the client for health pills.
|
||||
@ -125,5 +148,20 @@ class DebuggerPlugin(base_plugin.TBPlugin):
|
||||
'%s is not a JSON list of node names:', jsonified_node_names)
|
||||
return wrappers.Response(status=400)
|
||||
|
||||
# TODO(chizeng): Actually respond with the health pills per node name.
|
||||
return http_util.Respond(request, node_names, mimetype='application/json')
|
||||
mapping = collections.defaultdict(list)
|
||||
run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN)
|
||||
for node_name in node_names:
|
||||
try:
|
||||
pill_events = self._event_multiplexer.HealthPills(run, node_name)
|
||||
for pill_event in pill_events:
|
||||
mapping[node_name].append({
|
||||
'wall_time': pill_event[0],
|
||||
'step': pill_event[1],
|
||||
'node_name': pill_event[2],
|
||||
'output_slot': pill_event[3],
|
||||
'value': pill_event[4],
|
||||
})
|
||||
except KeyError:
|
||||
logging.info('No health pills found for node %s.', node_name)
|
||||
|
||||
return http_util.Respond(request, mapping, 'application/json')
|
||||
|
@ -22,6 +22,7 @@ import collections
|
||||
import json
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import event_accumulator
|
||||
from tensorflow.tensorboard.plugins.debugger import plugin as debugger_plugin
|
||||
|
||||
|
||||
@ -42,14 +43,148 @@ class FakeRequest(object):
|
||||
self.method = method
|
||||
self.form = post_data
|
||||
|
||||
# http_util.Respond requires a headers property.
|
||||
self.headers = {}
|
||||
|
||||
|
||||
class FakeEventMultiplexer(object):
|
||||
"""A fake event multiplexer we can populate with custom health pills."""
|
||||
|
||||
def __init__(self, run_to_node_name_to_health_pills):
|
||||
"""Constructs a fake event multiplexer.
|
||||
|
||||
Args:
|
||||
run_to_node_name_to_health_pills: A dict mapping run to a dict mapping
|
||||
node name to a list of health pills.
|
||||
"""
|
||||
self._run_to_node_name_to_health_pills = run_to_node_name_to_health_pills
|
||||
|
||||
def HealthPills(self, run, node_name):
|
||||
"""Retrieve the health pill events associated with a run and node name.
|
||||
|
||||
Args:
|
||||
run: A string name of the run for which health pills are retrieved.
|
||||
node_name: A string name of the node for which health pills are retrieved.
|
||||
|
||||
Raises:
|
||||
KeyError: If the run is not found, or the node name is not available for
|
||||
the given run.
|
||||
|
||||
Returns:
|
||||
An array of strings (that substitute for
|
||||
event_accumulator.HealthPillEvents) that represent health pills.
|
||||
"""
|
||||
return self._run_to_node_name_to_health_pills[run][node_name]
|
||||
|
||||
|
||||
class DebuggerPluginTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.debugger_plugin = debugger_plugin.DebuggerPlugin()
|
||||
self.fake_event_multiplexer = FakeEventMultiplexer({
|
||||
'.': {
|
||||
'layers/Matmul': [
|
||||
event_accumulator.HealthPillEvent(
|
||||
wall_time=42,
|
||||
step=2,
|
||||
node_name='layers/Matmul',
|
||||
output_slot=0,
|
||||
value=[1, 2, 3]),
|
||||
event_accumulator.HealthPillEvent(
|
||||
wall_time=43,
|
||||
step=3,
|
||||
node_name='layers/Matmul',
|
||||
output_slot=1,
|
||||
value=[4, 5, 6]),
|
||||
],
|
||||
'logits/Add': [
|
||||
event_accumulator.HealthPillEvent(
|
||||
wall_time=1337,
|
||||
step=7,
|
||||
node_name='logits/Add',
|
||||
output_slot=0,
|
||||
value=[7, 8, 9]),
|
||||
event_accumulator.HealthPillEvent(
|
||||
wall_time=1338,
|
||||
step=8,
|
||||
node_name='logits/Add',
|
||||
output_slot=0,
|
||||
value=[10, 11, 12]),
|
||||
],
|
||||
},
|
||||
'run_foo': {
|
||||
'layers/Variable': [
|
||||
event_accumulator.HealthPillEvent(
|
||||
wall_time=4242,
|
||||
step=42,
|
||||
node_name='layers/Variable',
|
||||
output_slot=0,
|
||||
value=[13, 14, 15]),
|
||||
],
|
||||
},
|
||||
})
|
||||
self.debugger_plugin = debugger_plugin.DebuggerPlugin(
|
||||
self.fake_event_multiplexer)
|
||||
self.unused_run_paths = {}
|
||||
self.unused_logdir = '/logdir'
|
||||
|
||||
def _DeserializeResponse(self, byte_content):
|
||||
"""Deserializes byte content that is a JSON encoding.
|
||||
|
||||
Args:
|
||||
byte_content: The byte content of a JSON response.
|
||||
|
||||
Returns:
|
||||
The deserialized python object.
|
||||
"""
|
||||
return json.loads(byte_content.decode('utf-8'))
|
||||
|
||||
def testRequestHealthPillsForRunFoo(self):
|
||||
"""Tests that the plugin produces health pills for a specified run."""
|
||||
request = FakeRequest('POST', {
|
||||
'node_names': json.dumps(['layers/Variable', 'unavailable_node']),
|
||||
'run': 'run_foo',
|
||||
})
|
||||
response = self.debugger_plugin._serve_health_pills_helper(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertDictEqual({
|
||||
'layers/Variable': [{
|
||||
'wall_time': 4242,
|
||||
'step': 42,
|
||||
'node_name': 'layers/Variable',
|
||||
'output_slot': 0,
|
||||
'value': [13, 14, 15],
|
||||
}],
|
||||
}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testRequestHealthPillsForDefaultRun(self):
|
||||
"""Tests that the plugin produces health pills for the default '.' run."""
|
||||
# Do not provide a 'run' parameter in POST data.
|
||||
request = FakeRequest('POST', {
|
||||
'node_names': json.dumps(['logits/Add', 'unavailable_node']),
|
||||
})
|
||||
response = self.debugger_plugin._serve_health_pills_helper(request)
|
||||
self.assertEqual(200, response.status_code)
|
||||
# The health pills for 'layers/Matmul' should not be included since the
|
||||
# request excluded that node name.
|
||||
self.assertDictEqual({
|
||||
'logits/Add': [
|
||||
{
|
||||
'wall_time': 1337,
|
||||
'step': 7,
|
||||
'node_name': 'logits/Add',
|
||||
'output_slot': 0,
|
||||
'value': [7, 8, 9],
|
||||
},
|
||||
{
|
||||
'wall_time': 1338,
|
||||
'step': 8,
|
||||
'node_name': 'logits/Add',
|
||||
'output_slot': 0,
|
||||
'value': [10, 11, 12],
|
||||
},
|
||||
],
|
||||
}, self._DeserializeResponse(response.get_data()))
|
||||
|
||||
def testHealthPillsRouteProvided(self):
|
||||
"""Tests that the plugin offers the route for requesting health pills."""
|
||||
apps = self.debugger_plugin.get_plugin_apps(self.unused_run_paths,
|
||||
|
@ -122,7 +122,7 @@ class Server(object):
|
||||
purge_orphaned_data=FLAGS.purge_orphaned_data)
|
||||
plugins = {
|
||||
debugger_plugin.PLUGIN_PREFIX_ROUTE:
|
||||
debugger_plugin.DebuggerPlugin(),
|
||||
debugger_plugin.DebuggerPlugin(multiplexer),
|
||||
projector_plugin.PLUGIN_PREFIX_ROUTE:
|
||||
projector_plugin.ProjectorPlugin(),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user