Extend the Keras TensorBoard callback to optionally display the global steps per second happening during training.
PiperOrigin-RevId: 344210886 Change-Id: I36854954709ec29944696629ec67ef453e4dd2a7
This commit is contained in:
parent
2e9f721453
commit
cef7da10e1
tensorflow
python/keras
tools/api/golden/v2
@ -669,7 +669,7 @@ class Callback(object):
|
||||
epoch: Integer, index of epoch.
|
||||
logs: Dict, metric results for this training epoch, and for the
|
||||
validation epoch if validation is performed. Validation result keys
|
||||
are prefixed with `val_`. For training epoch, the values of the
|
||||
are prefixed with `val_`. For training epoch, the values of the
|
||||
`Model`'s metrics are returned. Example : `{'loss': 0.2, 'acc': 0.7}`.
|
||||
"""
|
||||
|
||||
@ -2002,6 +2002,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
can become quite large when write_graph is set to True.
|
||||
write_images: whether to write model weights to visualize as image in
|
||||
TensorBoard.
|
||||
write_steps_per_second: whether to log the training steps per second into
|
||||
Tensorboard. This supports both epoch and batch frequency logging.
|
||||
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
|
||||
writes the losses and metrics to TensorBoard after each batch. The same
|
||||
applies for `'epoch'`. If using an integer, let's say `1000`, the
|
||||
@ -2097,6 +2099,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
histogram_freq=0,
|
||||
write_graph=True,
|
||||
write_images=False,
|
||||
write_steps_per_second=False,
|
||||
update_freq='epoch',
|
||||
profile_batch=2,
|
||||
embeddings_freq=0,
|
||||
@ -2110,12 +2113,16 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
self.histogram_freq = histogram_freq
|
||||
self.write_graph = write_graph
|
||||
self.write_images = write_images
|
||||
self.write_steps_per_second = write_steps_per_second
|
||||
self.update_freq = 1 if update_freq == 'batch' else update_freq
|
||||
self.embeddings_freq = embeddings_freq
|
||||
self.embeddings_metadata = embeddings_metadata
|
||||
self._init_profile_batch(profile_batch)
|
||||
self._epoch = 0
|
||||
self._global_train_batch = 0
|
||||
self._previous_epoch_iterations = 0
|
||||
self._train_accumulated_time = 0
|
||||
self._batch_start_time = 0
|
||||
|
||||
# Lazily initialized in order to avoid creating event files when
|
||||
# not needed.
|
||||
@ -2336,6 +2343,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self._global_train_batch = 0
|
||||
self._previous_epoch_iterations = 0
|
||||
self._train_accumulated_time = 0
|
||||
self._push_writer(self._train_writer, self._train_step)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
@ -2358,6 +2367,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
self._global_train_batch += 1
|
||||
if self.write_steps_per_second:
|
||||
self._batch_start_time = time.time()
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
@ -2368,6 +2379,10 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
if self._should_write_train_graph:
|
||||
self._write_keras_model_train_graph()
|
||||
self._should_write_train_graph = False
|
||||
if self.write_steps_per_second:
|
||||
batch_run_time = time.time() - self._batch_start_time
|
||||
self._train_accumulated_time += batch_run_time
|
||||
summary_ops_v2.scalar('batch_steps_per_second', 1. / batch_run_time)
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
@ -2377,6 +2392,9 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
# Keeps track of epoch for profiling.
|
||||
self._epoch = epoch
|
||||
if self.write_steps_per_second:
|
||||
self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
|
||||
self._train_accumulated_time = 0
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Runs metrics and histogram summaries at epoch end."""
|
||||
@ -2410,6 +2428,12 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
|
||||
return logs
|
||||
|
||||
def _compute_steps_per_second(self):
|
||||
current_iteration = self.model.optimizer.iterations.numpy()
|
||||
steps_per_second = ((current_iteration - self._previous_epoch_iterations) /
|
||||
(self._train_accumulated_time))
|
||||
return steps_per_second
|
||||
|
||||
def _log_epoch_metrics(self, epoch, logs):
|
||||
"""Writes epoch metrics out as scalar summaries.
|
||||
|
||||
@ -2423,6 +2447,8 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
|
||||
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
|
||||
train_logs = self._collect_learning_rate(train_logs)
|
||||
if self.write_steps_per_second:
|
||||
train_logs['steps_per_second'] = self._compute_steps_per_second()
|
||||
|
||||
with summary_ops_v2.record_if(True):
|
||||
if train_logs:
|
||||
|
@ -2029,6 +2029,37 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_global_step(self):
|
||||
model = self._get_model(compile_model=False)
|
||||
opt = gradient_descent.SGD(learning_rate_schedule.CosineDecay(0.01, 1))
|
||||
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=2,
|
||||
epochs=2,
|
||||
callbacks=[
|
||||
keras.callbacks.TensorBoard(
|
||||
self.logdir, update_freq=1, write_steps_per_second=True)
|
||||
])
|
||||
|
||||
summary_file = list_summaries(self.logdir)
|
||||
self.assertEqual(
|
||||
summary_file.scalars,
|
||||
{
|
||||
_ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
|
||||
_ObservedSummary(logdir=self.train_dir, tag='batch_loss'),
|
||||
_ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag='epoch_steps_per_second'),
|
||||
_ObservedSummary(
|
||||
logdir=self.train_dir, tag='batch_steps_per_second'),
|
||||
},
|
||||
)
|
||||
|
||||
def test_TensorBoard_weight_histograms(self):
|
||||
model = self._get_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
|
@ -6,7 +6,7 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'log_dir\', \'histogram_freq\', \'write_graph\', \'write_images\', \'write_steps_per_second\', \'update_freq\', \'profile_batch\', \'embeddings_freq\', \'embeddings_metadata\'], varargs=None, keywords=kwargs, defaults=[\'logs\', \'0\', \'True\', \'False\', \'False\', \'epoch\', \'2\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "on_batch_begin"
|
||||
|
Loading…
Reference in New Issue
Block a user