diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 50852d8aa4b..bb9e61d01a2 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -46,6 +46,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar from tensorflow.python.keras.utils.mode_keys import ModeKeys from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import variables @@ -1948,9 +1949,7 @@ class TensorBoard(Callback): def on_train_begin(self, logs=None): self._init_batch_steps() if self._start_batch == 1: - summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) - self._is_tracing = True + self._enable_trace() def on_test_begin(self, logs=None): self._set_default_writer(self._validation_run_name) @@ -1976,14 +1975,14 @@ class TensorBoard(Callback): self._log_metrics(logs, prefix='batch_', step=train_batches) self._increment_step(self._train_run_name) - - if context.executing_eagerly(): - if self._is_tracing and math_ops.greater_equal(train_batches, - self._stop_batch): - self._log_trace() - elif (not self._is_tracing and - math_ops.equal(train_batches, self._start_batch - 1)): - self._enable_trace() + if self._is_tracing: + control_flow_ops.cond( + math_ops.greater_equal(train_batches, self._stop_batch), + lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda + else: + control_flow_ops.cond( + math_ops.equal(train_batches, self._start_batch - 1), + lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda def on_test_batch_end(self, batch, logs=None): if self.update_freq == 'epoch': @@ -2020,21 +2019,48 @@ class TensorBoard(Callback): self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access def _enable_trace(self): + """Starts to collect trace graph to TensorBoard. + + Collects both trace and graph in eager mode, and trace only in graph mode. + """ if context.executing_eagerly(): + # Graph must be traced in eager mode. summary_ops_v2.trace_on(graph=True, profiler=False) - profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) - self._is_tracing = True + profiler.start(logdir=os.path.join(self._log_write_dir, 'train')) + self._is_tracing = True + + def _enable_trace_return_true(self): + """Starts to collect trace graph to TensorBoard and returns True. + + Returns: + True. + """ + self._enable_trace() + return True def _log_trace(self): - """Logs the trace graph to TensorBoard.""" + """Logs the trace graph to TensorBoard. + + Logs both trace and graph in eager mode, and trace only in graph mode. + """ + profiler.stop() if context.executing_eagerly(): + # Graph must be traced in eager mode. with self._get_writer(self._train_run_name).as_default(), \ summary_ops_v2.always_record_summaries(): # TODO(b/126388999): Remove step info in the summary name. step = K.get_value(self._total_batches_seen[self._train_run_name]) summary_ops_v2.trace_export(name='batch_%d' % step, step=step) - profiler.stop() - self._is_tracing = False + self._is_tracing = False + + def _log_trace_return_true(self): + """Logs the trace graph to TensorBoard and returns True. + + Returns: + True. + """ + self._log_trace() + return True def _log_metrics(self, logs, prefix, step): """Writes metrics out as custom scalar summaries. diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py index 98ed5842802..eb62d0b29ee 100644 --- a/tensorflow/python/keras/callbacks_test.py +++ b/tensorflow/python/keras/callbacks_test.py @@ -36,6 +36,7 @@ from tensorflow.core.framework import summary_pb2 from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils @@ -2065,6 +2066,31 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase): ) self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + # Test case that replicates a Github issue. + # https://github.com/tensorflow/tensorflow/issues/37543 + def test_TensorBoard_autoTrace_profileTwiceGraphMode(self): + ops.disable_eager_execution() + inp = keras.Input((1,)) + out = keras.layers.Dense(units=1)(inp) + model = keras.Model(inp, out) + + model.compile(gradient_descent.SGD(1), 'mse') + + model.fit( + np.zeros((64, 1)), + np.zeros((64, 1)), + callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)], + ) + # Verifies trace exists in the first train_dir. + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + model.fit( + np.zeros((64, 1)), + np.zeros((64, 1)), + callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)], + ) + # Verifies trace exists in the second train_dir. + self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir)) + def test_TensorBoard_autoTrace_profileBatchRange(self): model = self._get_seq_model() x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))