Create Variables to track mini-batches seen in Model.fit / evaluate / predict.

Use these counters in the TensorBoard Callback.

PiperOrigin-RevId: 301459298
Change-Id: I8e92e119ef4cef37c41532d11caefd601d4395a7
This commit is contained in:
Thomas O'Malley 2020-03-17 14:54:56 -07:00 committed by TensorFlower Gardener
parent 1cea2490cb
commit 69565ec400
9 changed files with 305 additions and 284 deletions

View File

@ -35,21 +35,19 @@ import six
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distributed_file_utils
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.distribute import multi_worker_training_state as training_state
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.data_utils import Sequence
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.training import checkpoint_management
@ -1614,7 +1612,7 @@ class LearningRateScheduler(Callback):
@keras_export('keras.callbacks.TensorBoard', v1=[])
class TensorBoard(Callback):
class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
# pylint: disable=line-too-long
"""Enable visualizations for TensorBoard.
@ -1676,11 +1674,10 @@ class TensorBoard(Callback):
batches. Note that writing too frequently to TensorBoard can slow down
your training.
profile_batch: Profile the batch(es) to sample compute characteristics.
profile_batch must be a non-negative integer or a comma separated string
of pair of positive integers. A pair of positive integers signify a
range of batches to profile. By default, it will profile the second
batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow
eager mode.
profile_batch must be a non-negative integer or a tuple of integers.
A pair of positive integers signify a range of batches to profile.
By default, it will profile the second batch. Set profile_batch=0
to disable profiling. Must run in TensorFlow eager mode.
embeddings_freq: frequency (in epochs) at which embedding layers will be
visualized. If set to 0, embeddings won't be visualized.
embeddings_metadata: a dictionary which maps layer name to a file name in
@ -1713,30 +1710,18 @@ class TensorBoard(Callback):
self.histogram_freq = histogram_freq
self.write_graph = write_graph
self.write_images = write_images
if update_freq == 'batch':
self.update_freq = 1
else:
self.update_freq = update_freq
self.update_freq = 1 if update_freq == 'batch' else update_freq
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
self._init_profile_batch(profile_batch)
self._epoch = 0
self._samples_seen = 0
self._samples_seen_at_last_write = 0
self._current_batch = 0
# A collection of file writers currently in use, to be closed when
# training ends for this callback. Writers are keyed by the
# directory name under the root logdir: e.g., "train" or
# "validation".
self._train_run_name = 'train'
self._validation_run_name = 'validation'
# Lazily initialized in order to avoid creating event files when
# not needed.
self._writers = {}
self._start_batch, self._stop_batch = self._init_profile_batch(
profile_batch)
if self._start_batch > 0:
profiler.warmup() # Improve the profiling accuracy.
# True when a trace is running.
self._is_tracing = False
# Used to restore any existing `SummaryWriter` after training ends.
self._prev_summary_state = []
def _validate_kwargs(self, kwargs):
"""Handle arguments were supported in V1."""
@ -1768,37 +1753,56 @@ class TensorBoard(Callback):
def set_model(self, model):
"""Sets Keras model and writes graph if specified."""
self.model = model
self._log_write_dir = self._get_log_write_dir()
# In case this callback is used via native Keras, _get_distribution_strategy does not exist.
if hasattr(self.model, '_get_distribution_strategy'):
# TensorBoard callback involves writing a summary file in a
# possibly distributed settings.
self._log_write_dir = distributed_file_utils.write_dirpath(
self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access
else:
self._log_write_dir = self.log_dir
self._train_dir = os.path.join(self._log_write_dir, 'train')
self._train_step = self.model._train_counter # pylint: disable=protected-access
with context.eager_mode():
self._close_writers()
if self.write_graph:
with self._get_writer(self._train_run_name).as_default():
with summary_ops_v2.always_record_summaries():
if not model.run_eagerly:
summary_ops_v2.graph(K.get_graph(), step=0)
self._val_dir = os.path.join(self._log_write_dir, 'validation')
self._val_step = self.model._test_counter # pylint: disable=protected-access
summary_writable = (
self.model._is_graph_network or # pylint: disable=protected-access
self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access
if summary_writable:
summary_ops_v2.keras_model('keras', self.model, step=0)
self._writers = {} # Resets writers.
if self.write_graph:
self._write_keras_model_graph()
if self.embeddings_freq:
self._configure_embeddings()
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
self._prev_summary_recording = summary_state.is_recording
self._prev_summary_writer = summary_state.writer
self._prev_summary_step = summary_state.step
@property
def _train_writer(self):
if 'train' not in self._writers:
self._writers['train'] = summary_ops_v2.create_file_writer_v2(
self._train_dir)
return self._writers['train']
@property
def _val_writer(self):
if 'val' not in self._writers:
self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir)
return self._writers['val']
def _get_log_write_dir(self):
"""For multi-worker, only chief should write, others write to '/tmp'."""
return distributed_file_utils.write_dirpath(self.log_dir,
self.model.distribute_strategy)
def _delete_tmp_write_dir(self):
"""Deletes tmp write directories for multi-worker."""
distributed_file_utils.remove_temp_dirpath(self.log_dir,
self.model.distribute_strategy)
def _write_keras_model_graph(self):
"""Writes Keras graph networks to TensorBoard."""
with self._train_writer.as_default():
with summary_ops_v2.always_record_summaries():
if not self.model.run_eagerly:
summary_ops_v2.graph(K.get_graph(), step=0)
summary_writable = (
self.model._is_graph_network or # pylint: disable=protected-access
self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access
if summary_writable:
summary_ops_v2.keras_model('keras', self.model, step=0)
def _configure_embeddings(self):
"""Configure the Projector for embeddings."""
@ -1839,74 +1843,44 @@ class TensorBoard(Callback):
writer = DummyWriter(self._log_write_dir)
projector.visualize_embeddings(writer, config)
def _close_writers(self):
"""Close all remaining open file writers owned by this callback.
If there are no such file writers, this is a no-op.
"""
with context.eager_mode():
for writer in six.itervalues(self._writers):
writer.close()
self._writers.clear()
def _get_writer(self, writer_name):
"""Get a summary writer for the given subdirectory under the logdir.
A writer will be created if it does not yet exist.
Arguments:
writer_name: The name of the directory for which to create or
retrieve a writer. Should be either `self._train_run_name` or
`self._validation_run_name`.
Returns:
A `SummaryWriter` object.
"""
if writer_name not in self._writers:
path = os.path.join(self._log_write_dir, writer_name)
writer = summary_ops_v2.create_file_writer_v2(path)
self._writers[writer_name] = writer
return self._writers[writer_name]
def _set_default_writer(self, writer_name):
def _push_writer(self, writer, step):
"""Sets the default writer for custom batch-level summaries."""
if self.update_freq == 'epoch':
# Writer is only used for custom summaries, which are written
# batch-by-batch.
return
step = self._total_batches_seen[writer_name]
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
self._prev_summary_state.append({
'is_recording': summary_state.is_recording,
'writer': summary_state.writer,
'step': summary_state.step
})
def _should_record():
return math_ops.equal(step % self.update_freq, 0)
if self.update_freq == 'epoch':
should_record = False
writer = None
else:
should_record = lambda: math_ops.equal(step % self.update_freq, 0)
summary_state.is_recording = should_record
summary_state.writer = writer
# TODO(b/151339474): Fix deadlock when not using .value() here.
summary_ops_v2.set_step(step.value())
def _pop_writer(self):
"""Pops the current writer."""
if self.update_freq == 'epoch':
return
prev_state = self._prev_summary_state.pop()
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
summary_state.is_recording = _should_record
summary_state.writer = self._get_writer(writer_name)
summary_ops_v2.set_step(step)
summary_state.is_recording = prev_state['is_recording']
summary_state.writer = prev_state['writer']
summary_ops_v2.set_step(prev_state['step'])
def _init_batch_steps(self):
"""Create the total batch counters."""
if ops.executing_eagerly_outside_functions():
# Variables are needed for the `step` value of custom tf.summaries
# to be updated inside a tf.function.
self._total_batches_seen = {
self._train_run_name: variables.Variable(0, dtype='int64'),
self._validation_run_name: variables.Variable(0, dtype='int64')
}
else:
# Custom tf.summaries are not supported in legacy graph mode.
self._total_batches_seen = {
self._train_run_name: 0,
self._validation_run_name: 0
}
def _increment_step(self, writer_name):
step = self._total_batches_seen[writer_name]
if isinstance(step, variables.Variable):
step.assign_add(1)
else:
self._total_batches_seen[writer_name] += 1
def _close_writers(self):
for writer in self._writers.values():
writer.close()
def _init_profile_batch(self, profile_batch):
"""Validate profile_batch value and set the range of batches to profile.
@ -1926,75 +1900,79 @@ class TensorBoard(Callback):
"""
profile_batch_error_message = (
'profile_batch must be a non-negative integer or a comma separated '
'string of pair of positive integers. A pair of positive integers '
'signify a range of batches to profile.')
try:
profile_range = [int(i) for i in str(profile_batch).split(',')]
except ValueError:
raise ValueError(profile_batch_error_message)
if len(profile_range) == 1: # single batch
start_batch, stop_batch = profile_range[0], profile_range[0]
if start_batch < 0:
raise ValueError(profile_batch_error_message)
elif len(profile_range) == 2: # (start_batch, stop_batch)
start_batch, stop_batch = profile_range
# [0, 0], [-1, 100], [6, 5] are illegal.
if start_batch <= 0 or start_batch > stop_batch:
raise ValueError(profile_batch_error_message)
'profile_batch must be a non-negative integer or 2-tuple of positive '
'integers. A pair of positive integers signifies a range of batches '
'to profile. Found: {}'.format(profile_batch))
# Support legacy way of specifying "start,stop" or "start" as str.
if isinstance(profile_batch, six.string_types):
profile_batch = str(profile_batch).split(',')
profile_batch = nest.map_structure(int, profile_batch)
if isinstance(profile_batch, int):
self._start_batch = profile_batch
self._stop_batch = profile_batch
elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2:
self._start_batch, self._stop_batch = profile_batch
else:
raise ValueError(profile_batch_error_message)
return start_batch, stop_batch
if self._start_batch < 0 or self._stop_batch < self._start_batch:
raise ValueError(profile_batch_error_message)
if self._start_batch > 0:
profiler.warmup() # Improve the profiling accuracy.
# True when a trace is running.
self._is_tracing = False
# Setting `profile_batch=0` disables profiling.
self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0)
def on_train_begin(self, logs=None):
self._init_batch_steps()
if self._start_batch == 1:
self._enable_trace()
self._push_writer(self._train_writer, self._train_step)
def on_train_end(self, logs=None):
self._pop_writer()
if self._is_tracing:
self._stop_trace()
self._close_writers()
self._delete_tmp_write_dir()
def on_test_begin(self, logs=None):
self._set_default_writer(self._validation_run_name)
self._push_writer(self._val_writer, self._val_step)
def on_test_end(self, logs=None):
self._pop_writer()
def on_train_batch_begin(self, batch, logs=None):
if not self._should_trace:
return
if self._epoch == 0 and batch == self._start_batch:
self._start_trace()
def on_train_batch_end(self, batch, logs=None):
"""Writes scalar summaries for metrics on every training batch.
Performs profiling if current batch is in profiler_batches.
"""Performs profiling if current batch is in profiler_batches.
Arguments:
batch: Integer, index of batch within the current epoch.
logs: Dict. Metric results for this batch.
"""
# TODO(b/150629188): Make TensorBoard callback not use batch hooks
# by default.
if self.update_freq == 'epoch' and self._start_batch is None:
if not self._should_trace:
return
# Don't output batch_size and batch number as TensorBoard summaries
logs = logs or {}
train_batches = self._total_batches_seen[self._train_run_name]
if self.update_freq != 'epoch' and batch % self.update_freq == 0:
self._log_metrics(logs, prefix='batch_', step=train_batches)
self._increment_step(self._train_run_name)
if self._is_tracing:
control_flow_ops.cond(
math_ops.greater_equal(train_batches, self._stop_batch),
lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
else:
control_flow_ops.cond(
math_ops.equal(train_batches, self._start_batch - 1),
lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
def on_test_batch_end(self, batch, logs=None):
if self.update_freq == 'epoch':
return
self._increment_step(self._validation_run_name)
if self._is_tracing and batch >= self._stop_batch:
self._stop_trace()
def on_epoch_begin(self, epoch, logs=None):
self._set_default_writer(self._train_run_name)
# Keeps track of epoch for profiling.
self._epoch = epoch
def on_epoch_end(self, epoch, logs=None):
"""Runs metrics and histogram summaries at epoch end."""
self._log_metrics(logs, prefix='epoch_', step=epoch)
self._log_epoch_metrics(epoch, logs)
if self.histogram_freq and epoch % self.histogram_freq == 0:
self._log_weights(epoch)
@ -2002,124 +1980,57 @@ class TensorBoard(Callback):
if self.embeddings_freq and epoch % self.embeddings_freq == 0:
self._log_embeddings(epoch)
def on_train_end(self, logs=None):
if self._is_tracing:
self._log_trace()
self._close_writers()
summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
summary_state.is_recording = self._prev_summary_recording
summary_state.writer = self._prev_summary_writer
summary_state.step = self._prev_summary_step
# In case this callback is used via native Keras, _get_distribution_strategy does not exist.
if hasattr(self.model, '_get_distribution_strategy'):
# Safely remove the unneeded temp files.
distributed_file_utils.remove_temp_dirpath(
self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access
def _enable_trace(self):
"""Starts to collect trace graph to TensorBoard.
Collects both trace and graph in eager mode, and trace only in graph mode.
"""
if context.executing_eagerly():
# Graph must be traced in eager mode.
summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
def _start_trace(self):
summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=self._train_dir)
self._is_tracing = True
def _enable_trace_return_true(self):
"""Starts to collect trace graph to TensorBoard and returns True.
Returns:
True.
"""
self._enable_trace()
return True
def _log_trace(self):
"""Logs the trace graph to TensorBoard.
Logs both trace and graph in eager mode, and trace only in graph mode.
"""
profiler.stop()
if context.executing_eagerly():
# Graph must be traced in eager mode.
with self._get_writer(self._train_run_name).as_default(), \
summary_ops_v2.always_record_summaries():
def _stop_trace(self, batch=None):
"""Logs the trace graph to TensorBoard."""
if batch is None:
batch = self._stop_batch
with self._train_writer.as_default():
with summary_ops_v2.always_record_summaries():
# TODO(b/126388999): Remove step info in the summary name.
step = K.get_value(self._total_batches_seen[self._train_run_name])
summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch)
profiler.stop()
self._is_tracing = False
def _log_trace_return_true(self):
"""Logs the trace graph to TensorBoard and returns True.
Returns:
True.
"""
self._log_trace()
return True
def _log_metrics(self, logs, prefix, step):
"""Writes metrics out as custom scalar summaries.
def _log_epoch_metrics(self, epoch, logs):
"""Writes epoch metrics out as scalar summaries.
Arguments:
logs: Dict. Keys are scalar summary names, values are NumPy scalars.
prefix: String. The prefix to apply to the scalar summary names.
step: Int. The global step to use for TensorBoard.
epoch: Int. The global step to use for TensorBoard.
logs: Dict. Keys are scalar summary names, values are scalars.
"""
if logs is None:
logs = {}
if not logs:
return
# Group metrics by the name of their associated file writer. Values
# are lists of metrics, as (name, scalar_value) pairs.
logs_by_writer = {
self._train_run_name: [],
self._validation_run_name: [],
}
validation_prefix = 'val_'
for (name, value) in logs.items():
if name in ('batch', 'size', 'num_steps'):
# Scrub non-metric items.
continue
if name.startswith(validation_prefix):
name = name[len(validation_prefix):]
writer_name = self._validation_run_name
else:
writer_name = self._train_run_name
name = prefix + name # assign batch or epoch prefix
logs_by_writer[writer_name].append((name, value))
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
with context.eager_mode():
with summary_ops_v2.always_record_summaries():
for writer_name in logs_by_writer:
these_logs = logs_by_writer[writer_name]
if not these_logs:
# Don't create a "validation" events file if we don't
# actually have any validation data.
continue
writer = self._get_writer(writer_name)
with writer.as_default():
for (name, value) in these_logs:
summary_ops_v2.scalar(name, value, step=step)
with summary_ops_v2.always_record_summaries():
if train_logs:
with self._train_writer.as_default():
for name, value in train_logs.items():
summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
if val_logs:
with self._val_writer.as_default():
for name, value in val_logs.items():
name = name[4:] # Remove 'val_' prefix.
summary_ops_v2.scalar('epoch_' + name, value, step=epoch)
def _log_weights(self, epoch):
"""Logs the weights of the Model to TensorBoard."""
writer = self._get_writer(self._train_run_name)
with context.eager_mode(), \
writer.as_default(), \
summary_ops_v2.always_record_summaries():
for layer in self.model.layers:
for weight in layer.weights:
weight_name = weight.name.replace(':', '_')
with ops.init_scope():
weight = K.get_value(weight)
summary_ops_v2.histogram(weight_name, weight, step=epoch)
if self.write_images:
self._log_weight_as_image(weight, weight_name, epoch)
writer.flush()
with self._train_writer.as_default():
with summary_ops_v2.always_record_summaries():
for layer in self.model.layers:
for weight in layer.weights:
weight_name = weight.name.replace(':', '_')
summary_ops_v2.histogram(weight_name, weight, step=epoch)
if self.write_images:
self._log_weight_as_image(weight, weight_name, epoch)
self._train_writer.flush()
def _log_weight_as_image(self, weight, weight_name, epoch):
"""Logs a weight as a TensorBoard image."""
@ -2150,6 +2061,9 @@ class TensorBoard(Callback):
'keras_embedding.ckpt-{}'.format(epoch))
self.model.save_weights(embeddings_ckpt)
def _implements_train_batch_hooks(self):
return not (self._start_batch == 0 and self._stop_batch == 0)
@keras_export('keras.callbacks.ReduceLROnPlateau')
class ReduceLROnPlateau(Callback):

View File

@ -2079,17 +2079,19 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
model.fit(
np.zeros((64, 1)),
np.zeros((64, 1)),
batch_size=32,
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)],
)
# Verifies trace exists in the first train_dir.
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertIsNotNone(self._get_trace_file(logdir=self.logdir))
model.fit(
np.zeros((64, 1)),
np.zeros((64, 1)),
batch_size=32,
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)],
)
# Verifies trace exists in the second train_dir.
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
self.assertIsNotNone(self._get_trace_file(logdir=self.logdir))
def test_TensorBoard_autoTrace_profileBatchRange(self):
model = self._get_seq_model()

View File

@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export
@keras_export(v1=['keras.callbacks.TensorBoard'])
class TensorBoard(callbacks.Callback):
class TensorBoard(callbacks.TensorBoard):
# pylint: disable=line-too-long
"""Enable visualizations for TensorBoard.
@ -127,7 +127,8 @@ class TensorBoard(callbacks.Callback):
embeddings_data=None,
update_freq='epoch',
profile_batch=2):
super(TensorBoard, self).__init__()
# Don't call super's init since it is an eager-only version.
callbacks.Callback.__init__(self)
self.log_dir = log_dir
self.histogram_freq = histogram_freq
if self.histogram_freq and context.executing_eagerly():
@ -342,6 +343,21 @@ class TensorBoard(callbacks.Callback):
self.writer.add_summary(summary, step)
self.writer.flush()
def on_train_batch_begin(self, batch, logs=None):
if (not self._is_profiling and
self._total_batches_seen == self._profile_batch - 1):
profiler.start(self.log_dir)
self._is_profiling = True
def on_train_batch_end(self, batch, logs=None):
return self.on_batch_end(batch, logs)
def on_test_begin(self, logs=None):
pass
def on_test_end(self, logs=None):
pass
def on_batch_end(self, batch, logs=None):
"""Writes scalar summaries for metrics on every training batch.
@ -358,18 +374,13 @@ class TensorBoard(callbacks.Callback):
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1
if self._is_profiling:
profiler.stop()
self._is_profiling = False
elif (not self._is_profiling and
self._total_batches_seen == self._profile_batch - 1):
profiler.start(self.log_dir)
self._is_profiling = True
def on_train_begin(self, logs=None):
if self._profile_batch == 1:
profiler.start(self.log_dir)
self._is_profiling = True
pass
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model eval_function callbacks, reset batch count."""

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import copy
import itertools
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
@ -28,6 +29,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import callbacks as callbacks_module
from tensorflow.python.keras import optimizers
@ -43,6 +45,8 @@ from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.mode_keys import ModeKeys
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.profiler import trace
@ -161,6 +165,9 @@ class Model(network.Network, version_utils.ModelVersionSelector):
Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for
additional details.
"""
_TF_MODULE_IGNORED_PROPERTIES = frozenset(
itertools.chain(('_train_counter', '_test_counter', '_predict_counter'),
network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access
def __init__(self, *args, **kwargs):
super(Model, self).__init__(*args, **kwargs)
@ -186,6 +193,18 @@ class Model(network.Network, version_utils.ModelVersionSelector):
self.compiled_loss = None
self.compiled_metrics = None
self._init_batch_counters()
@trackable.no_automatic_dependency_tracking
def _init_batch_counters(self):
# Untracked Variables, used to keep track of mini-batches seen in `fit`,
# `evaluate`, and `predict`.
agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA
self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg)
self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg)
self._predict_counter = variables.Variable(
0, dtype='int64', aggregation=agg)
def get_weights(self):
"""Retrieves the weights of the model.
@ -499,11 +518,18 @@ class Model(network.Network, version_utils.ModelVersionSelector):
return self.train_function
def train_function(iterator):
"""Runs one call to `self.train_function`."""
def run_step(data):
outputs = self.train_step(data)
self._train_counter.assign_add(1)
return outputs
data = next(iterator)
outputs = self.distribute_strategy.run(
self.train_step, args=(data,))
outputs = self.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='first')
write_scalar_summaries(outputs, step=self._train_counter)
return outputs
if not self.run_eagerly:
@ -762,6 +788,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
self.stop_training = False
train_function = self.make_train_function()
self._train_counter.assign(0)
callbacks.on_train_begin()
# Handle fault-tolerance for multi-worker.
# TODO(omalleyt): Fix the ordering issues that mean this has to
@ -872,9 +899,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
return self.test_function
def test_function(iterator):
"""Runs one call to `self.test_function`."""
def run_step(data):
outputs = self.test_step(data)
self._test_counter.assign_add(1)
return outputs
data = next(iterator)
outputs = self.distribute_strategy.run(
self.test_step, args=(data,))
outputs = self.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='first')
return outputs
@ -1003,6 +1036,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
steps=data_handler.inferred_steps)
test_function = self.make_test_function()
self._test_counter.assign(0)
callbacks.on_test_begin()
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
self.reset_metrics()
@ -1075,9 +1109,15 @@ class Model(network.Network, version_utils.ModelVersionSelector):
return self.predict_function
def predict_function(iterator):
"""Runs one call to `self.predict_function`."""
def run_step(data):
outputs = self.predict_step(data)
self._predict_counter.assign_add(1)
return outputs
data = next(iterator)
outputs = self.distribute_strategy.run(
self.predict_step, args=(data,))
outputs = self.distribute_strategy.run(run_step, args=(data,))
outputs = reduce_per_replica(
outputs, self.distribute_strategy, reduction='concat')
return outputs
@ -1192,6 +1232,7 @@ class Model(network.Network, version_utils.ModelVersionSelector):
steps=data_handler.inferred_steps)
predict_function = self.make_predict_function()
self._predict_counter.assign(0)
callbacks.on_predict_begin()
for _, iterator in data_handler.enumerate_epochs(): # Single epoch.
with data_handler.catch_stop_iteration():
@ -1734,3 +1775,13 @@ def _minimize(tape, optimizer, loss, trainable_variables):
all_reduce_sum_gradients=False)
else:
optimizer.apply_gradients(zip(gradients, trainable_variables))
def _is_scalar(x):
return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0
def write_scalar_summaries(logs, step):
for name, value in logs.items():
if _is_scalar(value):
summary_ops_v2.scalar('batch_' + name, value, step=step)

