Add summary.writer before invoking trace_export
PiperOrigin-RevId: 236422330
This commit is contained in:
parent
d63fb26158
commit
876604a93b
@ -1140,7 +1140,7 @@ class TensorBoard(Callback):
|
||||
your training.
|
||||
profile_batch: Profile the batch to sample compute characteristics. By
|
||||
default, it will profile the second batch. Set profile_batch=0 to
|
||||
disable profiling.
|
||||
disable profiling. Must run in TensorFlow eager mode.
|
||||
|
||||
Raises:
|
||||
ValueError: If histogram_freq is set and no validation data is provided.
|
||||
@ -1274,16 +1274,10 @@ class TensorBoard(Callback):
|
||||
self._samples_seen_at_last_write = self._samples_seen
|
||||
self._total_batches_seen += 1
|
||||
if self._is_tracing:
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
summary_ops_v2.trace_export(
|
||||
name='batch_%d' % self._total_batches_seen,
|
||||
step=self._total_batches_seen,
|
||||
profiler_outdir=self.log_dir)
|
||||
self._is_tracing = False
|
||||
self._log_trace()
|
||||
elif (not self._is_tracing and
|
||||
self._total_batches_seen == self._profile_batch - 1):
|
||||
summary_ops_v2.trace_on(graph=True, profiler=True)
|
||||
self._is_tracing = True
|
||||
self._enable_trace()
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Runs metrics and histogram summaries at epoch end."""
|
||||
@ -1294,13 +1288,24 @@ class TensorBoard(Callback):
|
||||
self._log_weights(epoch)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
self._close_writers()
|
||||
if self._is_tracing:
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
summary_ops_v2.trace_export(
|
||||
name='batch_%d' % self._total_batches_seen,
|
||||
step=self._total_batches_seen,
|
||||
profiler_outdir=self.log_dir)
|
||||
self._log_trace()
|
||||
self._close_writers()
|
||||
|
||||
def _enable_trace(self):
|
||||
if context.executing_eagerly():
|
||||
summary_ops_v2.trace_on(graph=True, profiler=True)
|
||||
self._is_tracing = True
|
||||
|
||||
def _log_trace(self):
|
||||
if context.executing_eagerly():
|
||||
with self._train_writer.as_default(), \
|
||||
summary_ops_v2.always_record_summaries():
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
summary_ops_v2.trace_export(
|
||||
name='batch_%d' % self._total_batches_seen,
|
||||
step=self._total_batches_seen,
|
||||
profiler_outdir=self.log_dir)
|
||||
self._is_tracing = False
|
||||
|
||||
def _log_metrics(self, logs, prefix, step):
|
||||
|
@ -1236,17 +1236,28 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
|
||||
|
||||
# Note that this test specifies model_type explicitly.
|
||||
class TestTensorBoardV2WriteModelTest(test.TestCase):
|
||||
class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestTensorBoardV2WriteModelTest, self).setUp()
|
||||
super(TestTensorBoardV2NonParameterizedTest, self).setUp()
|
||||
self.logdir = os.path.join(self.get_temp_dir(), 'tb')
|
||||
self.train_dir = os.path.join(self.logdir, 'train')
|
||||
self.validation_dir = os.path.join(self.logdir, 'validation')
|
||||
|
||||
def _get_seq_model(self):
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
|
||||
keras.layers.Flatten(),
|
||||
keras.layers.Dense(1),
|
||||
])
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
return model
|
||||
|
||||
def fitModelAndAssertKerasModelWritten(self, model):
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(self.logdir, write_graph=True)
|
||||
tb_cbk = keras.callbacks.TensorBoard(self.logdir,
|
||||
write_graph=True,
|
||||
profile_batch=0)
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
@ -1289,6 +1300,74 @@ class TestTensorBoardV2WriteModelTest(test.TestCase):
|
||||
model.compile('sgd', 'mse', run_eagerly=False)
|
||||
self.fitModelAndAssertKerasModelWritten(model)
|
||||
|
||||
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_TensorBoard_autoTrace(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch=1, write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.tensors,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
||||
},
|
||||
)
|
||||
|
||||
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch=2, write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.tensors,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||
},
|
||||
)
|
||||
|
||||
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
# Enabled trace only on the 10000th batch, thus it should be empty.
|
||||
self.assertEmpty(summary_file.tensors)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user