Create Variables to track mini-batches seen in Model.fit / evaluate / predict.
Use these counters in the TensorBoard Callback. PiperOrigin-RevId: 301459298 Change-Id: I8e92e119ef4cef37c41532d11caefd601d4395a7
This commit is contained in:
parent
1cea2490cb
commit
69565ec400
@ -35,21 +35,19 @@ import six
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distributed_file_utils
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.data_utils import Sequence
|
||||
from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
@ -1614,7 +1612,7 @@ class LearningRateScheduler(Callback):
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.TensorBoard', v1=[])
|
||||
class TensorBoard(Callback):
|
||||
class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
||||
# pylint: disable=line-too-long
|
||||
"""Enable visualizations for TensorBoard.
|
||||
|
||||
@ -1676,11 +1674,10 @@ class TensorBoard(Callback):
|
||||
batches. Note that writing too frequently to TensorBoard can slow down
|
||||
your training.
|
||||
profile_batch: Profile the batch(es) to sample compute characteristics.
|
||||
profile_batch must be a non-negative integer or a comma separated string
|
||||
of pair of positive integers. A pair of positive integers signify a
|
||||
range of batches to profile. By default, it will profile the second
|
||||
batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow
|
||||
eager mode.
|
||||
profile_batch must be a non-negative integer or a tuple of integers.
|
||||
A pair of positive integers signify a range of batches to profile.
|
||||
By default, it will profile the second batch. Set profile_batch=0
|
||||
to disable profiling. Must run in TensorFlow eager mode.
|
||||
embeddings_freq: frequency (in epochs) at which embedding layers will be
|
||||
visualized. If set to 0, embeddings won't be visualized.
|
||||
embeddings_metadata: a dictionary which maps layer name to a file name in
|
||||
@ -1713,30 +1710,18 @@ class TensorBoard(Callback):
|
||||
self.histogram_freq = histogram_freq
|
||||
self.write_graph = write_graph
|
||||
self.write_images = write_images
|
||||
if update_freq == 'batch':
|
||||
self.update_freq = 1
|
||||
else:
|
||||
self.update_freq = update_freq
|
||||
self.update_freq = 1 if update_freq == 'batch' else update_freq
|
||||
self.embeddings_freq = embeddings_freq
|
||||
self.embeddings_metadata = embeddings_metadata
|
||||
self._init_profile_batch(profile_batch)
|
||||
self._epoch = 0
|
||||
|
||||
self._samples_seen = 0
|
||||
self._samples_seen_at_last_write = 0
|
||||
self._current_batch = 0
|
||||
|
||||
# A collection of file writers currently in use, to be closed when
|
||||
# training ends for this callback. Writers are keyed by the
|
||||
# directory name under the root logdir: e.g., "train" or
|
||||
# "validation".
|
||||
self._train_run_name = 'train'
|
||||
self._validation_run_name = 'validation'
|
||||
# Lazily initialized in order to avoid creating event files when
|
||||
# not needed.
|
||||
self._writers = {}
|
||||
self._start_batch, self._stop_batch = self._init_profile_batch(
|
||||
profile_batch)
|
||||
if self._start_batch > 0:
|
||||
profiler.warmup() # Improve the profiling accuracy.
|
||||
# True when a trace is running.
|
||||
self._is_tracing = False
|
||||
|
||||
# Used to restore any existing `SummaryWriter` after training ends.
|
||||
self._prev_summary_state = []
|
||||
|
||||
def _validate_kwargs(self, kwargs):
|
||||
"""Handle arguments were supported in V1."""
|
||||
@ -1768,22 +1753,49 @@ class TensorBoard(Callback):
|
||||
def set_model(self, model):
|
||||
"""Sets Keras model and writes graph if specified."""
|
||||
self.model = model
|
||||
self._log_write_dir = self._get_log_write_dir()
|
||||
|
||||
# In case this callback is used via native Keras, _get_distribution_strategy does not exist.
|
||||
if hasattr(self.model, '_get_distribution_strategy'):
|
||||
# TensorBoard callback involves writing a summary file in a
|
||||
# possibly distributed settings.
|
||||
self._log_write_dir = distributed_file_utils.write_dirpath(
|
||||
self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access
|
||||
else:
|
||||
self._log_write_dir = self.log_dir
|
||||
self._train_dir = os.path.join(self._log_write_dir, 'train')
|
||||
self._train_step = self.model._train_counter # pylint: disable=protected-access
|
||||
|
||||
self._val_dir = os.path.join(self._log_write_dir, 'validation')
|
||||
self._val_step = self.model._test_counter # pylint: disable=protected-access
|
||||
|
||||
self._writers = {} # Resets writers.
|
||||
|
||||
with context.eager_mode():
|
||||
self._close_writers()
|
||||
if self.write_graph:
|
||||
with self._get_writer(self._train_run_name).as_default():
|
||||
self._write_keras_model_graph()
|
||||
if self.embeddings_freq:
|
||||
self._configure_embeddings()
|
||||
|
||||
@property
|
||||
def _train_writer(self):
|
||||
if 'train' not in self._writers:
|
||||
self._writers['train'] = summary_ops_v2.create_file_writer_v2(
|
||||
self._train_dir)
|
||||
return self._writers['train']
|
||||
|
||||
@property
|
||||
def _val_writer(self):
|
||||
if 'val' not in self._writers:
|
||||
self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir)
|
||||
return self._writers['val']
|
||||
|
||||
def _get_log_write_dir(self):
|
||||
"""For multi-worker, only chief should write, others write to '/tmp'."""
|
||||
return distributed_file_utils.write_dirpath(self.log_dir,
|
||||
self.model.distribute_strategy)
|
||||
|
||||
def _delete_tmp_write_dir(self):
|
||||
"""Deletes tmp write directories for multi-worker."""
|
||||
distributed_file_utils.remove_temp_dirpath(self.log_dir,
|
||||
self.model.distribute_strategy)
|
||||
|
||||
def _write_keras_model_graph(self):
|
||||
"""Writes Keras graph networks to TensorBoard."""
|
||||
with self._train_writer.as_default():
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
if not model.run_eagerly:
|
||||
if not self.model.run_eagerly:
|
||||
summary_ops_v2.graph(K.get_graph(), step=0)
|
||||
|
||||
summary_writable = (
|
||||
@ -1792,14 +1804,6 @@ class TensorBoard(Callback):
|
||||
if summary_writable:
|
||||
summary_ops_v2.keras_model('keras', self.model, step=0)
|
||||
|
||||
if self.embeddings_freq:
|
||||
self._configure_embeddings()
|
||||
|
||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||
self._prev_summary_recording = summary_state.is_recording
|
||||
self._prev_summary_writer = summary_state.writer
|
||||
self._prev_summary_step = summary_state.step
|
||||
|
||||
def _configure_embeddings(self):
|
||||
"""Configure the Projector for embeddings."""
|
||||
# TODO(omalleyt): Add integration tests.
|
||||
@ -1839,74 +1843,44 @@ class TensorBoard(Callback):
|
||||
writer = DummyWriter(self._log_write_dir)
|
||||
projector.visualize_embeddings(writer, config)
|
||||
|
||||
def _close_writers(self):
|
||||
"""Close all remaining open file writers owned by this callback.
|
||||
|
||||
If there are no such file writers, this is a no-op.
|
||||
"""
|
||||
with context.eager_mode():
|
||||
for writer in six.itervalues(self._writers):
|
||||
writer.close()
|
||||
self._writers.clear()
|
||||
|
||||
def _get_writer(self, writer_name):
|
||||
"""Get a summary writer for the given subdirectory under the logdir.
|
||||
|
||||
A writer will be created if it does not yet exist.
|
||||
|
||||
Arguments:
|
||||
writer_name: The name of the directory for which to create or
|
||||
retrieve a writer. Should be either `self._train_run_name` or
|
||||
`self._validation_run_name`.
|
||||
|
||||
Returns:
|
||||
A `SummaryWriter` object.
|
||||
"""
|
||||
if writer_name not in self._writers:
|
||||
path = os.path.join(self._log_write_dir, writer_name)
|
||||
writer = summary_ops_v2.create_file_writer_v2(path)
|
||||
self._writers[writer_name] = writer
|
||||
return self._writers[writer_name]
|
||||
|
||||
def _set_default_writer(self, writer_name):
|
||||
def _push_writer(self, writer, step):
|
||||
"""Sets the default writer for custom batch-level summaries."""
|
||||
if self.update_freq == 'epoch':
|
||||
# Writer is only used for custom summaries, which are written
|
||||
# batch-by-batch.
|
||||
return
|
||||
|
||||
step = self._total_batches_seen[writer_name]
|
||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||
self._prev_summary_state.append({
|
||||
'is_recording': summary_state.is_recording,
|
||||
'writer': summary_state.writer,
|
||||
'step': summary_state.step
|
||||
})
|
||||
|
||||
def _should_record():
|
||||
return math_ops.equal(step % self.update_freq, 0)
|
||||
if self.update_freq == 'epoch':
|
||||
should_record = False
|
||||
writer = None
|
||||
else:
|
||||
should_record = lambda: math_ops.equal(step % self.update_freq, 0)
|
||||
|
||||
summary_state.is_recording = should_record
|
||||
summary_state.writer = writer
|
||||
# TODO(b/151339474): Fix deadlock when not using .value() here.
|
||||
summary_ops_v2.set_step(step.value())
|
||||
|
||||
def _pop_writer(self):
|
||||
"""Pops the current writer."""
|
||||
if self.update_freq == 'epoch':
|
||||
return
|
||||
|
||||
prev_state = self._prev_summary_state.pop()
|
||||
|
||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||
summary_state.is_recording = _should_record
|
||||
summary_state.writer = self._get_writer(writer_name)
|
||||
summary_ops_v2.set_step(step)
|
||||
summary_state.is_recording = prev_state['is_recording']
|
||||
summary_state.writer = prev_state['writer']
|
||||
summary_ops_v2.set_step(prev_state['step'])
|
||||
|
||||
def _init_batch_steps(self):
|
||||
"""Create the total batch counters."""
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# Variables are needed for the `step` value of custom tf.summaries
|
||||
# to be updated inside a tf.function.
|
||||
self._total_batches_seen = {
|
||||
self._train_run_name: variables.Variable(0, dtype='int64'),
|
||||
self._validation_run_name: variables.Variable(0, dtype='int64')
|
||||
}
|
||||
else:
|
||||
# Custom tf.summaries are not supported in legacy graph mode.
|
||||
self._total_batches_seen = {
|
||||
self._train_run_name: 0,
|
||||
self._validation_run_name: 0
|
||||
}
|
||||
|
||||
def _increment_step(self, writer_name):
|
||||
step = self._total_batches_seen[writer_name]
|
||||
if isinstance(step, variables.Variable):
|
||||
step.assign_add(1)
|
||||
else:
|
||||
self._total_batches_seen[writer_name] += 1
|
||||
def _close_writers(self):
|
||||
for writer in self._writers.values():
|
||||
writer.close()
|
||||
|
||||
def _init_profile_batch(self, profile_batch):
|
||||
"""Validate profile_batch value and set the range of batches to profile.
|
||||
@ -1926,75 +1900,79 @@ class TensorBoard(Callback):
|
||||
|
||||
"""
|
||||
profile_batch_error_message = (
|
||||
'profile_batch must be a non-negative integer or a comma separated '
|
||||
'string of pair of positive integers. A pair of positive integers '
|
||||
'signify a range of batches to profile.')
|
||||
try:
|
||||
profile_range = [int(i) for i in str(profile_batch).split(',')]
|
||||
except ValueError:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
if len(profile_range) == 1: # single batch
|
||||
start_batch, stop_batch = profile_range[0], profile_range[0]
|
||||
if start_batch < 0:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
elif len(profile_range) == 2: # (start_batch, stop_batch)
|
||||
start_batch, stop_batch = profile_range
|
||||
# [0, 0], [-1, 100], [6, 5] are illegal.
|
||||
if start_batch <= 0 or start_batch > stop_batch:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
'profile_batch must be a non-negative integer or 2-tuple of positive '
|
||||
'integers. A pair of positive integers signifies a range of batches '
|
||||
'to profile. Found: {}'.format(profile_batch))
|
||||
|
||||
# Support legacy way of specifying "start,stop" or "start" as str.
|
||||
if isinstance(profile_batch, six.string_types):
|
||||
profile_batch = str(profile_batch).split(',')
|
||||
profile_batch = nest.map_structure(int, profile_batch)
|
||||
|
||||
if isinstance(profile_batch, int):
|
||||
self._start_batch = profile_batch
|
||||
self._stop_batch = profile_batch
|
||||
elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2:
|
||||
self._start_batch, self._stop_batch = profile_batch
|
||||
else:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
return start_batch, stop_batch
|
||||
|
||||
if self._start_batch < 0 or self._stop_batch < self._start_batch:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
|
||||
if self._start_batch > 0:
|
||||
profiler.warmup() # Improve the profiling accuracy.
|
||||
# True when a trace is running.
|
||||
self._is_tracing = False
|
||||
|
||||
# Setting `profile_batch=0` disables profiling.
|
||||
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self._init_batch_steps()
|
||||
if self._start_batch == 1:
|
||||
self._enable_trace()
|
||||
self._push_writer(self._train_writer, self._train_step)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
self._pop_writer()
|
||||
|
||||
if self._is_tracing:
|
||||
self._stop_trace()
|
||||
|
||||
self._close_writers()
|
||||
self._delete_tmp_write_dir()
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
self._set_default_writer(self._validation_run_name)
|
||||
self._push_writer(self._val_writer, self._val_step)
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
self._pop_writer()
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
if self._epoch == 0 and batch == self._start_batch:
|
||||
self._start_trace()
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
"""Writes scalar summaries for metrics on every training batch.
|
||||
|
||||
Performs profiling if current batch is in profiler_batches.
|
||||
"""Performs profiling if current batch is in profiler_batches.
|
||||
|
||||
Arguments:
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Metric results for this batch.
|
||||
"""
|
||||
# TODO(b/150629188): Make TensorBoard callback not use batch hooks
|
||||
# by default.
|
||||
if self.update_freq == 'epoch' and self._start_batch is None:
|
||||
if not self._should_trace:
|
||||
return
|
||||
|
||||
# Don't output batch_size and batch number as TensorBoard summaries
|
||||
logs = logs or {}
|
||||
train_batches = self._total_batches_seen[self._train_run_name]
|
||||
if self.update_freq != 'epoch' and batch % self.update_freq == 0:
|
||||
self._log_metrics(logs, prefix='batch_', step=train_batches)
|
||||
|
||||
self._increment_step(self._train_run_name)
|
||||
if self._is_tracing:
|
||||
control_flow_ops.cond(
|
||||
math_ops.greater_equal(train_batches, self._stop_batch),
|
||||
lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
|
||||
else:
|
||||
control_flow_ops.cond(
|
||||
math_ops.equal(train_batches, self._start_batch - 1),
|
||||
lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
if self.update_freq == 'epoch':
|
||||
return
|
||||
self._increment_step(self._validation_run_name)
|
||||
if self._is_tracing and batch >= self._stop_batch:
|
||||
self._stop_trace()
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
self._set_default_writer(self._train_run_name)
|
||||
# Keeps track of epoch for profiling.
|
||||
self._epoch = epoch
|
||||
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
"""Runs metrics and histogram summaries at epoch end."""
|
||||
self._log_metrics(logs, prefix='epoch_', step=epoch)
|
||||
self._log_epoch_metrics(epoch, logs)
|
||||
|
||||
if self.histogram_freq and epoch % self.histogram_freq == 0:
|
||||
self._log_weights(epoch)
|
||||
@ -2002,124 +1980,57 @@ class TensorBoard(Callback):
|
||||
if self.embeddings_freq and epoch % self.embeddings_freq == 0:
|
||||
self._log_embeddings(epoch)
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if self._is_tracing:
|
||||
self._log_trace()
|
||||
self._close_writers()
|
||||
|
||||
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
|
||||
summary_state.is_recording = self._prev_summary_recording
|
||||
summary_state.writer = self._prev_summary_writer
|
||||
summary_state.step = self._prev_summary_step
|
||||
|
||||
# In case this callback is used via native Keras, _get_distribution_strategy does not exist.
|
||||
if hasattr(self.model, '_get_distribution_strategy'):
|
||||
# Safely remove the unneeded temp files.
|
||||
distributed_file_utils.remove_temp_dirpath(
|
||||
self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access
|
||||
|
||||
def _enable_trace(self):
|
||||
"""Starts to collect trace graph to TensorBoard.
|
||||
|
||||
Collects both trace and graph in eager mode, and trace only in graph mode.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
# Graph must be traced in eager mode.
|
||||
def _start_trace(self):
|
||||
summary_ops_v2.trace_on(graph=True, profiler=False)
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
profiler.start(logdir=self._train_dir)
|
||||
self._is_tracing = True
|
||||
|
||||
def _enable_trace_return_true(self):
|
||||
"""Starts to collect trace graph to TensorBoard and returns True.
|
||||
|
||||
Returns:
|
||||
True.
|
||||
"""
|
||||
self._enable_trace()
|
||||
return True
|
||||
|
||||
def _log_trace(self):
|
||||
"""Logs the trace graph to TensorBoard.
|
||||
|
||||
Logs both trace and graph in eager mode, and trace only in graph mode.
|
||||
"""
|
||||
profiler.stop()
|
||||
if context.executing_eagerly():
|
||||
# Graph must be traced in eager mode.
|
||||
with self._get_writer(self._train_run_name).as_default(), \
|
||||
summary_ops_v2.always_record_summaries():
|
||||
def _stop_trace(self, batch=None):
|
||||
"""Logs the trace graph to TensorBoard."""
|
||||
if batch is None:
|
||||
batch = self._stop_batch
|
||||
with self._train_writer.as_default():
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
step = K.get_value(self._total_batches_seen[self._train_run_name])
|
||||
summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
|
||||
summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch)
|
||||
profiler.stop()
|
||||
self._is_tracing = False
|
||||
|
||||
def _log_trace_return_true(self):
|
||||
"""Logs the trace graph to TensorBoard and returns True.
|
||||
|
||||
Returns:
|
||||
True.
|
||||
"""
|
||||
self._log_trace()
|
||||
return True
|
||||
|
||||
def _log_metrics(self, logs, prefix, step):
|
||||
"""Writes metrics out as custom scalar summaries.
|
||||
def _log_epoch_metrics(self, epoch, logs):
|
||||
"""Writes epoch metrics out as scalar summaries.
|
||||
|
||||
Arguments:
|
||||
logs: Dict. Keys are scalar summary names, values are NumPy scalars.
|
||||
prefix: String. The prefix to apply to the scalar summary names.
|
||||
step: Int. The global step to use for TensorBoard.
|
||||
epoch: Int. The global step to use for TensorBoard.
|
||||
logs: Dict. Keys are scalar summary names, values are scalars.
|
||||
"""
|
||||
if logs is None:
|
||||
logs = {}
|
||||
if not logs:
|
||||
return
|
||||
|
||||
# Group metrics by the name of their associated file writer. Values
|
||||
# are lists of metrics, as (name, scalar_value) pairs.
|
||||
logs_by_writer = {
|
||||
self._train_run_name: [],
|
||||
self._validation_run_name: [],
|
||||
}
|
||||
validation_prefix = 'val_'
|
||||
for (name, value) in logs.items():
|
||||
if name in ('batch', 'size', 'num_steps'):
|
||||
# Scrub non-metric items.
|
||||
continue
|
||||
if name.startswith(validation_prefix):
|
||||
name = name[len(validation_prefix):]
|
||||
writer_name = self._validation_run_name
|
||||
else:
|
||||
writer_name = self._train_run_name
|
||||
name = prefix + name # assign batch or epoch prefix
|
||||
logs_by_writer[writer_name].append((name, value))
|
||||
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
|
||||
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
|
||||
|
||||
with context.eager_mode():
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
for writer_name in logs_by_writer:
|
||||
these_logs = logs_by_writer[writer_name]
|
||||
if not these_logs:
|
||||
# Don't create a "validation" events file if we don't
|
||||
# actually have any validation data.
|
||||
continue
|
||||
writer = self._get_writer(writer_name)
|
||||
with writer.as_default():
|
||||
for (name, value) in these_logs:
|
||||
summary_ops_v2.scalar(name, value, step=step)
|
||||
if train_logs:
|
||||
with self._train_writer.as_default():
|
||||
for name, value in train_logs.items():
|
||||
summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
|
||||
if val_logs:
|
||||
with self._val_writer.as_default():
|
||||
for name, value in val_logs.items():
|
||||
name = name[4:] # Remove 'val_' prefix.
|
||||
summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
|
||||
|
||||
def _log_weights(self, epoch):
|
||||
"""Logs the weights of the Model to TensorBoard."""
|
||||
writer = self._get_writer(self._train_run_name)
|
||||
with context.eager_mode(), \
|
||||
writer.as_default(), \
|
||||
summary_ops_v2.always_record_summaries():
|
||||
with self._train_writer.as_default():
|
||||
with summary_ops_v2.always_record_summaries():
|
||||
for layer in self.model.layers:
|
||||
for weight in layer.weights:
|
||||
weight_name = weight.name.replace(':', '_')
|
||||
with ops.init_scope():
|
||||
weight = K.get_value(weight)
|
||||
summary_ops_v2.histogram(weight_name, weight, step=epoch)
|
||||
if self.write_images:
|
||||
self._log_weight_as_image(weight, weight_name, epoch)
|
||||
writer.flush()
|
||||
self._train_writer.flush()
|
||||
|
||||
def _log_weight_as_image(self, weight, weight_name, epoch):
|
||||
"""Logs a weight as a TensorBoard image."""
|
||||
@ -2150,6 +2061,9 @@ class TensorBoard(Callback):
|
||||
'keras_embedding.ckpt-{}'.format(epoch))
|
||||
self.model.save_weights(embeddings_ckpt)
|
||||
|
||||
def _implements_train_batch_hooks(self):
|
||||
return not (self._start_batch == 0 and self._stop_batch == 0)
|
||||
|
||||
|
||||
@keras_export('keras.callbacks.ReduceLROnPlateau')
|
||||
class ReduceLROnPlateau(Callback):
|
||||
|
@ -2079,17 +2079,19 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
model.fit(
|
||||
np.zeros((64, 1)),
|
||||
np.zeros((64, 1)),
|
||||
batch_size=32,
|
||||
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)],
|
||||
)
|
||||
# Verifies trace exists in the first train_dir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.logdir))
|
||||
model.fit(
|
||||
np.zeros((64, 1)),
|
||||
np.zeros((64, 1)),
|
||||
batch_size=32,
|
||||
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)],
|
||||
)
|
||||
# Verifies trace exists in the second train_dir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.logdir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
||||
model = self._get_seq_model()
|
||||
|
@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export(v1=['keras.callbacks.TensorBoard'])
|
||||
class TensorBoard(callbacks.Callback):
|
||||
class TensorBoard(callbacks.TensorBoard):
|
||||
# pylint: disable=line-too-long
|
||||
"""Enable visualizations for TensorBoard.
|
||||
|
||||
@ -127,7 +127,8 @@ class TensorBoard(callbacks.Callback):
|
||||
embeddings_data=None,
|
||||
update_freq='epoch',
|
||||
profile_batch=2):
|
||||
super(TensorBoard, self).__init__()
|
||||
# Don't call super's init since it is an eager-only version.
|
||||
callbacks.Callback.__init__(self)
|
||||
self.log_dir = log_dir
|
||||
self.histogram_freq = histogram_freq
|
||||
if self.histogram_freq and context.executing_eagerly():
|
||||
@ -342,6 +343,21 @@ class TensorBoard(callbacks.Callback):
|
||||
self.writer.add_summary(summary, step)
|
||||
self.writer.flush()
|
||||
|
||||
def on_train_batch_begin(self, batch, logs=None):
|
||||
if (not self._is_profiling and
|
||||
self._total_batches_seen == self._profile_batch - 1):
|
||||
profiler.start(self.log_dir)
|
||||
self._is_profiling = True
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
return self.on_batch_end(batch, logs)
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
pass
|
||||
|
||||
def on_test_end(self, logs=None):
|
||||
pass
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
"""Writes scalar summaries for metrics on every training batch.
|
||||
|
||||
@ -358,18 +374,13 @@ class TensorBoard(callbacks.Callback):
|
||||
self._write_custom_summaries(self._total_batches_seen, batch_logs)
|
||||
self._samples_seen_at_last_write = self._samples_seen
|
||||
self._total_batches_seen += 1
|
||||
|
||||
if self._is_profiling:
|
||||
profiler.stop()
|
||||
self._is_profiling = False
|
||||
elif (not self._is_profiling and
|
||||
self._total_batches_seen == self._profile_batch - 1):
|
||||
profiler.start(self.log_dir)
|
||||
self._is_profiling = True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if self._profile_batch == 1:
|
||||
profiler.start(self.log_dir)
|
||||
self._is_profiling = True
|
||||
pass
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
"""Add histogram op to Model eval_function callbacks, reset batch count."""
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
from tensorflow.python.distribute import distribute_coordinator as dc
|
||||
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
|
||||
@ -28,6 +29,7 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import callbacks as callbacks_module
|
||||
from tensorflow.python.keras import optimizers
|
||||
@ -43,6 +45,8 @@ from tensorflow.python.keras.utils import version_utils
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.profiler import trace
|
||||
@ -161,6 +165,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for
|
||||
additional details.
|
||||
"""
|
||||
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
|
||||
itertools.chain(('_train_counter', '_test_counter', '_predict_counter'),
|
||||
network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Model, self).__init__(*args, **kwargs)
|
||||
@ -186,6 +193,18 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
self.compiled_loss = None
|
||||
self.compiled_metrics = None
|
||||
|
||||
self._init_batch_counters()
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _init_batch_counters(self):
|
||||
# Untracked Variables, used to keep track of mini-batches seen in `fit`,
|
||||
# `evaluate`, and `predict`.
|
||||
agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
|
||||
self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
|
||||
self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
|
||||
self._predict_counter = variables.Variable(
|
||||
0, dtype='int64', aggregation=agg)
|
||||
|
||||
def get_weights(self):
|
||||
"""Retrieves the weights of the model.
|
||||
|
||||
@ -499,11 +518,18 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
return self.train_function
|
||||
|
||||
def train_function(iterator):
|
||||
"""Runs one call to `self.train_function`."""
|
||||
|
||||
def run_step(data):
|
||||
outputs = self.train_step(data)
|
||||
self._train_counter.assign_add(1)
|
||||
return outputs
|
||||
|
||||
data = next(iterator)
|
||||
outputs = self.distribute_strategy.run(
|
||||
self.train_step, args=(data,))
|
||||
outputs = self.distribute_strategy.run(run_step, args=(data,))
|
||||
outputs = reduce_per_replica(
|
||||
outputs, self.distribute_strategy, reduction='first')
|
||||
write_scalar_summaries(outputs, step=self._train_counter)
|
||||
return outputs
|
||||
|
||||
if not self.run_eagerly:
|
||||
@ -762,6 +788,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
|
||||
self.stop_training = False
|
||||
train_function = self.make_train_function()
|
||||
self._train_counter.assign(0)
|
||||
callbacks.on_train_begin()
|
||||
# Handle fault-tolerance for multi-worker.
|
||||
# TODO(omalleyt): Fix the ordering issues that mean this has to
|
||||
@ -872,9 +899,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
return self.test_function
|
||||
|
||||
def test_function(iterator):
|
||||
"""Runs one call to `self.test_function`."""
|
||||
|
||||
def run_step(data):
|
||||
outputs = self.test_step(data)
|
||||
self._test_counter.assign_add(1)
|
||||
return outputs
|
||||
|
||||
data = next(iterator)
|
||||
outputs = self.distribute_strategy.run(
|
||||
self.test_step, args=(data,))
|
||||
outputs = self.distribute_strategy.run(run_step, args=(data,))
|
||||
outputs = reduce_per_replica(
|
||||
outputs, self.distribute_strategy, reduction='first')
|
||||
return outputs
|
||||
@ -1003,6 +1036,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
steps=data_handler.inferred_steps)
|
||||
|
||||
test_function = self.make_test_function()
|
||||
self._test_counter.assign(0)
|
||||
callbacks.on_test_begin()
|
||||
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
|
||||
self.reset_metrics()
|
||||
@ -1075,9 +1109,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
return self.predict_function
|
||||
|
||||
def predict_function(iterator):
|
||||
"""Runs one call to `self.predict_function`."""
|
||||
|
||||
def run_step(data):
|
||||
outputs = self.predict_step(data)
|
||||
self._predict_counter.assign_add(1)
|
||||
return outputs
|
||||
|
||||
data = next(iterator)
|
||||
outputs = self.distribute_strategy.run(
|
||||
self.predict_step, args=(data,))
|
||||
outputs = self.distribute_strategy.run(run_step, args=(data,))
|
||||
outputs = reduce_per_replica(
|
||||
outputs, self.distribute_strategy, reduction='concat')
|
||||
return outputs
|
||||
@ -1192,6 +1232,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
|
||||
steps=data_handler.inferred_steps)
|
||||
|
||||
predict_function = self.make_predict_function()
|
||||
self._predict_counter.assign(0)
|
||||
callbacks.on_predict_begin()
|
||||
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
|
||||
with data_handler.catch_stop_iteration():
|
||||
@ -1734,3 +1775,13 @@ def _minimize(tape, optimizer, loss, trainable_variables):
|
||||
all_reduce_sum_gradients=False)
|
||||
else:
|
||||
optimizer.apply_gradients(zip(gradients, trainable_variables))
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
|
||||
|
||||
|
||||
def write_scalar_summaries(logs, step):
|
||||
for name, value in logs.items():
|
||||
if _is_scalar(value):
|
||||
summary_ops_v2.scalar('batch_' + name, value, step=step)
|
||||
|
@ -162,6 +162,9 @@ class Model(training_lib.Model):
|
||||
|
||||
self._v1_compile_was_called = False
|
||||
|
||||
def _init_batch_counters(self):
|
||||
pass # Batch counters should not be created in legacy graph mode.
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def _set_strategy(self, strategy):
|
||||
self._compile_time_distribution_strategy = strategy
|
||||
|
@ -737,6 +737,21 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertLen(new_model.variables, 1)
|
||||
self.assertLen(new_model.layers, 1)
|
||||
|
||||
def test_batch_counters_not_in_variables(self):
|
||||
|
||||
class MyModel(keras.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
self.layer = keras.layers.Dense(4)
|
||||
|
||||
def call(self, obs):
|
||||
return self.layer(obs)
|
||||
|
||||
model = MyModel()
|
||||
model(np.ones((10, 10)))
|
||||
self.assertLen(model.variables, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -36,6 +36,13 @@ base_layer = lazy_loader.LazyLoader(
|
||||
base_layer_v1 = lazy_loader.LazyLoader(
|
||||
"base_layer_v1", globals(),
|
||||
"tensorflow.python.keras.engine.base_layer_v1")
|
||||
callbacks = lazy_loader.LazyLoader(
|
||||
"callbacks", globals(),
|
||||
"tensorflow.python.keras.callbacks")
|
||||
callbacks_v1 = lazy_loader.LazyLoader(
|
||||
"callbacks_v1", globals(),
|
||||
"tensorflow.python.keras.callbacks_v1")
|
||||
|
||||
|
||||
# pylint: enable=g-inconsistent-quotes
|
||||
|
||||
@ -58,6 +65,21 @@ class LayerVersionSelector(object):
|
||||
return super(LayerVersionSelector, cls).__new__(cls)
|
||||
|
||||
|
||||
class TensorBoardVersionSelector(object):
|
||||
"""Chooses between Keras v1 and v2 TensorBoard callback class."""
|
||||
|
||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
||||
start_cls = cls
|
||||
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
|
||||
eager_enabled)
|
||||
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
|
||||
# Since the v2 class is not a subclass of the v1 class, __init__ has to
|
||||
# be called manually.
|
||||
return cls(*args, **kwargs)
|
||||
return super(TensorBoardVersionSelector, cls).__new__(cls)
|
||||
|
||||
|
||||
def swap_class(cls, v2_cls, v1_cls, eager_enabled):
|
||||
"""Swaps in v2_cls or v1_cls depending on graph mode."""
|
||||
if cls == object:
|
||||
|
@ -1,7 +1,9 @@
|
||||
path: "tensorflow.keras.callbacks.TensorBoard"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.callbacks_v1.TensorBoard\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.callbacks.TensorBoard\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.callbacks.Callback\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.TensorBoardVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
|
@ -2,6 +2,7 @@ path: "tensorflow.keras.callbacks.TensorBoard"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.callbacks.TensorBoard\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.callbacks.Callback\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.TensorBoardVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
|
Loading…
Reference in New Issue
Block a user