View File

@ -162,6 +162,9 @@ class Model(training_lib.Model):
self._v1_compile_was_called = False
def _init_batch_counters(self):
pass # Batch counters should not be created in legacy graph mode.
@trackable.no_automatic_dependency_tracking
def _set_strategy(self, strategy):
self._compile_time_distribution_strategy = strategy

View File

@ -737,6 +737,21 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase):
self.assertLen(new_model.variables, 1)
self.assertLen(new_model.layers, 1)
def test_batch_counters_not_in_variables(self):
class MyModel(keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.layer = keras.layers.Dense(4)
def call(self, obs):
return self.layer(obs)
model = MyModel()
model(np.ones((10, 10)))
self.assertLen(model.variables, 2)
if __name__ == '__main__':
test.main()

View File

@ -36,6 +36,13 @@ base_layer = lazy_loader.LazyLoader(
base_layer_v1 = lazy_loader.LazyLoader(
"base_layer_v1", globals(),
"tensorflow.python.keras.engine.base_layer_v1")
callbacks = lazy_loader.LazyLoader(
"callbacks", globals(),
"tensorflow.python.keras.callbacks")
callbacks_v1 = lazy_loader.LazyLoader(
"callbacks_v1", globals(),
"tensorflow.python.keras.callbacks_v1")
# pylint: enable=g-inconsistent-quotes
@ -58,6 +65,21 @@ class LayerVersionSelector(object):
return super(LayerVersionSelector, cls).__new__(cls)
class TensorBoardVersionSelector(object):
"""Chooses between Keras v1 and v2 TensorBoard callback class."""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
eager_enabled = ops.executing_eagerly_outside_functions()
start_cls = cls
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
eager_enabled)
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
# Since the v2 class is not a subclass of the v1 class, __init__ has to
# be called manually.
return cls(*args, **kwargs)
return super(TensorBoardVersionSelector, cls).__new__(cls)
def swap_class(cls, v2_cls, v1_cls, eager_enabled):
"""Swaps in v2_cls or v1_cls depending on graph mode."""
if cls == object:

View File

@ -1,7 +1,9 @@
path: "tensorflow.keras.callbacks.TensorBoard"
tf_class {
is_instance: "<class \'tensorflow.python.keras.callbacks_v1.TensorBoard\'>"
is_instance: "<class \'tensorflow.python.keras.callbacks.TensorBoard\'>"
is_instance: "<class \'tensorflow.python.keras.callbacks.Callback\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.TensorBoardVersionSelector\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"

View File

@ -2,6 +2,7 @@ path: "tensorflow.keras.callbacks.TensorBoard"
tf_class {
is_instance: "<class \'tensorflow.python.keras.callbacks.TensorBoard\'>"
is_instance: "<class \'tensorflow.python.keras.callbacks.Callback\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.TensorBoardVersionSelector\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"