Move summary ops and keras tensorboard callback to use the new API.
PiperOrigin-RevId: 296541416 Change-Id: Ie6c32d0e24c08ad73b2d647b8c7fb65a8869e801
This commit is contained in:
parent
4ce3d12407
commit
88e2109ac8
@ -141,6 +141,7 @@ py_library(
|
||||
"//tensorflow/python/keras/distribute:multi_worker_training_state",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/keras/utils:mode_keys",
|
||||
"//tensorflow/python/profiler:profiler_v2",
|
||||
"//tensorflow/tools/docs:doc_controls",
|
||||
],
|
||||
)
|
||||
@ -153,8 +154,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python/eager:profiler",
|
||||
"//tensorflow/python/keras/utils:engine_utils",
|
||||
"//tensorflow/python/profiler:profiler_v2",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
@ -1575,11 +1576,25 @@ class TensorBoard(Callback):
|
||||
You can find more information about TensorBoard
|
||||
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
||||
|
||||
Example:
|
||||
Example (Basic):
|
||||
```python
|
||||
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
#run the tensorboard command to view the visualizations
|
||||
# run the tensorboard command to view the visualizations.
|
||||
```
|
||||
Example (Profile):
|
||||
```python
|
||||
# profile a single batch, e.g. the 5th batch.
|
||||
tensorboard_callback =
|
||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# run the tensorboard command to view the visualizations in profile plugin.
|
||||
|
||||
# profile a range of batches, e.g. from 10 to 20.
|
||||
tensorboard_callback =
|
||||
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
|
||||
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
||||
# run the tensorboard command to view the visualizations in profile plugin.
|
||||
```
|
||||
|
||||
Arguments:
|
||||
@ -1599,11 +1614,14 @@ class TensorBoard(Callback):
|
||||
callback will write the metrics and losses to TensorBoard every 1000
|
||||
batches. Note that writing too frequently to TensorBoard can slow down
|
||||
your training.
|
||||
profile_batch: Profile the batch to sample compute characteristics. By
|
||||
default, it will profile the second batch. Set profile_batch=0 to
|
||||
disable profiling. Must run in TensorFlow eager mode.
|
||||
embeddings_freq: frequency (in epochs) at which embedding layers will
|
||||
be visualized. If set to 0, embeddings won't be visualized.
|
||||
profile_batch: Profile the batch(es) to sample compute characteristics.
|
||||
profile_batch must be a non-negative integer or a comma separated string
|
||||
of pair of positive integers. A pair of positive integers signify a
|
||||
range of batches to profile. By default, it will profile the second
|
||||
batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow
|
||||
eager mode.
|
||||
embeddings_freq: frequency (in epochs) at which embedding layers will be
|
||||
visualized. If set to 0, embeddings won't be visualized.
|
||||
embeddings_metadata: a dictionary which maps layer name to a file name in
|
||||
which metadata for this embedding layer is saved. See the
|
||||
[details](
|
||||
@ -1652,8 +1670,8 @@ class TensorBoard(Callback):
|
||||
self._train_run_name = 'train'
|
||||
self._validation_run_name = 'validation'
|
||||
self._writers = {}
|
||||
|
||||
self._profile_batch = profile_batch
|
||||
self._start_batch, self._stop_batch = self._init_profile_batch(
|
||||
profile_batch)
|
||||
# True when a trace is running.
|
||||
self._is_tracing = False
|
||||
|
||||
@ -1827,10 +1845,49 @@ class TensorBoard(Callback):
|
||||
else:
|
||||
self._total_batches_seen[writer_name] += 1
|
||||
|
||||
def _init_profile_batch(self, profile_batch):
|
||||
"""Validate profile_batch value and set the range of batches to profile.
|
||||
|
||||
Arguments:
|
||||
profile_batch: The range of batches to profile. Should be a non-negative
|
||||
integer or a comma separated string of pair of positive integers. A pair
|
||||
of positive integers signify a range of batches to profile.
|
||||
|
||||
Returns:
|
||||
A pair of non-negative integers specifying the start and stop batch to
|
||||
profile.
|
||||
|
||||
Raises:
|
||||
ValueError: If profile_batch is not an integer or a comma seperated pair
|
||||
of positive integers.
|
||||
|
||||
"""
|
||||
profile_batch_error_message = (
|
||||
'profile_batch must be a non-negative integer or a comma separated '
|
||||
'string of pair of positive integers. A pair of positive integers '
|
||||
'signify a range of batches to profile.')
|
||||
try:
|
||||
profile_range = [int(i) for i in str(profile_batch).split(',')]
|
||||
except ValueError:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
if len(profile_range) == 1: # single batch
|
||||
start_batch, stop_batch = profile_range[0], profile_range[0]
|
||||
if start_batch < 0:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
elif len(profile_range) == 2: # (start_batch, stop_batch)
|
||||
start_batch, stop_batch = profile_range
|
||||
# [0, 0], [-1, 100], [6, 5] are illegal.
|
||||
if start_batch <= 0 or start_batch > stop_batch:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
else:
|
||||
raise ValueError(profile_batch_error_message)
|
||||
return start_batch, stop_batch
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
self._init_batch_steps()
|
||||
if self._profile_batch == 1:
|
||||
summary_ops_v2.trace_on(graph=True, profiler=True)
|
||||
if self._start_batch == 1:
|
||||
summary_ops_v2.trace_on(graph=True, profiler=False)
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
self._is_tracing = True
|
||||
|
||||
def on_test_begin(self, logs=None):
|
||||
@ -1845,7 +1902,7 @@ class TensorBoard(Callback):
|
||||
batch: Integer, index of batch within the current epoch.
|
||||
logs: Dict. Metric results for this batch.
|
||||
"""
|
||||
if self.update_freq == 'epoch' and self._profile_batch is None:
|
||||
if self.update_freq == 'epoch' and self._start_batch is None:
|
||||
return
|
||||
|
||||
# Don't output batch_size and batch number as TensorBoard summaries
|
||||
@ -1857,10 +1914,11 @@ class TensorBoard(Callback):
|
||||
self._increment_step(self._train_run_name)
|
||||
|
||||
if context.executing_eagerly():
|
||||
if self._is_tracing:
|
||||
if self._is_tracing and math_ops.greater_equal(train_batches,
|
||||
self._stop_batch):
|
||||
self._log_trace()
|
||||
elif (not self._is_tracing and
|
||||
math_ops.equal(train_batches, self._profile_batch - 1)):
|
||||
math_ops.equal(train_batches, self._start_batch - 1)):
|
||||
self._enable_trace()
|
||||
|
||||
def on_test_batch_end(self, batch, logs=None):
|
||||
@ -1899,7 +1957,8 @@ class TensorBoard(Callback):
|
||||
|
||||
def _enable_trace(self):
|
||||
if context.executing_eagerly():
|
||||
summary_ops_v2.trace_on(graph=True, profiler=True)
|
||||
summary_ops_v2.trace_on(graph=True, profiler=False)
|
||||
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
|
||||
self._is_tracing = True
|
||||
|
||||
def _log_trace(self):
|
||||
@ -1909,10 +1968,8 @@ class TensorBoard(Callback):
|
||||
summary_ops_v2.always_record_summaries():
|
||||
# TODO(b/126388999): Remove step info in the summary name.
|
||||
step = K.get_value(self._total_batches_seen[self._train_run_name])
|
||||
summary_ops_v2.trace_export(
|
||||
name='batch_%d' % step,
|
||||
step=step,
|
||||
profiler_outdir=os.path.join(self._log_write_dir, 'train'))
|
||||
summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
|
||||
profiler.stop()
|
||||
self._is_tracing = False
|
||||
|
||||
def _log_metrics(self, logs, prefix, step):
|
||||
|
@ -1805,6 +1805,15 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
return model
|
||||
|
||||
def _get_trace_file(self, logdir):
|
||||
profile_dir = os.path.join(logdir, 'plugins', 'profile')
|
||||
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
|
||||
del dirnames # unused
|
||||
for filename in filenames:
|
||||
if filename.endswith('.trace'):
|
||||
return os.path.join(dirpath, filename)
|
||||
return None
|
||||
|
||||
def fitModelAndAssertKerasModelWritten(self, model):
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(self.logdir,
|
||||
@ -1873,6 +1882,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
|
||||
model = self._get_seq_model()
|
||||
@ -1895,6 +1905,78 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch='2,2', write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=3,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.tensors,
|
||||
{
|
||||
# Trace will be logged once at the batch it stops profiling.
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileBatchRange(self):
|
||||
model = self._get_seq_model()
|
||||
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
|
||||
tb_cbk = keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch='1,3', write_graph=False)
|
||||
|
||||
model.fit(
|
||||
x,
|
||||
y,
|
||||
batch_size=4,
|
||||
epochs=2,
|
||||
validation_data=(x, y),
|
||||
callbacks=[tb_cbk])
|
||||
summary_file = list_summaries(self.logdir)
|
||||
|
||||
self.assertEqual(
|
||||
summary_file.tensors,
|
||||
{
|
||||
# Trace will be logged once at the batch it stops profiling.
|
||||
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
|
||||
},
|
||||
)
|
||||
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
|
||||
with self.assertRaises(ValueError):
|
||||
keras.callbacks.TensorBoard(
|
||||
self.logdir,
|
||||
histogram_freq=1,
|
||||
profile_batch='-1,3',
|
||||
write_graph=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
keras.callbacks.TensorBoard(
|
||||
self.logdir,
|
||||
histogram_freq=1,
|
||||
profile_batch='1,None',
|
||||
write_graph=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch='6,5', write_graph=False)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
keras.callbacks.TensorBoard(
|
||||
self.logdir, histogram_freq=1, profile_batch=-1, write_graph=False)
|
||||
|
||||
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
|
||||
model = self._get_seq_model()
|
||||
@ -1913,6 +1995,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
|
||||
|
||||
# Enabled trace only on the 10000th batch, thus it should be empty.
|
||||
self.assertEmpty(summary_file.tensors)
|
||||
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
|
||||
|
||||
|
||||
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):
|
||||
|
@ -24,7 +24,6 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import profiler
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import callbacks
|
||||
@ -33,6 +32,7 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import summary_ops_v2
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.summary import summary as tf_summary
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -359,16 +359,16 @@ class TensorBoard(callbacks.Callback):
|
||||
self._samples_seen_at_last_write = self._samples_seen
|
||||
self._total_batches_seen += 1
|
||||
if self._is_profiling:
|
||||
profiler.save(self.log_dir, profiler.stop())
|
||||
profiler.stop()
|
||||
self._is_profiling = False
|
||||
elif (not self._is_profiling and
|
||||
self._total_batches_seen == self._profile_batch - 1):
|
||||
profiler.start()
|
||||
profiler.start(self.log_dir)
|
||||
self._is_profiling = True
|
||||
|
||||
def on_train_begin(self, logs=None):
|
||||
if self._profile_batch == 1:
|
||||
profiler.start()
|
||||
profiler.start(self.log_dir)
|
||||
self._is_profiling = True
|
||||
|
||||
def on_epoch_begin(self, epoch, logs=None):
|
||||
@ -452,6 +452,6 @@ class TensorBoard(callbacks.Callback):
|
||||
|
||||
def on_train_end(self, logs=None):
|
||||
if self._is_profiling:
|
||||
profiler.save(self.log_dir, profiler.stop())
|
||||
profiler.stop()
|
||||
self._is_profiling = False
|
||||
self.writer.close()
|
||||
|
Loading…
x
Reference in New Issue
Block a user