Add summary.writer before invoking trace_export

PiperOrigin-RevId: 236422330
This commit is contained in:
Stephan Lee 2019-03-01 20:38:19 -08:00 committed by TensorFlower Gardener
parent d63fb26158
commit 876604a93b
2 changed files with 102 additions and 18 deletions

View File

@ -1140,7 +1140,7 @@ class TensorBoard(Callback):
your training.
profile_batch: Profile the batch to sample compute characteristics. By
default, it will profile the second batch. Set profile_batch=0 to
disable profiling.
disable profiling. Must run in TensorFlow eager mode.
Raises:
ValueError: If histogram_freq is set and no validation data is provided.
@ -1274,16 +1274,10 @@ class TensorBoard(Callback):
self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1
if self._is_tracing:
# TODO(b/126388999): Remove step info in the summary name.
summary_ops_v2.trace_export(
name='batch_%d' % self._total_batches_seen,
step=self._total_batches_seen,
profiler_outdir=self.log_dir)
self._is_tracing = False
self._log_trace()
elif (not self._is_tracing and
self._total_batches_seen == self._profile_batch - 1):
summary_ops_v2.trace_on(graph=True, profiler=True)
self._is_tracing = True
self._enable_trace()
def on_epoch_end(self, epoch, logs=None):
"""Runs metrics and histogram summaries at epoch end."""
@ -1294,13 +1288,24 @@ class TensorBoard(Callback):
self._log_weights(epoch)
def on_train_end(self, logs=None):
self._close_writers()
if self._is_tracing:
# TODO(b/126388999): Remove step info in the summary name.
summary_ops_v2.trace_export(
name='batch_%d' % self._total_batches_seen,
step=self._total_batches_seen,
profiler_outdir=self.log_dir)
self._log_trace()
self._close_writers()
def _enable_trace(self):
if context.executing_eagerly():
summary_ops_v2.trace_on(graph=True, profiler=True)
self._is_tracing = True
def _log_trace(self):
if context.executing_eagerly():
with self._train_writer.as_default(), \
summary_ops_v2.always_record_summaries():
# TODO(b/126388999): Remove step info in the summary name.
summary_ops_v2.trace_export(
name='batch_%d' % self._total_batches_seen,
step=self._total_batches_seen,
profiler_outdir=self.log_dir)
self._is_tracing = False
def _log_metrics(self, logs, prefix, step):

View File

@ -1236,17 +1236,28 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
# Note that this test specifies model_type explicitly.
class TestTensorBoardV2WriteModelTest(test.TestCase):
class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
def setUp(self):
super(TestTensorBoardV2WriteModelTest, self).setUp()
super(TestTensorBoardV2NonParameterizedTest, self).setUp()
self.logdir = os.path.join(self.get_temp_dir(), 'tb')
self.train_dir = os.path.join(self.logdir, 'train')
self.validation_dir = os.path.join(self.logdir, 'validation')
def _get_seq_model(self):
model = keras.models.Sequential([
keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
keras.layers.Flatten(),
keras.layers.Dense(1),
])
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
return model
def fitModelAndAssertKerasModelWritten(self, model):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(self.logdir, write_graph=True)
tb_cbk = keras.callbacks.TensorBoard(self.logdir,
write_graph=True,
profile_batch=0)
model.fit(
x,
y,
@ -1289,6 +1300,74 @@ class TestTensorBoardV2WriteModelTest(test.TestCase):
model.compile('sgd', 'mse', run_eagerly=False)
self.fitModelAndAssertKerasModelWritten(model)
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_TensorBoard_autoTrace(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch=1, write_graph=False)
model.fit(
x,
y,
batch_size=2,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
},
)
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch=2, write_graph=False)
model.fit(
x,
y,
batch_size=2,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False)
model.fit(
x,
y,
batch_size=2,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
# Enabled trace only on the 10000th batch, thus it should be empty.
self.assertEmpty(summary_file.tensors)
if __name__ == '__main__':
test.main()