Fix another profiler is running error when calling model.fit twice in graph mode.
Remove context.executing_eagerly() check for profiler, so that it works in graph mode. Use tf.cond when comparing tensors. Github issue: https://github.com/tensorflow/tensorflow/issues/37543 PiperOrigin-RevId: 300882803 Change-Id: I6ed4f9212984aa896d049fb2726c7868f3d427ea
This commit is contained in:
parent
1646cd9bb9
commit
01641aee30
@ -46,6 +46,7 @@ from tensorflow.python.keras.utils.generic_utils import Progbar
|
||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
@ -1948,9 +1949,7 @@ class TensorBoard(Callback):
|
||||
def on_train_begin(self, logs=None):
|
||||
self._init_batch_steps()
|
||||
if self._start_batch == 1:
|
||||
summary_ops_v2.trace_on(graph=True, profiler=False)
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
self._is_tracing = True
|
||||
self._enable_trace()
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
self._set_default_writer(self._validation_run_name)
|
||||
@ -1976,14 +1975,14 @@ class TensorBoard(Callback):
|
||||
self._log_metrics(logs, prefix='batch_', step=train_batches)
|
||||
|
||||
self._increment_step(self._train_run_name)
|
||||
|
||||
if context.executing_eagerly():
|
||||
if self._is_tracing and math_ops.greater_equal(train_batches,
|
||||
self._stop_batch):
|
||||
self._log_trace()
|
||||
elif (not self._is_tracing and
|
||||
math_ops.equal(train_batches, self._start_batch - 1)):
|
||||
self._enable_trace()
|
||||
if self._is_tracing:
|
||||
control_flow_ops.cond(
|
||||
math_ops.greater_equal(train_batches, self._stop_batch),
|
||||
lambda: self._log_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
|
||||
else:
|
||||
control_flow_ops.cond(
|
||||
math_ops.equal(train_batches, self._start_batch - 1),
|
||||
lambda: self._enable_trace_return_true(), lambda: False) # pylint: disable=unnecessary-lambda
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
if self.update_freq == 'epoch':
|
||||
@ -2020,21 +2019,48 @@ class TensorBoard(Callback):
|
||||
self.log_dir, self.model._get_distribution_strategy()) # pylint: disable=protected-access
|
||||
|
||||
def _enable_trace(self):
|
||||
"""Starts to collect trace graph to TensorBoard.
|
||||
|
||||
Collects both trace and graph in eager mode, and trace only in graph mode.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
# Graph must be traced in eager mode.
|
||||
summary_ops_v2.trace_on(graph=True, profiler=False)
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
self._is_tracing = True
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
self._is_tracing = True
|
||||
|
||||
def _enable_trace_return_true(self):
|
||||
"""Starts to collect trace graph to TensorBoard and returns True.
|
||||
|
||||
Returns:
|
||||
True.
|
||||
"""
|
||||
self._enable_trace()
|
||||
return True
|
||||
|
||||
def _log_trace(self):
|
||||
"""Logs the trace graph to TensorBoard."""
|
||||
"""Logs the trace graph to TensorBoard.
|
||||
|
||||
Logs both trace and graph in eager mode, and trace only in graph mode.
|
||||
"""
|
||||
profiler.stop()
|
||||
if context.executing_eagerly():
|
||||
# Graph must be traced in eager mode.
|
||||
with self._get_writer(self._train_run_name).as_default(), \
|
||||
summary_ops_v2.always_record_summaries():
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
step = K.get_value(self._total_batches_seen[self._train_run_name])
|
||||
summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
|
||||
profiler.stop()
|
||||
self._is_tracing = False
|
||||
self._is_tracing = False
|
||||
|
||||
def _log_trace_return_true(self):
|
||||
"""Logs the trace graph to TensorBoard and returns True.
|
||||
|
||||
Returns:
|
||||
True.
|
||||
"""
|
||||
self._log_trace()
|
||||
return True
|
||||
|
||||
def _log_metrics(self, logs, prefix, step):
|
||||
"""Writes metrics out as custom scalar summaries.
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.core.framework import summary_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
@ -2065,6 +2066,31 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
# Test case that replicates a Github issue.
|
||||
# https://github.com/tensorflow/tensorflow/issues/37543
|
||||
def test_TensorBoard_autoTrace_profileTwiceGraphMode(self):
|
||||
ops.disable_eager_execution()
|
||||
inp = keras.Input((1,))
|
||||
out = keras.layers.Dense(units=1)(inp)
|
||||
model = keras.Model(inp, out)
|
||||
|
||||
model.compile(gradient_descent.SGD(1), 'mse')
|
||||
|
||||
model.fit(
|
||||
np.zeros((64, 1)),
|
||||
np.zeros((64, 1)),
|
||||
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=1)],
|
||||
)
|
||||
# Verifies trace exists in the first train_dir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
model.fit(
|
||||
np.zeros((64, 1)),
|
||||
np.zeros((64, 1)),
|
||||
callbacks=[keras.callbacks.TensorBoard(self.logdir, profile_batch=2)],
|
||||
)
|
||||
# Verifies trace exists in the second train_dir.
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
|
Loading…
Reference in New Issue
Block a user