diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 8d33bb22978..7db9fae5f3b 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1168,6 +1168,23 @@ class LearningRateScheduler(Callback): (integer, indexed from 0) and returns a new learning rate as output (float). verbose: int. 0: quiet, 1: update messages. + + ```python + # Firstly, let's create a function which + # maps a given epoch to a learning rate. + # This function would keep the learning rate + # constant at 0.001 for the first ten epochs and + # let it decrease exponentially after that. + def scheduler(epoch): + if epoch < 10: + return 0.001 + else: + return 0.001 * tf.math.exp(0.1 * (10 - epoch)) + # Next, we need to set up the callback and train the model. + callback = tf.keras.callbacks.LearningRateScheduler(scheduler) + model.fit(data, labels, epochs=100, callbacks=[callback], + validation_data=(val_data, val_labels)) + ``` """ def __init__(self, schedule, verbose=0):