Extend the Keras TensorBoard callback to optionally display the global steps per second happening during training.

PiperOrigin-RevId: 344210886
Change-Id: I36854954709ec29944696629ec67ef453e4dd2a7
This commit is contained in:
A. Unique TensorFlower 2020-11-25 01:43:42 -08:00 committed by TensorFlower Gardener
parent 2e9f721453
commit cef7da10e1
3 changed files with 59 additions and 2 deletions

View File

@ -669,7 +669,7 @@ class Callback(object):
epoch: Integer, index of epoch.
logs: Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys
are prefixed with `val_`. For training epoch, the values of the
are prefixed with `val_`. For training epoch, the values of the
`Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
"""
@ -2002,6 +2002,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
can become quite large when write_graph is set to True.
write_images: whether to write model weights to visualize as image in
TensorBoard.
write_steps_per_second: whether to log the training steps per second into
Tensorboard. This supports both epoch and batch frequency logging.
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
writes the losses and metrics to TensorBoard after each batch. The same
applies for `'epoch'`. If using an integer, let's say `1000`, the
@ -2097,6 +2099,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
histogram_freq=0,
write_graph=True,
write_images=False,
write_steps_per_second=False,
update_freq='epoch',
profile_batch=2,
embeddings_freq=0,
@ -2110,12 +2113,16 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
self.histogram_freq = histogram_freq
self.write_graph = write_graph
self.write_images = write_images
self.write_steps_per_second = write_steps_per_second
self.update_freq = 1 if update_freq == 'batch' else update_freq
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
self._init_profile_batch(profile_batch)
self._epoch = 0
self._global_train_batch = 0
self._previous_epoch_iterations = 0
self._train_accumulated_time = 0
self._batch_start_time = 0
# Lazily initialized in order to avoid creating event files when
# not needed.
@ -2336,6 +2343,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
def on_train_begin(self, logs=None):
self._global_train_batch = 0
self._previous_epoch_iterations = 0
self._train_accumulated_time = 0
self._push_writer(self._train_writer, self._train_step)
def on_train_end(self, logs=None):
@ -2358,6 +2367,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
def on_train_batch_begin(self, batch, logs=None):
self._global_train_batch += 1
if self.write_steps_per_second:
self._batch_start_time = time.time()
if not self._should_trace:
return
@ -2368,6 +2379,10 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
if self._should_write_train_graph:
self._write_keras_model_train_graph()
self._should_write_train_graph = False
if self.write_steps_per_second:
batch_run_time = time.time() - self._batch_start_time
self._train_accumulated_time += batch_run_time
summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time)
if not self._should_trace:
return
@ -2377,6 +2392,9 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
self._epoch = epoch
if self.write_steps_per_second:
self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
self._train_accumulated_time = 0
def on_epoch_end(self, epoch, logs=None):
"""Runs metrics and histogram summaries at epoch end."""
@ -2410,6 +2428,12 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
return logs
def _compute_steps_per_second(self):
current_iteration = self.model.optimizer.iterations.numpy()
steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
(self._train_accumulated_time))
return steps_per_second
def _log_epoch_metrics(self, epoch, logs):
"""Writes epoch metrics out as scalar summaries.
@ -2423,6 +2447,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
train_logs = self._collect_learning_rate(train_logs)
if self.write_steps_per_second:
train_logs['steps_per_second'] = self._compute_steps_per_second()
with summary_ops_v2.record_if(True):
if train_logs:

View File

@ -2029,6 +2029,37 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
},
)
def test_TensorBoard_global_step(self):
model = self._get_model(compile_model=False)
opt = gradient_descent.SGD(learning_rate_schedule.CosineDecay(0.01, 1))
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
model.fit(
x,
y,
batch_size=2,
epochs=2,
callbacks=[
keras.callbacks.TensorBoard(
self.logdir, update_freq=1, write_steps_per_second=True)
])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
_ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
_ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'),
_ObservedSummary(
logdir=self.train_dir, tag='epoch_steps_per_second'),
_ObservedSummary(
logdir=self.train_dir, tag='batch_steps_per_second'),
},
)
def test_TensorBoard_weight_histograms(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))

View File

@ -6,7 +6,7 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'write_steps_per_second\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
}
member_method {
name: "on_batch_begin"