Move summary ops and keras tensorboard callback to use the new API.

PiperOrigin-RevId: 296541416
Change-Id: Ie6c32d0e24c08ad73b2d647b8c7fb65a8869e801
This commit is contained in:
A. Unique TensorFlower 2020-02-21 17:13:26 -08:00 committed by TensorFlower Gardener
parent 4ce3d12407
commit 88e2109ac8
4 changed files with 166 additions and 25 deletions

View File

@ -141,6 +141,7 @@ py_library(
"//tensorflow/python/keras/distribute:multi_worker_training_state", "//tensorflow/python/keras/distribute:multi_worker_training_state",
"//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/keras/utils:mode_keys", "//tensorflow/python/keras/utils:mode_keys",
"//tensorflow/python/profiler:profiler_v2",
"//tensorflow/tools/docs:doc_controls", "//tensorflow/tools/docs:doc_controls",
], ],
) )
@ -153,8 +154,8 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":backend", ":backend",
"//tensorflow/python/eager:profiler",
"//tensorflow/python/keras/utils:engine_utils", "//tensorflow/python/keras/utils:engine_utils",
"//tensorflow/python/profiler:profiler_v2",
], ],
) )

View File

@ -48,6 +48,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.compat import collections_abc
@ -1575,11 +1576,25 @@ class TensorBoard(Callback):
You can find more information about TensorBoard You can find more information about TensorBoard
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard). [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Example: Example (Basic):
```python ```python
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs") tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback]) model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
#run the tensorboard command to view the visualizations # run the tensorboard command to view the visualizations.
```
Example (Profile):
```python
# profile a single batch, e.g. the 5th batch.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
# profile a range of batches, e.g. from 10 to 20.
tensorboard_callback =
tf.keras.callbacks.TensorBoard(log_dir='./logs', profile_batch='10,20')
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# run the tensorboard command to view the visualizations in profile plugin.
``` ```
Arguments: Arguments:
@ -1599,11 +1614,14 @@ class TensorBoard(Callback):
callback will write the metrics and losses to TensorBoard every 1000 callback will write the metrics and losses to TensorBoard every 1000
batches. Note that writing too frequently to TensorBoard can slow down batches. Note that writing too frequently to TensorBoard can slow down
your training. your training.
profile_batch: Profile the batch to sample compute characteristics. By profile_batch: Profile the batch(es) to sample compute characteristics.
default, it will profile the second batch. Set profile_batch=0 to profile_batch must be a non-negative integer or a comma separated string
disable profiling. Must run in TensorFlow eager mode. of pair of positive integers. A pair of positive integers signify a
embeddings_freq: frequency (in epochs) at which embedding layers will range of batches to profile. By default, it will profile the second
be visualized. If set to 0, embeddings won't be visualized. batch. Set profile_batch=0 to disable profiling. Must run in TensorFlow
eager mode.
embeddings_freq: frequency (in epochs) at which embedding layers will be
visualized. If set to 0, embeddings won't be visualized.
embeddings_metadata: a dictionary which maps layer name to a file name in embeddings_metadata: a dictionary which maps layer name to a file name in
which metadata for this embedding layer is saved. See the which metadata for this embedding layer is saved. See the
[details]( [details](
@ -1652,8 +1670,8 @@ class TensorBoard(Callback):
self._train_run_name = 'train' self._train_run_name = 'train'
self._validation_run_name = 'validation' self._validation_run_name = 'validation'
self._writers = {} self._writers = {}
self._start_batch, self._stop_batch = self._init_profile_batch(
self._profile_batch = profile_batch profile_batch)
# True when a trace is running. # True when a trace is running.
self._is_tracing = False self._is_tracing = False
@ -1827,10 +1845,49 @@ class TensorBoard(Callback):
else: else:
self._total_batches_seen[writer_name] += 1 self._total_batches_seen[writer_name] += 1
def _init_profile_batch(self, profile_batch):
"""Validate profile_batch value and set the range of batches to profile.
Arguments:
profile_batch: The range of batches to profile. Should be a non-negative
integer or a comma separated string of pair of positive integers. A pair
of positive integers signify a range of batches to profile.
Returns:
A pair of non-negative integers specifying the start and stop batch to
profile.
Raises:
ValueError: If profile_batch is not an integer or a comma seperated pair
of positive integers.
"""
profile_batch_error_message = (
'profile_batch must be a non-negative integer or a comma separated '
'string of pair of positive integers. A pair of positive integers '
'signify a range of batches to profile.')
try:
profile_range = [int(i) for i in str(profile_batch).split(',')]
except ValueError:
raise ValueError(profile_batch_error_message)
if len(profile_range) == 1: # single batch
start_batch, stop_batch = profile_range[0], profile_range[0]
if start_batch < 0:
raise ValueError(profile_batch_error_message)
elif len(profile_range) == 2: # (start_batch, stop_batch)
start_batch, stop_batch = profile_range
# [0, 0], [-1, 100], [6, 5] are illegal.
if start_batch <= 0 or start_batch > stop_batch:
raise ValueError(profile_batch_error_message)
else:
raise ValueError(profile_batch_error_message)
return start_batch, stop_batch
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
self._init_batch_steps() self._init_batch_steps()
if self._profile_batch == 1: if self._start_batch == 1:
summary_ops_v2.trace_on(graph=True, profiler=True) summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
self._is_tracing = True self._is_tracing = True
def on_test_begin(self, logs=None): def on_test_begin(self, logs=None):
@ -1845,7 +1902,7 @@ class TensorBoard(Callback):
batch: Integer, index of batch within the current epoch. batch: Integer, index of batch within the current epoch.
logs: Dict. Metric results for this batch. logs: Dict. Metric results for this batch.
""" """
if self.update_freq == 'epoch' and self._profile_batch is None: if self.update_freq == 'epoch' and self._start_batch is None:
return return
# Don't output batch_size and batch number as TensorBoard summaries # Don't output batch_size and batch number as TensorBoard summaries
@ -1857,10 +1914,11 @@ class TensorBoard(Callback):
self._increment_step(self._train_run_name) self._increment_step(self._train_run_name)
if context.executing_eagerly(): if context.executing_eagerly():
if self._is_tracing: if self._is_tracing and math_ops.greater_equal(train_batches,
self._stop_batch):
self._log_trace() self._log_trace()
elif (not self._is_tracing and elif (not self._is_tracing and
math_ops.equal(train_batches, self._profile_batch - 1)): math_ops.equal(train_batches, self._start_batch - 1)):
self._enable_trace() self._enable_trace()
def on_test_batch_end(self, batch, logs=None): def on_test_batch_end(self, batch, logs=None):
@ -1899,7 +1957,8 @@ class TensorBoard(Callback):
def _enable_trace(self): def _enable_trace(self):
if context.executing_eagerly(): if context.executing_eagerly():
summary_ops_v2.trace_on(graph=True, profiler=True) summary_ops_v2.trace_on(graph=True, profiler=False)
profiler.start(logdir=os.path.join(self._log_write_dir, 'train'))
self._is_tracing = True self._is_tracing = True
def _log_trace(self): def _log_trace(self):
@ -1909,10 +1968,8 @@ class TensorBoard(Callback):
summary_ops_v2.always_record_summaries(): summary_ops_v2.always_record_summaries():
# TODO(b/126388999): Remove step info in the summary name. # TODO(b/126388999): Remove step info in the summary name.
step = K.get_value(self._total_batches_seen[self._train_run_name]) step = K.get_value(self._total_batches_seen[self._train_run_name])
summary_ops_v2.trace_export( summary_ops_v2.trace_export(name='batch_%d' % step, step=step)
name='batch_%d' % step, profiler.stop()
step=step,
profiler_outdir=os.path.join(self._log_write_dir, 'train'))
self._is_tracing = False self._is_tracing = False
def _log_metrics(self, logs, prefix, step): def _log_metrics(self, logs, prefix, step):

View File

@ -1805,6 +1805,15 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
experimental_run_tf_function=testing_utils.should_run_tf_function()) experimental_run_tf_function=testing_utils.should_run_tf_function())
return model return model
def _get_trace_file(self, logdir):
profile_dir = os.path.join(logdir, 'plugins', 'profile')
for (dirpath, dirnames, filenames) in os.walk(profile_dir):
del dirnames # unused
for filename in filenames:
if filename.endswith('.trace'):
return os.path.join(dirpath, filename)
return None
def fitModelAndAssertKerasModelWritten(self, model): def fitModelAndAssertKerasModelWritten(self, model):
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1)) x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(self.logdir, tb_cbk = keras.callbacks.TensorBoard(self.logdir,
@ -1873,6 +1882,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_1'), _ObservedSummary(logdir=self.train_dir, tag=u'batch_1'),
}, },
) )
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_tagNameWithBatchNum(self): def test_TensorBoard_autoTrace_tagNameWithBatchNum(self):
model = self._get_seq_model() model = self._get_seq_model()
@ -1895,6 +1905,78 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'), _ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
}, },
) )
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRangeSingle(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='2,2', write_graph=False)
model.fit(
x,
y,
batch_size=3,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
# Trace will be logged once at the batch it stops profiling.
_ObservedSummary(logdir=self.train_dir, tag=u'batch_2'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileBatchRange(self):
model = self._get_seq_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
tb_cbk = keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='1,3', write_graph=False)
model.fit(
x,
y,
batch_size=4,
epochs=2,
validation_data=(x, y),
callbacks=[tb_cbk])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.tensors,
{
# Trace will be logged once at the batch it stops profiling.
_ObservedSummary(logdir=self.train_dir, tag=u'batch_3'),
},
)
self.assertIsNotNone(self._get_trace_file(logdir=self.train_dir))
def test_TensorBoard_autoTrace_profileInvalidBatchRange(self):
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir,
histogram_freq=1,
profile_batch='-1,3',
write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir,
histogram_freq=1,
profile_batch='1,None',
write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch='6,5', write_graph=False)
with self.assertRaises(ValueError):
keras.callbacks.TensorBoard(
self.logdir, histogram_freq=1, profile_batch=-1, write_graph=False)
def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self): def test_TensorBoard_autoTrace_profile_batch_largerThanBatchCount(self):
model = self._get_seq_model() model = self._get_seq_model()
@ -1913,6 +1995,7 @@ class TestTensorBoardV2NonParameterizedTest(keras_parameterized.TestCase):
# Enabled trace only on the 10000th batch, thus it should be empty. # Enabled trace only on the 10000th batch, thus it should be empty.
self.assertEmpty(summary_file.tensors) self.assertEmpty(summary_file.tensors)
self.assertIsNone(self._get_trace_file(logdir=self.train_dir))
class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase): class MostRecentlyModifiedFileMatchingPatternTest(test.TestCase):

