diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 22337e8843e..5cb6bc753cc 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1140,7 +1140,7 @@ class TensorBoard(Callback): your training. profile_batch: Profile the batch to sample compute characteristics. By default, it will profile the second batch. Set profile_batch=0 to - disable profiling. + disable profiling. Must run in TensorFlow eager mode. Raises: ValueError: If histogram_freq is set and no validation data is provided. @@ -1274,16 +1274,10 @@ class TensorBoard(Callback): self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 if self._is_tracing: - # TODO(b/126388999): Remove step info in the summary name. - summary_ops_v2.trace_export( - name='batch_%d' % self._total_batches_seen, - step=self._total_batches_seen, - profiler_outdir=self.log_dir) - self._is_tracing = False + self._log_trace() elif (not self._is_tracing and self._total_batches_seen == self._profile_batch - 1): - summary_ops_v2.trace_on(graph=True, profiler=True) - self._is_tracing = True + self._enable_trace() def on_epoch_end(self, epoch, logs=None): """Runs metrics and histogram summaries at epoch end.""" @@ -1294,13 +1288,24 @@ class TensorBoard(Callback): self._log_weights(epoch) def on_train_end(self, logs=None): - self._close_writers() if self._is_tracing: - # TODO(b/126388999): Remove step info in the summary name. - summary_ops_v2.trace_export( - name='batch_%d' % self._total_batches_seen, - step=self._total_batches_seen, - profiler_outdir=self.log_dir) + self._log_trace() + self._close_writers() + + def _enable_trace(self): + if context.executing_eagerly(): + summary_ops_v2.trace_on(graph=True, profiler=True) + self._is_tracing = True + + def _log_trace(self): + if context.executing_eagerly(): + with self._train_writer.as_default(), \ + summary_ops_v2.always_record_summaries(): + # TODO(b/126388999): Remove step info in the summary name. + summary_ops_v2.trace_export( + name='batch_%d' % self._total_batches_seen, + step=self._total_batches_seen, + profiler_outdir=self.log_dir) self._is_tracing = False def _log_metrics(self, logs, prefix, step): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 42713451811..b475c29c652 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -1236,17 +1236,28 @@ class TestTensorBoardV2(keras_parameterized.TestCase): # Note that this test specifies model_type explicitly. -class TestTensorBoardV2WriteModelTest(test.TestCase): +class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): def setUp(self): - super(TestTensorBoardV2WriteModelTest, self).setUp() + super(TestTensorBoardV2NonParameterizedTest, self).setUp() self.logdir = os.path.join(self.get_temp_dir(), 'tb') self.train_dir = os.path.join(self.logdir, 'train') self.validation_dir = os.path.join(self.logdir, 'validation') + def _get_seq_model(self): + model = keras.models.Sequential([ + keras.layers.Conv2D(8, (3, 3), input_shape=(10, 10, 1)), + keras.layers.Flatten(), + keras.layers.Dense(1), + ]) + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + return model + def fitModelAndAssertKerasModelWritten(self, model): x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) - tb_cbk = keras.callbacks.TensorBoard(self.logdir, write_graph=True) + tb_cbk = keras.callbacks.TensorBoard(self.logdir, + write_graph=True, + profile_batch=0) model.fit( x, y, @@ -1289,6 +1300,74 @@ class TestTensorBoardV2WriteModelTest(test.TestCase): model.compile('sgd', 'mse', run_eagerly=False) self.fitModelAndAssertKerasModelWritten(model) + # TODO(b/126944683): Put parameterization in the class when graph is fixed. + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_TensorBoard_autoTrace(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch=1, write_graph=False) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + summary_file = list_summaries(self.logdir) + + self.assertEqual( + summary_file.tensors, + { + _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), + }, + ) + + # TODO(b/126944683): Put parameterization in the class when graph is fixed. + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch=2, write_graph=False) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + summary_file = list_summaries(self.logdir) + + self.assertEqual( + summary_file.tensors, + { + _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), + }, + ) + + # TODO(b/126944683): Put parameterization in the class when graph is fixed. + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) + def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch=10000, write_graph=False) + + model.fit( + x, + y, + batch_size=2, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + summary_file = list_summaries(self.logdir) + + # Enabled trace only on the 10000th batch, thus it should be empty. + self.assertEmpty(summary_file.tensors) + if __name__ == '__main__': test.main()