diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 80b83b7869e..c157f990467 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1239,6 +1239,20 @@ class LearningRateScheduler(Callback): (integer, indexed from 0) and returns a new learning rate as output (float). verbose: int. 0: quiet, 1: update messages. + + ```python + # This function keeps the learning rate at 0.001 for the first ten epochs + # and decreases it exponentially after that. + def scheduler(epoch): + if epoch < 10: + return 0.001 + else: + return 0.001 * tf.math.exp(0.1 * (10 - epoch)) + + callback = tf.keras.callbacks.LearningRateScheduler(scheduler) + model.fit(data, labels, epochs=100, callbacks=[callback], + validation_data=(val_data, val_labels)) + ``` """ def __init__(self, schedule, verbose=0):