From d8fbe760f02662db7c79969aec15227fc261f8f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dan=20Man=C3=A9?= Date: Thu, 25 Feb 2016 17:24:24 -0800 Subject: [PATCH] Make TensorBoard purging of orphaned or out-of-order events optional. Now it is controlled by a flag. Default behavior is the same. Will be quite useful for debugging missing data type issues. Change: 115623272 --- .../python/summary/event_accumulator.py | 83 ++++--- .../python/summary/event_accumulator_test.py | 232 +++++++++++++----- .../python/summary/event_multiplexer.py | 21 +- .../scripts/serialize_tensorboard.py | 8 +- tensorflow/tensorboard/tensorboard.py | 8 +- 5 files changed, 243 insertions(+), 109 deletions(-) diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py index 3715ca64cb1..7cac3690b75 100644 --- a/tensorflow/python/summary/event_accumulator.py +++ b/tensorflow/python/summary/event_accumulator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Takes a generator of values, and accumulates them for a frontend.""" from __future__ import absolute_import from __future__ import division @@ -30,8 +29,7 @@ from tensorflow.python.summary.impl import event_file_loader from tensorflow.python.summary.impl import reservoir namedtuple = collections.namedtuple -ScalarEvent = namedtuple('ScalarEvent', - ['wall_time', 'step', 'value']) +ScalarEvent = namedtuple('ScalarEvent', ['wall_time', 'step', 'value']) CompressedHistogramEvent = namedtuple('CompressedHistogramEvent', ['wall_time', 'step', @@ -48,8 +46,8 @@ HistogramValue = namedtuple('HistogramValue', 'bucket_limit', 'bucket']) ImageEvent = namedtuple('ImageEvent', - ['wall_time', 'step', 'encoded_image_string', - 'width', 'height']) + ['wall_time', 'step', 'encoded_image_string', 'width', + 'height']) ## Different types of summary events handled by the event_accumulator SUMMARY_TYPES = ('_scalars', '_histograms', '_compressed_histograms', '_images') @@ -119,8 +117,11 @@ class EventAccumulator(object): @@Images """ - def __init__(self, path, size_guidance=DEFAULT_SIZE_GUIDANCE, - compression_bps=NORMAL_HISTOGRAM_BPS): + def __init__(self, + path, + size_guidance=DEFAULT_SIZE_GUIDANCE, + compression_bps=NORMAL_HISTOGRAM_BPS, + purge_orphaned_data=True): """Construct the `EventAccumulator`. Args: @@ -135,6 +136,8 @@ class EventAccumulator(object): compression_bps: Information on how the `EventAccumulator` should compress histogram data for the `CompressedHistograms` tag (for details see `ProcessCompressedHistogram`). + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. """ sizes = {} for key in DEFAULT_SIZE_GUIDANCE: @@ -149,10 +152,14 @@ class EventAccumulator(object): self._compressed_histograms = reservoir.Reservoir( size=sizes[COMPRESSED_HISTOGRAMS]) self._images = reservoir.Reservoir(size=sizes[IMAGES]) + self._generator_mutex = threading.Lock() self._generator = _GeneratorFromPath(path) - self._activated = False + self._compression_bps = compression_bps + self.purge_orphaned_data = purge_orphaned_data + + self._activated = False self.most_recent_step = -1 self.most_recent_wall_time = -1 self.file_version = None @@ -179,15 +186,7 @@ class EventAccumulator(object): new_file_version)) self.file_version = new_file_version - ## Check if the event happened after a crash, and purge expired tags. - if self.file_version and self.file_version >= 2: - ## If the file_version is recent enough, use the SessionLog enum - ## to check for restarts. - self._CheckForRestartAndMaybePurge(event) - else: - ## If there is no file version, default to old logic of checking for - ## out of order steps. - self._CheckForOutOfOrderStepAndMaybePurge(event) + self._MaybePurgeOrphanedData(event) ## Process the event if event.HasField('graph_def'): @@ -307,6 +306,31 @@ class EventAccumulator(object): self._VerifyActivated() return self._images.Items(tag) + def _MaybePurgeOrphanedData(self, event): + """Maybe purge orphaned data due to a TensorFlow crash. + + When TensorFlow crashes at step T+O and restarts at step T, any events + written after step T are now "orphaned" and will be at best misleading if + they are included in TensorBoard. + + This logic attempts to determine if there is orphaned data, and purge it + if it is found. + + Args: + event: The event to use as a reference, to determine if a purge is needed. + """ + if not self.purge_orphaned_data: + return + ## Check if the event happened after a crash, and purge expired tags. + if self.file_version and self.file_version >= 2: + ## If the file_version is recent enough, use the SessionLog enum + ## to check for restarts. + self._CheckForRestartAndMaybePurge(event) + else: + ## If there is no file version, default to old logic of checking for + ## out of order steps. + self._CheckForOutOfOrderStepAndMaybePurge(event) + def _CheckForRestartAndMaybePurge(self, event): """Check and discard expired events using SessionLog.START. @@ -362,16 +386,18 @@ class EventAccumulator(object): Returns: A linearly interpolated value of the histogram weight estimate. """ - if histo_num == 0: return 0 + if histo_num == 0: + return 0 for i, cumsum in enumerate(cumsum_weights): if cumsum >= compression_bps: - cumsum_prev = cumsum_weights[i-1] if i > 0 else 0 + cumsum_prev = cumsum_weights[i - 1] if i > 0 else 0 # Prevent cumsum = 0, cumsum_prev = 0, lerp divide by zero. - if cumsum == cumsum_prev: continue + if cumsum == cumsum_prev: + continue # Calculate the lower bound of interpolation - lhs = bucket_limit[i-1] if (i > 0 and cumsum_prev > 0) else histo_min + lhs = bucket_limit[i - 1] if (i > 0 and cumsum_prev > 0) else histo_min lhs = max(lhs, histo_min) # Calculate the upper bound of interpolation @@ -398,8 +424,9 @@ class EventAccumulator(object): step: Number of steps that have passed histo: proto2 histogram Object """ + def _CumulativeSum(arr): - return [sum(arr[:i+1]) for i in range(len(arr))] + return [sum(arr[:i + 1]) for i in range(len(arr))] # Convert from proto repeated field into a Python list. bucket = list(histo.bucket) @@ -443,13 +470,11 @@ class EventAccumulator(object): def _ProcessImage(self, tag, wall_time, step, image): """Processes an image by adding it to accumulated state.""" - event = ImageEvent( - wall_time=wall_time, - step=step, - encoded_image_string=image.encoded_image_string, - width=image.width, - height=image.height - ) + event = ImageEvent(wall_time=wall_time, + step=step, + encoded_image_string=image.encoded_image_string, + width=image.width, + height=image.height) self._images.AddItem(tag, event) def _ProcessScalar(self, tag, wall_time, step, scalar): diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py index ada914bdf39..9a92008a0da 100644 --- a/tensorflow/python/summary/event_accumulator_test.py +++ b/tensorflow/python/summary/event_accumulator_test.py @@ -41,34 +41,50 @@ class _EventGenerator(object): def AddScalar(self, tag, wall_time=0, step=0, value=0): event = tf.Event( - wall_time=wall_time, step=step, - summary=tf.Summary( - value=[tf.Summary.Value(tag=tag, simple_value=value)] - ) - ) + wall_time=wall_time, + step=step, + summary=tf.Summary(value=[tf.Summary.Value(tag=tag, + simple_value=value)])) self.AddEvent(event) - def AddHistogram(self, tag, wall_time=0, step=0, hmin=1, hmax=2, hnum=3, - hsum=4, hsum_squares=5, hbucket_limit=None, hbucket=None): - histo = tf.HistogramProto(min=hmin, max=hmax, num=hnum, sum=hsum, + def AddHistogram(self, + tag, + wall_time=0, + step=0, + hmin=1, + hmax=2, + hnum=3, + hsum=4, + hsum_squares=5, + hbucket_limit=None, + hbucket=None): + histo = tf.HistogramProto(min=hmin, + max=hmax, + num=hnum, + sum=hsum, sum_squares=hsum_squares, bucket_limit=hbucket_limit, bucket=hbucket) - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)])) + event = tf.Event(wall_time=wall_time, + step=step, + summary=tf.Summary(value=[tf.Summary.Value(tag=tag, + histo=histo)])) self.AddEvent(event) - def AddImage(self, tag, wall_time=0, step=0, encoded_image_string=b'imgstr', - width=150, height=100): + def AddImage(self, + tag, + wall_time=0, + step=0, + encoded_image_string=b'imgstr', + width=150, + height=100): image = tf.Summary.Image(encoded_image_string=encoded_image_string, - width=width, height=height) - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary( - value=[tf.Summary.Value(tag=tag, image=image)])) + width=width, + height=height) + event = tf.Event(wall_time=wall_time, + step=step, + summary=tf.Summary(value=[tf.Summary.Value(tag=tag, + image=image)])) self.AddEvent(event) def AddEvent(self, event): @@ -103,9 +119,11 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): ea.GRAPH: False} self._real_constructor = ea.EventAccumulator self._real_generator = ea._GeneratorFromPath + def _FakeAccumulatorConstructor(generator, *args, **kwargs): ea._GeneratorFromPath = lambda x: generator return self._real_constructor(generator, *args, **kwargs) + ea.EventAccumulator = _FakeAccumulatorConstructor def tearDown(self): @@ -129,13 +147,13 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): gen.AddImage('im2') acc = ea.EventAccumulator(gen) acc.Reload() - self.assertTagsEqual( - acc.Tags(), { - ea.IMAGES: ['im1', 'im2'], - ea.SCALARS: ['s1', 's2'], - ea.HISTOGRAMS: ['hst1', 'hst2'], - ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - ea.GRAPH: False}) + self.assertTagsEqual(acc.Tags(), { + ea.IMAGES: ['im1', 'im2'], + ea.SCALARS: ['s1', 's2'], + ea.HISTOGRAMS: ['hst1', 'hst2'], + ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], + ea.GRAPH: False + }) def testReload(self): gen = _EventGenerator() @@ -155,7 +173,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): ea.SCALARS: ['s1', 's2'], ea.HISTOGRAMS: ['hst1', 'hst2'], ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - ea.GRAPH: False}) + ea.GRAPH: False + }) def testScalars(self): gen = _EventGenerator() @@ -172,18 +191,42 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): gen = _EventGenerator() acc = ea.EventAccumulator(gen) - val1 = ea.HistogramValue(min=1, max=2, num=3, sum=4, sum_squares=5, - bucket_limit=[1, 2, 3], bucket=[0, 3, 0]) - val2 = ea.HistogramValue(min=-2, max=3, num=4, sum=5, sum_squares=6, - bucket_limit=[2, 3, 4], bucket=[1, 3, 0]) + val1 = ea.HistogramValue(min=1, + max=2, + num=3, + sum=4, + sum_squares=5, + bucket_limit=[1, 2, 3], + bucket=[0, 3, 0]) + val2 = ea.HistogramValue(min=-2, + max=3, + num=4, + sum=5, + sum_squares=6, + bucket_limit=[2, 3, 4], + bucket=[1, 3, 0]) hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1) hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2) - gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3, - hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3], + gen.AddHistogram('hst1', + wall_time=1, + step=10, + hmin=1, + hmax=2, + hnum=3, + hsum=4, + hsum_squares=5, + hbucket_limit=[1, 2, 3], hbucket=[0, 3, 0]) - gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4, - hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4], + gen.AddHistogram('hst2', + wall_time=2, + step=12, + hmin=-2, + hmax=3, + hnum=4, + hsum=5, + hsum_squares=6, + hbucket_limit=[2, 3, 4], hbucket=[1, 3, 0]) acc.Reload() self.assertEqual(acc.Histograms('hst1'), [hst1]) @@ -193,17 +236,32 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): gen = _EventGenerator() acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000)) - gen.AddHistogram('hst1', wall_time=1, step=10, hmin=1, hmax=2, hnum=3, - hsum=4, hsum_squares=5, hbucket_limit=[1, 2, 3], + gen.AddHistogram('hst1', + wall_time=1, + step=10, + hmin=1, + hmax=2, + hnum=3, + hsum=4, + hsum_squares=5, + hbucket_limit=[1, 2, 3], hbucket=[0, 3, 0]) - gen.AddHistogram('hst2', wall_time=2, step=12, hmin=-2, hmax=3, hnum=4, - hsum=5, hsum_squares=6, hbucket_limit=[2, 3, 4], + gen.AddHistogram('hst2', + wall_time=2, + step=12, + hmin=-2, + hmax=3, + hnum=4, + hsum=5, + hsum_squares=6, + hbucket_limit=[2, 3, 4], hbucket=[1, 3, 0]) acc.Reload() # Create the expected values after compressing hst1 - expected_vals1 = [ea.CompressedHistogramValue(bp, val) for bp, val in [( - 0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75), (10000, 2.0)]] + expected_vals1 = [ea.CompressedHistogramValue(bp, val) + for bp, val in [(0, 1.0), (2500, 1.25), (5000, 1.5), ( + 7500, 1.75), (10000, 2.0)]] expected_cmphst1 = ea.CompressedHistogramEvent( wall_time=1, step=10, @@ -225,8 +283,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def testPercentile(self): def AssertExpectedForBps(bps, expected): - output = acc._Percentile( - bps, bucket_limit, cumsum_weights, histo_min, histo_max, histo_num) + output = acc._Percentile(bps, bucket_limit, cumsum_weights, histo_min, + histo_max, histo_num) self.assertAlmostEqual(expected, output) gen = _EventGenerator() @@ -311,14 +369,28 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): def testImages(self): gen = _EventGenerator() acc = ea.EventAccumulator(gen) - im1 = ea.ImageEvent(wall_time=1, step=10, encoded_image_string=b'big', - width=400, height=300) - im2 = ea.ImageEvent(wall_time=2, step=12, encoded_image_string=b'small', - width=40, height=30) - gen.AddImage('im1', wall_time=1, step=10, encoded_image_string=b'big', - width=400, height=300) - gen.AddImage('im2', wall_time=2, step=12, encoded_image_string=b'small', - width=40, height=30) + im1 = ea.ImageEvent(wall_time=1, + step=10, + encoded_image_string=b'big', + width=400, + height=300) + im2 = ea.ImageEvent(wall_time=2, + step=12, + encoded_image_string=b'small', + width=40, + height=30) + gen.AddImage('im1', + wall_time=1, + step=10, + encoded_image_string=b'big', + width=400, + height=300) + gen.AddImage('im2', + wall_time=2, + step=12, + encoded_image_string=b'small', + width=40, + height=30) acc.Reload() self.assertEqual(acc.Images('im1'), [im1]) self.assertEqual(acc.Images('im2'), [im2]) @@ -359,8 +431,7 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): gen = _EventGenerator() acc = ea.EventAccumulator(gen) gen.AddScalar('s1', wall_time=1, step=10, value=20) - gen.AddEvent(tf.Event( - wall_time=2, step=20, file_version='nots2')) + gen.AddEvent(tf.Event(wall_time=2, step=20, file_version='nots2')) gen.AddScalar('s3', wall_time=3, step=100, value=1) gen.AddHistogram('hst1') gen.AddImage('im1') @@ -371,7 +442,8 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): ea.SCALARS: ['s1', 's3'], ea.HISTOGRAMS: ['hst1'], ea.COMPRESSED_HISTOGRAMS: ['hst1'], - ea.GRAPH: False}) + ea.GRAPH: False + }) def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): """Tests that events are discarded after a restart is detected. @@ -404,6 +476,28 @@ class MockingEventAccumulatorTest(EventAccumulatorTest): ## Check that we have discarded 200 and 300 from s1 self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) + def testOrphanedDataNotDiscardedIfFlagUnset(self): + """Tests that events are not discarded if purge_orphaned_data is false. + """ + gen = _EventGenerator() + acc = ea.EventAccumulator(gen, purge_orphaned_data=False) + + gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) + gen.AddScalar('s1', wall_time=1, step=100, value=20) + gen.AddScalar('s1', wall_time=1, step=200, value=20) + gen.AddScalar('s1', wall_time=1, step=300, value=20) + acc.Reload() + ## Check that number of items are what they should be + self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) + + gen.AddScalar('s1', wall_time=1, step=101, value=20) + gen.AddScalar('s1', wall_time=1, step=201, value=20) + gen.AddScalar('s1', wall_time=1, step=301, value=20) + acc.Reload() + ## Check that we have discarded 200 and 300 from s1 + self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300, 101, + 201, 301]) + def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): """Tests that event discards after restart, only affect the misordered tag. @@ -486,6 +580,7 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): def testScalarsRealistically(self): """Test accumulator by writing values and then reading them.""" + def FakeScalarSummary(tag, value): value = tf.Summary.Value(tag=tag, simple_value=value) summary = tf.Summary(value=[value]) @@ -504,9 +599,9 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): # Write a bunch of events using the writer for i in xrange(30): summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i*i) - writer.add_summary(summ_id, i*5) - writer.add_summary(summ_sq, i*5) + summ_sq = FakeScalarSummary('sq', i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) writer.flush() # Verify that we can load those events properly @@ -517,23 +612,24 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): ea.SCALARS: ['id', 'sq'], ea.HISTOGRAMS: [], ea.COMPRESSED_HISTOGRAMS: [], - ea.GRAPH: True}) + ea.GRAPH: True + }) id_events = acc.Scalars('id') sq_events = acc.Scalars('sq') self.assertEqual(30, len(id_events)) self.assertEqual(30, len(sq_events)) for i in xrange(30): - self.assertEqual(i*5, id_events[i].step) - self.assertEqual(i*5, sq_events[i].step) + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) self.assertEqual(i, id_events[i].value) - self.assertEqual(i*i, sq_events[i].value) + self.assertEqual(i * i, sq_events[i].value) # Write a few more events to test incremental reloading for i in xrange(30, 40): summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i*i) - writer.add_summary(summ_id, i*5) - writer.add_summary(summ_sq, i*5) + summ_sq = FakeScalarSummary('sq', i * i) + writer.add_summary(summ_id, i * 5) + writer.add_summary(summ_sq, i * 5) writer.flush() # Verify we can now see all of the data @@ -541,10 +637,10 @@ class RealisticEventAccumulatorTest(EventAccumulatorTest): self.assertEqual(40, len(id_events)) self.assertEqual(40, len(sq_events)) for i in xrange(40): - self.assertEqual(i*5, id_events[i].step) - self.assertEqual(i*5, sq_events[i].step) + self.assertEqual(i * 5, id_events[i].step) + self.assertEqual(i * 5, sq_events[i].step) self.assertEqual(i, id_events[i].value) - self.assertEqual(i*i, sq_events[i].value) + self.assertEqual(i * i, sq_events[i].value) self.assertProtoEquals(graph_def, acc.Graph()) diff --git a/tensorflow/python/summary/event_multiplexer.py b/tensorflow/python/summary/event_multiplexer.py index c02a30926a7..4cdfb815940 100644 --- a/tensorflow/python/summary/event_multiplexer.py +++ b/tensorflow/python/summary/event_multiplexer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Provides an interface for working with multiple event files.""" from __future__ import absolute_import @@ -77,8 +76,10 @@ class EventMultiplexer(object): @@Images """ - def __init__(self, run_path_map=None, - size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE): + def __init__(self, + run_path_map=None, + size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE, + purge_orphaned_data=True): """Constructor for the `EventMultiplexer`. Args: @@ -88,12 +89,15 @@ class EventMultiplexer(object): size_guidance: A dictionary mapping from `tagType` to the number of items to store for each tag of that type. See `event_ccumulator.EventAccumulator` for details. + purge_orphaned_data: Whether to discard any events that were "orphaned" by + a TensorFlow restart. """ self._accumulators_mutex = threading.Lock() self._accumulators = {} self._paths = {} self._reload_called = False self._size_guidance = size_guidance + self.purge_orphaned_data = purge_orphaned_data if run_path_map is not None: for (run, path) in six.iteritems(run_path_map): self.AddRun(path, run) @@ -129,8 +133,8 @@ class EventMultiplexer(object): logging.warning('Conflict for name %s: old path %s, new path %s', name, self._paths[name], path) logging.info('Constructing EventAccumulator for %s', path) - accumulator = event_accumulator.EventAccumulator(path, - self._size_guidance) + accumulator = event_accumulator.EventAccumulator( + path, self._size_guidance, self.purge_orphaned_data) self._accumulators[name] = accumulator self._paths[name] = path if accumulator: @@ -169,7 +173,7 @@ class EventMultiplexer(object): return # Maybe it hasn't been created yet, fail silently to retry later if not gfile.IsDirectory(path): raise ValueError('AddRunsFromDirectory: path exists and is not a ' - 'directory, %s' % path) + 'directory, %s' % path) for (subdir, _, files) in gfile.Walk(path): if list(filter(event_accumulator.IsTensorFlowEventsFile, files)): @@ -294,10 +298,7 @@ class EventMultiplexer(object): with self._accumulators_mutex: # To avoid nested locks, we construct a copy of the run-accumulator map items = list(six.iteritems(self._accumulators)) - return { - run_name: accumulator.Tags() - for run_name, accumulator in items - } + return {run_name: accumulator.Tags() for run_name, accumulator in items} def _GetAccumulator(self, run): with self._accumulators_mutex: diff --git a/tensorflow/tensorboard/scripts/serialize_tensorboard.py b/tensorflow/tensorboard/scripts/serialize_tensorboard.py index 707aa38c132..82d0cd1be94 100644 --- a/tensorflow/tensorboard/scripts/serialize_tensorboard.py +++ b/tensorflow/tensorboard/scripts/serialize_tensorboard.py @@ -48,6 +48,11 @@ will be written""") tf.flags.DEFINE_boolean('overwrite', False, """Whether to remove and overwrite TARGET if it already exists.""") +tf.flags.DEFINE_boolean( + 'purge_orphaned_data', True, 'Whether to purge data that ' + 'may have been orphaned due to TensorBoard restarts. ' + 'Disabling purge_orphaned_data can be used to debug data ' + 'disappearance') FLAGS = tf.flags.FLAGS BAD_CHARACTERS = "#%&{}\\/<>*? $!'\":@+`|=" @@ -161,7 +166,8 @@ def main(unused_argv=None): PrintAndLog('About to load Multiplexer. This may take some time.') multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=server.TENSORBOARD_SIZE_GUIDANCE) + size_guidance=server.TENSORBOARD_SIZE_GUIDANCE, + purge_orphaned_data=FLAGS.purge_orphaned_data) server.ReloadMultiplexer(multiplexer, path_to_run) PrintAndLog('Multiplexer load finished. Starting TensorBoard server.') diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py index 690ec34c211..24cd993ab89 100644 --- a/tensorflow/tensorboard/tensorboard.py +++ b/tensorflow/tensorboard/tensorboard.py @@ -53,6 +53,11 @@ flags.DEFINE_string('host', '0.0.0.0', 'What host to listen to. Defaults to ' flags.DEFINE_integer('port', 6006, 'What port to serve TensorBoard on.') +flags.DEFINE_boolean('purge_orphaned_data', True, 'Whether to purge data that ' + 'may have been orphaned due to TensorBoard restarts. ' + 'Disabling purge_orphaned_data can be used to debug data ' + 'disappearance') + FLAGS = flags.FLAGS @@ -73,7 +78,8 @@ def main(unused_argv=None): logging.info('TensorBoard path_to_run is: %s', path_to_run) multiplexer = event_multiplexer.EventMultiplexer( - size_guidance=server.TENSORBOARD_SIZE_GUIDANCE) + size_guidance=server.TENSORBOARD_SIZE_GUIDANCE, + purge_orphaned_data=FLAGS.purge_orphaned_data) server.StartMultiplexerReloadingThread(multiplexer, path_to_run) try: tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port)