Merge pull request #37552 from lgeiger:learning_rate_schedules_logging

PiperOrigin-RevId: 324245198
Change-Id: I3f8ac7f0d632369658450417d7d5dc34de272d21
This commit is contained in:
TensorFlower Gardener 2020-07-31 11:20:38 -07:00
commit 1267e76078
2 changed files with 35 additions and 6 deletions

View File

@ -41,6 +41,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.distribute import worker_training_state
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils import version_utils
@ -2254,6 +2255,12 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
profiler.stop()
self._is_tracing = False
def _collect_learning_rate(self, logs):
lr_schedule = getattr(self.model.optimizer, 'lr', None)
if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
logs['learning_rate'] = lr_schedule(self.model.optimizer.iterations)
return logs
def _log_epoch_metrics(self, epoch, logs):
"""Writes epoch metrics out as scalar summaries.
@ -2266,6 +2273,7 @@ class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
train_logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
val_logs = {k: v for k, v in logs.items() if k.startswith('val_')}
train_logs = self._collect_learning_rate(train_logs)
with summary_ops_v2.always_record_summaries():
if train_logs:

View File

@ -1834,18 +1834,16 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
self.train_dir = os.path.join(self.logdir, 'train')
self.validation_dir = os.path.join(self.logdir, 'validation')
def _get_model(self):
def _get_model(self, compile_model=True):
layers = [
keras.layers.Conv2D(8, (3, 3)),
keras.layers.Flatten(),
keras.layers.Dense(1)
]
model = testing_utils.get_model_from_layers(layers, input_shape=(10, 10, 1))
opt = gradient_descent.SGD(learning_rate=0.001)
model.compile(
opt,
'mse',
run_eagerly=testing_utils.should_run_eagerly())
if compile_model:
opt = gradient_descent.SGD(learning_rate=0.001)
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
return model
def test_TensorBoard_default_logdir(self):
@ -1959,6 +1957,29 @@ class TestTensorBoardV2(keras_parameterized.TestCase):
},
)
def test_TensorBoard_learning_rate_schedules(self):
model = self._get_model(compile_model=False)
opt = gradient_descent.SGD(learning_rate_schedule.CosineDecay(0.01, 1))
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))
model.fit(
x,
y,
batch_size=2,
epochs=2,
callbacks=[keras.callbacks.TensorBoard(self.logdir)])
summary_file = list_summaries(self.logdir)
self.assertEqual(
summary_file.scalars,
{
_ObservedSummary(logdir=self.train_dir, tag='epoch_loss'),
_ObservedSummary(logdir=self.train_dir, tag='epoch_learning_rate'),
},
)
def test_TensorBoard_weight_histograms(self):
model = self._get_model()
x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))