diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 7c89061ab2c..19c1b646103 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -141,6 +141,7 @@ py_library( "//tensorflow/python/keras/distribute:multi_worker_training_state", "//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:mode_keys", + "//tensorflow/python/profiler:profiler_v2", "//tensorflow/tools/docs:doc_controls", ], ) @@ -153,8 +154,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":backend", - "//tensorflow/python/eager:profiler", "//tensorflow/python/keras/utils:engine_utils", + "//tensorflow/python/profiler:profiler_v2", ], ) diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 5fae5eb9218..8651cf27375 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -48,6 +48,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.training import checkpoint_management from tensorflow.python.util import nest from tensorflow.python.util.compat import collections_abc @@ -1575,11 +1576,25 @@ class TensorBoard(Callback): You can find more information about TensorBoard [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). - Example: + Example (Basic): ```python tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) - #run the tensorboard command to view the visualizations + # run the tensorboard command to view the visualizations. + ``` + Example (Profile): + ```python + # profile a single batch, e.g. the 5th batch. + tensorboard_callback = + tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5) + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + # run the tensorboard command to view the visualizations in profile plugin. + + # profile a range of batches, e.g. from 10 to 20. + tensorboard_callback = + tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20') + model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) + # run the tensorboard command to view the visualizations in profile plugin. ``` Arguments: @@ -1599,11 +1614,14 @@ class TensorBoard(Callback): callback will write the metrics and losses to TensorBoard every 1000 batches. Note that writing too frequently to TensorBoard can slow down your training. - profile_batch: Profile the batch to sample compute characteristics. By - default, it will profile the second batch. Set profile_batch=0 to - disable profiling. Must run in TensorFlow eager mode. - embeddings_freq: frequency (in epochs) at which embedding layers will - be visualized. If set to 0, embeddings won't be visualized. + profile_batch: Profile the batch(es) to sample compute characteristics. + profile_batch must be a non-negative integer or a comma separated string + of pair of positive integers. A pair of positive integers signify a + range of batches to profile. By default, it will profile the second + batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow + eager mode. + embeddings_freq: frequency (in epochs) at which embedding layers will be + visualized. If set to 0, embeddings won't be visualized. embeddings_metadata: a dictionary which maps layer name to a file name in which metadata for this embedding layer is saved. See the [details]( @@ -1652,8 +1670,8 @@ class TensorBoard(Callback): self._train_run_name = 'train' self._validation_run_name = 'validation' self._writers = {} - - self._profile_batch = profile_batch + self._start_batch, self._stop_batch = self._init_profile_batch( + profile_batch) # True when a trace is running. self._is_tracing = False @@ -1827,10 +1845,49 @@ class TensorBoard(Callback): else: self._total_batches_seen[writer_name] += 1 + def _init_profile_batch(self, profile_batch): + """Validate profile_batch value and set the range of batches to profile. + + Arguments: + profile_batch: The range of batches to profile. Should be a non-negative + integer or a comma separated string of pair of positive integers. A pair + of positive integers signify a range of batches to profile. + + Returns: + A pair of non-negative integers specifying the start and stop batch to + profile. + + Raises: + ValueError: If profile_batch is not an integer or a comma seperated pair + of positive integers. + + """ + profile_batch_error_message = ( + 'profile_batch must be a non-negative integer or a comma separated ' + 'string of pair of positive integers. A pair of positive integers ' + 'signify a range of batches to profile.') + try: + profile_range = [int(i) for i in str(profile_batch).split(',')] + except ValueError: + raise ValueError(profile_batch_error_message) + if len(profile_range) == 1: # single batch + start_batch, stop_batch = profile_range[0], profile_range[0] + if start_batch < 0: + raise ValueError(profile_batch_error_message) + elif len(profile_range) == 2: # (start_batch, stop_batch) + start_batch, stop_batch = profile_range + # [0, 0], [-1, 100], [6, 5] are illegal. + if start_batch <= 0 or start_batch > stop_batch: + raise ValueError(profile_batch_error_message) + else: + raise ValueError(profile_batch_error_message) + return start_batch, stop_batch + def on_train_begin(self, logs=None): self._init_batch_steps() - if self._profile_batch == 1: - summary_ops_v2.trace_on(graph=True, profiler=True) + if self._start_batch == 1: + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) self._is_tracing = True def on_test_begin(self, logs=None): @@ -1845,7 +1902,7 @@ class TensorBoard(Callback): batch: Integer, index of batch within the current epoch. logs: Dict. Metric results for this batch. """ - if self.update_freq == 'epoch' and self._profile_batch is None: + if self.update_freq == 'epoch' and self._start_batch is None: return # Don't output batch_size and batch number as TensorBoard summaries @@ -1857,10 +1914,11 @@ class TensorBoard(Callback): self._increment_step(self._train_run_name) if context.executing_eagerly(): - if self._is_tracing: + if self._is_tracing and math_ops.greater_equal(train_batches, + self._stop_batch): self._log_trace() elif (not self._is_tracing and - math_ops.equal(train_batches, self._profile_batch - 1)): + math_ops.equal(train_batches, self._start_batch - 1)): self._enable_trace() def on_test_batch_end(self, batch, logs=None): @@ -1899,7 +1957,8 @@ class TensorBoard(Callback): def _enable_trace(self): if context.executing_eagerly(): - summary_ops_v2.trace_on(graph=True, profiler=True) + summary_ops_v2.trace_on(graph=True, profiler=False) + profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) self._is_tracing = True def _log_trace(self): @@ -1909,10 +1968,8 @@ class TensorBoard(Callback): summary_ops_v2.always_record_summaries(): # TODO(b/126388999): Remove step info in the summary name. step = K.get_value(self._total_batches_seen[self._train_run_name]) - summary_ops_v2.trace_export( - name='batch_%d' % step, - step=step, - profiler_outdir=os.path.join(self._log_write_dir, 'train')) + summary_ops_v2.trace_export(name='batch_%d' % step, step=step) + profiler.stop() self._is_tracing = False def _log_metrics(self, logs, prefix, step): diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index bf6d8cda6f2..34f0138560c 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -1805,6 +1805,15 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): experimental_run_tf_function=testing_utils.should_run_tf_function()) return model + def _get_trace_file(self, logdir): + profile_dir = os.path.join(logdir, 'plugins', 'profile') + for (dirpath, dirnames, filenames) in os.walk(profile_dir): + del dirnames # unused + for filename in filenames: + if filename.endswith('.trace'): + return os.path.join(dirpath, filename) + return None + def fitModelAndAssertKerasModelWritten(self, model): x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) tb_cbk = keras.callbacks.TensorBoard(self.logdir, @@ -1873,6 +1882,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), }, ) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): model = self._get_seq_model() @@ -1895,6 +1905,78 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), }, ) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + + def test_TensorBoard_autoTrace_profileBatchRangeSingle(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch='2,2', write_graph=False) + + model.fit( + x, + y, + batch_size=3, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + summary_file = list_summaries(self.logdir) + + self.assertEqual( + summary_file.tensors, + { + # Trace will be logged once at the batch it stops profiling. + _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), + }, + ) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + + def test_TensorBoard_autoTrace_profileBatchRange(self): + model = self._get_seq_model() + x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) + tb_cbk = keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch='1,3', write_graph=False) + + model.fit( + x, + y, + batch_size=4, + epochs=2, + validation_data=(x, y), + callbacks=[tb_cbk]) + summary_file = list_summaries(self.logdir) + + self.assertEqual( + summary_file.tensors, + { + # Trace will be logged once at the batch it stops profiling. + _ObservedSummary(logdir=self.train_dir, tag=u'batch_3'), + }, + ) + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + + def test_TensorBoard_autoTrace_profileInvalidBatchRange(self): + with self.assertRaises(ValueError): + keras.callbacks.TensorBoard( + self.logdir, + histogram_freq=1, + profile_batch='-1,3', + write_graph=False) + + with self.assertRaises(ValueError): + keras.callbacks.TensorBoard( + self.logdir, + histogram_freq=1, + profile_batch='1,None', + write_graph=False) + + with self.assertRaises(ValueError): + keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch='6,5', write_graph=False) + + with self.assertRaises(ValueError): + keras.callbacks.TensorBoard( + self.logdir, histogram_freq=1, profile_batch=-1, write_graph=False) def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self): model = self._get_seq_model() @@ -1913,6 +1995,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): # Enabled trace only on the 10000th batch, thus it should be empty. self.assertEmpty(summary_file.tensors) + self.assertIsNone(self._get_trace_file(logdir=self.train_dir)) class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase): diff --git a/tensorflow/python/keras/callbacks_v1.py b/tensorflow/python/keras/callbacks_v1.py index db0d2b9f4b5..524e039f597 100644 --- a/tensorflow/python/keras/callbacks_v1.py +++ b/tensorflow/python/keras/callbacks_v1.py @@ -24,7 +24,6 @@ import os import numpy as np from tensorflow.python.eager import context -from tensorflow.python.eager import profiler from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks @@ -33,6 +32,7 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.summary import summary as tf_summary from tensorflow.python.training import saver from tensorflow.python.util.tf_export import keras_export @@ -359,16 +359,16 @@ class TensorBoard(callbacks.Callback): self._samples_seen_at_last_write = self._samples_seen self._total_batches_seen += 1 if self._is_profiling: - profiler.save(self.log_dir, profiler.stop()) + profiler.stop() self._is_profiling = False elif (not self._is_profiling and self._total_batches_seen == self._profile_batch - 1): - profiler.start() + profiler.start(self.log_dir) self._is_profiling = True def on_train_begin(self, logs=None): if self._profile_batch == 1: - profiler.start() + profiler.start(self.log_dir) self._is_profiling = True def on_epoch_begin(self, epoch, logs=None): @@ -452,6 +452,6 @@ class TensorBoard(callbacks.Callback): def on_train_end(self, logs=None): if self._is_profiling: - profiler.save(self.log_dir, profiler.stop()) + profiler.stop() self._is_profiling = False self.writer.close()