diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py index 7db9fae5f3b..6b01f314505 100644 --- a/tensorflow/python/keras/callbacks.py +++ b/tensorflow/python/keras/callbacks.py @@ -1172,15 +1172,15 @@ class LearningRateScheduler(Callback): ```python # Firstly, let's create a function which # maps a given epoch to a learning rate. - # This function would keep the learning rate - # constant at 0.001 for the first ten epochs and - # let it decrease exponentially after that. + # This function keeps the learning rate + # at 0.001 for the first ten epochs and + # decreases it exponentially after that. def scheduler(epoch): if epoch < 10: return 0.001 else: return 0.001 * tf.math.exp(0.1 * (10 - epoch)) - # Next, we need to set up the callback and train the model. + callback = tf.keras.callbacks.LearningRateScheduler(scheduler) model.fit(data, labels, epochs=100, callbacks=[callback], validation_data=(val_data, val_labels))