View File

@ -24,7 +24,6 @@ import os
import numpy as np import numpy as np
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import profiler
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks from tensorflow.python.keras import callbacks
@ -33,6 +32,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.summary import summary as tf_summary from tensorflow.python.summary import summary as tf_summary
from tensorflow.python.training import saver from tensorflow.python.training import saver
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -359,16 +359,16 @@ class TensorBoard(callbacks.Callback):
self._samples_seen_at_last_write = self._samples_seen self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1 self._total_batches_seen += 1
if self._is_profiling: if self._is_profiling:
profiler.save(self.log_dir, profiler.stop()) profiler.stop()
self._is_profiling = False self._is_profiling = False
elif (not self._is_profiling and elif (not self._is_profiling and
self._total_batches_seen == self._profile_batch - 1): self._total_batches_seen == self._profile_batch - 1):
profiler.start() profiler.start(self.log_dir)
self._is_profiling = True self._is_profiling = True
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
if self._profile_batch == 1: if self._profile_batch == 1:
profiler.start() profiler.start(self.log_dir)
self._is_profiling = True self._is_profiling = True
def on_epoch_begin(self, epoch, logs=None): def on_epoch_begin(self, epoch, logs=None):
@ -452,6 +452,6 @@ class TensorBoard(callbacks.Callback):
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
if self._is_profiling: if self._is_profiling:
profiler.save(self.log_dir, profiler.stop()) profiler.stop()
self._is_profiling = False self._is_profiling = False
self.writer.close() self.writer.close()