From 69565ec4003902794bc94e10ba5fe9469a0b3ae4 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Tue, 17 Mar 2020 14:54:56 -0700 Subject: [PATCH] Create Variables to track mini-batches seen in Model.fit / evaluate / predict. Use these counters in the TensorBoard Callback. PiperOrigin-RevId: 301459298 Change-Id: I8e92e119ef4cef37c41532d11caefd601d4395a7 --- tensorflow/python/keras/callbacks.py | 448 +++++++----------- tensorflow/python/keras/callbacks_test.py | 6 +- tensorflow/python/keras/callbacks_v1.py | 29 +- tensorflow/python/keras/engine/training.py | 63 ++- tensorflow/python/keras/engine/training_v1.py | 3 + .../keras/tests/model_subclassing_test.py | 15 + .../python/keras/utils/version_utils.py | 22 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 2 + ...orflow.keras.callbacks.-tensor-board.pbtxt | 1 + 9 files changed, 305 insertions(+), 284 deletions(-) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index bb9e61d01a2..9177d89c67b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -35,21 +35,19 @@ import six from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distributed_file_utils from tensorflow.python.distribute import multi_worker_util -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.keras import backend as K from tensorflow.python.keras.distribute import multi_worker_training_state as training_state from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.data_utils import Sequence from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 -from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.training import checkpoint_management @@ -1614,7 +1612,7 @@ class LearningRateScheduler(Callback): @keras_export('keras.callbacks.TensorBoard', v1=[]) -class TensorBoard(Callback): +class TensorBoard(Callback, version_utils.TensorBoardVersionSelector): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -1676,11 +1674,10 @@ class TensorBoard(Callback): batches. Note that writing too frequently to TensorBoard can slow down your training. profile_batch: Profile the batch(es) to sample compute characteristics. - profile_batch must be a non-negative integer or a comma separated string - of pair of positive integers. A pair of positive integers signify a - range of batches to profile. By default, it will profile the second - batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow - eager mode. + profile_batch must be a non-negative integer or a tuple of integers. + A pair of positive integers signify a range of batches to profile. + By default, it will profile the second batch. Set profile_batch=0 + to disable profiling. Must run in TensorFlow eager mode. embeddings_freq: frequency (in epochs) at which embedding layers will be visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in @@ -1713,30 +1710,18 @@ class TensorBoard(Callback): self.histogram_freq = histogram_freq self.write_graph = write_graph self.write_images = write_images - if update_freq == 'batch': - self.update_freq = 1 - else: - self.update_freq = update_freq + self.update_freq = 1 if update_freq == 'batch' else update_freq self.embeddings_freq = embeddings_freq self.embeddings_metadata = embeddings_metadata + self._init_profile_batch(profile_batch) + self._epoch = 0 - self._samples_seen = 0 - self._samples_seen_at_last_write = 0 - self._current_batch = 0 - - # A collection of file writers currently in use, to be closed when - # training ends for this callback. Writers are keyed by the - # directory name under the root logdir: e.g., "train" or - # "validation". - self._train_run_name = 'train' - self._validation_run_name = 'validation' + # Lazily initialized in order to avoid creating event files when + # not needed. self._writers = {} - self._start_batch, self._stop_batch = self._init_profile_batch( - profile_batch) - if self._start_batch > 0: - profiler.warmup() # Improve the profiling accuracy. - # True when a trace is running. - self._is_tracing = False + + # Used to restore any existing `SummaryWriter` after training ends. + self._prev_summary_state = [] def _validate_kwargs(self, kwargs): """Handle arguments were supported in V1.""" @@ -1768,37 +1753,56 @@ class TensorBoard(Callback): def set_model(self, model): """Sets Keras model and writes graph if specified.""" self.model = model + self._log_write_dir = self._get_log_write_dir() - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # TensorBoard callback involves writing a summary file in a - # possibly distributed settings. - self._log_write_dir = distributed_file_utils.write_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - else: - self._log_write_dir = self.log_dir + self._train_dir = os.path.join(self._log_write_dir, 'train') + self._train_step = self.model._train_counter # pylint: disable=protected-access - with context.eager_mode(): - self._close_writers() - if self.write_graph: - with self._get_writer(self._train_run_name).as_default(): - with summary_ops_v2.always_record_summaries(): - if not model.run_eagerly: - summary_ops_v2.graph(K.get_graph(), step=0) + self._val_dir = os.path.join(self._log_write_dir, 'validation') + self._val_step = self.model._test_counter # pylint: disable=protected-access - summary_writable = ( - self.model._is_graph_network or # pylint: disable=protected-access - self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access - if summary_writable: - summary_ops_v2.keras_model('keras', self.model, step=0) + self._writers = {} # Resets writers. + if self.write_graph: + self._write_keras_model_graph() if self.embeddings_freq: self._configure_embeddings() - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - self._prev_summary_recording = summary_state.is_recording - self._prev_summary_writer = summary_state.writer - self._prev_summary_step = summary_state.step + @property + def _train_writer(self): + if 'train' not in self._writers: + self._writers['train'] = summary_ops_v2.create_file_writer_v2( + self._train_dir) + return self._writers['train'] + + @property + def _val_writer(self): + if 'val' not in self._writers: + self._writers['val'] = summary_ops_v2.create_file_writer_v2(self._val_dir) + return self._writers['val'] + + def _get_log_write_dir(self): + """For multi-worker, only chief should write, others write to '/tmp'.""" + return distributed_file_utils.write_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _delete_tmp_write_dir(self): + """Deletes tmp write directories for multi-worker.""" + distributed_file_utils.remove_temp_dirpath(self.log_dir, + self.model.distribute_strategy) + + def _write_keras_model_graph(self): + """Writes Keras graph networks to TensorBoard.""" + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + if not self.model.run_eagerly: + summary_ops_v2.graph(K.get_graph(), step=0) + + summary_writable = ( + self.model._is_graph_network or # pylint: disable=protected-access + self.model.__class__.__name__ == 'Sequential') # pylint: disable=protected-access + if summary_writable: + summary_ops_v2.keras_model('keras', self.model, step=0) def _configure_embeddings(self): """Configure the Projector for embeddings.""" @@ -1839,74 +1843,44 @@ class TensorBoard(Callback): writer = DummyWriter(self._log_write_dir) projector.visualize_embeddings(writer, config) - def _close_writers(self): - """Close all remaining open file writers owned by this callback. - - If there are no such file writers, this is a no-op. - """ - with context.eager_mode(): - for writer in six.itervalues(self._writers): - writer.close() - self._writers.clear() - - def _get_writer(self, writer_name): - """Get a summary writer for the given subdirectory under the logdir. - - A writer will be created if it does not yet exist. - - Arguments: - writer_name: The name of the directory for which to create or - retrieve a writer. Should be either `self._train_run_name` or - `self._validation_run_name`. - - Returns: - A `SummaryWriter` object. - """ - if writer_name not in self._writers: - path = os.path.join(self._log_write_dir, writer_name) - writer = summary_ops_v2.create_file_writer_v2(path) - self._writers[writer_name] = writer - return self._writers[writer_name] - - def _set_default_writer(self, writer_name): + def _push_writer(self, writer, step): """Sets the default writer for custom batch-level summaries.""" if self.update_freq == 'epoch': - # Writer is only used for custom summaries, which are written - # batch-by-batch. return - step = self._total_batches_seen[writer_name] + summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access + self._prev_summary_state.append({ + 'is_recording': summary_state.is_recording, + 'writer': summary_state.writer, + 'step': summary_state.step + }) - def _should_record(): - return math_ops.equal(step % self.update_freq, 0) + if self.update_freq == 'epoch': + should_record = False + writer = None + else: + should_record = lambda: math_ops.equal(step % self.update_freq, 0) + + summary_state.is_recording = should_record + summary_state.writer = writer + # TODO(b/151339474): Fix deadlock when not using .value() here. + summary_ops_v2.set_step(step.value()) + + def _pop_writer(self): + """Pops the current writer.""" + if self.update_freq == 'epoch': + return + + prev_state = self._prev_summary_state.pop() summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = _should_record - summary_state.writer = self._get_writer(writer_name) - summary_ops_v2.set_step(step) + summary_state.is_recording = prev_state['is_recording'] + summary_state.writer = prev_state['writer'] + summary_ops_v2.set_step(prev_state['step']) - def _init_batch_steps(self): - """Create the total batch counters.""" - if ops.executing_eagerly_outside_functions(): - # Variables are needed for the `step` value of custom tf.summaries - # to be updated inside a tf.function. - self._total_batches_seen = { - self._train_run_name: variables.Variable(0, dtype='int64'), - self._validation_run_name: variables.Variable(0, dtype='int64') - } - else: - # Custom tf.summaries are not supported in legacy graph mode. - self._total_batches_seen = { - self._train_run_name: 0, - self._validation_run_name: 0 - } - - def _increment_step(self, writer_name): - step = self._total_batches_seen[writer_name] - if isinstance(step, variables.Variable): - step.assign_add(1) - else: - self._total_batches_seen[writer_name] += 1 + def _close_writers(self): + for writer in self._writers.values(): + writer.close() def _init_profile_batch(self, profile_batch): """Validate profile_batch value and set the range of batches to profile. @@ -1926,75 +1900,79 @@ class TensorBoard(Callback): """ profile_batch_error_message = ( - 'profile_batch must be a non-negative integer or a comma separated ' - 'string of pair of positive integers. A pair of positive integers ' - 'signify a range of batches to profile.') - try: - profile_range = [int(i) for i in str(profile_batch).split(',')] - except ValueError: - raise ValueError(profile_batch_error_message) - if len(profile_range) == 1: # single batch - start_batch, stop_batch = profile_range[0], profile_range[0] - if start_batch < 0: - raise ValueError(profile_batch_error_message) - elif len(profile_range) == 2: # (start_batch, stop_batch) - start_batch, stop_batch = profile_range - # [0, 0], [-1, 100], [6, 5] are illegal. - if start_batch <= 0 or start_batch > stop_batch: - raise ValueError(profile_batch_error_message) + 'profile_batch must be a non-negative integer or 2-tuple of positive ' + 'integers. A pair of positive integers signifies a range of batches ' + 'to profile. Found: {}'.format(profile_batch)) + + # Support legacy way of specifying "start,stop" or "start" as str. + if isinstance(profile_batch, six.string_types): + profile_batch = str(profile_batch).split(',') + profile_batch = nest.map_structure(int, profile_batch) + + if isinstance(profile_batch, int): + self._start_batch = profile_batch + self._stop_batch = profile_batch + elif isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2: + self._start_batch, self._stop_batch = profile_batch else: raise ValueError(profile_batch_error_message) - return start_batch, stop_batch + + if self._start_batch < 0 or self._stop_batch < self._start_batch: + raise ValueError(profile_batch_error_message) + + if self._start_batch > 0: + profiler.warmup() # Improve the profiling accuracy. + # True when a trace is running. + self._is_tracing = False + + # Setting `profile_batch=0` disables profiling. + self._should_trace = not (self._start_batch == 0 and self._stop_batch == 0) def on_train_begin(self, logs=None): - self._init_batch_steps() - if self._start_batch == 1: - self._enable_trace() + self._push_writer(self._train_writer, self._train_step) + + def on_train_end(self, logs=None): + self._pop_writer() + + if self._is_tracing: + self._stop_trace() + + self._close_writers() + self._delete_tmp_write_dir() def on_test_begin(self, logs=None): - self._set_default_writer(self._validation_run_name) + self._push_writer(self._val_writer, self._val_step) + + def on_test_end(self, logs=None): + self._pop_writer() + + def on_train_batch_begin(self, batch, logs=None): + if not self._should_trace: + return + + if self._epoch == 0 and batch == self._start_batch: + self._start_trace() def on_train_batch_end(self, batch, logs=None): - """Writes scalar summaries for metrics on every training batch. - - Performs profiling if current batch is in profiler_batches. + """Performs profiling if current batch is in profiler_batches. Arguments: batch: Integer, index of batch within the current epoch. logs: Dict. Metric results for this batch. """ - # TODO(b/150629188): Make TensorBoard callback not use batch hooks - # by default. - if self.update_freq == 'epoch' and self._start_batch is None: + if not self._should_trace: return - # Don't output batch_size and batch number as TensorBoard summaries - logs = logs or {} - train_batches = self._total_batches_seen[self._train_run_name] - if self.update_freq != 'epoch' and batch % self.update_freq == 0: - self._log_metrics(logs, prefix='batch_', step=train_batches) - - self._increment_step(self._train_run_name) - if self._is_tracing: - control_flow_ops.cond( - math_ops.greater_equal(train_batches, self._stop_batch), - lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - else: - control_flow_ops.cond( - math_ops.equal(train_batches, self._start_batch - 1), - lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda - - def on_test_batch_end(self, batch, logs=None): - if self.update_freq == 'epoch': - return - self._increment_step(self._validation_run_name) + if self._is_tracing and batch >= self._stop_batch: + self._stop_trace() def on_epoch_begin(self, epoch, logs=None): - self._set_default_writer(self._train_run_name) + # Keeps track of epoch for profiling. + self._epoch = epoch def on_epoch_end(self, epoch, logs=None): """Runs metrics and histogram summaries at epoch end.""" - self._log_metrics(logs, prefix='epoch_', step=epoch) + self._log_epoch_metrics(epoch, logs) if self.histogram_freq and epoch % self.histogram_freq == 0: self._log_weights(epoch) @@ -2002,124 +1980,57 @@ class TensorBoard(Callback): if self.embeddings_freq and epoch % self.embeddings_freq == 0: self._log_embeddings(epoch) - def on_train_end(self, logs=None): - if self._is_tracing: - self._log_trace() - self._close_writers() - - summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access - summary_state.is_recording = self._prev_summary_recording - summary_state.writer = self._prev_summary_writer - summary_state.step = self._prev_summary_step - - # In case this callback is used via native Keras, _get_distribution_strategy does not exist. - if hasattr(self.model, '_get_distribution_strategy'): - # Safely remove the unneeded temp files. - distributed_file_utils.remove_temp_dirpath( - self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access - - def _enable_trace(self): - """Starts to collect trace graph to TensorBoard. - - Collects both trace and graph in eager mode, and trace only in graph mode. - """ - if context.executing_eagerly(): - # Graph must be traced in eager mode. - summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) + def _start_trace(self): + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=self._train_dir) self._is_tracing = True - def _enable_trace_return_true(self): - """Starts to collect trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._enable_trace() - return True - - def _log_trace(self): - """Logs the trace graph to TensorBoard. - - Logs both trace and graph in eager mode, and trace only in graph mode. - """ - profiler.stop() - if context.executing_eagerly(): - # Graph must be traced in eager mode. - with self._get_writer(self._train_run_name).as_default(), \ - summary_ops_v2.always_record_summaries(): + def _stop_trace(self, batch=None): + """Logs the trace graph to TensorBoard.""" + if batch is None: + batch = self._stop_batch + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): # TODO(b/126388999): Remove step info in the summary name. - step = K.get_value(self._total_batches_seen[self._train_run_name]) - summary_ops_v2.trace_export(name='batch_%d' % step, step=step) + summary_ops_v2.trace_export(name='batch_%d' % batch, step=batch) + profiler.stop() self._is_tracing = False - def _log_trace_return_true(self): - """Logs the trace graph to TensorBoard and returns True. - - Returns: - True. - """ - self._log_trace() - return True - - def _log_metrics(self, logs, prefix, step): - """Writes metrics out as custom scalar summaries. + def _log_epoch_metrics(self, epoch, logs): + """Writes epoch metrics out as scalar summaries. Arguments: - logs: Dict. Keys are scalar summary names, values are NumPy scalars. - prefix: String. The prefix to apply to the scalar summary names. - step: Int. The global step to use for TensorBoard. + epoch: Int. The global step to use for TensorBoard. + logs: Dict. Keys are scalar summary names, values are scalars. """ - if logs is None: - logs = {} + if not logs: + return - # Group metrics by the name of their associated file writer. Values - # are lists of metrics, as (name, scalar_value) pairs. - logs_by_writer = { - self._train_run_name: [], - self._validation_run_name: [], - } - validation_prefix = 'val_' - for (name, value) in logs.items(): - if name in ('batch', 'size', 'num_steps'): - # Scrub non-metric items. - continue - if name.startswith(validation_prefix): - name = name[len(validation_prefix):] - writer_name = self._validation_run_name - else: - writer_name = self._train_run_name - name = prefix + name # assign batch or epoch prefix - logs_by_writer[writer_name].append((name, value)) + train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')} + val_logs = {k: v for k, v in logs.items() if k.startswith('val_')} - with context.eager_mode(): - with summary_ops_v2.always_record_summaries(): - for writer_name in logs_by_writer: - these_logs = logs_by_writer[writer_name] - if not these_logs: - # Don't create a "validation" events file if we don't - # actually have any validation data. - continue - writer = self._get_writer(writer_name) - with writer.as_default(): - for (name, value) in these_logs: - summary_ops_v2.scalar(name, value, step=step) + with summary_ops_v2.always_record_summaries(): + if train_logs: + with self._train_writer.as_default(): + for name, value in train_logs.items(): + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) + if val_logs: + with self._val_writer.as_default(): + for name, value in val_logs.items(): + name = name[4:] # Remove 'val_' prefix. + summary_ops_v2.scalar('epoch_' + name, value, step=epoch) def _log_weights(self, epoch): """Logs the weights of the Model to TensorBoard.""" - writer = self._get_writer(self._train_run_name) - with context.eager_mode(), \ - writer.as_default(), \ - summary_ops_v2.always_record_summaries(): - for layer in self.model.layers: - for weight in layer.weights: - weight_name = weight.name.replace(':', '_') - with ops.init_scope(): - weight = K.get_value(weight) - summary_ops_v2.histogram(weight_name, weight, step=epoch) - if self.write_images: - self._log_weight_as_image(weight, weight_name, epoch) - writer.flush() + with self._train_writer.as_default(): + with summary_ops_v2.always_record_summaries(): + for layer in self.model.layers: + for weight in layer.weights: + weight_name = weight.name.replace(':', '_') + summary_ops_v2.histogram(weight_name, weight, step=epoch) + if self.write_images: + self._log_weight_as_image(weight, weight_name, epoch) + self._train_writer.flush() def _log_weight_as_image(self, weight, weight_name, epoch): """Logs a weight as a TensorBoard image.""" @@ -2150,6 +2061,9 @@ class TensorBoard(Callback): 'keras_embedding.ckpt-{}'.format(epoch)) self.model.save_weights(embeddings_ckpt) + def _implements_train_batch_hooks(self): + return not (self._start_batch == 0 and self._stop_batch == 0) + @keras_export('keras.callbacks.ReduceLROnPlateau') class ReduceLROnPlateau(Callback): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index eb62d0b29ee..54f71402177 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -2079,17 +2079,19 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], ) # Verifies trace exists in the first train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) model.fit( np.zeros((64, 1)), np.zeros((64, 1)), + batch_size=32, callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], ) # Verifies trace exists in the second train_dir. - self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + self.assertIsNotNone(self._get_trace_file(logdir=self.logdir)) def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index 524e039f597..09af890b76c 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -39,7 +39,7 @@ from tensorflow.python.util.tf_export import keras_export @keras_export(v1=['keras.callbacks.TensorBoard']) -class TensorBoard(callbacks.Callback): +class TensorBoard(callbacks.TensorBoard): # pylint: disable=line-too-long """Enable visualizations for TensorBoard. @@ -127,7 +127,8 @@ class TensorBoard(callbacks.Callback): embeddings_data=None, update_freq='epoch', profile_batch=2): - super(TensorBoard, self).__init__() + # Don't call super's init since it is an eager-only version. + callbacks.Callback.__init__(self) self.log_dir = log_dir self.histogram_freq = histogram_freq if self.histogram_freq and context.executing_eagerly(): @@ -342,6 +343,21 @@ class TensorBoard(callbacks.Callback): self.writer.add_summary(summary, step) self.writer.flush() + def on_train_batch_begin(self, batch, logs=None): + if (not self._is_profiling and + self._total_batches_seen == self._profile_batch - 1): + profiler.start(self.log_dir) + self._is_profiling = True + + def on_train_batch_end(self, batch, logs=None): + return self.on_batch_end(batch, logs) + + def on_test_begin(self, logs=None): + pass + + def on_test_end(self, logs=None): + pass + def on_batch_end(self, batch, logs=None): """Writes scalar summaries for metrics on every training batch. @@ -358,18 +374,13 @@ class TensorBoard(callbacks.Callback): self._write_custom_summaries(self._total_batches_seen, batch_logs) self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 + if self._is_profiling: profiler.stop() self._is_profiling = False - elif (not self._is_profiling and - self._total_batches_seen == self._profile_batch - 1): - profiler.start(self.log_dir) - self._is_profiling = True def on_train_begin(self, logs=None): - if self._profile_batch == 1: - profiler.start(self.log_dir) - self._is_profiling = True + pass def on_epoch_begin(self, epoch, logs=None): """Add histogram op to Model eval_function callbacks, reset batch count.""" diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 7dcf10a506c..21361f680da 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import copy +import itertools from tensorflow.python.distribute import distribute_coordinator as dc from tensorflow.python.distribute import distribute_coordinator_context as dc_context @@ -28,6 +29,7 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import monitoring +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import callbacks as callbacks_module from tensorflow.python.keras import optimizers @@ -43,6 +45,8 @@ from tensorflow.python.keras.utils import version_utils from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.ops import array_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import summary_ops_v2 +from tensorflow.python.ops import variables from tensorflow.python.ops.ragged import ragged_concat_ops from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.profiler import trace @@ -161,6 +165,9 @@ class Model(network.Network, version_utils.ModelVersionSelector): Checkout [guide](https://www.tensorflow.org/guide/keras/overview) for additional details. """ + _TF_MODULE_IGNORED_PROPERTIES = frozenset( + itertools.chain(('_train_counter', '_test_counter', '_predict_counter'), + network.Network._TF_MODULE_IGNORED_PROPERTIES)) # pylint: disable=protected-access def __init__(self, *args, **kwargs): super(Model, self).__init__(*args, **kwargs) @@ -186,6 +193,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.compiled_loss = None self.compiled_metrics = None + self._init_batch_counters() + + @trackable.no_automatic_dependency_tracking + def _init_batch_counters(self): + # Untracked Variables, used to keep track of mini-batches seen in `fit`, + # `evaluate`, and `predict`. + agg = variables.VariableAggregationV2.ONLY_FIRST_REPLICA + self._train_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._test_counter = variables.Variable(0, dtype='int64', aggregation=agg) + self._predict_counter = variables.Variable( + 0, dtype='int64', aggregation=agg) + def get_weights(self): """Retrieves the weights of the model. @@ -499,11 +518,18 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.train_function def train_function(iterator): + """Runs one call to `self.train_function`.""" + + def run_step(data): + outputs = self.train_step(data) + self._train_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.train_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') + write_scalar_summaries(outputs, step=self._train_counter) return outputs if not self.run_eagerly: @@ -762,6 +788,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): self.stop_training = False train_function = self.make_train_function() + self._train_counter.assign(0) callbacks.on_train_begin() # Handle fault-tolerance for multi-worker. # TODO(omalleyt): Fix the ordering issues that mean this has to @@ -872,9 +899,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.test_function def test_function(iterator): + """Runs one call to `self.test_function`.""" + + def run_step(data): + outputs = self.test_step(data) + self._test_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.test_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='first') return outputs @@ -1003,6 +1036,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) test_function = self.make_test_function() + self._test_counter.assign(0) callbacks.on_test_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. self.reset_metrics() @@ -1075,9 +1109,15 @@ class Model(network.Network, version_utils.ModelVersionSelector): return self.predict_function def predict_function(iterator): + """Runs one call to `self.predict_function`.""" + + def run_step(data): + outputs = self.predict_step(data) + self._predict_counter.assign_add(1) + return outputs + data = next(iterator) - outputs = self.distribute_strategy.run( - self.predict_step, args=(data,)) + outputs = self.distribute_strategy.run(run_step, args=(data,)) outputs = reduce_per_replica( outputs, self.distribute_strategy, reduction='concat') return outputs @@ -1192,6 +1232,7 @@ class Model(network.Network, version_utils.ModelVersionSelector): steps=data_handler.inferred_steps) predict_function = self.make_predict_function() + self._predict_counter.assign(0) callbacks.on_predict_begin() for _, iterator in data_handler.enumerate_epochs(): # Single epoch. with data_handler.catch_stop_iteration(): @@ -1734,3 +1775,13 @@ def _minimize(tape, optimizer, loss, trainable_variables): all_reduce_sum_gradients=False) else: optimizer.apply_gradients(zip(gradients, trainable_variables)) + + +def _is_scalar(x): + return isinstance(x, (ops.Tensor, variables.Variable)) and x.shape.rank == 0 + + +def write_scalar_summaries(logs, step): + for name, value in logs.items(): + if _is_scalar(value): + summary_ops_v2.scalar('batch_' + name, value, step=step) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 1c0fea91337..710f9bf3497 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -162,6 +162,9 @@ class Model(training_lib.Model): self._v1_compile_was_called = False + def _init_batch_counters(self): + pass # Batch counters should not be created in legacy graph mode. + @trackable.no_automatic_dependency_tracking def _set_strategy(self, strategy): self._compile_time_distribution_strategy = strategy diff --git a/tensorflow/python/keras/tests/model_subclassing_test.py b/tensorflow/python/keras/tests/model_subclassing_test.py index 761f720cea5..5af1148f4f0 100644 --- a/tensorflow/python/keras/tests/model_subclassing_test.py +++ b/tensorflow/python/keras/tests/model_subclassing_test.py @@ -737,6 +737,21 @@ class CustomCallSignatureTests(test.TestCase, parameterized.TestCase): self.assertLen(new_model.variables, 1) self.assertLen(new_model.layers, 1) + def test_batch_counters_not_in_variables(self): + + class MyModel(keras.Model): + + def __init__(self): + super(MyModel, self).__init__() + self.layer = keras.layers.Dense(4) + + def call(self, obs): + return self.layer(obs) + + model = MyModel() + model(np.ones((10, 10))) + self.assertLen(model.variables, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/utils/version_utils.py b/tensorflow/python/keras/utils/version_utils.py index cf485e1080d..377f370430c 100644 --- a/tensorflow/python/keras/utils/version_utils.py +++ b/tensorflow/python/keras/utils/version_utils.py @@ -36,6 +36,13 @@ base_layer = lazy_loader.LazyLoader( base_layer_v1 = lazy_loader.LazyLoader( "base_layer_v1", globals(), "tensorflow.python.keras.engine.base_layer_v1") +callbacks = lazy_loader.LazyLoader( + "callbacks", globals(), + "tensorflow.python.keras.callbacks") +callbacks_v1 = lazy_loader.LazyLoader( + "callbacks_v1", globals(), + "tensorflow.python.keras.callbacks_v1") + # pylint: enable=g-inconsistent-quotes @@ -58,6 +65,21 @@ class LayerVersionSelector(object): return super(LayerVersionSelector, cls).__new__(cls) +class TensorBoardVersionSelector(object): + """Chooses between Keras v1 and v2 TensorBoard callback class.""" + + def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument + eager_enabled = ops.executing_eagerly_outside_functions() + start_cls = cls + cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, + eager_enabled) + if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: + # Since the v2 class is not a subclass of the v1 class, __init__ has to + # be called manually. + return cls(*args, **kwargs) + return super(TensorBoardVersionSelector, cls).__new__(cls) + + def swap_class(cls, v2_cls, v1_cls, eager_enabled): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt index 4504633d4a1..2e0c6c97826 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -1,7 +1,9 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" + is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt index 24385e2722a..51d6901e936 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-tensor-board.pbtxt @@ -2,6 +2,7 @@ path: "tensorflow.keras.callbacks.TensorBoard" tf_class { is_instance: "" is_instance: "" + is_instance: "" is_instance: "" member_method { name: "__init__"