Add summary.writer before invoking trace_export
PiperOrigin-RevId: 236422330
This commit is contained in:
parent
d63fb26158
commit
876604a93b
@ -1140,7 +1140,7 @@ class TensorBoard(Callback):
|
|||||||
your training.
|
your training.
|
||||||
profile_batch: Profile the batch to sample compute characteristics. By
|
profile_batch: Profile the batch to sample compute characteristics. By
|
||||||
default, it will profile the second batch. Set profile_batch=0 to
|
default, it will profile the second batch. Set profile_batch=0 to
|
||||||
disable profiling.
|
disable profiling. Must run in TensorFlow eager mode.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If histogram_freq is set and no validation data is provided.
|
ValueError: If histogram_freq is set and no validation data is provided.
|
||||||
@ -1274,16 +1274,10 @@ class TensorBoard(Callback):
|
|||||||
self._samples_seen_at_last_write = self._samples_seen
|
self._samples_seen_at_last_write = self._samples_seen
|
||||||
self._total_batches_seen += 1
|
self._total_batches_seen += 1
|
||||||
if self._is_tracing:
|
if self._is_tracing:
|
||||||
# TODO(b/126388999): Remove step info in the summary name.
|
self._log_trace()
|
||||||
summary_ops_v2.trace_export(
|
|
||||||
name='batch_%d' % self._total_batches_seen,
|
|
||||||
step=self._total_batches_seen,
|
|
||||||
profiler_outdir=self.log_dir)
|
|
||||||
self._is_tracing = False
|
|
||||||
elif (not self._is_tracing and
|
elif (not self._is_tracing and
|
||||||
self._total_batches_seen == self._profile_batch - 1):
|
self._total_batches_seen == self._profile_batch - 1):
|
||||||
summary_ops_v2.trace_on(graph=True, profiler=True)
|
self._enable_trace()
|
||||||
self._is_tracing = True
|
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, logs=None):
|
def on_epoch_end(self, epoch, logs=None):
|
||||||
"""Runs metrics and histogram summaries at epoch end."""
|
"""Runs metrics and histogram summaries at epoch end."""
|
||||||
@ -1294,8 +1288,19 @@ class TensorBoard(Callback):
|
|||||||
self._log_weights(epoch)
|
self._log_weights(epoch)
|
||||||
|
|
||||||
def on_train_end(self, logs=None):
|
def on_train_end(self, logs=None):
|
||||||
self._close_writers()
|
|
||||||
if self._is_tracing:
|
if self._is_tracing:
|
||||||
|
self._log_trace()
|
||||||
|
self._close_writers()
|
||||||
|
|
||||||
|
def _enable_trace(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
summary_ops_v2.trace_on(graph=True, profiler=True)
|
||||||
|
self._is_tracing = True
|
||||||
|
|
||||||
|
def _log_trace(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
with self._train_writer.as_default(), \
|
||||||
|
summary_ops_v2.always_record_summaries():
|
||||||
# TODO(b/126388999): Remove step info in the summary name.
|
# TODO(b/126388999): Remove step info in the summary name.
|
||||||
summary_ops_v2.trace_export(
|
summary_ops_v2.trace_export(
|
||||||
name='batch_%d' % self._total_batches_seen,
|
name='batch_%d' % self._total_batches_seen,
|
||||||
|
@ -1236,17 +1236,28 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
# Note that this test specifies model_type explicitly.
|
# Note that this test specifies model_type explicitly.
|
||||||
class TestTensorBoardV2WriteModelTest(test.TestCase):
|
class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(TestTensorBoardV2WriteModelTest, self).setUp()
|
super(TestTensorBoardV2NonParameterizedTest, self).setUp()
|
||||||
self.logdir = os.path.join(self.get_temp_dir(), 'tb')
|
self.logdir = os.path.join(self.get_temp_dir(), 'tb')
|
||||||
self.train_dir = os.path.join(self.logdir, 'train')
|
self.train_dir = os.path.join(self.logdir, 'train')
|
||||||
self.validation_dir = os.path.join(self.logdir, 'validation')
|
self.validation_dir = os.path.join(self.logdir, 'validation')
|
||||||
|
|
||||||
|
def _get_seq_model(self):
|
||||||
|
model = keras.models.Sequential([
|
||||||
|
keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)),
|
||||||
|
keras.layers.Flatten(),
|
||||||
|
keras.layers.Dense(1),
|
||||||
|
])
|
||||||
|
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
return model
|
||||||
|
|
||||||
def fitModelAndAssertKerasModelWritten(self, model):
|
def fitModelAndAssertKerasModelWritten(self, model):
|
||||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
tb_cbk = keras.callbacks.TensorBoard(self.logdir, write_graph=True)
|
tb_cbk = keras.callbacks.TensorBoard(self.logdir,
|
||||||
|
write_graph=True,
|
||||||
|
profile_batch=0)
|
||||||
model.fit(
|
model.fit(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
@ -1289,6 +1300,74 @@ class TestTensorBoardV2WriteModelTest(test.TestCase):
|
|||||||
model.compile('sgd', 'mse', run_eagerly=False)
|
model.compile('sgd', 'mse', run_eagerly=False)
|
||||||
self.fitModelAndAssertKerasModelWritten(model)
|
self.fitModelAndAssertKerasModelWritten(model)
|
||||||
|
|
||||||
|
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
def test_TensorBoard_autoTrace(self):
|
||||||
|
model = self._get_seq_model()
|
||||||
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
|
tb_cbk = keras.callbacks.TensorBoard(
|
||||||
|
self.logdir, histogram_freq=1, profile_batch=1, write_graph=False)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=2,
|
||||||
|
epochs=2,
|
||||||
|
validation_data=(x, y),
|
||||||
|
callbacks=[tb_cbk])
|
||||||
|
summary_file = list_summaries(self.logdir)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
summary_file.tensors,
|
||||||
|
{
|
||||||
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
||||||
|
model = self._get_seq_model()
|
||||||
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
|
tb_cbk = keras.callbacks.TensorBoard(
|
||||||
|
self.logdir, histogram_freq=1, profile_batch=2, write_graph=False)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=2,
|
||||||
|
epochs=2,
|
||||||
|
validation_data=(x, y),
|
||||||
|
callbacks=[tb_cbk])
|
||||||
|
summary_file = list_summaries(self.logdir)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
summary_file.tensors,
|
||||||
|
{
|
||||||
|
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(b/126944683): Put parameterization in the class when graph is fixed.
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
|
||||||
|
model = self._get_seq_model()
|
||||||
|
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||||
|
tb_cbk = keras.callbacks.TensorBoard(
|
||||||
|
self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
batch_size=2,
|
||||||
|
epochs=2,
|
||||||
|
validation_data=(x, y),
|
||||||
|
callbacks=[tb_cbk])
|
||||||
|
summary_file = list_summaries(self.logdir)
|
||||||
|
|
||||||
|
# Enabled trace only on the 10000th batch, thus it should be empty.
|
||||||
|
self.assertEmpty(summary_file.tensors)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user