From 234c738bcd513f5747925c87efb9ed4175a2aa5e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Wed, 30 Jan 2019 10:42:44 -0800 Subject: [PATCH] This CL introduces serializable/deserializable learning rate decay schedules for the Keras v2 optimizers. PiperOrigin-RevId: 231623483 --- tensorflow/python/BUILD | 2 +- tensorflow/python/keras/BUILD | 14 +- tensorflow/python/keras/optimizer_v2/BUILD | 29 + .../python/keras/optimizer_v2/adagrad_test.py | 47 + .../python/keras/optimizer_v2/adam_test.py | 50 + .../optimizer_v2/gradient_descent_test.py | 84 +- .../optimizer_v2/learning_rate_schedule.py | 1031 +++++++++++++++++ .../learning_rate_schedule_test.py | 527 +++++++++ tensorflow/python/keras/optimizer_v2/nadam.py | 7 + .../python/keras/optimizer_v2/optimizer_v2.py | 19 +- .../keras/optimizer_v2/optimizer_v2_test.py | 35 + .../python/keras/optimizer_v2/rmsprop_test.py | 73 ++ .../tools/api/generator/api_init_files.bzl | 1 + .../tools/api/generator/api_init_files_v1.bzl | 1 + .../python/training/learning_rate_decay.py | 90 +- .../python/training/learning_rate_decay_v2.py | 898 -------------- .../training/learning_rate_decay_v2_test.py | 497 -------- ...low.keras.experimental.-cosine-decay.pbtxt | 18 + .../v1/tensorflow.keras.experimental.pbtxt | 4 + .../v1/tensorflow.keras.optimizers.pbtxt | 4 + ...imizers.schedules.-exponential-decay.pbtxt | 18 + ...mizers.schedules.-inverse-time-decay.pbtxt | 18 + ...rs.schedules.-learning-rate-schedule.pbtxt | 16 + ....schedules.-piecewise-constant-decay.pbtxt | 18 + ...timizers.schedules.-polynomial-decay.pbtxt | 18 + ...ensorflow.keras.optimizers.schedules.pbtxt | 31 + ....experimental.-cosine-decay-restarts.pbtxt | 18 + ...low.keras.experimental.-cosine-decay.pbtxt | 18 + ...as.experimental.-linear-cosine-decay.pbtxt | 18 + ...erimental.-noisy-linear-cosine-decay.pbtxt | 18 + .../v2/tensorflow.keras.experimental.pbtxt | 16 + .../v2/tensorflow.keras.optimizers.pbtxt | 4 + ...imizers.schedules.-exponential-decay.pbtxt | 18 + ...mizers.schedules.-inverse-time-decay.pbtxt | 18 + ...rs.schedules.-learning-rate-schedule.pbtxt | 16 + ....schedules.-piecewise-constant-decay.pbtxt | 18 + ...timizers.schedules.-polynomial-decay.pbtxt | 18 + ...ensorflow.keras.optimizers.schedules.pbtxt | 31 + .../api/golden/v2/tensorflow.train.pbtxt | 36 - tensorflow/tools/compatibility/renames_v2.py | 9 + 40 files changed, 2298 insertions(+), 1508 deletions(-) create mode 100644 tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py create mode 100644 tensorflow/python/keras/optimizer_v2/learning_rate_schedule_test.py delete mode 100644 tensorflow/python/training/learning_rate_decay_v2.py delete mode 100644 tensorflow/python/training/learning_rate_decay_v2_test.py create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.pbtxt diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 813040177d3..109beace58b 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3722,6 +3722,7 @@ py_library( "//tensorflow/python/distribute:reduce_util", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", + "//tensorflow/python/keras/optimizer_v2:learning_rate_schedule", "//tensorflow/python/ops/losses", "//tensorflow/python/training/checkpointable:base", "//tensorflow/python/training/checkpointable:util", @@ -4807,7 +4808,6 @@ cuda_py_tests( "training/ftrl_test.py", "training/gradient_descent_test.py", "training/learning_rate_decay_test.py", - "training/learning_rate_decay_v2_test.py", "training/momentum_test.py", "training/optimizer_test.py", "training/proximal_adagrad_test.py", diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index b37bce90574..ad895d21dcb 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -310,7 +310,6 @@ py_library( "layers/recurrent.py", "layers/serialization.py", "layers/wrappers.py", - "utils/generic_utils.py", "utils/kernelized_utils.py", "utils/layer_utils.py", "utils/tf_utils.py", @@ -318,6 +317,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":engine", + ":generic_utils", "//tensorflow/python:array_ops", "//tensorflow/python:cudnn_rnn_ops_gen", "//tensorflow/python:dtypes", @@ -339,6 +339,18 @@ py_library( ], ) +py_library( + name = "generic_utils", + srcs = [ + "utils/generic_utils.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "integration_test", size = "medium", diff --git a/tensorflow/python/keras/optimizer_v2/BUILD b/tensorflow/python/keras/optimizer_v2/BUILD index 45afe2a134c..88f2521d5e8 100644 --- a/tensorflow/python/keras/optimizer_v2/BUILD +++ b/tensorflow/python/keras/optimizer_v2/BUILD @@ -25,6 +25,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":learning_rate_schedule", "//tensorflow/python:control_flow_ops", "//tensorflow/python:distribute", "//tensorflow/python:framework", @@ -39,6 +40,21 @@ py_library( ], ) +py_library( + name = "learning_rate_schedule", + srcs = [ + "learning_rate_schedule.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python/keras:generic_utils", + ], +) + cuda_py_test( name = "adagrad_test", size = "medium", @@ -197,6 +213,19 @@ py_test( ], ) +py_test( + name = "learning_rate_schedule_test", + size = "small", + srcs = ["learning_rate_schedule_test.py"], + deps = [ + ":optimizer_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "rmsprop_test", size = "medium", diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py index 864aefaf70d..9c8d3ff8a4e 100644 --- a/tensorflow/python/keras/optimizer_v2/adagrad_test.py +++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.optimizer_v2 import adagrad +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -160,6 +161,52 @@ class AdagradOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + def testBasicWithLearningRateInverseTimeDecay(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.cached_session(): + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = 3.0 + decay = 0.5 + lr_schedule = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps=1.0, decay_rate=decay) + + ada_opt = adagrad.Adagrad(lr_schedule) + + accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + + if not context.executing_eagerly(): + ada_update = ada_opt.apply_gradients( + zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + # Fetch params to validate initial values + v0_val, v1_val = self.evaluate([var0, var1]) + self.assertAllClose([1.0, 2.0], v0_val) + self.assertAllClose([3.0, 4.0], v1_val) + + # Run 3 steps of adagrad + for t in range(3): + if not context.executing_eagerly(): + self.evaluate(ada_update) + else: + ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + lr_np = learning_rate / (1 + decay * t) + var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np, + grads0_np, lr_np) + var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np, + grads1_np, lr_np) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py index 7918c09b7e0..761b6a0854d 100644 --- a/tensorflow/python/keras/optimizer_v2/adam_test.py +++ b/tensorflow/python/keras/optimizer_v2/adam_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import optimizers from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops @@ -399,6 +400,55 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + @test_util.run_deprecated_v1 + def testBasicWithLearningRateInverseTimeDecay(self): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + learning_rate = 0.001 + decay = 0.5 + lr_schedule = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps=1.0, decay_rate=decay) + beta_1 = 0.9 + beta_2 = 0.999 + epsilon = 1e-7 + + opt = adam.Adam( + learning_rate=lr_schedule, + beta_1=beta_1, + beta_2=beta_2, + epsilon=epsilon) + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + self.evaluate(variables.global_variables_initializer()) + # Run 3 steps of Adam + for t in range(3): + self.evaluate(update) + + lr_np = learning_rate / (1 + decay * t) + + var0_np, m0, v0 = adam_update_numpy( + var0_np, grads0_np, t, m0, v0, lr=lr_np) + var1_np, m1, v1 = adam_update_numpy( + var1_np, grads1_np, t, m1, v1, lr=lr_np) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + @test_util.run_deprecated_v1 def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py index 333a6f288ea..6bd56372b9a 100644 --- a/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py +++ b/tensorflow/python/keras/optimizer_v2/gradient_descent_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops @@ -57,42 +58,61 @@ class GradientDescentOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], self.evaluate(var1)) + def _test_basic_sgd_with_learning_rate_decay(self, sgd, dtype): + var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + if not context.executing_eagerly(): + sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + # Run 2 steps of sgd + if not context.executing_eagerly(): + self.evaluate(sgd_op) + else: + sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) + # Validate updated params + self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], + self.evaluate(var0)) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + self.evaluate(var1)) + + if not context.executing_eagerly(): + self.evaluate(sgd_op) + else: + sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) + # Validate updated params + self.assertAllCloseAccordingToType( + [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1], + self.evaluate(var0)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01], + self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes def testBasicWithLearningRateDecay(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.cached_session(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) - grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) - learning_rate = 3.0 - decay = 0.5 - sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay) - if not context.executing_eagerly(): - sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) - self.evaluate(variables.global_variables_initializer()) - # Run 2 steps of sgd - if not context.executing_eagerly(): - self.evaluate(sgd_op) - else: - sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) - # Validate updated params - self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], - self.evaluate(var0)) - self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], - self.evaluate(var1)) + learning_rate = 3.0 + decay = 0.5 + sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay) + self._test_basic_sgd_with_learning_rate_decay(sgd, dtype) - if not context.executing_eagerly(): - self.evaluate(sgd_op) - else: - sgd.apply_gradients(zip([grads0, grads1], [var0, var1])) - # Validate updated params - self.assertAllCloseAccordingToType( - [1.0 - 3.0 * 0.1 - 2.0 * 0.1, 2.0 - 3.0 * 0.1 - 2.0 * 0.1], - self.evaluate(var0)) - self.assertAllCloseAccordingToType( - [3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01], - self.evaluate(var1)) + @test_util.run_in_graph_and_eager_modes + def testBasicWithLearningRateInverseTimeDecay(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + learning_rate = learning_rate_schedule.InverseTimeDecay( + 3.0, decay_steps=1.0, decay_rate=0.5) + sgd = gradient_descent.SGD(learning_rate=learning_rate) + self._test_basic_sgd_with_learning_rate_decay(sgd, dtype) + + @test_util.run_in_graph_and_eager_modes + def testBasicWithLearningRateInverseTimeDecaySerializeAndDeserialize(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + learning_rate = learning_rate_schedule.InverseTimeDecay( + 3.0, decay_steps=1.0, decay_rate=0.5) + sgd = gradient_descent.SGD(learning_rate=learning_rate) + sgd = gradient_descent.SGD.from_config(sgd.get_config()) + self._test_basic_sgd_with_learning_rate_decay(sgd, dtype) @test_util.run_in_graph_and_eager_modes def testBasicCallableParams(self): diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py new file mode 100644 index 00000000000..a182d74e56e --- /dev/null +++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py @@ -0,0 +1,1031 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various learning rate decay functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import math + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.util.tf_export import keras_export + + +@keras_export("keras.optimizers.schedules.LearningRateSchedule") +class LearningRateSchedule(object): + """A serializable learning rate decay schedule. + + `LearningRateSchedule`s can be passed in as the learning rate of optimizers in + `tf.keras.optimizers`. They can be serialized and deserialized using + `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + """ + + @abc.abstractmethod + def __call__(self, step): + raise NotImplementedError("Learning rate schedule must override __call__") + + @abc.abstractmethod + def get_config(self): + raise NotImplementedError("Learning rate schedule must override get_config") + + @classmethod + def from_config(cls, config): + """Instantiates a `LearningRateSchedule` from its config. + + Args: + config: Output of `get_config()`. + + Returns: + A `LearningRateSchedule` instance. + """ + return cls(**config) + + +@keras_export("keras.optimizers.schedules.ExponentialDecay") +class ExponentialDecay(LearningRateSchedule): + """A LearningRateSchedule that uses an exponential decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies exponential decay to the learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies an exponential decay function + to an optimizer step, given a provided initial learning rate. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate * decay_rate ^ (step / decay_steps) + ``` + + If the argument `staircase` is `True`, then `step / decay_steps` is + an integer division and the decayed learning rate follows a + staircase function. + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. + Example: When fitting a Keras model, decay every 100000 steps with a base + of 0.96: + + ```python + initial_learning_rate = 0.1 + lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( + initial_learning_rate, + decay_steps=100000, + decay_rate=0.96, + staircase=True) + + model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + The learning rate schedule is also serializable and deserializable using + `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Must be positive. See the decay computation above. + decay_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The decay rate. + staircase: Boolean. If `True` decay the learning rate at discrete + intervals + name: String. Optional name of the operation. Defaults to + 'ExponentialDecay'. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(ExponentialDecay, self).__init__() + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.decay_rate = decay_rate + self.staircase = staircase + self.name = name + + def __call__(self, step): + with ops.name_scope( + self.name, "ExponentialDecay", + [self.initial_learning_rate, step, self.decay_steps, self.decay_rate] + ) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + decay_steps = math_ops.cast(self.decay_steps, dtype) + decay_rate = math_ops.cast(self.decay_rate, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + p = global_step_recomp / decay_steps + if self.staircase: + p = math_ops.floor(p) + return math_ops.multiply( + initial_learning_rate, math_ops.pow(decay_rate, p), name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "decay_rate": self.decay_rate, + "staircase": self.staircase, + "name": self.name + } + + +@keras_export("keras.optimizers.schedules.PiecewiseConstantDecay") +class PiecewiseConstantDecay(LearningRateSchedule): + """A LearningRateSchedule that uses a piecewise constant decay schedule.""" + + def __init__( + self, + boundaries, + values, + name=None): + """Piecewise constant from boundaries and interval values. + + The function returns a 1-arg callable to compute the piecewise constant + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + + Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + + ```python + step = tf.Variable(0, trainable=False) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay( + boundaries, values) + + # Later, whenever we perform an optimization step, we pass in the step. + learning_rate = learning_rate_fn(step) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + boundaries: A list of `Tensor`s or `int`s or `float`s with strictly + increasing entries, and with all elements having the same type as the + optimizer step. + values: A list of `Tensor`s or `float`s or `int`s that specifies the + values for the intervals defined by `boundaries`. It should have one + more element than `boundaries`, and all elements should have the same + type. + name: A string. Optional name of the operation. Defaults to + 'PiecewiseConstant'. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as the boundary tensors. + + The output of the 1-arg function that takes the `step` + is `values[0]` when `step <= boundaries[0]`, + `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`, ..., + and values[-1] when `step > boundaries[-1]`. + + Raises: + ValueError: if types of all `values` do not match or + the number of elements in the lists does not match. + """ + super(PiecewiseConstantDecay, self).__init__() + + if len(boundaries) != len(values) - 1: + raise ValueError( + "The length of boundaries should be 1 less than the length of values") + + self.boundaries = boundaries + self.values = values + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "PiecewiseConstant", + [step, self.boundaries, self.values, self.name]): + boundaries = ops.convert_n_to_tensor(self.boundaries) + values = ops.convert_n_to_tensor(self.values) + x_recomp = ops.convert_to_tensor(step) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + # TODO(rdipietro): Ensure that boundaries' elements strictly increases. + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + pred_fn_pairs = [] + pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) + pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x_recomp > low) & (x_recomp <= high) + pred_fn_pairs.append((pred, lambda v=v: v)) + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + def get_config(self): + return { + "boundaries": self.boundaries, + "values": self.values, + "name": self.name + } + + +@keras_export("keras.optimizers.schedules.PolynomialDecay") +class PolynomialDecay(LearningRateSchedule): + """A LearningRateSchedule that uses a polynomial decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name=None): + """Applies a polynomial decay to the learning rate. + + It is commonly observed that a monotonically decreasing learning rate, whose + degree of change is carefully chosen, results in a better performing model. + This schedule applies a polynomial decay function to an optimizer step, + given a provided `initial_learning_rate`, to reach an `end_learning_rate` + in the given `decay_steps`. + + It requires a `step` value to compute the decayed learning rate. You + can just pass a TensorFlow variable that you increment at each training + step. + + The schedule is a 1-arg callable that produces a decayed learning rate + when passed the current optimizer step. This can be useful for changing the + learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + step = min(step, decay_steps) + return ((initial_learning_rate - end_learning_rate) * + (1 - step / decay_steps) ^ (power) + ) + end_learning_rate + ``` + + If `cycle` is True then a multiple of `decay_steps` is used, the first one + that is bigger than `step`. + + ```python + def decayed_learning_rate(step): + decay_steps = decay_steps * ceil(step / decay_steps) + return ((initial_learning_rate - end_learning_rate) * + (1 - step / decay_steps) ^ (power) + ) + end_learning_rate + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. + Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using + sqrt (i.e. power=0.5): + + ```python + ... + starter_learning_rate = 0.1 + end_learning_rate = 0.01 + decay_steps = 10000 + learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay( + starter_learning_rate, + decay_steps, + end_learning_rate, + power=0.5) + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=learning_rate_fn), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + The learning rate schedule is also serializable and deserializable using + `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Must be positive. See the decay computation above. + end_learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The minimal end learning rate. + power: A scalar `float32` or `float64` `Tensor` or a + Python number. The power of the polynomial. Defaults to linear, 1.0. + cycle: A boolean, whether or not it should cycle beyond decay_steps. + name: String. Optional name of the operation. Defaults to + 'PolynomialDecay'. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(PolynomialDecay, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.end_learning_rate = end_learning_rate + self.power = power + self.cycle = cycle + self.name = name + + def __call__(self, step): + with ops.name_scope( + self.name, "PolynomialDecay", + [self.initial_learning_rate, step, self.decay_steps, + self.end_learning_rate, self.power] + ) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + end_learning_rate = math_ops.cast(self.end_learning_rate, dtype) + power = math_ops.cast(self.power, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + decay_steps_recomp = math_ops.cast(self.decay_steps, dtype) + if self.cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = control_flow_ops.cond( + math_ops.equal(global_step_recomp, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step_recomp / self.decay_steps)) + decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) + else: + # Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, + self.decay_steps) + + p = math_ops.div(global_step_recomp, decay_steps_recomp) + return math_ops.add( + math_ops.multiply(initial_learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "end_learning_rate": self.end_learning_rate, + "power": self.power, + "cycle": self.cycle, + "name": self.name + } + + +@keras_export("keras.optimizers.schedules.InverseTimeDecay") +class InverseTimeDecay(LearningRateSchedule): + """A LearningRateSchedule that uses an inverse time decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies inverse time decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies the inverse decay function + to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a TensorFlow variable that you increment at each training step. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate / (1 + decay_rate * step / decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + def decayed_learning_rate(step): + return initial_learning_rate / (1 + decay_rate * floor(step / decay_step)) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. + Example: Fit a Keras model when decaying 1/t with a rate of 0.5: + + ```python + ... + initial_learning_rate = 0.1 + decay_steps = 1.0 + decay_rate = 0.5 + learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay( + initial_learning_rate, global_step, decay_steps, decay_rate) + + model.compile(optimizer=tf.keras.optimizers.SGD( + learning_rate=learning_rate_fn), + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) + + model.fit(data, labels, epochs=5) + ``` + + Args: + initial_learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as opposed to + continuous, fashion. + name: String. Optional name of the operation. Defaults to + 'InverseTimeDecay'. + + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(InverseTimeDecay, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.decay_rate = decay_rate + self.staircase = staircase + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "InverseTimeDecay", + [self.initial_learning_rate, step, self.decay_rate] + ) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + decay_steps = math_ops.cast(self.decay_steps, dtype) + decay_rate = math_ops.cast(self.decay_rate, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + p = global_step_recomp / decay_steps + if self.staircase: + p = math_ops.floor(p) + const = math_ops.cast(constant_op.constant(1), dtype) + denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) + return math_ops.div(initial_learning_rate, denom, name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "decay_rate": self.decay_rate, + "staircase": self.staircase, + "name": self.name + } + + +@keras_export("keras.experimental.CosineDecay") +class CosineDecay(LearningRateSchedule): + """A LearningRateSchedule that uses a cosine decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + alpha=0.0, + name=None): + """Applies cosine decay to the learning rate. + + See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies a cosine decay function + to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a TensorFlow variable that you increment at each training step. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + step = min(step, decay_steps) + cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + return initial_learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate, global_step, decay_steps) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` Tensor or a + Python number. The initial learning rate. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + alpha: A scalar `float32` or `float64` Tensor or a Python number. + Minimum learning rate value as a fraction of initial_learning_rate. + name: String. Optional name of the operation. Defaults to 'CosineDecay'. + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(CosineDecay, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.alpha = alpha + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "CosineDecay", + [self.initial_learning_rate, step]): + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + decay_steps = math_ops.cast(self.decay_steps, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + completed_fraction = global_step_recomp / decay_steps + cosine_decayed = 0.5 * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + + decayed = (1 - self.alpha) * cosine_decayed + self.alpha + return math_ops.multiply(initial_learning_rate, decayed) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "alpha": self.alpha, + "name": self.name + } + + +@keras_export("keras.experimental.CosineDecayRestarts", + v1=[]) +class CosineDecayRestarts(LearningRateSchedule): + """A LearningRateSchedule that uses a cosine decay schedule with restarts.""" + + def __init__( + self, + initial_learning_rate, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name=None): + """Applies cosine decay with restarts to the learning rate. + + See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies a cosine decay function with + restarts to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a TensorFlow variable that you increment at each training step. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + + The learning rate multiplier first decays + from 1 to `alpha` for `first_decay_steps` steps. Then, a warm + restart is performed. Each new warm restart runs for `t_mul` times more + steps and with `m_mul` times smaller initial learning rate. + + Example usage: + ```python + first_decay_steps = 1000 + lr_decayed_fn = ( + tf.keras.experimental.CosineDecayRestarts( + initial_learning_rate, + global_step, + first_decay_steps)) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python + number. The initial learning rate. + first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python + number. Number of steps to decay over. + t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Used to derive the number of iterations in the i-th period + m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Used to derive the initial learning rate of the i-th period: + alpha: A scalar `float32` or `float64` Tensor or a Python number. + Minimum learning rate value as a fraction of the initial_learning_rate. + name: String. Optional name of the operation. Defaults to 'SGDRDecay'. + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + Raises: + ValueError: if `global_step` is not supplied. + """ + super(CosineDecayRestarts, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.first_decay_steps = first_decay_steps + self._t_mul = t_mul + self._m_mul = m_mul + self.alpha = alpha + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "SGDRDecay", + [self.initial_learning_rate, step] + ) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + first_decay_steps = math_ops.cast(self.first_decay_steps, dtype) + alpha = math_ops.cast(self.alpha, dtype) + t_mul = math_ops.cast(self._t_mul, dtype) + m_mul = math_ops.cast(self._m_mul, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + completed_fraction = global_step_recomp / first_decay_steps + + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) + + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart + + else: + i_restart = math_ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction + + i_restart, completed_fraction = control_flow_ops.cond( + math_ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True)) + + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + decayed = (1 - alpha) * cosine_decayed + alpha + + return math_ops.multiply(initial_learning_rate, decayed, name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "first_decay_steps": self.first_decay_steps, + "t_mul": self._t_mul, + "m_mul": self._m_mul, + "alpha": self.alpha, + "name": self.name + } + + +@keras_export("keras.experimental.LinearCosineDecay", + v1=[]) +class LinearCosineDecay(LearningRateSchedule): + """A LearningRateSchedule that uses a linear cosine decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies linear cosine decay to the learning rate. + + See [Bello et al., ICML2017] Neural Optimizer Search with RL. + https://arxiv.org/abs/1709.07417 + + For the idea of warm starts here controlled by `num_periods`, + see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies a linear cosine decay + function to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a TensorFlow variable that you increment at each training step. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + step = min(step, decay_steps) + linear_decay = (decay_steps - step) / decay_steps + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * step / decay_steps)) + decayed = (alpha + linear_decay) * cosine_decay + beta + return initial_learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = ( + tf.keras.experimental.LinearCosineDecay( + initial_learning_rate, global_step, decay_steps)) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python + number. The initial learning rate. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + num_periods: Number of periods in the cosine part of the decay. + See computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'LinearCosineDecay'. + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(LinearCosineDecay, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.num_periods = num_periods + self.alpha = alpha + self.beta = beta + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "LinearCosineDecay", + [self.initial_learning_rate, step]) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + decay_steps = math_ops.cast(self.decay_steps, dtype) + num_periods = math_ops.cast(self.num_periods, dtype) + alpha = math_ops.cast(self.alpha, dtype) + beta = math_ops.cast(self.beta, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + + linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta + return math_ops.multiply(initial_learning_rate, linear_cosine_decayed, + name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "num_periods": self.num_periods, + "alpha": self.alpha, + "beta": self.beta, + "name": self.name + } + + +@keras_export("keras.experimental.NoisyLinearCosineDecay", + v1=[]) +class NoisyLinearCosineDecay(LearningRateSchedule): + """A LearningRateSchedule that uses a noisy linear cosine decay schedule.""" + + def __init__( + self, + initial_learning_rate, + decay_steps, + initial_variance=1.0, + variance_decay=0.55, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies noisy linear cosine decay to the learning rate. + + See [Bello et al., ICML2017] Neural Optimizer Search with RL. + https://arxiv.org/abs/1709.07417 + + For the idea of warm starts here controlled by `num_periods`, + see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This schedule applies a noisy linear cosine decay + function to an optimizer step, given a provided initial learning rate. + It requires a `step` value to compute the decayed learning rate. You can + just pass a TensorFlow variable that you increment at each training step. + + The schedule a 1-arg callable that produces a decayed learning + rate when passed the current optimizer step. This can be useful for changing + the learning rate value across different invocations of optimizer functions. + It is computed as: + + ```python + def decayed_learning_rate(step): + step = min(step, decay_steps) + linear_decay = (decay_steps - step) / decay_steps) + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * step / decay_steps)) + decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta + return initial_learning_rate * decayed + ``` + where eps_t is 0-centered gaussian noise with variance + initial_variance / (1 + global_step) ** variance_decay + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = ( + tf.keras.experimental.NoisyLinearCosineDecay( + initial_learning_rate, global_step, decay_steps)) + ``` + + You can pass this schedule directly into a `tf.keras.optimizers.Optimizer` + as the learning rate. The learning rate schedule is also serializable and + deserializable using `tf.keras.optimizers.schedules.serialize` and + `tf.keras.optimizers.schedules.deserialize`. + + Args: + initial_learning_rate: A scalar `float32` or `float64` Tensor or a Python + number. The initial learning rate. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + initial_variance: initial variance for the noise. See computation above. + variance_decay: decay for the noise's variance. See computation above. + num_periods: Number of periods in the cosine part of the decay. + See computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'NoisyLinearCosineDecay'. + Returns: + A 1-arg callable learning rate schedule that takes the current optimizer + step and outputs the decayed learning rate, a scalar `Tensor` of the same + type as `initial_learning_rate`. + """ + super(NoisyLinearCosineDecay, self).__init__() + + self.initial_learning_rate = initial_learning_rate + self.decay_steps = decay_steps + self.initial_variance = initial_variance + self.variance_decay = variance_decay + self.num_periods = num_periods + self.alpha = alpha + self.beta = beta + self.name = name + + def __call__(self, step): + with ops.name_scope(self.name, "NoisyLinearCosineDecay", + [self.initial_learning_rate, step]) as name: + initial_learning_rate = ops.convert_to_tensor( + self.initial_learning_rate, name="initial_learning_rate") + dtype = initial_learning_rate.dtype + decay_steps = math_ops.cast(self.decay_steps, dtype) + initial_variance = math_ops.cast(self.initial_variance, dtype) + variance_decay = math_ops.cast(self.variance_decay, dtype) + num_periods = math_ops.cast(self.num_periods, dtype) + alpha = math_ops.cast(self.alpha, dtype) + beta = math_ops.cast(self.beta, dtype) + + global_step_recomp = math_ops.cast(step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + variance = initial_variance / ( + math_ops.pow(1.0 + global_step_recomp, variance_decay)) + std = math_ops.sqrt(variance) + noisy_linear_decayed = ( + linear_decayed + random_ops.random_normal( + linear_decayed.shape, stddev=std)) + + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + noisy_linear_cosine_decayed = ( + (alpha + noisy_linear_decayed) * cosine_decayed + beta) + + return math_ops.multiply( + initial_learning_rate, noisy_linear_cosine_decayed, name=name) + + def get_config(self): + return { + "initial_learning_rate": self.initial_learning_rate, + "decay_steps": self.decay_steps, + "initial_variance": self.initial_variance, + "variance_decay": self.variance_decay, + "num_periods": self.num_periods, + "alpha": self.alpha, + "beta": self.beta, + "name": self.name + } + + +@keras_export("keras.optimizers.schedules.serialize") +def serialize(learning_rate_schedule): + return generic_utils.serialize_keras_object(learning_rate_schedule) + + +@keras_export("keras.optimizers.schedules.deserialize") +def deserialize(config, custom_objects=None): + return generic_utils.deserialize_keras_object( + config, + module_objects=globals(), + custom_objects=custom_objects, + printable_module_name="decay") diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule_test.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule_test.py new file mode 100644 index 00000000000..87b97fa76ca --- /dev/null +++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule_test.py @@ -0,0 +1,527 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional test for learning rate decay.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +from absl.testing import parameterized + +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule +# Import resource_variable_ops for the variables-to-tensor implicit conversion. +from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +def _maybe_serialized(lr_decay, serialize_and_deserialize): + if serialize_and_deserialize: + serialized = learning_rate_schedule.serialize(lr_decay) + return learning_rate_schedule.deserialize(serialized) + else: + return lr_decay + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class LRDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testContinuous(self, serialize): + self.evaluate(variables.global_variables_initializer()) + step = 5 + decayed_lr = learning_rate_schedule.ExponentialDecay(0.05, 10, 0.96) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = .05 * 0.96**(5.0 / 10.0) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testStaircase(self, serialize): + if context.executing_eagerly(): + step = resource_variable_ops.ResourceVariable(0) + self.evaluate(variables.global_variables_initializer()) + decayed_lr = learning_rate_schedule.ExponentialDecay( + .1, 3, 0.96, staircase=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + # No change to learning rate due to staircase + expected = .1 + self.evaluate(step.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + expected = .1 + self.evaluate(step.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + # Decayed learning rate + expected = .1 * 0.96 ** (100 // 3) + self.evaluate(step.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_deprecated_v1 + def testVariables(self, serialize): + step = variables.Variable(1) + assign_1 = step.assign(1) + assign_2 = step.assign(2) + assign_100 = step.assign(100) + decayed_lr = learning_rate_schedule.ExponentialDecay( + .1, 3, 0.96, staircase=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + self.evaluate(variables.global_variables_initializer()) + # No change to learning rate + self.evaluate(assign_1.op) + self.assertAllClose(self.evaluate(decayed_lr(step)), .1, 1e-6) + self.evaluate(assign_2.op) + self.assertAllClose(self.evaluate(decayed_lr(step)), .1, 1e-6) + # Decayed learning rate + self.evaluate(assign_100.op) + expected = .1 * 0.96**(100 // 3) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testPiecewiseConstant(self, serialize): + x = resource_variable_ops.ResourceVariable(-999) + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + self.evaluate(variables.global_variables_initializer()) + + self.assertAllClose(self.evaluate(decayed_lr(x)), 1.0, 1e-6) + self.evaluate(x.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr(x)), 1.0, 1e-6) + self.evaluate(x.assign(105)) + self.assertAllClose(self.evaluate(decayed_lr(x)), 0.1, 1e-6) + self.evaluate(x.assign(110)) + self.assertAllClose(self.evaluate(decayed_lr(x)), 0.1, 1e-6) + self.evaluate(x.assign(120)) + self.assertAllClose(self.evaluate(decayed_lr(x)), 0.01, 1e-6) + self.evaluate(x.assign(999)) + self.assertAllClose(self.evaluate(decayed_lr(x)), 0.001, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testPiecewiseConstantEdgeCases(self, serialize): + x_int = resource_variable_ops.ResourceVariable( + 0, dtype=variables.dtypes.int32) + boundaries, values = [-1.0, 1.0], [1, 2, 3] + with self.assertRaises(ValueError): + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + boundaries, values) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + decayed_lr(x_int) + + x = resource_variable_ops.ResourceVariable(0.0) + boundaries, values = [-1.0, 1.0], [1.0, 2, 3] + with self.assertRaises(ValueError): + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + boundaries, values) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + decayed_lr(x) + + # Test casting boundaries from int32 to int64. + x_int64 = resource_variable_ops.ResourceVariable( + 0, dtype=variables.dtypes.int64) + boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + boundaries, values) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.4, 1e-6) + self.evaluate(x_int64.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.4, 1e-6) + self.evaluate(x_int64.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.5, 1e-6) + self.evaluate(x_int64.assign(3)) + self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.6, 1e-6) + self.evaluate(x_int64.assign(4)) + self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.7, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class LinearDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testHalfWay(self, serialize): + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = lr * 0.5 + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testEnd(self, serialize): + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testHalfWayWithEnd(self, serialize): + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEnd(self, serialize): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEndWithCycle(self, serialize): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, cycle=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class SqrtDecayTestV2(test_util.TensorFlowTestCase, + parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testHalfWay(self, serialize): + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, power=power) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = lr * 0.5**power + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testEnd(self, serialize): + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, power=power) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testHalfWayWithEnd(self, serialize): + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, power=power) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEnd(self, serialize): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, power=power) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEndWithCycle(self, serialize): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, 10, end_lr, power=power, cycle=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class PolynomialDecayTestV2(test_util.TensorFlowTestCase, + parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testBeginWithCycle(self, serialize): + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = learning_rate_schedule.PolynomialDecay( + lr, decay_steps, cycle=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = lr + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class InverseDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testDecay(self, serialize): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_schedule.InverseTimeDecay(initial_lr, k, + decay_rate) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes + def testStaircase(self, serialize): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_schedule.InverseTimeDecay( + initial_lr, k, decay_rate, staircase=True) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class CosineDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase): + + def np_cosine_decay(self, step, decay_steps, alpha=0.0): + step = min(step, decay_steps) + completed_fraction = step / decay_steps + decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + @test_util.run_in_graph_and_eager_modes + def testDecay(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecay(initial_lr, + num_training_steps) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testAlpha(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecay(initial_lr, + num_training_steps, + alpha) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase, + parameterized.TestCase): + + def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, + alpha=0.0): + fac = 1.0 + while step >= decay_steps: + step -= decay_steps + decay_steps *= t_mul + fac *= m_mul + + completed_fraction = step / decay_steps + decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + @test_util.run_in_graph_and_eager_modes + def testDecay(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecayRestarts( + initial_lr, num_training_steps) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testAlpha(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecayRestarts( + initial_lr, num_training_steps, alpha=alpha) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testMMul(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + m_mul = 0.9 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecayRestarts( + initial_lr, num_training_steps, m_mul=m_mul) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testTMul(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + t_mul = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.CosineDecayRestarts( + initial_lr, num_training_steps, t_mul=t_mul) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class LinearCosineDecayTestV2(test_util.TensorFlowTestCase, + parameterized.TestCase): + + def np_linear_cosine_decay(self, + step, + decay_steps, + alpha=0.0, + beta=0.001, + num_periods=0.5): + step = min(step, decay_steps) + linear_decayed = float(decay_steps - step) / decay_steps + fraction = 2.0 * num_periods * step / float(decay_steps) + cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) + return (alpha + linear_decayed) * cosine_decayed + beta + + @test_util.run_in_graph_and_eager_modes + def testDefaultDecay(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.LinearCosineDecay( + initial_lr, num_training_steps) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_linear_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testNonDefaultDecay(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_schedule.LinearCosineDecay( + initial_lr, + num_training_steps, + alpha=0.1, + beta=1e-4, + num_periods=5) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + expected = self.np_linear_cosine_decay( + step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) + self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6) + + +@parameterized.named_parameters( + ("NotSerialized", False), + ("Serialized", True)) +class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase, + parameterized.TestCase): + + @test_util.run_in_graph_and_eager_modes + def testDefaultNoisyLinearCosine(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + # No numerical check because of noise + decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( + initial_lr, num_training_steps) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + # Cannot be deterministically tested + self.evaluate(decayed_lr(step)) + + @test_util.run_in_graph_and_eager_modes + def testNonDefaultNoisyLinearCosine(self, serialize): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + # No numerical check because of noise + decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( + initial_lr, + num_training_steps, + initial_variance=0.5, + variance_decay=0.1, + alpha=0.1, + beta=1e-4, + num_periods=5) + decayed_lr = _maybe_serialized(decayed_lr, serialize) + # Cannot be deterministically tested + self.evaluate(decayed_lr(step)) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/keras/optimizer_v2/nadam.py b/tensorflow/python/keras/optimizer_v2/nadam.py index d515f987251..77a897124be 100644 --- a/tensorflow/python/keras/optimizer_v2/nadam.py +++ b/tensorflow/python/keras/optimizer_v2/nadam.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend_config +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -86,6 +87,12 @@ class Nadam(optimizer_v2.OptimizerV2): # Backwards compatiblity with keras NAdam optimizer. kwargs['decay'] = kwargs.pop('schedule_decay', 0.004) + learning_rate = kwargs.get('lr', learning_rate) + if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule): + raise ValueError('The Nadam optimizer does not support ' + 'tf.keras.optimizers.LearningRateSchedules as the ' + 'learning rate.') + if epsilon is None: epsilon = backend_config.epsilon() super(Nadam, self).__init__(name, **kwargs) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 9a62d168550..bf6dcaad12c 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import ops from tensorflow.python.keras import backend from tensorflow.python.keras import initializers from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import gradients @@ -452,8 +453,11 @@ class OptimizerV2(checkpointable.Checkpointable): self._hyper[name] = value else: prev_value = self._hyper[name] - if callable(prev_value) or isinstance(prev_value, - (ops.Tensor, int, float)): + if (callable(prev_value) + or isinstance(prev_value, + (ops.Tensor, int, float, + learning_rate_schedule.LearningRateSchedule)) + or isinstance(value, learning_rate_schedule.LearningRateSchedule)): self._hyper[name] = value else: backend.set_value(self._hyper[name], value) @@ -462,6 +466,8 @@ class OptimizerV2(checkpointable.Checkpointable): if not self._hypers_created: self._create_hypers() value = self._hyper[name] + if isinstance(value, learning_rate_schedule.LearningRateSchedule): + return value if callable(value): value = value() if dtype: @@ -575,6 +581,9 @@ class OptimizerV2(checkpointable.Checkpointable): def _decayed_lr(self, var_dtype): """Get decayed learning rate as a Tensor with dtype=var_dtype.""" lr_t = self._get_hyper("learning_rate", var_dtype) + if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule): + local_step = math_ops.cast(self.iterations, var_dtype) + lr_t = math_ops.cast(lr_t(local_step), var_dtype) if self._initial_decay > 0.: local_step = math_ops.cast(self.iterations, var_dtype) decay_t = self._get_hyper("decay", var_dtype) @@ -619,11 +628,17 @@ class OptimizerV2(checkpointable.Checkpointable): """ if "lr" in config: config["learning_rate"] = config.pop("lr") + if "learning_rate" in config: + if isinstance(config["learning_rate"], dict): + config["learning_rate"] = learning_rate_schedule.deserialize( + config["learning_rate"]) return cls(**config) def _serialize_hyperparameter(self, hyperparameter_name): """Serialize a hyperparameter that can be a float, callable, or Tensor.""" value = self._hyper[hyperparameter_name] + if isinstance(value, learning_rate_schedule.LearningRateSchedule): + return learning_rate_schedule.serialize(value) if callable(value): return value() if isinstance(value, (ops.Tensor, tf_variables.Variable, diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 57bd9439edf..2d4c1827167 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -41,6 +41,7 @@ from tensorflow.python.keras.optimizer_v2 import adagrad from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.keras.optimizer_v2 import adamax from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.optimizer_v2 import nadam from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import rmsprop @@ -113,6 +114,13 @@ class OptimizerTest(test.TestCase): # var1 = [0., 1.] - 0.5 * [3, 3] self.assertAllClose([-1.5, -0.5], self.evaluate(var1)) + sgd.learning_rate = learning_rate_schedule.InverseTimeDecay( + 0.5, decay_steps=1.0, decay_rate=0.5) + if context.executing_eagerly(): + sgd.minimize(loss, [var0, var1]) + else: + self.evaluate(opt_op) + @test_util.run_in_graph_and_eager_modes def testPrecomputedGradient(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: @@ -281,6 +289,33 @@ class OptimizerTest(test.TestCase): self.evaluate(variables.global_variables_initializer()) self.assertEqual(self.evaluate(lr), self.evaluate(lr3)) + @test_util.run_in_graph_and_eager_modes + def testConfigWithLearningRateDecay(self): + with self.cached_session(): + decay_schedule = learning_rate_schedule.InverseTimeDecay( + 0.5, decay_steps=1.0, decay_rate=0.1) + step = 10 + opt = gradient_descent.SGD(decay_schedule) + config = opt.get_config() + opt2 = gradient_descent.SGD.from_config(config) + # assert both are equal float values. + self.assertAllEqual( + decay_schedule(step), + opt._get_hyper('learning_rate')(step)) + self.assertAllEqual( + decay_schedule(step), + opt2._get_hyper('learning_rate')(step)) + var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32) + loss = lambda: 3 * var0 + # learning rate variable created when calling minimize. + opt.minimize(loss, [var0]) + self.evaluate(variables.global_variables_initializer()) + config = opt.get_config() + opt3 = gradient_descent.SGD.from_config(config) + self.assertAllEqual( + self.evaluate(opt._get_hyper('learning_rate')(step)), + opt3._get_hyper('learning_rate')(step)) + @test_util.run_in_graph_and_eager_modes def testGradClipValue(self): with self.cached_session(): diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py index a9ddc2155a6..ab805266762 100644 --- a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py +++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule from tensorflow.python.keras.optimizer_v2 import rmsprop from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops @@ -244,6 +245,78 @@ class RMSpropOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + @test_util.run_deprecated_v1 + def testDenseWithLearningRateInverseTimeDecay(self): + var0_np = np.array([1.0, 2.0]) + grads0_np = np.array([0.1, 0.2]) + var1_np = np.array([3.0, 4.0]) + grads1_np = np.array([0.01, 0.2]) + + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + learning_rate = 0.01 + rho = 0.9 + momentum = 0.0 + epsilon = 1e-7 + centered = False + decay = 0.5 + lr_schedule = learning_rate_schedule.InverseTimeDecay( + learning_rate, decay_steps=1.0, decay_rate=decay) + opt = rmsprop.RMSprop( + learning_rate=lr_schedule, + rho=rho, + momentum=momentum, + epsilon=epsilon, + centered=centered) + + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + self.evaluate(variables.global_variables_initializer()) + + rms0 = opt.get_slot(var0, "rms") + self.assertTrue(rms0 is not None) + rms1 = opt.get_slot(var1, "rms") + self.assertTrue(rms1 is not None) + if momentum > 0.: + mom0 = opt.get_slot(var0, "momentum") + mom1 = opt.get_slot(var1, "momentum") + else: + mom0 = None + mom1 = None + + mg0_np = np.array([0.0, 0.0]) + mg1_np = np.array([0.0, 0.0]) + rms0_np = np.array([0.0, 0.0]) + rms1_np = np.array([0.0, 0.0]) + mom0_np = np.array([0.0, 0.0]) + mom1_np = np.array([0.0, 0.0]) + + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 4 steps of RMSprop + for t in range(2): + self.evaluate(update) + + lr = learning_rate / (1 + decay * t) + var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy( + var0_np, grads0_np, mg0_np, rms0_np, mom0_np, lr, rho, momentum, + epsilon, centered) + var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy( + var1_np, grads1_np, mg1_np, rms1_np, mom1_np, lr, rho, momentum, + epsilon, centered) + + # Validate updated params + self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0)) + self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1)) + if momentum > 0.: + self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0)) + self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + @test_util.run_deprecated_v1 def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.float32, dtypes.float64]: diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index d18776eab70..245380287a2 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -84,6 +84,7 @@ KERAS_API_INIT_FILES = [ "keras/metrics/__init__.py", "keras/models/__init__.py", "keras/optimizers/__init__.py", + "keras/optimizers/schedules/__init__.py", "keras/preprocessing/__init__.py", "keras/preprocessing/image/__init__.py", "keras/preprocessing/sequence/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 5213d5eed85..c8c7c51b406 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -107,6 +107,7 @@ KERAS_API_INIT_FILES_V1 = [ "keras/metrics/__init__.py", "keras/models/__init__.py", "keras/optimizers/__init__.py", + "keras/optimizers/schedules/__init__.py", "keras/preprocessing/__init__.py", "keras/preprocessing/image/__init__.py", "keras/preprocessing/sequence/__init__.py", diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index c52e89db1f4..ab9d923bedc 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -17,8 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from tensorflow.python.eager import context -from tensorflow.python.training import learning_rate_decay_v2 +from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule +from tensorflow.python.ops import math_ops from tensorflow.python.util.tf_export import tf_export @@ -88,15 +91,15 @@ def exponential_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=staircase, - name=name) + decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate, + decay_steps, + decay_rate, + staircase=staircase, + name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -143,11 +146,12 @@ def piecewise_constant(x, boundaries, values, name=None): the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values, - name=name) + decayed_lr = learning_rate_schedule.PiecewiseConstantDecay( + boundaries, values, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(x) + else: + decayed_lr = functools.partial(decayed_lr, x) return decayed_lr @@ -236,9 +240,8 @@ def polynomial_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.polynomial_decay( + decayed_lr = learning_rate_schedule.PolynomialDecay( learning_rate, - global_step, decay_steps, end_learning_rate=end_learning_rate, power=power, @@ -246,8 +249,9 @@ def polynomial_decay(learning_rate, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -323,13 +327,15 @@ def natural_exp_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.natural_exp_decay( - learning_rate, global_step, decay_steps, decay_rate, staircase=staircase, + natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate)) + decayed_lr = learning_rate_schedule.ExponentialDecay( + learning_rate, decay_steps, natural_exp_rate, staircase=staircase, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -405,17 +411,17 @@ def inverse_time_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.inverse_time_decay( + decayed_lr = learning_rate_schedule.InverseTimeDecay( learning_rate, - global_step, decay_steps, decay_rate, staircase=staircase, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -468,12 +474,13 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.cosine_decay( - learning_rate, global_step, decay_steps, alpha=alpha, name=name) + decayed_lr = learning_rate_schedule.CosineDecay( + learning_rate, decay_steps, alpha=alpha, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -535,9 +542,8 @@ def cosine_decay_restarts(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + decayed_lr = learning_rate_schedule.CosineDecayRestarts( learning_rate, - global_step, first_decay_steps, t_mul=t_mul, m_mul=m_mul, @@ -545,8 +551,9 @@ def cosine_decay_restarts(learning_rate, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -617,9 +624,8 @@ def linear_cosine_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.linear_cosine_decay( + decayed_lr = learning_rate_schedule.LinearCosineDecay( learning_rate, - global_step, decay_steps, num_periods=num_periods, alpha=alpha, @@ -627,8 +633,9 @@ def linear_cosine_decay(learning_rate, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr @@ -707,8 +714,8 @@ def noisy_linear_cosine_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( - learning_rate, global_step, + decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay( + learning_rate, decay_steps, initial_variance=initial_variance, variance_decay=variance_decay, @@ -718,6 +725,7 @@ def noisy_linear_cosine_decay(learning_rate, name=name) if not context.executing_eagerly(): - decayed_lr = decayed_lr() - + decayed_lr = decayed_lr(global_step) + else: + decayed_lr = functools.partial(decayed_lr, global_step) return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py deleted file mode 100644 index eb69feb17d3..00000000000 --- a/tensorflow/python/training/learning_rate_decay_v2.py +++ /dev/null @@ -1,898 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Various learning rate decay functions.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import math - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.util.tf_export import tf_export - - -@tf_export("train.exponential_decay", v1=[]) -def exponential_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies exponential decay to the learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an exponential decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns a no-arg function that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. - It is computed as: - - ```python - decayed_learning_rate = learning_rate * - decay_rate ^ (global_step / decay_steps) - ``` - - If the argument `staircase` is `True`, then `global_step / decay_steps` is an - integer division and the decayed learning rate follows a staircase function. - - Example: decay every 100000 steps with a base of 0.96: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - starter_learning_rate = 0.1 - learning_rate_fn = tf.train.exponential_decay(starter_learning_rate, - global_step, 100000, 0.96, - staircase=True) - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.train.GradientDescentOptimizer(learning_rate_fn) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. Must not be negative. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Must be positive. See the decay computation above. - decay_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The decay rate. - staircase: Boolean. If `True` decay the learning rate at discrete intervals - name: String. Optional name of the operation. Defaults to - 'ExponentialDecay'. - - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("global_step is required for exponential_decay.") - def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, - staircase, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope( - name, "ExponentialDecay", - [learning_rate, global_step, decay_steps, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - return math_ops.multiply( - learning_rate, math_ops.pow(decay_rate, p), name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - decay_rate, staircase, name) - - -@tf_export("train.piecewise_constant_decay", v1=[]) -def piecewise_constant(x, boundaries, values, name=None): - """Piecewise constant from boundaries and interval values. - - This function returns a no-arg callable to compute the piecewise constant. - This can be useful for changing the learning rate value across - different invocations of optimizer functions. - - Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 - for the next 10000 steps, and 0.1 for any additional steps. - - ```python - global_step = tf.Variable(0, trainable=False) - boundaries = [100000, 110000] - values = [1.0, 0.5, 0.1] - learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries, - values) - learning_rate = learning_rate_fn() - - # Later, whenever we perform an optimization step, we increment global_step. - ``` - - Args: - x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, - `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. - boundaries: A list of `Tensor`s or `int`s or `float`s with strictly - increasing entries, and with all elements having the same type as `x`. - values: A list of `Tensor`s or `float`s or `int`s that specifies the values - for the intervals defined by `boundaries`. It should have one more element - than `boundaries`, and all elements should have the same type. - name: A string. Optional name of the operation. Defaults to - 'PiecewiseConstant'. - - Returns: - A no-arg function that outputs a 0-D Tensor. The output of the no-arg - function is `values[0]` when `x <= boundaries[0]`, - `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., - and values[-1] when `x > boundaries[-1]`. - - Raises: - ValueError: if types of `x` and `boundaries` do not match, or types of all - `values` do not match or - the number of elements in the lists does not match. - """ - if len(boundaries) != len(values) - 1: - raise ValueError( - "The length of boundaries should be 1 less than the length of values") - def decayed_lr(x, boundaries, values, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "PiecewiseConstant", - [x, boundaries, values, name]) as name: - boundaries = ops.convert_n_to_tensor(boundaries) - values = ops.convert_n_to_tensor(values) - x_recomp = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x_recomp.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x_recomp.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x_recomp.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) - # TODO(rdipietro): Ensure that boundaries' elements strictly increases. - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - pred_fn_pairs = [] - pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) - pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) - for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): - # Need to bind v here; can do this with lambda v=v: ... - pred = (x_recomp > low) & (x_recomp <= high) - pred_fn_pairs.append((pred, lambda v=v: v)) - - # The default isn't needed here because our conditions are mutually - # exclusive and exhaustive, but tf.case requires it. - default = lambda: values[0] - return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) - - return functools.partial(decayed_lr, x, boundaries, values, name) - - -@tf_export("train.polynomial_decay", v1=[]) -def polynomial_decay(learning_rate, - global_step, - decay_steps, - end_learning_rate=0.0001, - power=1.0, - cycle=False, - name=None): - """Applies a polynomial decay to the learning rate. - - It is commonly observed that a monotonically decreasing learning rate, whose - degree of change is carefully chosen, results in a better performing model. - This function applies a polynomial decay function to a provided initial - `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. - - It requires a `global_step` value to compute the decayed learning rate. You - can just pass a TensorFlow variable that you increment at each training step. - - The function returns a no-arg callable that outputs the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - global_step = min(global_step, decay_steps) - decayed_learning_rate = (learning_rate - end_learning_rate) * - (1 - global_step / decay_steps) ^ (power) + - end_learning_rate - - ``` - - If `cycle` is True then a multiple of `decay_steps` is used, the first one - that is bigger than `global_steps`. - - ```python - decay_steps = decay_steps * ceil(global_step / decay_steps) - decayed_learning_rate_fn = (learning_rate - end_learning_rate) * - (1 - global_step / decay_steps) ^ (power) + - end_learning_rate - decayed_learning_rate = decayed_learning_rate_fn() - - ``` - - Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5): - - ```python - ... - global_step = tf.Variable(0, trainable=False) - starter_learning_rate = 0.1 - end_learning_rate = 0.01 - decay_steps = 10000 - learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate, - global_step, decay_steps, - end_learning_rate, - power=0.5) - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.train.GradientDescentOptimizer(learning_rate_fn) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. Must not be negative. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Must be positive. See the decay computation above. - end_learning_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The minimal end learning rate. - power: A scalar `float32` or `float64` `Tensor` or a - Python number. The power of the polynomial. Defaults to linear, 1.0. - cycle: A boolean, whether or not it should cycle beyond decay_steps. - name: String. Optional name of the operation. Defaults to - 'PolynomialDecay'. - - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("global_step is required for polynomial_decay.") - def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate, - power, cycle, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope( - name, "PolynomialDecay", - [learning_rate, global_step, decay_steps, end_learning_rate, power] - ) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - end_learning_rate = math_ops.cast(end_learning_rate, dtype) - power = math_ops.cast(power, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - decay_steps_recomp = math_ops.cast(decay_steps, dtype) - if cycle: - # Find the first multiple of decay_steps that is bigger than - # global_step. If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond( - math_ops.equal(global_step_recomp, 0), lambda: 1.0, - lambda: math_ops.ceil(global_step_recomp / decay_steps)) - decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) - else: - # Make sure that the global_step used is not bigger than decay_steps. - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - - p = math_ops.div(global_step_recomp, decay_steps_recomp) - return math_ops.add( - math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, - name=name) - - return functools.partial( - decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate, - power, cycle, name) - - -@tf_export("train.natural_exp_decay", v1=[]) -def natural_exp_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies natural exponential decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an exponential decay function - to a provided initial learning rate. It requires an `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns a no-arg callable that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / - decay_step) - ``` - - or, if `staircase` is `True`, as: - - ```python - decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / - decay_step)) - ``` - - Example: decay exponentially with a base of 0.96: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - learning_rate = 0.1 - decay_steps = 5 - k = 0.5 - learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step, - decay_steps, k) - - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.train.GradientDescentOptimizer(learning_rate_fn) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The initial learning rate. - global_step: A Python number. - Global step to use for the decay computation. Must not be negative. - decay_steps: How often to apply decay. - decay_rate: A Python number. The decay rate. - staircase: Whether to apply decay in a discrete staircase, as opposed to - continuous, fashion. - name: String. Optional name of the operation. Defaults to - 'ExponentialTimeDecay'. - - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("global_step is required for natural_exp_decay.") - def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase, - name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "NaturalExpDecay", - [learning_rate, global_step, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - exponent = math_ops.exp( - math_ops.multiply(math_ops.negative(decay_rate), p)) - return math_ops.multiply(learning_rate, exponent, name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - decay_rate, staircase, name) - - -@tf_export("train.inverse_time_decay", v1=[]) -def inverse_time_decay(learning_rate, - global_step, - decay_steps, - decay_rate, - staircase=False, - name=None): - """Applies inverse time decay to the initial learning rate. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies an inverse decay function - to a provided initial learning rate. It requires an `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns a no-arg callable that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / - decay_step) - ``` - - or, if `staircase` is `True`, as: - - ```python - decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / - decay_step)) - ``` - - Example: decay 1/t with a rate of 0.5: - - ```python - ... - global_step = tf.Variable(0, trainable=False) - learning_rate = 0.1 - decay_steps = 1.0 - decay_rate = 0.5 - learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step, - decay_steps, decay_rate) - - # Passing global_step to minimize() will increment it at each step. - learning_step = ( - tf.train.GradientDescentOptimizer(learning_rate_fn) - .minimize(...my loss..., global_step=global_step) - ) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` `Tensor` or a - Python number. The initial learning rate. - global_step: A Python number. - Global step to use for the decay computation. Must not be negative. - decay_steps: How often to apply decay. - decay_rate: A Python number. The decay rate. - staircase: Whether to apply decay in a discrete staircase, as opposed to - continuous, fashion. - name: String. Optional name of the operation. Defaults to - 'InverseTimeDecay'. - - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("global_step is required for inverse_time_decay.") - def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase, - name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "InverseTimeDecay", - [learning_rate, global_step, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - const = math_ops.cast(constant_op.constant(1), dtype) - denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) - return math_ops.div(learning_rate, denom, name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - decay_rate, staircase, name) - - -@tf_export("train.cosine_decay", v1=[]) -def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, - name=None): - """Applies cosine decay to the learning rate. - - See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent - with Warm Restarts. https://arxiv.org/abs/1608.03983 - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a cosine decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns a no-arg callable that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - global_step = min(global_step, decay_steps) - cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps)) - decayed = (1 - alpha) * cosine_decay + alpha - decayed_learning_rate = learning_rate * decayed - ``` - - Example usage: - ```python - decay_steps = 1000 - lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Number of steps to decay over. - alpha: A scalar `float32` or `float64` Tensor or a Python number. - Minimum learning rate value as a fraction of learning_rate. - name: String. Optional name of the operation. Defaults to 'CosineDecay'. - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("cosine decay requires global_step") - def decayed_lr(learning_rate, global_step, decay_steps, alpha, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "CosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - completed_fraction = global_step_recomp / decay_steps - cosine_decayed = 0.5 * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) - - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - alpha, name) - - -@tf_export("train.cosine_decay_restarts", v1=[]) -def cosine_decay_restarts(learning_rate, - global_step, - first_decay_steps, - t_mul=2.0, - m_mul=1.0, - alpha=0.0, - name=None): - """Applies cosine decay with restarts to the learning rate. - - See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent - with Warm Restarts. https://arxiv.org/abs/1608.03983 - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a cosine decay function with - restarts to a provided initial learning rate. It requires a `global_step` - value to compute the decayed learning rate. You can just pass a TensorFlow - variable that you increment at each training step. - - The function returns a no-arg callable that produces the decayed learning - rate while taking into account possible warm restarts. This can be useful for - changing the learning rate value across different invocations of optimizer - functions. - - The learning rate multiplier first decays - from 1 to `alpha` for `first_decay_steps` steps. Then, a warm - restart is performed. Each new warm restart runs for `t_mul` times more steps - and with `m_mul` times smaller initial learning rate. - - Example usage: - ```python - first_decay_steps = 1000 - lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step, - first_decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. - first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Number of steps to decay over. - t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. - Used to derive the number of iterations in the i-th period - m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. - Used to derive the initial learning rate of the i-th period: - alpha: A scalar `float32` or `float64` Tensor or a Python number. - Minimum learning rate value as a fraction of the learning_rate. - name: String. Optional name of the operation. Defaults to 'SGDRDecay'. - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("cosine decay restarts requires global_step") - def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul, - alpha, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step] - ) as name: - learning_rate = ops.convert_to_tensor( - learning_rate, name="initial_learning_rate") - dtype = learning_rate.dtype - first_decay_steps = math_ops.cast(first_decay_steps, dtype) - alpha = math_ops.cast(alpha, dtype) - t_mul = math_ops.cast(t_mul, dtype) - m_mul = math_ops.cast(m_mul, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - completed_fraction = global_step_recomp / first_decay_steps - - def compute_step(completed_fraction, geometric=False): - """Helper for `cond` operation.""" - if geometric: - i_restart = math_ops.floor( - math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / - math_ops.log(t_mul)) - - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart - - else: - i_restart = math_ops.floor(completed_fraction) - completed_fraction -= i_restart - - return i_restart, completed_fraction - - i_restart, completed_fraction = control_flow_ops.cond( - math_ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True)) - - m_fac = m_mul**i_restart - cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha - - return math_ops.multiply(learning_rate, decayed, name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, - first_decay_steps, t_mul, m_mul, alpha, name) - - -@tf_export("train.linear_cosine_decay", v1=[]) -def linear_cosine_decay(learning_rate, - global_step, - decay_steps, - num_periods=0.5, - alpha=0.0, - beta=0.001, - name=None): - """Applies linear cosine decay to the learning rate. - - See [Bello et al., ICML2017] Neural Optimizer Search with RL. - https://arxiv.org/abs/1709.07417 - - For the idea of warm starts here controlled by `num_periods`, - see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent - with Warm Restarts. https://arxiv.org/abs/1608.03983 - - Note that linear cosine decay is more aggressive than cosine decay and - larger initial learning rates can typically be used. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a linear cosine decay function - to a provided initial learning rate. It requires a `global_step` value to - compute the decayed learning rate. You can just pass a TensorFlow variable - that you increment at each training step. - - The function returns a no-arg callable that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - global_step = min(global_step, decay_steps) - linear_decay = (decay_steps - global_step) / decay_steps) - cosine_decay = 0.5 * ( - 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) - decayed = (alpha + linear_decay) * cosine_decay + beta - decayed_learning_rate = learning_rate * decayed - ``` - - Example usage: - ```python - decay_steps = 1000 - lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step, - decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Number of steps to decay over. - num_periods: Number of periods in the cosine part of the decay. - See computation above. - alpha: See computation above. - beta: See computation above. - name: String. Optional name of the operation. Defaults to - 'LinearCosineDecay'. - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("linear cosine decay requires global_step") - def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha, - beta, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "LinearCosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - num_periods = math_ops.cast(num_periods, dtype) - alpha = math_ops.cast(alpha, dtype) - beta = math_ops.cast(beta, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - linear_decayed = (decay_steps - global_step_recomp) / decay_steps - completed_fraction = global_step_recomp / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - - linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta - return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - num_periods, alpha, beta, name) - - -@tf_export("train.noisy_linear_cosine_decay", v1=[]) -def noisy_linear_cosine_decay(learning_rate, - global_step, - decay_steps, - initial_variance=1.0, - variance_decay=0.55, - num_periods=0.5, - alpha=0.0, - beta=0.001, - name=None): - """Applies noisy linear cosine decay to the learning rate. - - See [Bello et al., ICML2017] Neural Optimizer Search with RL. - https://arxiv.org/abs/1709.07417 - - For the idea of warm starts here controlled by `num_periods`, - see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent - with Warm Restarts. https://arxiv.org/abs/1608.03983 - - Note that linear cosine decay is more aggressive than cosine decay and - larger initial learning rates can typically be used. - - When training a model, it is often recommended to lower the learning rate as - the training progresses. This function applies a noisy linear - cosine decay function to a provided initial learning rate. - It requires a `global_step` value to compute the decayed learning rate. - You can just pass a TensorFlow variable that you increment at each - training step. - - The function returns a no-arg callable that produces the decayed learning - rate. This can be useful for changing the learning rate value across - different invocations of optimizer functions. It is computed as: - - ```python - global_step = min(global_step, decay_steps) - linear_decay = (decay_steps - global_step) / decay_steps) - cosine_decay = 0.5 * ( - 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) - decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta - decayed_learning_rate = learning_rate * decayed - ``` - where eps_t is 0-centered gaussian noise with variance - initial_variance / (1 + global_step) ** variance_decay - - Example usage: - ```python - decay_steps = 1000 - lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step, - decay_steps) - ``` - - Args: - learning_rate: A scalar `float32` or `float64` Tensor or a Python number. - The initial learning rate. - global_step: A scalar `int32` or `int64` `Tensor` or a Python number. - Global step to use for the decay computation. - decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. - Number of steps to decay over. - initial_variance: initial variance for the noise. See computation above. - variance_decay: decay for the noise's variance. See computation above. - num_periods: Number of periods in the cosine part of the decay. - See computation above. - alpha: See computation above. - beta: See computation above. - name: String. Optional name of the operation. Defaults to - 'NoisyLinearCosineDecay'. - Returns: - A no-arg function that outputs the decayed learning rate, a scalar `Tensor` - of the same type as `learning_rate`. - Raises: - ValueError: if `global_step` is not supplied. - """ - if global_step is None: - raise ValueError("noisy linear cosine decay requires global_step") - def decayed_lr(learning_rate, global_step, decay_steps, initial_variance, - variance_decay, num_periods, alpha, beta, name): - """Helper to recompute learning rate; most helpful in eager-mode.""" - with ops.name_scope(name, "NoisyLinearCosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - initial_variance = math_ops.cast(initial_variance, dtype) - variance_decay = math_ops.cast(variance_decay, dtype) - num_periods = math_ops.cast(num_periods, dtype) - alpha = math_ops.cast(alpha, dtype) - beta = math_ops.cast(beta, dtype) - - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - linear_decayed = (decay_steps - global_step_recomp) / decay_steps - variance = initial_variance / ( - math_ops.pow(1.0 + global_step_recomp, variance_decay)) - std = math_ops.sqrt(variance) - noisy_linear_decayed = ( - linear_decayed + random_ops.random_normal( - linear_decayed.shape, stddev=std)) - - completed_fraction = global_step_recomp / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - noisy_linear_cosine_decayed = ( - (alpha + noisy_linear_decayed) * cosine_decayed + beta) - - return math_ops.multiply( - learning_rate, noisy_linear_cosine_decayed, name=name) - - return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, - initial_variance, variance_decay, num_periods, alpha, - beta, name) diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py deleted file mode 100644 index cb96773e299..00000000000 --- a/tensorflow/python/training/learning_rate_decay_v2_test.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Functional test for learning rate decay.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -from tensorflow.python.eager import context -from tensorflow.python.framework import test_util -# Import resource_variable_ops for the variables-to-tensor implicit conversion. -from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import -from tensorflow.python.ops import variables -from tensorflow.python.platform import googletest -from tensorflow.python.training import learning_rate_decay_v2 - - -class LRDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testContinuous(self): - self.evaluate(variables.global_variables_initializer()) - step = 5 - decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96) - expected = .05 * 0.96**(5.0 / 10.0) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testStaircase(self): - if context.executing_eagerly(): - step = resource_variable_ops.ResourceVariable(0) - self.evaluate(variables.global_variables_initializer()) - decayed_lr = learning_rate_decay_v2.exponential_decay( - .1, step, 3, 0.96, staircase=True) - - # No change to learning rate due to staircase - expected = .1 - self.evaluate(step.assign(1)) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - expected = .1 - self.evaluate(step.assign(2)) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - # Decayed learning rate - expected = .1 * 0.96 ** (100 // 3) - self.evaluate(step.assign(100)) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_deprecated_v1 - def testVariables(self): - step = variables.Variable(1) - assign_1 = step.assign(1) - assign_2 = step.assign(2) - assign_100 = step.assign(100) - decayed_lr = learning_rate_decay_v2.exponential_decay( - .1, step, 3, 0.96, staircase=True) - self.evaluate(variables.global_variables_initializer()) - # No change to learning rate - self.evaluate(assign_1.op) - self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6) - self.evaluate(assign_2.op) - self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6) - # Decayed learning rate - self.evaluate(assign_100.op) - expected = .1 * 0.96**(100 // 3) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testPiecewiseConstant(self): - x = resource_variable_ops.ResourceVariable(-999) - decayed_lr = learning_rate_decay_v2.piecewise_constant( - x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) - - self.evaluate(variables.global_variables_initializer()) - - self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6) - self.evaluate(x.assign(100)) - self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6) - self.evaluate(x.assign(105)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6) - self.evaluate(x.assign(110)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6) - self.evaluate(x.assign(120)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6) - self.evaluate(x.assign(999)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testPiecewiseConstantEdgeCases(self): - x_int = resource_variable_ops.ResourceVariable( - 0, dtype=variables.dtypes.int32) - boundaries, values = [-1.0, 1.0], [1, 2, 3] - with self.assertRaises(ValueError): - decayed_lr = learning_rate_decay_v2.piecewise_constant( - x_int, boundaries, values) - decayed_lr() - - x = resource_variable_ops.ResourceVariable(0.0) - boundaries, values = [-1.0, 1.0], [1.0, 2, 3] - with self.assertRaises(ValueError): - decayed_lr = learning_rate_decay_v2.piecewise_constant( - x, boundaries, values)() - decayed_lr() - - # Test that ref types are valid. - if not context.executing_eagerly(): - x = variables.Variable(0.0) - x_ref = x.op.outputs[0] # float32_ref tensor should be accepted - boundaries, values = [1.0, 2.0], [1, 2, 3] - learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values) - - # Test casting boundaries from int32 to int64. - x_int64 = resource_variable_ops.ResourceVariable( - 0, dtype=variables.dtypes.int64) - boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] - decayed_lr = learning_rate_decay_v2.piecewise_constant( - x_int64, boundaries, values) - - self.evaluate(variables.global_variables_initializer()) - self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6) - self.evaluate(x_int64.assign(1)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6) - self.evaluate(x_int64.assign(2)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6) - self.evaluate(x_int64.assign(3)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6) - self.evaluate(x_int64.assign(4)) - self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6) - - -class LinearDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testHalfWay(self): - step = 5 - lr = 0.05 - end_lr = 0.0 - decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) - expected = lr * 0.5 - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testEnd(self): - step = 10 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testHalfWayWithEnd(self): - step = 5 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) - expected = (lr + end_lr) * 0.5 - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testBeyondEnd(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) - expected = end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testBeyondEndWithCycle(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, cycle=True) - expected = (lr - end_lr) * 0.25 + end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class SqrtDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testHalfWay(self): - step = 5 - lr = 0.05 - end_lr = 0.0 - power = 0.5 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, power=power) - expected = lr * 0.5**power - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testEnd(self): - step = 10 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, power=power) - expected = end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testHalfWayWithEnd(self): - step = 5 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, power=power) - expected = (lr - end_lr) * 0.5**power + end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testBeyondEnd(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, power=power) - expected = end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testBeyondEndWithCycle(self): - step = 15 - lr = 0.05 - end_lr = 0.001 - power = 0.5 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, 10, end_lr, power=power, cycle=True) - expected = (lr - end_lr) * 0.25**power + end_lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class PolynomialDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testBeginWithCycle(self): - lr = 0.001 - decay_steps = 10 - step = 0 - decayed_lr = learning_rate_decay_v2.polynomial_decay( - lr, step, decay_steps, cycle=True) - expected = lr - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class ExponentialDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testDecay(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = resource_variable_ops.ResourceVariable(0) - decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k, - decay_rate) - - self.evaluate(variables.global_variables_initializer()) - for i in range(k + 1): - expected = initial_lr * math.exp(-i / k * decay_rate) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - self.evaluate(step.assign_add(1)) - - @test_util.run_in_graph_and_eager_modes - def testStaircase(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = resource_variable_ops.ResourceVariable(0) - decayed_lr = learning_rate_decay_v2.natural_exp_decay( - initial_lr, step, k, decay_rate, staircase=True) - - self.evaluate(variables.global_variables_initializer()) - for i in range(k + 1): - expected = initial_lr * math.exp(-decay_rate * (i // k)) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - self.evaluate(step.assign_add(1)) - - -class InverseDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testDecay(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = resource_variable_ops.ResourceVariable(0) - decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k, - decay_rate) - - self.evaluate(variables.global_variables_initializer()) - for i in range(k + 1): - expected = initial_lr / (1 + i / k * decay_rate) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - self.evaluate(step.assign_add(1)) - - @test_util.run_in_graph_and_eager_modes - def testStaircase(self): - initial_lr = 0.1 - k = 10 - decay_rate = 0.96 - step = resource_variable_ops.ResourceVariable(0) - decayed_lr = learning_rate_decay_v2.inverse_time_decay( - initial_lr, step, k, decay_rate, staircase=True) - - self.evaluate(variables.global_variables_initializer()) - for i in range(k + 1): - expected = initial_lr / (1 + decay_rate * (i // k)) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - self.evaluate(step.assign_add(1)) - - -class CosineDecayTestV2(test_util.TensorFlowTestCase): - - def np_cosine_decay(self, step, decay_steps, alpha=0.0): - step = min(step, decay_steps) - completed_fraction = step / decay_steps - decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) - return (1.0 - alpha) * decay + alpha - - @test_util.run_in_graph_and_eager_modes - def testDecay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step, - num_training_steps) - expected = self.np_cosine_decay(step, num_training_steps) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testAlpha(self): - num_training_steps = 1000 - initial_lr = 1.0 - alpha = 0.1 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step, - num_training_steps, - alpha) - expected = self.np_cosine_decay(step, num_training_steps, alpha) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase): - - def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, - alpha=0.0): - fac = 1.0 - while step >= decay_steps: - step -= decay_steps - decay_steps *= t_mul - fac *= m_mul - - completed_fraction = step / decay_steps - decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) - return (1.0 - alpha) * decay + alpha - - @test_util.run_in_graph_and_eager_modes - def testDecay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( - initial_lr, step, num_training_steps) - expected = self.np_cosine_decay_restarts(step, num_training_steps) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testAlpha(self): - num_training_steps = 1000 - initial_lr = 1.0 - alpha = 0.1 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( - initial_lr, step, num_training_steps, alpha=alpha) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, alpha=alpha) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testMMul(self): - num_training_steps = 1000 - initial_lr = 1.0 - m_mul = 0.9 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( - initial_lr, step, num_training_steps, m_mul=m_mul) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, m_mul=m_mul) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testTMul(self): - num_training_steps = 1000 - initial_lr = 1.0 - t_mul = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( - initial_lr, step, num_training_steps, t_mul=t_mul) - expected = self.np_cosine_decay_restarts( - step, num_training_steps, t_mul=t_mul) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class LinearCosineDecayTestV2(test_util.TensorFlowTestCase): - - def np_linear_cosine_decay(self, - step, - decay_steps, - alpha=0.0, - beta=0.001, - num_periods=0.5): - step = min(step, decay_steps) - linear_decayed = float(decay_steps - step) / decay_steps - fraction = 2.0 * num_periods * step / float(decay_steps) - cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) - return (alpha + linear_decayed) * cosine_decayed + beta - - @test_util.run_in_graph_and_eager_modes - def testDefaultDecay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.linear_cosine_decay( - initial_lr, step, num_training_steps) - expected = self.np_linear_cosine_decay(step, num_training_steps) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - @test_util.run_in_graph_and_eager_modes - def testNonDefaultDecay(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - decayed_lr = learning_rate_decay_v2.linear_cosine_decay( - initial_lr, - step, - num_training_steps, - alpha=0.1, - beta=1e-4, - num_periods=5) - expected = self.np_linear_cosine_decay( - step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) - self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) - - -class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase): - - @test_util.run_in_graph_and_eager_modes - def testDefaultNoisyLinearCosine(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - # No numerical check because of noise - decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( - initial_lr, step, num_training_steps) - # Cannot be deterministically tested - self.evaluate(decayed_lr()) - - @test_util.run_in_graph_and_eager_modes - def testNonDefaultNoisyLinearCosine(self): - num_training_steps = 1000 - initial_lr = 1.0 - for step in range(0, 1500, 250): - # No numerical check because of noise - decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( - initial_lr, - step, - num_training_steps, - initial_variance=0.5, - variance_decay=0.1, - alpha=0.1, - beta=1e-4, - num_periods=5) - # Cannot be deterministically tested - self.evaluate(decayed_lr()) - -if __name__ == "__main__": - googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay.pbtxt new file mode 100644 index 00000000000..e2549a2ac62 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.experimental.CosineDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt index 5cd6851278d..24684b9f4d4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.experimental" tf_module { + member { + name: "CosineDecay" + mtype: "<type \'type\'>" + } member { name: "PeepholeLSTMCell" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt index 7257b02087e..97f7c4b8c9c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.pbtxt @@ -32,6 +32,10 @@ tf_module { name: "SGD" mtype: "<type \'type\'>" } + member { + name: "schedules" + mtype: "<type \'module\'>" + } member_method { name: "deserialize" argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt new file mode 100644 index 00000000000..25ae478cb2c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.ExponentialDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.ExponentialDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt new file mode 100644 index 00000000000..b2fe61f4d2c --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.InverseTimeDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.InverseTimeDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt new file mode 100644 index 00000000000..3b33bd7526b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.keras.optimizers.schedules.LearningRateSchedule" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt new file mode 100644 index 00000000000..6f1496492ab --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.PiecewiseConstantDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt new file mode 100644 index 00000000000..728436c3611 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.PolynomialDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.pbtxt new file mode 100644 index 00000000000..024e472a734 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.optimizers.schedules.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.keras.optimizers.schedules" +tf_module { + member { + name: "ExponentialDecay" + mtype: "<type \'type\'>" + } + member { + name: "InverseTimeDecay" + mtype: "<type \'type\'>" + } + member { + name: "LearningRateSchedule" + mtype: "<type \'type\'>" + } + member { + name: "PiecewiseConstantDecay" + mtype: "<type \'type\'>" + } + member { + name: "PolynomialDecay" + mtype: "<type \'type\'>" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize" + argspec: "args=[\'learning_rate_schedule\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt new file mode 100644 index 00000000000..58bede556df --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.experimental.CosineDecayRestarts" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecayRestarts\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay.pbtxt new file mode 100644 index 00000000000..e2549a2ac62 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-cosine-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.experimental.CosineDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt new file mode 100644 index 00000000000..f083120b52c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.experimental.LinearCosineDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LinearCosineDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt new file mode 100644 index 00000000000..8ea3c6beb1c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.experimental.NoisyLinearCosineDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.NoisyLinearCosineDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt index 5cd6851278d..721c18890ac 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.pbtxt @@ -1,5 +1,21 @@ path: "tensorflow.keras.experimental" tf_module { + member { + name: "CosineDecay" + mtype: "<type \'type\'>" + } + member { + name: "CosineDecayRestarts" + mtype: "<type \'type\'>" + } + member { + name: "LinearCosineDecay" + mtype: "<type \'type\'>" + } + member { + name: "NoisyLinearCosineDecay" + mtype: "<type \'type\'>" + } member { name: "PeepholeLSTMCell" mtype: "<type \'type\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt index 7257b02087e..97f7c4b8c9c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.pbtxt @@ -32,6 +32,10 @@ tf_module { name: "SGD" mtype: "<type \'type\'>" } + member { + name: "schedules" + mtype: "<type \'module\'>" + } member_method { name: "deserialize" argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt new file mode 100644 index 00000000000..25ae478cb2c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-exponential-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.ExponentialDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.ExponentialDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt new file mode 100644 index 00000000000..b2fe61f4d2c --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-inverse-time-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.InverseTimeDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.InverseTimeDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt new file mode 100644 index 00000000000..3b33bd7526b --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-learning-rate-schedule.pbtxt @@ -0,0 +1,16 @@ +path: "tensorflow.keras.optimizers.schedules.LearningRateSchedule" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt new file mode 100644 index 00000000000..6f1496492ab --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-piecewise-constant-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.PiecewiseConstantDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt new file mode 100644 index 00000000000..728436c3611 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.-polynomial-decay.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.keras.optimizers.schedules.PolynomialDecay" +tf_class { + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay\'>" + is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>" + is_instance: "<type \'object\'>" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.pbtxt new file mode 100644 index 00000000000..024e472a734 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.optimizers.schedules.pbtxt @@ -0,0 +1,31 @@ +path: "tensorflow.keras.optimizers.schedules" +tf_module { + member { + name: "ExponentialDecay" + mtype: "<type \'type\'>" + } + member { + name: "InverseTimeDecay" + mtype: "<type \'type\'>" + } + member { + name: "LearningRateSchedule" + mtype: "<type \'type\'>" + } + member { + name: "PiecewiseConstantDecay" + mtype: "<type \'type\'>" + } + member { + name: "PolynomialDecay" + mtype: "<type \'type\'>" + } + member_method { + name: "deserialize" + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize" + argspec: "args=[\'learning_rate_schedule\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index 8b39086ed1b..4f293fb40d4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -68,34 +68,14 @@ tf_module { name: "ServerDef" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" } - member_method { - name: "cosine_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], " - } - member_method { - name: "cosine_decay_restarts" - argspec: "args=[\'learning_rate\', \'global_step\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], " - } - member_method { - name: "exponential_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } member_method { name: "get_checkpoint_state" argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "inverse_time_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } member_method { name: "latest_checkpoint" argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "linear_cosine_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], " - } member_method { name: "list_variables" argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" @@ -108,22 +88,6 @@ tf_module { name: "load_variable" argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "natural_exp_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " - } - member_method { - name: "noisy_linear_cosine_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], " - } - member_method { - name: "piecewise_constant_decay" - argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "polynomial_decay" - argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " - } member_method { name: "sdca_fprint" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 5eb3d35527e..011be5172db 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -693,8 +693,11 @@ renames = { 'tf.train.batch': 'tf.compat.v1.train.batch', 'tf.train.batch_join': 'tf.compat.v1.train.batch_join', 'tf.train.checkpoint_exists': 'tf.compat.v1.train.checkpoint_exists', + 'tf.train.cosine_decay': 'tf.compat.v1.train.cosine_decay', + 'tf.train.cosine_decay_restarts': 'tf.compat.v1.train.cosine_decay_restarts', 'tf.train.create_global_step': 'tf.compat.v1.train.create_global_step', 'tf.train.do_quantize_training_on_graphdef': 'tf.compat.v1.train.do_quantize_training_on_graphdef', + 'tf.train.exponential_decay': 'tf.compat.v1.train.exponential_decay', 'tf.train.export_meta_graph': 'tf.compat.v1.train.export_meta_graph', 'tf.train.generate_checkpoint_state_proto': 'tf.compat.v1.train.generate_checkpoint_state_proto', 'tf.train.get_checkpoint_mtimes': 'tf.compat.v1.train.get_checkpoint_mtimes', @@ -704,13 +707,19 @@ renames = { 'tf.train.import_meta_graph': 'tf.compat.v1.train.import_meta_graph', 'tf.train.init_from_checkpoint': 'tf.compat.v1.train.init_from_checkpoint', 'tf.train.input_producer': 'tf.compat.v1.train.input_producer', + 'tf.train.inverse_time_decay': 'tf.compat.v1.train.inverse_time_decay', 'tf.train.limit_epochs': 'tf.compat.v1.train.limit_epochs', + 'tf.train.linear_cosine_decay': 'tf.compat.v1.train.linear_cosine_decay', 'tf.train.match_filenames_once': 'tf.io.match_filenames_once', 'tf.train.maybe_batch': 'tf.compat.v1.train.maybe_batch', 'tf.train.maybe_batch_join': 'tf.compat.v1.train.maybe_batch_join', 'tf.train.maybe_shuffle_batch': 'tf.compat.v1.train.maybe_shuffle_batch', 'tf.train.maybe_shuffle_batch_join': 'tf.compat.v1.train.maybe_shuffle_batch_join', + 'tf.train.natural_exp_decay': 'tf.compat.v1.train.natural_exp_decay', + 'tf.train.noisy_linear_cosine_decay': 'tf.compat.v1.train.noisy_linear_cosine_decay', 'tf.train.piecewise_constant': 'tf.compat.v1.train.piecewise_constant', + 'tf.train.piecewise_constant_decay': 'tf.compat.v1.train.piecewise_constant_decay', + 'tf.train.polynomial_decay': 'tf.compat.v1.train.polynomial_decay', 'tf.train.queue_runner.QueueRunner': 'tf.compat.v1.train.queue_runner.QueueRunner', 'tf.train.queue_runner.add_queue_runner': 'tf.compat.v1.train.queue_runner.add_queue_runner', 'tf.train.queue_runner.start_queue_runners': 'tf.compat.v1.train.queue_runner.start_queue_runners',