Move summary ops and keras tensorboard callback to use the new API.

PiperOrigin-RevId: 296541416
Change-Id: Ie6c32d0e24c08ad73b2d647b8c7fb65a8869e801
This commit is contained in:
A. Unique TensorFlower 2020-02-21 17:13:26 -08:00 committed by TensorFlower Gardener
parent 4ce3d12407
commit 88e2109ac8
4 changed files with 166 additions and 25 deletions

View File

@ -141,6 +141,7 @@ py_library(
"//tensorflow/python/keras/distribute:multi_worker_training_state",
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:mode_keys",
"//tensorflow/python/profiler:profiler_v2",
"//tensorflow/tools/docs:doc_controls",
],
)
@ -153,8 +154,8 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":backend",
"//tensorflow/python/eager:profiler",
"//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/profiler:profiler_v2",
],
)

View File

@ -48,6 +48,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
@ -1575,11 +1576,25 @@ class TensorBoard(Callback):
You can find more information about TensorBoard
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Example:
Example (Basic):
```python
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
#run the tensorboard command to view the visualizations
# run the tensorboard command to view the visualizations.
```
Example (Profile):
```python
# profile a single batch, e.g. the 5th batch.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
# profile a range of batches, e.g. from 10 to 20.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
```
Arguments:
@ -1599,11 +1614,14 @@ class TensorBoard(Callback):
callback will write the metrics and losses to TensorBoard every 1000
batches. Note that writing too frequently to TensorBoard can slow down
your training.
profile_batch: Profile the batch to sample compute characteristics. By
default, it will profile the second batch. Set profile_batch=0 to
disable profiling. Must run in TensorFlow eager mode.
embeddings_freq: frequency (in epochs) at which embedding layers will
be visualized. If set to 0, embeddings won't be visualized.
profile_batch: Profile the batch(es) to sample compute characteristics.
profile_batch must be a non-negative integer or a comma separated string
of pair of positive integers. A pair of positive integers signify a
range of batches to profile. By default, it will profile the second
batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow
eager mode.
embeddings_freq: frequency (in epochs) at which embedding layers will be
visualized. If set to 0, embeddings won't be visualized.
embeddings_metadata: a dictionary which maps layer name to a file name in
which metadata for this embedding layer is saved. See the
[details](
@ -1652,8 +1670,8 @@ class TensorBoard(Callback):
self._train_run_name = 'train'
self._validation_run_name = 'validation'
self._writers = {}
self._profile_batch = profile_batch
self._start_batch, self._stop_batch = self._init_profile_batch(
profile_batch)
# True when a trace is running.
self._is_tracing = False
@ -1827,10 +1845,49 @@ class TensorBoard(Callback):
else:
self._total_batches_seen[writer_name] += 1
def _init_profile_batch(self, profile_batch):
"""Validate profile_batch value and set the range of batches to profile.
Arguments:
profile_batch: The range of batches to profile. Should be a non-negative
integer or a comma separated string of pair of positive integers. A pair
of positive integers signify a range of batches to profile.
Returns:
A pair of non-negative integers specifying the start and stop batch to
profile.
Raises:
ValueError: If profile_batch is not an integer or a comma seperated pair
of positive integers.
"""
profile_batch_error_message = (
'profile_batch must be a non-negative integer or a comma separated '
'string of pair of positive integers. A pair of positive integers '
'signify a range of batches to profile.')
try:
profile_range = [int(i) for i in str(profile_batch).split(',')]
except ValueError:
raise ValueError(profile_batch_error_message)
if len(profile_range) == 1: # single batch
start_batch, stop_batch = profile_range[0], profile_range[0]
if start_batch < 0:
raise ValueError(profile_batch_error_message)
elif len(profile_range) == 2: # (start_batch, stop_batch)
start_batch, stop_batch = profile_range
# [0, 0], [-1, 100], [6, 5] are illegal.
if start_batch <= 0 or start_batch > stop_batch:
raise ValueError(profile_batch_error_message)
else:
raise ValueError(profile_batch_error_message)
return start_batch, stop_batch
def on_train_begin(self, logs=None):
self._init_batch_steps()
if self._profile_batch == 1:
summary_ops_v2.trace_on(graph=True, profiler=True)
if self._start_batch == 1:
summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
self._is_tracing = True
def on_test_begin(self, logs=None):
@ -1845,7 +1902,7 @@ class TensorBoard(Callback):
batch: Integer, index of batch within the current epoch.
logs: Dict. Metric results for this batch.
"""
if self.update_freq == 'epoch' and self._profile_batch is None:
if self.update_freq == 'epoch' and self._start_batch is None:
return
# Don't output batch_size and batch number as TensorBoard summaries
@ -1857,10 +1914,11 @@ class TensorBoard(Callback):
self._increment_step(self._train_run_name)
if context.executing_eagerly():
if self._is_tracing:
if self._is_tracing and math_ops.greater_equal(train_batches,
self._stop_batch):
self._log_trace()
elif (not self._is_tracing and
math_ops.equal(train_batches, self._profile_batch - 1)):
math_ops.equal(train_batches, self._start_batch - 1)):
self._enable_trace()
def on_test_batch_end(self, batch, logs=None):
@ -1899,7 +1957,8 @@ class TensorBoard(Callback):
def _enable_trace(self):
if context.executing_eagerly():
summary_ops_v2.trace_on(graph=True, profiler=True)
summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
self._is_tracing = True
def _log_trace(self):
@ -1909,10 +1968,8 @@ class TensorBoard(Callback):
summary_ops_v2.always_record_summaries():
# TODO(b/126388999): Remove step info in the summary name.
step = K.get_value(self._total_batches_seen[self._train_run_name])
summary_ops_v2.trace_export(
name='batch_%d' % step,
step=step,
profiler_outdir=os.path.join(self._log_write_dir, 'train'))
summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
profiler.stop()
self._is_tracing = False
def _log_metrics(self, logs, prefix, step):

View File

@ -1805,6 +1805,15 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
experimental_run_tf_function=testing_utils.should_run_tf_function())
return model
def _get_trace_file(self, logdir):
profile_dir = os.path.join(logdir, 'plugins', 'profile')
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
del dirnames # unused
for filename in filenames:
if filename.endswith('.trace'):
return os.path.join(dirpath, filename)
return None
def fitModelAndAssertKerasModelWritten(self, model):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(self.logdir,
@ -1873,6 +1882,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
model = self._get_seq_model()
@ -1895,6 +1905,78 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='2,2', write_graph=False)
model.fit(
x,
y,
batch_size=3,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
# Trace will be logged once at the batch it stops profiling.
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRange(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='1,3', write_graph=False)
model.fit(
x,
y,
batch_size=4,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
# Trace will be logged once at the batch it stops profiling.
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir,
histogram_freq=1,
profile_batch='-1,3',
write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir,
histogram_freq=1,
profile_batch='1,None',
write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='6,5', write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch=-1, write_graph=False)
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
model = self._get_seq_model()
@ -1913,6 +1995,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
# Enabled trace only on the 10000th batch, thus it should be empty.
self.assertEmpty(summary_file.tensors)
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):

View File

@ -24,7 +24,6 @@ import os
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import profiler
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
@ -33,6 +32,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import keras_export
@ -359,16 +359,16 @@ class TensorBoard(callbacks.Callback):
self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1
if self._is_profiling:
profiler.save(self.log_dir, profiler.stop())
profiler.stop()
self._is_profiling = False
elif (not self._is_profiling and
self._total_batches_seen == self._profile_batch - 1):
profiler.start()
profiler.start(self.log_dir)
self._is_profiling = True
def on_train_begin(self, logs=None):
if self._profile_batch == 1:
profiler.start()
profiler.start(self.log_dir)
self._is_profiling = True
def on_epoch_begin(self, epoch, logs=None):
@ -452,6 +452,6 @@ class TensorBoard(callbacks.Callback):
def on_train_end(self, logs=None):
if self._is_profiling:
profiler.save(self.log_dir, profiler.stop())
profiler.stop()
self._is_profiling = False
self.writer.close()