This CL introduces serializable/deserializable learning rate decay schedules for the Keras v2 optimizers.
PiperOrigin-RevId: 231623483
This commit is contained in:
parent
19c79e944b
commit
234c738bcd
@ -3722,6 +3722,7 @@ py_library(
|
|||||||
"//tensorflow/python/distribute:reduce_util",
|
"//tensorflow/python/distribute:reduce_util",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/keras/optimizer_v2:learning_rate_schedule",
|
||||||
"//tensorflow/python/ops/losses",
|
"//tensorflow/python/ops/losses",
|
||||||
"//tensorflow/python/training/checkpointable:base",
|
"//tensorflow/python/training/checkpointable:base",
|
||||||
"//tensorflow/python/training/checkpointable:util",
|
"//tensorflow/python/training/checkpointable:util",
|
||||||
@ -4807,7 +4808,6 @@ cuda_py_tests(
|
|||||||
"training/ftrl_test.py",
|
"training/ftrl_test.py",
|
||||||
"training/gradient_descent_test.py",
|
"training/gradient_descent_test.py",
|
||||||
"training/learning_rate_decay_test.py",
|
"training/learning_rate_decay_test.py",
|
||||||
"training/learning_rate_decay_v2_test.py",
|
|
||||||
"training/momentum_test.py",
|
"training/momentum_test.py",
|
||||||
"training/optimizer_test.py",
|
"training/optimizer_test.py",
|
||||||
"training/proximal_adagrad_test.py",
|
"training/proximal_adagrad_test.py",
|
||||||
|
@ -310,7 +310,6 @@ py_library(
|
|||||||
"layers/recurrent.py",
|
"layers/recurrent.py",
|
||||||
"layers/serialization.py",
|
"layers/serialization.py",
|
||||||
"layers/wrappers.py",
|
"layers/wrappers.py",
|
||||||
"utils/generic_utils.py",
|
|
||||||
"utils/kernelized_utils.py",
|
"utils/kernelized_utils.py",
|
||||||
"utils/layer_utils.py",
|
"utils/layer_utils.py",
|
||||||
"utils/tf_utils.py",
|
"utils/tf_utils.py",
|
||||||
@ -318,6 +317,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":engine",
|
":engine",
|
||||||
|
":generic_utils",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:cudnn_rnn_ops_gen",
|
"//tensorflow/python:cudnn_rnn_ops_gen",
|
||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
@ -339,6 +339,18 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "generic_utils",
|
||||||
|
srcs = [
|
||||||
|
"utils/generic_utils.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_py_test(
|
tf_py_test(
|
||||||
name = "integration_test",
|
name = "integration_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -25,6 +25,7 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":learning_rate_schedule",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:distribute",
|
"//tensorflow/python:distribute",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
@ -39,6 +40,21 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "learning_rate_schedule",
|
||||||
|
srcs = [
|
||||||
|
"learning_rate_schedule.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:control_flow_ops",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//tensorflow/python/keras:generic_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "adagrad_test",
|
name = "adagrad_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
@ -197,6 +213,19 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "learning_rate_schedule_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["learning_rate_schedule_test.py"],
|
||||||
|
deps = [
|
||||||
|
":optimizer_v2",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/keras",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "rmsprop_test",
|
name = "rmsprop_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.optimizer_v2 import adagrad
|
from tensorflow.python.keras.optimizer_v2 import adagrad
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -160,6 +161,52 @@ class AdagradOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
def testBasicWithLearningRateInverseTimeDecay(self):
|
||||||
|
for dtype in [dtypes.float32, dtypes.float64]:
|
||||||
|
with self.cached_session():
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
learning_rate = 3.0
|
||||||
|
decay = 0.5
|
||||||
|
lr_schedule = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
learning_rate, decay_steps=1.0, decay_rate=decay)
|
||||||
|
|
||||||
|
ada_opt = adagrad.Adagrad(lr_schedule)
|
||||||
|
|
||||||
|
accum0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
accum1_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
ada_update = ada_opt.apply_gradients(
|
||||||
|
zip([grads0, grads1], [var0, var1]))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
v0_val, v1_val = self.evaluate([var0, var1])
|
||||||
|
self.assertAllClose([1.0, 2.0], v0_val)
|
||||||
|
self.assertAllClose([3.0, 4.0], v1_val)
|
||||||
|
|
||||||
|
# Run 3 steps of adagrad
|
||||||
|
for t in range(3):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(ada_update)
|
||||||
|
else:
|
||||||
|
ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
lr_np = learning_rate / (1 + decay * t)
|
||||||
|
var0_np, accum0_np = adagrad_update_numpy(var0_np, accum0_np,
|
||||||
|
grads0_np, lr_np)
|
||||||
|
var1_np, accum1_np = adagrad_update_numpy(var1_np, accum1_np,
|
||||||
|
grads1_np, lr_np)
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMinimizeSparseResourceVariable(self):
|
def testMinimizeSparseResourceVariable(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras import optimizers
|
from tensorflow.python.keras import optimizers
|
||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
@ -399,6 +400,55 @@ class AdamOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testBasicWithLearningRateInverseTimeDecay(self):
|
||||||
|
for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
|
||||||
|
with self.session(graph=ops.Graph()):
|
||||||
|
# Initialize variables for numpy implementation.
|
||||||
|
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
|
||||||
|
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
|
||||||
|
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
|
||||||
|
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(
|
||||||
|
var0_np, name="var0_%d" % i)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(
|
||||||
|
var1_np, name="var1_%d" % i)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
|
decay = 0.5
|
||||||
|
lr_schedule = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
learning_rate, decay_steps=1.0, decay_rate=decay)
|
||||||
|
beta_1 = 0.9
|
||||||
|
beta_2 = 0.999
|
||||||
|
epsilon = 1e-7
|
||||||
|
|
||||||
|
opt = adam.Adam(
|
||||||
|
learning_rate=lr_schedule,
|
||||||
|
beta_1=beta_1,
|
||||||
|
beta_2=beta_2,
|
||||||
|
epsilon=epsilon)
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# Run 3 steps of Adam
|
||||||
|
for t in range(3):
|
||||||
|
self.evaluate(update)
|
||||||
|
|
||||||
|
lr_np = learning_rate / (1 + decay * t)
|
||||||
|
|
||||||
|
var0_np, m0, v0 = adam_update_numpy(
|
||||||
|
var0_np, grads0_np, t, m0, v0, lr=lr_np)
|
||||||
|
var1_np, m1, v1 = adam_update_numpy(
|
||||||
|
var1_np, grads1_np, t, m1, v1, lr=lr_np)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testTensorLearningRate(self):
|
def testTensorLearningRate(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -57,17 +58,11 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
||||||
self.evaluate(var1))
|
self.evaluate(var1))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
def _test_basic_sgd_with_learning_rate_decay(self, sgd, dtype):
|
||||||
def testBasicWithLearningRateDecay(self):
|
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
|
||||||
with self.cached_session():
|
|
||||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||||
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
|
||||||
learning_rate = 3.0
|
|
||||||
decay = 0.5
|
|
||||||
sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay)
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
sgd_op = sgd.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
@ -94,6 +89,31 @@ class GradientDescentOptimizerTest(test.TestCase):
|
|||||||
[3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01],
|
[3.0 - 3.0 * 0.01 - 2.0 * 0.01, 4.0 - 3.0 * 0.01 - 2.0 * 0.01],
|
||||||
self.evaluate(var1))
|
self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBasicWithLearningRateDecay(self):
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
learning_rate = 3.0
|
||||||
|
decay = 0.5
|
||||||
|
sgd = gradient_descent.SGD(learning_rate=learning_rate, decay=decay)
|
||||||
|
self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBasicWithLearningRateInverseTimeDecay(self):
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
learning_rate = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
3.0, decay_steps=1.0, decay_rate=0.5)
|
||||||
|
sgd = gradient_descent.SGD(learning_rate=learning_rate)
|
||||||
|
self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBasicWithLearningRateInverseTimeDecaySerializeAndDeserialize(self):
|
||||||
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
learning_rate = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
3.0, decay_steps=1.0, decay_rate=0.5)
|
||||||
|
sgd = gradient_descent.SGD(learning_rate=learning_rate)
|
||||||
|
sgd = gradient_descent.SGD.from_config(sgd.get_config())
|
||||||
|
self._test_basic_sgd_with_learning_rate_decay(sgd, dtype)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testBasicCallableParams(self):
|
def testBasicCallableParams(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
|
1031
tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
Normal file
1031
tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,527 @@
|
|||||||
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Functional test for learning rate decay."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
|
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
|
||||||
|
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import googletest
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_serialized(lr_decay, serialize_and_deserialize):
|
||||||
|
if serialize_and_deserialize:
|
||||||
|
serialized = learning_rate_schedule.serialize(lr_decay)
|
||||||
|
return learning_rate_schedule.deserialize(serialized)
|
||||||
|
else:
|
||||||
|
return lr_decay
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class LRDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testContinuous(self, serialize):
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
step = 5
|
||||||
|
decayed_lr = learning_rate_schedule.ExponentialDecay(0.05, 10, 0.96)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = .05 * 0.96**(5.0 / 10.0)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testStaircase(self, serialize):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
step = resource_variable_ops.ResourceVariable(0)
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||||
|
.1, 3, 0.96, staircase=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
# No change to learning rate due to staircase
|
||||||
|
expected = .1
|
||||||
|
self.evaluate(step.assign(1))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
expected = .1
|
||||||
|
self.evaluate(step.assign(2))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
# Decayed learning rate
|
||||||
|
expected = .1 * 0.96 ** (100 // 3)
|
||||||
|
self.evaluate(step.assign(100))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testVariables(self, serialize):
|
||||||
|
step = variables.Variable(1)
|
||||||
|
assign_1 = step.assign(1)
|
||||||
|
assign_2 = step.assign(2)
|
||||||
|
assign_100 = step.assign(100)
|
||||||
|
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||||
|
.1, 3, 0.96, staircase=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
# No change to learning rate
|
||||||
|
self.evaluate(assign_1.op)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), .1, 1e-6)
|
||||||
|
self.evaluate(assign_2.op)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), .1, 1e-6)
|
||||||
|
# Decayed learning rate
|
||||||
|
self.evaluate(assign_100.op)
|
||||||
|
expected = .1 * 0.96**(100 // 3)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testPiecewiseConstant(self, serialize):
|
||||||
|
x = resource_variable_ops.ResourceVariable(-999)
|
||||||
|
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||||
|
[100, 110, 120], [1.0, 0.1, 0.01, 0.001])
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 1.0, 1e-6)
|
||||||
|
self.evaluate(x.assign(100))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 1.0, 1e-6)
|
||||||
|
self.evaluate(x.assign(105))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 0.1, 1e-6)
|
||||||
|
self.evaluate(x.assign(110))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 0.1, 1e-6)
|
||||||
|
self.evaluate(x.assign(120))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 0.01, 1e-6)
|
||||||
|
self.evaluate(x.assign(999))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x)), 0.001, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testPiecewiseConstantEdgeCases(self, serialize):
|
||||||
|
x_int = resource_variable_ops.ResourceVariable(
|
||||||
|
0, dtype=variables.dtypes.int32)
|
||||||
|
boundaries, values = [-1.0, 1.0], [1, 2, 3]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||||
|
boundaries, values)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
decayed_lr(x_int)
|
||||||
|
|
||||||
|
x = resource_variable_ops.ResourceVariable(0.0)
|
||||||
|
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||||
|
boundaries, values)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
decayed_lr(x)
|
||||||
|
|
||||||
|
# Test casting boundaries from int32 to int64.
|
||||||
|
x_int64 = resource_variable_ops.ResourceVariable(
|
||||||
|
0, dtype=variables.dtypes.int64)
|
||||||
|
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
|
||||||
|
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||||
|
boundaries, values)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.4, 1e-6)
|
||||||
|
self.evaluate(x_int64.assign(1))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.4, 1e-6)
|
||||||
|
self.evaluate(x_int64.assign(2))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.5, 1e-6)
|
||||||
|
self.evaluate(x_int64.assign(3))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.6, 1e-6)
|
||||||
|
self.evaluate(x_int64.assign(4))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(x_int64)), 0.7, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class LinearDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testHalfWay(self, serialize):
|
||||||
|
step = 5
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.0
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = lr * 0.5
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testEnd(self, serialize):
|
||||||
|
step = 10
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testHalfWayWithEnd(self, serialize):
|
||||||
|
step = 5
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = (lr + end_lr) * 0.5
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBeyondEnd(self, serialize):
|
||||||
|
step = 15
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(lr, 10, end_lr)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBeyondEndWithCycle(self, serialize):
|
||||||
|
step = 15
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, cycle=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = (lr - end_lr) * 0.25 + end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class SqrtDecayTestV2(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testHalfWay(self, serialize):
|
||||||
|
step = 5
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.0
|
||||||
|
power = 0.5
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, power=power)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = lr * 0.5**power
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testEnd(self, serialize):
|
||||||
|
step = 10
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
power = 0.5
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, power=power)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testHalfWayWithEnd(self, serialize):
|
||||||
|
step = 5
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
power = 0.5
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, power=power)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = (lr - end_lr) * 0.5**power + end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBeyondEnd(self, serialize):
|
||||||
|
step = 15
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
power = 0.5
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, power=power)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBeyondEndWithCycle(self, serialize):
|
||||||
|
step = 15
|
||||||
|
lr = 0.05
|
||||||
|
end_lr = 0.001
|
||||||
|
power = 0.5
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, 10, end_lr, power=power, cycle=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = (lr - end_lr) * 0.25**power + end_lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class PolynomialDecayTestV2(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testBeginWithCycle(self, serialize):
|
||||||
|
lr = 0.001
|
||||||
|
decay_steps = 10
|
||||||
|
step = 0
|
||||||
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
|
lr, decay_steps, cycle=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = lr
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class InverseDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDecay(self, serialize):
|
||||||
|
initial_lr = 0.1
|
||||||
|
k = 10
|
||||||
|
decay_rate = 0.96
|
||||||
|
step = resource_variable_ops.ResourceVariable(0)
|
||||||
|
decayed_lr = learning_rate_schedule.InverseTimeDecay(initial_lr, k,
|
||||||
|
decay_rate)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
for i in range(k + 1):
|
||||||
|
expected = initial_lr / (1 + i / k * decay_rate)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
self.evaluate(step.assign_add(1))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testStaircase(self, serialize):
|
||||||
|
initial_lr = 0.1
|
||||||
|
k = 10
|
||||||
|
decay_rate = 0.96
|
||||||
|
step = resource_variable_ops.ResourceVariable(0)
|
||||||
|
decayed_lr = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
initial_lr, k, decay_rate, staircase=True)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
for i in range(k + 1):
|
||||||
|
expected = initial_lr / (1 + decay_rate * (i // k))
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
self.evaluate(step.assign_add(1))
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class CosineDecayTestV2(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def np_cosine_decay(self, step, decay_steps, alpha=0.0):
|
||||||
|
step = min(step, decay_steps)
|
||||||
|
completed_fraction = step / decay_steps
|
||||||
|
decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
||||||
|
return (1.0 - alpha) * decay + alpha
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDecay(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecay(initial_lr,
|
||||||
|
num_training_steps)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay(step, num_training_steps)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testAlpha(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
alpha = 0.1
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecay(initial_lr,
|
||||||
|
num_training_steps,
|
||||||
|
alpha)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay(step, num_training_steps, alpha)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
|
||||||
|
alpha=0.0):
|
||||||
|
fac = 1.0
|
||||||
|
while step >= decay_steps:
|
||||||
|
step -= decay_steps
|
||||||
|
decay_steps *= t_mul
|
||||||
|
fac *= m_mul
|
||||||
|
|
||||||
|
completed_fraction = step / decay_steps
|
||||||
|
decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
||||||
|
return (1.0 - alpha) * decay + alpha
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDecay(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
|
||||||
|
initial_lr, num_training_steps)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay_restarts(step, num_training_steps)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testAlpha(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
alpha = 0.1
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
|
||||||
|
initial_lr, num_training_steps, alpha=alpha)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay_restarts(
|
||||||
|
step, num_training_steps, alpha=alpha)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testMMul(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
m_mul = 0.9
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
|
||||||
|
initial_lr, num_training_steps, m_mul=m_mul)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay_restarts(
|
||||||
|
step, num_training_steps, m_mul=m_mul)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testTMul(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
t_mul = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
|
||||||
|
initial_lr, num_training_steps, t_mul=t_mul)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_cosine_decay_restarts(
|
||||||
|
step, num_training_steps, t_mul=t_mul)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class LinearCosineDecayTestV2(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
def np_linear_cosine_decay(self,
|
||||||
|
step,
|
||||||
|
decay_steps,
|
||||||
|
alpha=0.0,
|
||||||
|
beta=0.001,
|
||||||
|
num_periods=0.5):
|
||||||
|
step = min(step, decay_steps)
|
||||||
|
linear_decayed = float(decay_steps - step) / decay_steps
|
||||||
|
fraction = 2.0 * num_periods * step / float(decay_steps)
|
||||||
|
cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
|
||||||
|
return (alpha + linear_decayed) * cosine_decayed + beta
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDefaultDecay(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.LinearCosineDecay(
|
||||||
|
initial_lr, num_training_steps)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_linear_cosine_decay(step, num_training_steps)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testNonDefaultDecay(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
decayed_lr = learning_rate_schedule.LinearCosineDecay(
|
||||||
|
initial_lr,
|
||||||
|
num_training_steps,
|
||||||
|
alpha=0.1,
|
||||||
|
beta=1e-4,
|
||||||
|
num_periods=5)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
expected = self.np_linear_cosine_decay(
|
||||||
|
step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
|
||||||
|
self.assertAllClose(self.evaluate(decayed_lr(step)), expected, 1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
("NotSerialized", False),
|
||||||
|
("Serialized", True))
|
||||||
|
class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testDefaultNoisyLinearCosine(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
# No numerical check because of noise
|
||||||
|
decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay(
|
||||||
|
initial_lr, num_training_steps)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
# Cannot be deterministically tested
|
||||||
|
self.evaluate(decayed_lr(step))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testNonDefaultNoisyLinearCosine(self, serialize):
|
||||||
|
num_training_steps = 1000
|
||||||
|
initial_lr = 1.0
|
||||||
|
for step in range(0, 1500, 250):
|
||||||
|
# No numerical check because of noise
|
||||||
|
decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay(
|
||||||
|
initial_lr,
|
||||||
|
num_training_steps,
|
||||||
|
initial_variance=0.5,
|
||||||
|
variance_decay=0.1,
|
||||||
|
alpha=0.1,
|
||||||
|
beta=1e-4,
|
||||||
|
num_periods=5)
|
||||||
|
decayed_lr = _maybe_serialized(decayed_lr, serialize)
|
||||||
|
# Cannot be deterministically tested
|
||||||
|
self.evaluate(decayed_lr(step))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
googletest.main()
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.keras import backend_config
|
from tensorflow.python.keras import backend_config
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -86,6 +87,12 @@ class Nadam(optimizer_v2.OptimizerV2):
|
|||||||
|
|
||||||
# Backwards compatiblity with keras NAdam optimizer.
|
# Backwards compatiblity with keras NAdam optimizer.
|
||||||
kwargs['decay'] = kwargs.pop('schedule_decay', 0.004)
|
kwargs['decay'] = kwargs.pop('schedule_decay', 0.004)
|
||||||
|
learning_rate = kwargs.get('lr', learning_rate)
|
||||||
|
if isinstance(learning_rate, learning_rate_schedule.LearningRateSchedule):
|
||||||
|
raise ValueError('The Nadam optimizer does not support '
|
||||||
|
'tf.keras.optimizers.LearningRateSchedules as the '
|
||||||
|
'learning rate.')
|
||||||
|
|
||||||
if epsilon is None:
|
if epsilon is None:
|
||||||
epsilon = backend_config.epsilon()
|
epsilon = backend_config.epsilon()
|
||||||
super(Nadam, self).__init__(name, **kwargs)
|
super(Nadam, self).__init__(name, **kwargs)
|
||||||
|
@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras import initializers
|
from tensorflow.python.keras import initializers
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
from tensorflow.python.ops import gradients
|
from tensorflow.python.ops import gradients
|
||||||
@ -452,8 +453,11 @@ class OptimizerV2(checkpointable.Checkpointable):
|
|||||||
self._hyper[name] = value
|
self._hyper[name] = value
|
||||||
else:
|
else:
|
||||||
prev_value = self._hyper[name]
|
prev_value = self._hyper[name]
|
||||||
if callable(prev_value) or isinstance(prev_value,
|
if (callable(prev_value)
|
||||||
(ops.Tensor, int, float)):
|
or isinstance(prev_value,
|
||||||
|
(ops.Tensor, int, float,
|
||||||
|
learning_rate_schedule.LearningRateSchedule))
|
||||||
|
or isinstance(value, learning_rate_schedule.LearningRateSchedule)):
|
||||||
self._hyper[name] = value
|
self._hyper[name] = value
|
||||||
else:
|
else:
|
||||||
backend.set_value(self._hyper[name], value)
|
backend.set_value(self._hyper[name], value)
|
||||||
@ -462,6 +466,8 @@ class OptimizerV2(checkpointable.Checkpointable):
|
|||||||
if not self._hypers_created:
|
if not self._hypers_created:
|
||||||
self._create_hypers()
|
self._create_hypers()
|
||||||
value = self._hyper[name]
|
value = self._hyper[name]
|
||||||
|
if isinstance(value, learning_rate_schedule.LearningRateSchedule):
|
||||||
|
return value
|
||||||
if callable(value):
|
if callable(value):
|
||||||
value = value()
|
value = value()
|
||||||
if dtype:
|
if dtype:
|
||||||
@ -575,6 +581,9 @@ class OptimizerV2(checkpointable.Checkpointable):
|
|||||||
def _decayed_lr(self, var_dtype):
|
def _decayed_lr(self, var_dtype):
|
||||||
"""Get decayed learning rate as a Tensor with dtype=var_dtype."""
|
"""Get decayed learning rate as a Tensor with dtype=var_dtype."""
|
||||||
lr_t = self._get_hyper("learning_rate", var_dtype)
|
lr_t = self._get_hyper("learning_rate", var_dtype)
|
||||||
|
if isinstance(lr_t, learning_rate_schedule.LearningRateSchedule):
|
||||||
|
local_step = math_ops.cast(self.iterations, var_dtype)
|
||||||
|
lr_t = math_ops.cast(lr_t(local_step), var_dtype)
|
||||||
if self._initial_decay > 0.:
|
if self._initial_decay > 0.:
|
||||||
local_step = math_ops.cast(self.iterations, var_dtype)
|
local_step = math_ops.cast(self.iterations, var_dtype)
|
||||||
decay_t = self._get_hyper("decay", var_dtype)
|
decay_t = self._get_hyper("decay", var_dtype)
|
||||||
@ -619,11 +628,17 @@ class OptimizerV2(checkpointable.Checkpointable):
|
|||||||
"""
|
"""
|
||||||
if "lr" in config:
|
if "lr" in config:
|
||||||
config["learning_rate"] = config.pop("lr")
|
config["learning_rate"] = config.pop("lr")
|
||||||
|
if "learning_rate" in config:
|
||||||
|
if isinstance(config["learning_rate"], dict):
|
||||||
|
config["learning_rate"] = learning_rate_schedule.deserialize(
|
||||||
|
config["learning_rate"])
|
||||||
return cls(**config)
|
return cls(**config)
|
||||||
|
|
||||||
def _serialize_hyperparameter(self, hyperparameter_name):
|
def _serialize_hyperparameter(self, hyperparameter_name):
|
||||||
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
|
"""Serialize a hyperparameter that can be a float, callable, or Tensor."""
|
||||||
value = self._hyper[hyperparameter_name]
|
value = self._hyper[hyperparameter_name]
|
||||||
|
if isinstance(value, learning_rate_schedule.LearningRateSchedule):
|
||||||
|
return learning_rate_schedule.serialize(value)
|
||||||
if callable(value):
|
if callable(value):
|
||||||
return value()
|
return value()
|
||||||
if isinstance(value, (ops.Tensor, tf_variables.Variable,
|
if isinstance(value, (ops.Tensor, tf_variables.Variable,
|
||||||
|
@ -41,6 +41,7 @@ from tensorflow.python.keras.optimizer_v2 import adagrad
|
|||||||
from tensorflow.python.keras.optimizer_v2 import adam
|
from tensorflow.python.keras.optimizer_v2 import adam
|
||||||
from tensorflow.python.keras.optimizer_v2 import adamax
|
from tensorflow.python.keras.optimizer_v2 import adamax
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.keras.optimizer_v2 import nadam
|
from tensorflow.python.keras.optimizer_v2 import nadam
|
||||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
@ -113,6 +114,13 @@ class OptimizerTest(test.TestCase):
|
|||||||
# var1 = [0., 1.] - 0.5 * [3, 3]
|
# var1 = [0., 1.] - 0.5 * [3, 3]
|
||||||
self.assertAllClose([-1.5, -0.5], self.evaluate(var1))
|
self.assertAllClose([-1.5, -0.5], self.evaluate(var1))
|
||||||
|
|
||||||
|
sgd.learning_rate = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
0.5, decay_steps=1.0, decay_rate=0.5)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
sgd.minimize(loss, [var0, var1])
|
||||||
|
else:
|
||||||
|
self.evaluate(opt_op)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testPrecomputedGradient(self):
|
def testPrecomputedGradient(self):
|
||||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||||
@ -281,6 +289,33 @@ class OptimizerTest(test.TestCase):
|
|||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.assertEqual(self.evaluate(lr), self.evaluate(lr3))
|
self.assertEqual(self.evaluate(lr), self.evaluate(lr3))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConfigWithLearningRateDecay(self):
|
||||||
|
with self.cached_session():
|
||||||
|
decay_schedule = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
0.5, decay_steps=1.0, decay_rate=0.1)
|
||||||
|
step = 10
|
||||||
|
opt = gradient_descent.SGD(decay_schedule)
|
||||||
|
config = opt.get_config()
|
||||||
|
opt2 = gradient_descent.SGD.from_config(config)
|
||||||
|
# assert both are equal float values.
|
||||||
|
self.assertAllEqual(
|
||||||
|
decay_schedule(step),
|
||||||
|
opt._get_hyper('learning_rate')(step))
|
||||||
|
self.assertAllEqual(
|
||||||
|
decay_schedule(step),
|
||||||
|
opt2._get_hyper('learning_rate')(step))
|
||||||
|
var0 = variables.Variable([[1.0], [2.0]], dtype=dtypes.float32)
|
||||||
|
loss = lambda: 3 * var0
|
||||||
|
# learning rate variable created when calling minimize.
|
||||||
|
opt.minimize(loss, [var0])
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
config = opt.get_config()
|
||||||
|
opt3 = gradient_descent.SGD.from_config(config)
|
||||||
|
self.assertAllEqual(
|
||||||
|
self.evaluate(opt._get_hyper('learning_rate')(step)),
|
||||||
|
opt3._get_hyper('learning_rate')(step))
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testGradClipValue(self):
|
def testGradClipValue(self):
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
from tensorflow.python.keras.optimizer_v2 import rmsprop
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -244,6 +245,78 @@ class RMSpropOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testDenseWithLearningRateInverseTimeDecay(self):
|
||||||
|
var0_np = np.array([1.0, 2.0])
|
||||||
|
grads0_np = np.array([0.1, 0.2])
|
||||||
|
var1_np = np.array([3.0, 4.0])
|
||||||
|
grads1_np = np.array([0.01, 0.2])
|
||||||
|
|
||||||
|
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
||||||
|
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
||||||
|
grads0 = constant_op.constant(grads0_np)
|
||||||
|
grads1 = constant_op.constant(grads1_np)
|
||||||
|
learning_rate = 0.01
|
||||||
|
rho = 0.9
|
||||||
|
momentum = 0.0
|
||||||
|
epsilon = 1e-7
|
||||||
|
centered = False
|
||||||
|
decay = 0.5
|
||||||
|
lr_schedule = learning_rate_schedule.InverseTimeDecay(
|
||||||
|
learning_rate, decay_steps=1.0, decay_rate=decay)
|
||||||
|
opt = rmsprop.RMSprop(
|
||||||
|
learning_rate=lr_schedule,
|
||||||
|
rho=rho,
|
||||||
|
momentum=momentum,
|
||||||
|
epsilon=epsilon,
|
||||||
|
centered=centered)
|
||||||
|
|
||||||
|
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
rms0 = opt.get_slot(var0, "rms")
|
||||||
|
self.assertTrue(rms0 is not None)
|
||||||
|
rms1 = opt.get_slot(var1, "rms")
|
||||||
|
self.assertTrue(rms1 is not None)
|
||||||
|
if momentum > 0.:
|
||||||
|
mom0 = opt.get_slot(var0, "momentum")
|
||||||
|
mom1 = opt.get_slot(var1, "momentum")
|
||||||
|
else:
|
||||||
|
mom0 = None
|
||||||
|
mom1 = None
|
||||||
|
|
||||||
|
mg0_np = np.array([0.0, 0.0])
|
||||||
|
mg1_np = np.array([0.0, 0.0])
|
||||||
|
rms0_np = np.array([0.0, 0.0])
|
||||||
|
rms1_np = np.array([0.0, 0.0])
|
||||||
|
mom0_np = np.array([0.0, 0.0])
|
||||||
|
mom1_np = np.array([0.0, 0.0])
|
||||||
|
|
||||||
|
# Fetch params to validate initial values
|
||||||
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
||||||
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
||||||
|
|
||||||
|
# Run 4 steps of RMSprop
|
||||||
|
for t in range(2):
|
||||||
|
self.evaluate(update)
|
||||||
|
|
||||||
|
lr = learning_rate / (1 + decay * t)
|
||||||
|
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
|
||||||
|
var0_np, grads0_np, mg0_np, rms0_np, mom0_np, lr, rho, momentum,
|
||||||
|
epsilon, centered)
|
||||||
|
var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
|
||||||
|
var1_np, grads1_np, mg1_np, rms1_np, mom1_np, lr, rho, momentum,
|
||||||
|
epsilon, centered)
|
||||||
|
|
||||||
|
# Validate updated params
|
||||||
|
self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
|
||||||
|
self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
|
||||||
|
if momentum > 0.:
|
||||||
|
self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
|
||||||
|
self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
|
||||||
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
||||||
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testMinimizeSparseResourceVariable(self):
|
def testMinimizeSparseResourceVariable(self):
|
||||||
for dtype in [dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.float32, dtypes.float64]:
|
||||||
|
@ -84,6 +84,7 @@ KERAS_API_INIT_FILES = [
|
|||||||
"keras/metrics/__init__.py",
|
"keras/metrics/__init__.py",
|
||||||
"keras/models/__init__.py",
|
"keras/models/__init__.py",
|
||||||
"keras/optimizers/__init__.py",
|
"keras/optimizers/__init__.py",
|
||||||
|
"keras/optimizers/schedules/__init__.py",
|
||||||
"keras/preprocessing/__init__.py",
|
"keras/preprocessing/__init__.py",
|
||||||
"keras/preprocessing/image/__init__.py",
|
"keras/preprocessing/image/__init__.py",
|
||||||
"keras/preprocessing/sequence/__init__.py",
|
"keras/preprocessing/sequence/__init__.py",
|
||||||
|
@ -107,6 +107,7 @@ KERAS_API_INIT_FILES_V1 = [
|
|||||||
"keras/metrics/__init__.py",
|
"keras/metrics/__init__.py",
|
||||||
"keras/models/__init__.py",
|
"keras/models/__init__.py",
|
||||||
"keras/optimizers/__init__.py",
|
"keras/optimizers/__init__.py",
|
||||||
|
"keras/optimizers/schedules/__init__.py",
|
||||||
"keras/preprocessing/__init__.py",
|
"keras/preprocessing/__init__.py",
|
||||||
"keras/preprocessing/image/__init__.py",
|
"keras/preprocessing/image/__init__.py",
|
||||||
"keras/preprocessing/sequence/__init__.py",
|
"keras/preprocessing/sequence/__init__.py",
|
||||||
|
@ -17,8 +17,11 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.training import learning_rate_decay_v2
|
from tensorflow.python.keras.optimizer_v2 import learning_rate_schedule
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@ -88,15 +91,15 @@ def exponential_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate,
|
decayed_lr = learning_rate_schedule.ExponentialDecay(learning_rate,
|
||||||
global_step,
|
|
||||||
decay_steps,
|
decay_steps,
|
||||||
decay_rate,
|
decay_rate,
|
||||||
staircase=staircase,
|
staircase=staircase,
|
||||||
name=name)
|
name=name)
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -143,11 +146,12 @@ def piecewise_constant(x, boundaries, values, name=None):
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values,
|
decayed_lr = learning_rate_schedule.PiecewiseConstantDecay(
|
||||||
name=name)
|
boundaries, values, name=name)
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(x)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, x)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -236,9 +240,8 @@ def polynomial_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
decayed_lr = learning_rate_schedule.PolynomialDecay(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
global_step,
|
|
||||||
decay_steps,
|
decay_steps,
|
||||||
end_learning_rate=end_learning_rate,
|
end_learning_rate=end_learning_rate,
|
||||||
power=power,
|
power=power,
|
||||||
@ -246,8 +249,9 @@ def polynomial_decay(learning_rate,
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -323,13 +327,15 @@ def natural_exp_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.natural_exp_decay(
|
natural_exp_rate = math_ops.exp(math_ops.negative(decay_rate))
|
||||||
learning_rate, global_step, decay_steps, decay_rate, staircase=staircase,
|
decayed_lr = learning_rate_schedule.ExponentialDecay(
|
||||||
|
learning_rate, decay_steps, natural_exp_rate, staircase=staircase,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -405,17 +411,17 @@ def inverse_time_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.inverse_time_decay(
|
decayed_lr = learning_rate_schedule.InverseTimeDecay(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
global_step,
|
|
||||||
decay_steps,
|
decay_steps,
|
||||||
decay_rate,
|
decay_rate,
|
||||||
staircase=staircase,
|
staircase=staircase,
|
||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -468,12 +474,13 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None):
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay(
|
decayed_lr = learning_rate_schedule.CosineDecay(
|
||||||
learning_rate, global_step, decay_steps, alpha=alpha, name=name)
|
learning_rate, decay_steps, alpha=alpha, name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -535,9 +542,8 @@ def cosine_decay_restarts(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
|
decayed_lr = learning_rate_schedule.CosineDecayRestarts(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
global_step,
|
|
||||||
first_decay_steps,
|
first_decay_steps,
|
||||||
t_mul=t_mul,
|
t_mul=t_mul,
|
||||||
m_mul=m_mul,
|
m_mul=m_mul,
|
||||||
@ -545,8 +551,9 @@ def cosine_decay_restarts(learning_rate,
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -617,9 +624,8 @@ def linear_cosine_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
|
decayed_lr = learning_rate_schedule.LinearCosineDecay(
|
||||||
learning_rate,
|
learning_rate,
|
||||||
global_step,
|
|
||||||
decay_steps,
|
decay_steps,
|
||||||
num_periods=num_periods,
|
num_periods=num_periods,
|
||||||
alpha=alpha,
|
alpha=alpha,
|
||||||
@ -627,8 +633,9 @@ def linear_cosine_decay(learning_rate,
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
|
||||||
|
|
||||||
@ -707,8 +714,8 @@ def noisy_linear_cosine_decay(learning_rate,
|
|||||||
the learning rate value across different invocations of optimizer functions.
|
the learning rate value across different invocations of optimizer functions.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
|
decayed_lr = learning_rate_schedule.NoisyLinearCosineDecay(
|
||||||
learning_rate, global_step,
|
learning_rate,
|
||||||
decay_steps,
|
decay_steps,
|
||||||
initial_variance=initial_variance,
|
initial_variance=initial_variance,
|
||||||
variance_decay=variance_decay,
|
variance_decay=variance_decay,
|
||||||
@ -718,6 +725,7 @@ def noisy_linear_cosine_decay(learning_rate,
|
|||||||
name=name)
|
name=name)
|
||||||
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
decayed_lr = decayed_lr()
|
decayed_lr = decayed_lr(global_step)
|
||||||
|
else:
|
||||||
|
decayed_lr = functools.partial(decayed_lr, global_step)
|
||||||
return decayed_lr
|
return decayed_lr
|
||||||
|
@ -1,898 +0,0 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Various learning rate decay functions."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import math
|
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import random_ops
|
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.exponential_decay", v1=[])
|
|
||||||
def exponential_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
decay_rate,
|
|
||||||
staircase=False,
|
|
||||||
name=None):
|
|
||||||
"""Applies exponential decay to the learning rate.
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies an exponential decay function
|
|
||||||
to a provided initial learning rate. It requires a `global_step` value to
|
|
||||||
compute the decayed learning rate. You can just pass a TensorFlow variable
|
|
||||||
that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg function that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions.
|
|
||||||
It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
decayed_learning_rate = learning_rate *
|
|
||||||
decay_rate ^ (global_step / decay_steps)
|
|
||||||
```
|
|
||||||
|
|
||||||
If the argument `staircase` is `True`, then `global_step / decay_steps` is an
|
|
||||||
integer division and the decayed learning rate follows a staircase function.
|
|
||||||
|
|
||||||
Example: decay every 100000 steps with a base of 0.96:
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
global_step = tf.Variable(0, trainable=False)
|
|
||||||
starter_learning_rate = 0.1
|
|
||||||
learning_rate_fn = tf.train.exponential_decay(starter_learning_rate,
|
|
||||||
global_step, 100000, 0.96,
|
|
||||||
staircase=True)
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
|
||||||
learning_step = (
|
|
||||||
tf.train.GradientDescentOptimizer(learning_rate_fn)
|
|
||||||
.minimize(...my loss..., global_step=global_step)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation. Must not be negative.
|
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Must be positive. See the decay computation above.
|
|
||||||
decay_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The decay rate.
|
|
||||||
staircase: Boolean. If `True` decay the learning rate at discrete intervals
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'ExponentialDecay'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("global_step is required for exponential_decay.")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, decay_rate,
|
|
||||||
staircase, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(
|
|
||||||
name, "ExponentialDecay",
|
|
||||||
[learning_rate, global_step, decay_steps, decay_rate]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
decay_rate = math_ops.cast(decay_rate, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
p = global_step_recomp / decay_steps
|
|
||||||
if staircase:
|
|
||||||
p = math_ops.floor(p)
|
|
||||||
return math_ops.multiply(
|
|
||||||
learning_rate, math_ops.pow(decay_rate, p), name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
decay_rate, staircase, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.piecewise_constant_decay", v1=[])
|
|
||||||
def piecewise_constant(x, boundaries, values, name=None):
|
|
||||||
"""Piecewise constant from boundaries and interval values.
|
|
||||||
|
|
||||||
This function returns a no-arg callable to compute the piecewise constant.
|
|
||||||
This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions.
|
|
||||||
|
|
||||||
Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5
|
|
||||||
for the next 10000 steps, and 0.1 for any additional steps.
|
|
||||||
|
|
||||||
```python
|
|
||||||
global_step = tf.Variable(0, trainable=False)
|
|
||||||
boundaries = [100000, 110000]
|
|
||||||
values = [1.0, 0.5, 0.1]
|
|
||||||
learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries,
|
|
||||||
values)
|
|
||||||
learning_rate = learning_rate_fn()
|
|
||||||
|
|
||||||
# Later, whenever we perform an optimization step, we increment global_step.
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
|
|
||||||
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
|
|
||||||
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
|
|
||||||
increasing entries, and with all elements having the same type as `x`.
|
|
||||||
values: A list of `Tensor`s or `float`s or `int`s that specifies the values
|
|
||||||
for the intervals defined by `boundaries`. It should have one more element
|
|
||||||
than `boundaries`, and all elements should have the same type.
|
|
||||||
name: A string. Optional name of the operation. Defaults to
|
|
||||||
'PiecewiseConstant'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs a 0-D Tensor. The output of the no-arg
|
|
||||||
function is `values[0]` when `x <= boundaries[0]`,
|
|
||||||
`values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
|
|
||||||
and values[-1] when `x > boundaries[-1]`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if types of `x` and `boundaries` do not match, or types of all
|
|
||||||
`values` do not match or
|
|
||||||
the number of elements in the lists does not match.
|
|
||||||
"""
|
|
||||||
if len(boundaries) != len(values) - 1:
|
|
||||||
raise ValueError(
|
|
||||||
"The length of boundaries should be 1 less than the length of values")
|
|
||||||
def decayed_lr(x, boundaries, values, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "PiecewiseConstant",
|
|
||||||
[x, boundaries, values, name]) as name:
|
|
||||||
boundaries = ops.convert_n_to_tensor(boundaries)
|
|
||||||
values = ops.convert_n_to_tensor(values)
|
|
||||||
x_recomp = ops.convert_to_tensor(x)
|
|
||||||
# Avoid explicit conversion to x's dtype. This could result in faulty
|
|
||||||
# comparisons, for example if floats are converted to integers.
|
|
||||||
for i, b in enumerate(boundaries):
|
|
||||||
if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
|
|
||||||
# We can promote int32 boundaries to int64 without loss of precision.
|
|
||||||
# This covers the most common case where the user passes in boundaries
|
|
||||||
# as an array of Python integers.
|
|
||||||
if (b.dtype.base_dtype == dtypes.int32 and
|
|
||||||
x_recomp.dtype.base_dtype == dtypes.int64):
|
|
||||||
b = math_ops.cast(b, x_recomp.dtype.base_dtype)
|
|
||||||
boundaries[i] = b
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Boundaries (%s) must have the same dtype as x (%s)." %
|
|
||||||
(b.dtype.base_dtype, x_recomp.dtype.base_dtype))
|
|
||||||
# TODO(rdipietro): Ensure that boundaries' elements strictly increases.
|
|
||||||
for v in values[1:]:
|
|
||||||
if v.dtype.base_dtype != values[0].dtype.base_dtype:
|
|
||||||
raise ValueError(
|
|
||||||
"Values must have elements all with the same dtype (%s vs %s)." %
|
|
||||||
(values[0].dtype.base_dtype, v.dtype.base_dtype))
|
|
||||||
pred_fn_pairs = []
|
|
||||||
pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
|
|
||||||
pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
|
|
||||||
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
|
|
||||||
# Need to bind v here; can do this with lambda v=v: ...
|
|
||||||
pred = (x_recomp > low) & (x_recomp <= high)
|
|
||||||
pred_fn_pairs.append((pred, lambda v=v: v))
|
|
||||||
|
|
||||||
# The default isn't needed here because our conditions are mutually
|
|
||||||
# exclusive and exhaustive, but tf.case requires it.
|
|
||||||
default = lambda: values[0]
|
|
||||||
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, x, boundaries, values, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.polynomial_decay", v1=[])
|
|
||||||
def polynomial_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
end_learning_rate=0.0001,
|
|
||||||
power=1.0,
|
|
||||||
cycle=False,
|
|
||||||
name=None):
|
|
||||||
"""Applies a polynomial decay to the learning rate.
|
|
||||||
|
|
||||||
It is commonly observed that a monotonically decreasing learning rate, whose
|
|
||||||
degree of change is carefully chosen, results in a better performing model.
|
|
||||||
This function applies a polynomial decay function to a provided initial
|
|
||||||
`learning_rate` to reach an `end_learning_rate` in the given `decay_steps`.
|
|
||||||
|
|
||||||
It requires a `global_step` value to compute the decayed learning rate. You
|
|
||||||
can just pass a TensorFlow variable that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that outputs the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
global_step = min(global_step, decay_steps)
|
|
||||||
decayed_learning_rate = (learning_rate - end_learning_rate) *
|
|
||||||
(1 - global_step / decay_steps) ^ (power) +
|
|
||||||
end_learning_rate
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
If `cycle` is True then a multiple of `decay_steps` is used, the first one
|
|
||||||
that is bigger than `global_steps`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
decay_steps = decay_steps * ceil(global_step / decay_steps)
|
|
||||||
decayed_learning_rate_fn = (learning_rate - end_learning_rate) *
|
|
||||||
(1 - global_step / decay_steps) ^ (power) +
|
|
||||||
end_learning_rate
|
|
||||||
decayed_learning_rate = decayed_learning_rate_fn()
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5):
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
global_step = tf.Variable(0, trainable=False)
|
|
||||||
starter_learning_rate = 0.1
|
|
||||||
end_learning_rate = 0.01
|
|
||||||
decay_steps = 10000
|
|
||||||
learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate,
|
|
||||||
global_step, decay_steps,
|
|
||||||
end_learning_rate,
|
|
||||||
power=0.5)
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
|
||||||
learning_step = (
|
|
||||||
tf.train.GradientDescentOptimizer(learning_rate_fn)
|
|
||||||
.minimize(...my loss..., global_step=global_step)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation. Must not be negative.
|
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Must be positive. See the decay computation above.
|
|
||||||
end_learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The minimal end learning rate.
|
|
||||||
power: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The power of the polynomial. Defaults to linear, 1.0.
|
|
||||||
cycle: A boolean, whether or not it should cycle beyond decay_steps.
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'PolynomialDecay'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("global_step is required for polynomial_decay.")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate,
|
|
||||||
power, cycle, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(
|
|
||||||
name, "PolynomialDecay",
|
|
||||||
[learning_rate, global_step, decay_steps, end_learning_rate, power]
|
|
||||||
) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
end_learning_rate = math_ops.cast(end_learning_rate, dtype)
|
|
||||||
power = math_ops.cast(power, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
decay_steps_recomp = math_ops.cast(decay_steps, dtype)
|
|
||||||
if cycle:
|
|
||||||
# Find the first multiple of decay_steps that is bigger than
|
|
||||||
# global_step. If global_step is zero set the multiplier to 1
|
|
||||||
multiplier = control_flow_ops.cond(
|
|
||||||
math_ops.equal(global_step_recomp, 0), lambda: 1.0,
|
|
||||||
lambda: math_ops.ceil(global_step_recomp / decay_steps))
|
|
||||||
decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
|
|
||||||
else:
|
|
||||||
# Make sure that the global_step used is not bigger than decay_steps.
|
|
||||||
global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
|
|
||||||
|
|
||||||
p = math_ops.div(global_step_recomp, decay_steps_recomp)
|
|
||||||
return math_ops.add(
|
|
||||||
math_ops.multiply(learning_rate - end_learning_rate,
|
|
||||||
math_ops.pow(1 - p, power)),
|
|
||||||
end_learning_rate,
|
|
||||||
name=name)
|
|
||||||
|
|
||||||
return functools.partial(
|
|
||||||
decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate,
|
|
||||||
power, cycle, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.natural_exp_decay", v1=[])
|
|
||||||
def natural_exp_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
decay_rate,
|
|
||||||
staircase=False,
|
|
||||||
name=None):
|
|
||||||
"""Applies natural exponential decay to the initial learning rate.
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies an exponential decay function
|
|
||||||
to a provided initial learning rate. It requires an `global_step` value to
|
|
||||||
compute the decayed learning rate. You can just pass a TensorFlow variable
|
|
||||||
that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
decayed_learning_rate = learning_rate * exp(-decay_rate * global_step /
|
|
||||||
decay_step)
|
|
||||||
```
|
|
||||||
|
|
||||||
or, if `staircase` is `True`, as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step /
|
|
||||||
decay_step))
|
|
||||||
```
|
|
||||||
|
|
||||||
Example: decay exponentially with a base of 0.96:
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
global_step = tf.Variable(0, trainable=False)
|
|
||||||
learning_rate = 0.1
|
|
||||||
decay_steps = 5
|
|
||||||
k = 0.5
|
|
||||||
learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step,
|
|
||||||
decay_steps, k)
|
|
||||||
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
|
||||||
learning_step = (
|
|
||||||
tf.train.GradientDescentOptimizer(learning_rate_fn)
|
|
||||||
.minimize(...my loss..., global_step=global_step)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The initial learning rate.
|
|
||||||
global_step: A Python number.
|
|
||||||
Global step to use for the decay computation. Must not be negative.
|
|
||||||
decay_steps: How often to apply decay.
|
|
||||||
decay_rate: A Python number. The decay rate.
|
|
||||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
|
||||||
continuous, fashion.
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'ExponentialTimeDecay'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("global_step is required for natural_exp_decay.")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
|
|
||||||
name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "NaturalExpDecay",
|
|
||||||
[learning_rate, global_step, decay_rate]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
decay_rate = math_ops.cast(decay_rate, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
p = global_step_recomp / decay_steps
|
|
||||||
if staircase:
|
|
||||||
p = math_ops.floor(p)
|
|
||||||
exponent = math_ops.exp(
|
|
||||||
math_ops.multiply(math_ops.negative(decay_rate), p))
|
|
||||||
return math_ops.multiply(learning_rate, exponent, name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
decay_rate, staircase, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.inverse_time_decay", v1=[])
|
|
||||||
def inverse_time_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
decay_rate,
|
|
||||||
staircase=False,
|
|
||||||
name=None):
|
|
||||||
"""Applies inverse time decay to the initial learning rate.
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies an inverse decay function
|
|
||||||
to a provided initial learning rate. It requires an `global_step` value to
|
|
||||||
compute the decayed learning rate. You can just pass a TensorFlow variable
|
|
||||||
that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
decayed_learning_rate = learning_rate / (1 + decay_rate * global_step /
|
|
||||||
decay_step)
|
|
||||||
```
|
|
||||||
|
|
||||||
or, if `staircase` is `True`, as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step /
|
|
||||||
decay_step))
|
|
||||||
```
|
|
||||||
|
|
||||||
Example: decay 1/t with a rate of 0.5:
|
|
||||||
|
|
||||||
```python
|
|
||||||
...
|
|
||||||
global_step = tf.Variable(0, trainable=False)
|
|
||||||
learning_rate = 0.1
|
|
||||||
decay_steps = 1.0
|
|
||||||
decay_rate = 0.5
|
|
||||||
learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step,
|
|
||||||
decay_steps, decay_rate)
|
|
||||||
|
|
||||||
# Passing global_step to minimize() will increment it at each step.
|
|
||||||
learning_step = (
|
|
||||||
tf.train.GradientDescentOptimizer(learning_rate_fn)
|
|
||||||
.minimize(...my loss..., global_step=global_step)
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` `Tensor` or a
|
|
||||||
Python number. The initial learning rate.
|
|
||||||
global_step: A Python number.
|
|
||||||
Global step to use for the decay computation. Must not be negative.
|
|
||||||
decay_steps: How often to apply decay.
|
|
||||||
decay_rate: A Python number. The decay rate.
|
|
||||||
staircase: Whether to apply decay in a discrete staircase, as opposed to
|
|
||||||
continuous, fashion.
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'InverseTimeDecay'.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("global_step is required for inverse_time_decay.")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase,
|
|
||||||
name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "InverseTimeDecay",
|
|
||||||
[learning_rate, global_step, decay_rate]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
decay_rate = math_ops.cast(decay_rate, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
p = global_step_recomp / decay_steps
|
|
||||||
if staircase:
|
|
||||||
p = math_ops.floor(p)
|
|
||||||
const = math_ops.cast(constant_op.constant(1), dtype)
|
|
||||||
denom = math_ops.add(const, math_ops.multiply(decay_rate, p))
|
|
||||||
return math_ops.div(learning_rate, denom, name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
decay_rate, staircase, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.cosine_decay", v1=[])
|
|
||||||
def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0,
|
|
||||||
name=None):
|
|
||||||
"""Applies cosine decay to the learning rate.
|
|
||||||
|
|
||||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
|
||||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies a cosine decay function
|
|
||||||
to a provided initial learning rate. It requires a `global_step` value to
|
|
||||||
compute the decayed learning rate. You can just pass a TensorFlow variable
|
|
||||||
that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
global_step = min(global_step, decay_steps)
|
|
||||||
cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
|
|
||||||
decayed = (1 - alpha) * cosine_decay + alpha
|
|
||||||
decayed_learning_rate = learning_rate * decayed
|
|
||||||
```
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
decay_steps = 1000
|
|
||||||
lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation.
|
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Number of steps to decay over.
|
|
||||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
Minimum learning rate value as a fraction of learning_rate.
|
|
||||||
name: String. Optional name of the operation. Defaults to 'CosineDecay'.
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("cosine decay requires global_step")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, alpha, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "CosineDecay",
|
|
||||||
[learning_rate, global_step]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
|
|
||||||
completed_fraction = global_step_recomp / decay_steps
|
|
||||||
cosine_decayed = 0.5 * (1.0 + math_ops.cos(
|
|
||||||
constant_op.constant(math.pi) * completed_fraction))
|
|
||||||
|
|
||||||
decayed = (1 - alpha) * cosine_decayed + alpha
|
|
||||||
return math_ops.multiply(learning_rate, decayed)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
alpha, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.cosine_decay_restarts", v1=[])
|
|
||||||
def cosine_decay_restarts(learning_rate,
|
|
||||||
global_step,
|
|
||||||
first_decay_steps,
|
|
||||||
t_mul=2.0,
|
|
||||||
m_mul=1.0,
|
|
||||||
alpha=0.0,
|
|
||||||
name=None):
|
|
||||||
"""Applies cosine decay with restarts to the learning rate.
|
|
||||||
|
|
||||||
See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent
|
|
||||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies a cosine decay function with
|
|
||||||
restarts to a provided initial learning rate. It requires a `global_step`
|
|
||||||
value to compute the decayed learning rate. You can just pass a TensorFlow
|
|
||||||
variable that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate while taking into account possible warm restarts. This can be useful for
|
|
||||||
changing the learning rate value across different invocations of optimizer
|
|
||||||
functions.
|
|
||||||
|
|
||||||
The learning rate multiplier first decays
|
|
||||||
from 1 to `alpha` for `first_decay_steps` steps. Then, a warm
|
|
||||||
restart is performed. Each new warm restart runs for `t_mul` times more steps
|
|
||||||
and with `m_mul` times smaller initial learning rate.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
first_decay_steps = 1000
|
|
||||||
lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step,
|
|
||||||
first_decay_steps)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation.
|
|
||||||
first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Number of steps to decay over.
|
|
||||||
t_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
|
||||||
Used to derive the number of iterations in the i-th period
|
|
||||||
m_mul: A scalar `float32` or `float64` `Tensor` or a Python number.
|
|
||||||
Used to derive the initial learning rate of the i-th period:
|
|
||||||
alpha: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
Minimum learning rate value as a fraction of the learning_rate.
|
|
||||||
name: String. Optional name of the operation. Defaults to 'SGDRDecay'.
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("cosine decay restarts requires global_step")
|
|
||||||
def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul,
|
|
||||||
alpha, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]
|
|
||||||
) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(
|
|
||||||
learning_rate, name="initial_learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
first_decay_steps = math_ops.cast(first_decay_steps, dtype)
|
|
||||||
alpha = math_ops.cast(alpha, dtype)
|
|
||||||
t_mul = math_ops.cast(t_mul, dtype)
|
|
||||||
m_mul = math_ops.cast(m_mul, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
completed_fraction = global_step_recomp / first_decay_steps
|
|
||||||
|
|
||||||
def compute_step(completed_fraction, geometric=False):
|
|
||||||
"""Helper for `cond` operation."""
|
|
||||||
if geometric:
|
|
||||||
i_restart = math_ops.floor(
|
|
||||||
math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) /
|
|
||||||
math_ops.log(t_mul))
|
|
||||||
|
|
||||||
sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul)
|
|
||||||
completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart
|
|
||||||
|
|
||||||
else:
|
|
||||||
i_restart = math_ops.floor(completed_fraction)
|
|
||||||
completed_fraction -= i_restart
|
|
||||||
|
|
||||||
return i_restart, completed_fraction
|
|
||||||
|
|
||||||
i_restart, completed_fraction = control_flow_ops.cond(
|
|
||||||
math_ops.equal(t_mul, 1.0),
|
|
||||||
lambda: compute_step(completed_fraction, geometric=False),
|
|
||||||
lambda: compute_step(completed_fraction, geometric=True))
|
|
||||||
|
|
||||||
m_fac = m_mul**i_restart
|
|
||||||
cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos(
|
|
||||||
constant_op.constant(math.pi) * completed_fraction))
|
|
||||||
decayed = (1 - alpha) * cosine_decayed + alpha
|
|
||||||
|
|
||||||
return math_ops.multiply(learning_rate, decayed, name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step,
|
|
||||||
first_decay_steps, t_mul, m_mul, alpha, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.linear_cosine_decay", v1=[])
|
|
||||||
def linear_cosine_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
num_periods=0.5,
|
|
||||||
alpha=0.0,
|
|
||||||
beta=0.001,
|
|
||||||
name=None):
|
|
||||||
"""Applies linear cosine decay to the learning rate.
|
|
||||||
|
|
||||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
|
||||||
https://arxiv.org/abs/1709.07417
|
|
||||||
|
|
||||||
For the idea of warm starts here controlled by `num_periods`,
|
|
||||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
|
||||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
|
||||||
|
|
||||||
Note that linear cosine decay is more aggressive than cosine decay and
|
|
||||||
larger initial learning rates can typically be used.
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies a linear cosine decay function
|
|
||||||
to a provided initial learning rate. It requires a `global_step` value to
|
|
||||||
compute the decayed learning rate. You can just pass a TensorFlow variable
|
|
||||||
that you increment at each training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
global_step = min(global_step, decay_steps)
|
|
||||||
linear_decay = (decay_steps - global_step) / decay_steps)
|
|
||||||
cosine_decay = 0.5 * (
|
|
||||||
1 + cos(pi * 2 * num_periods * global_step / decay_steps))
|
|
||||||
decayed = (alpha + linear_decay) * cosine_decay + beta
|
|
||||||
decayed_learning_rate = learning_rate * decayed
|
|
||||||
```
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
decay_steps = 1000
|
|
||||||
lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step,
|
|
||||||
decay_steps)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation.
|
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Number of steps to decay over.
|
|
||||||
num_periods: Number of periods in the cosine part of the decay.
|
|
||||||
See computation above.
|
|
||||||
alpha: See computation above.
|
|
||||||
beta: See computation above.
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'LinearCosineDecay'.
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("linear cosine decay requires global_step")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha,
|
|
||||||
beta, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "LinearCosineDecay",
|
|
||||||
[learning_rate, global_step]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
num_periods = math_ops.cast(num_periods, dtype)
|
|
||||||
alpha = math_ops.cast(alpha, dtype)
|
|
||||||
beta = math_ops.cast(beta, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
|
|
||||||
linear_decayed = (decay_steps - global_step_recomp) / decay_steps
|
|
||||||
completed_fraction = global_step_recomp / decay_steps
|
|
||||||
fraction = 2.0 * num_periods * completed_fraction
|
|
||||||
cosine_decayed = 0.5 * (
|
|
||||||
1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
|
|
||||||
|
|
||||||
linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta
|
|
||||||
return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
num_periods, alpha, beta, name)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("train.noisy_linear_cosine_decay", v1=[])
|
|
||||||
def noisy_linear_cosine_decay(learning_rate,
|
|
||||||
global_step,
|
|
||||||
decay_steps,
|
|
||||||
initial_variance=1.0,
|
|
||||||
variance_decay=0.55,
|
|
||||||
num_periods=0.5,
|
|
||||||
alpha=0.0,
|
|
||||||
beta=0.001,
|
|
||||||
name=None):
|
|
||||||
"""Applies noisy linear cosine decay to the learning rate.
|
|
||||||
|
|
||||||
See [Bello et al., ICML2017] Neural Optimizer Search with RL.
|
|
||||||
https://arxiv.org/abs/1709.07417
|
|
||||||
|
|
||||||
For the idea of warm starts here controlled by `num_periods`,
|
|
||||||
see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent
|
|
||||||
with Warm Restarts. https://arxiv.org/abs/1608.03983
|
|
||||||
|
|
||||||
Note that linear cosine decay is more aggressive than cosine decay and
|
|
||||||
larger initial learning rates can typically be used.
|
|
||||||
|
|
||||||
When training a model, it is often recommended to lower the learning rate as
|
|
||||||
the training progresses. This function applies a noisy linear
|
|
||||||
cosine decay function to a provided initial learning rate.
|
|
||||||
It requires a `global_step` value to compute the decayed learning rate.
|
|
||||||
You can just pass a TensorFlow variable that you increment at each
|
|
||||||
training step.
|
|
||||||
|
|
||||||
The function returns a no-arg callable that produces the decayed learning
|
|
||||||
rate. This can be useful for changing the learning rate value across
|
|
||||||
different invocations of optimizer functions. It is computed as:
|
|
||||||
|
|
||||||
```python
|
|
||||||
global_step = min(global_step, decay_steps)
|
|
||||||
linear_decay = (decay_steps - global_step) / decay_steps)
|
|
||||||
cosine_decay = 0.5 * (
|
|
||||||
1 + cos(pi * 2 * num_periods * global_step / decay_steps))
|
|
||||||
decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta
|
|
||||||
decayed_learning_rate = learning_rate * decayed
|
|
||||||
```
|
|
||||||
where eps_t is 0-centered gaussian noise with variance
|
|
||||||
initial_variance / (1 + global_step) ** variance_decay
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
decay_steps = 1000
|
|
||||||
lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step,
|
|
||||||
decay_steps)
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
learning_rate: A scalar `float32` or `float64` Tensor or a Python number.
|
|
||||||
The initial learning rate.
|
|
||||||
global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Global step to use for the decay computation.
|
|
||||||
decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
|
|
||||||
Number of steps to decay over.
|
|
||||||
initial_variance: initial variance for the noise. See computation above.
|
|
||||||
variance_decay: decay for the noise's variance. See computation above.
|
|
||||||
num_periods: Number of periods in the cosine part of the decay.
|
|
||||||
See computation above.
|
|
||||||
alpha: See computation above.
|
|
||||||
beta: See computation above.
|
|
||||||
name: String. Optional name of the operation. Defaults to
|
|
||||||
'NoisyLinearCosineDecay'.
|
|
||||||
Returns:
|
|
||||||
A no-arg function that outputs the decayed learning rate, a scalar `Tensor`
|
|
||||||
of the same type as `learning_rate`.
|
|
||||||
Raises:
|
|
||||||
ValueError: if `global_step` is not supplied.
|
|
||||||
"""
|
|
||||||
if global_step is None:
|
|
||||||
raise ValueError("noisy linear cosine decay requires global_step")
|
|
||||||
def decayed_lr(learning_rate, global_step, decay_steps, initial_variance,
|
|
||||||
variance_decay, num_periods, alpha, beta, name):
|
|
||||||
"""Helper to recompute learning rate; most helpful in eager-mode."""
|
|
||||||
with ops.name_scope(name, "NoisyLinearCosineDecay",
|
|
||||||
[learning_rate, global_step]) as name:
|
|
||||||
learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
|
|
||||||
dtype = learning_rate.dtype
|
|
||||||
decay_steps = math_ops.cast(decay_steps, dtype)
|
|
||||||
initial_variance = math_ops.cast(initial_variance, dtype)
|
|
||||||
variance_decay = math_ops.cast(variance_decay, dtype)
|
|
||||||
num_periods = math_ops.cast(num_periods, dtype)
|
|
||||||
alpha = math_ops.cast(alpha, dtype)
|
|
||||||
beta = math_ops.cast(beta, dtype)
|
|
||||||
|
|
||||||
global_step_recomp = math_ops.cast(global_step, dtype)
|
|
||||||
global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps)
|
|
||||||
linear_decayed = (decay_steps - global_step_recomp) / decay_steps
|
|
||||||
variance = initial_variance / (
|
|
||||||
math_ops.pow(1.0 + global_step_recomp, variance_decay))
|
|
||||||
std = math_ops.sqrt(variance)
|
|
||||||
noisy_linear_decayed = (
|
|
||||||
linear_decayed + random_ops.random_normal(
|
|
||||||
linear_decayed.shape, stddev=std))
|
|
||||||
|
|
||||||
completed_fraction = global_step_recomp / decay_steps
|
|
||||||
fraction = 2.0 * num_periods * completed_fraction
|
|
||||||
cosine_decayed = 0.5 * (
|
|
||||||
1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction))
|
|
||||||
noisy_linear_cosine_decayed = (
|
|
||||||
(alpha + noisy_linear_decayed) * cosine_decayed + beta)
|
|
||||||
|
|
||||||
return math_ops.multiply(
|
|
||||||
learning_rate, noisy_linear_cosine_decayed, name=name)
|
|
||||||
|
|
||||||
return functools.partial(decayed_lr, learning_rate, global_step, decay_steps,
|
|
||||||
initial_variance, variance_decay, num_periods, alpha,
|
|
||||||
beta, name)
|
|
@ -1,497 +0,0 @@
|
|||||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
"""Functional test for learning rate decay."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.framework import test_util
|
|
||||||
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
|
|
||||||
from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import
|
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.platform import googletest
|
|
||||||
from tensorflow.python.training import learning_rate_decay_v2
|
|
||||||
|
|
||||||
|
|
||||||
class LRDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testContinuous(self):
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
step = 5
|
|
||||||
decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96)
|
|
||||||
expected = .05 * 0.96**(5.0 / 10.0)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testStaircase(self):
|
|
||||||
if context.executing_eagerly():
|
|
||||||
step = resource_variable_ops.ResourceVariable(0)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
decayed_lr = learning_rate_decay_v2.exponential_decay(
|
|
||||||
.1, step, 3, 0.96, staircase=True)
|
|
||||||
|
|
||||||
# No change to learning rate due to staircase
|
|
||||||
expected = .1
|
|
||||||
self.evaluate(step.assign(1))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
expected = .1
|
|
||||||
self.evaluate(step.assign(2))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
# Decayed learning rate
|
|
||||||
expected = .1 * 0.96 ** (100 // 3)
|
|
||||||
self.evaluate(step.assign(100))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testVariables(self):
|
|
||||||
step = variables.Variable(1)
|
|
||||||
assign_1 = step.assign(1)
|
|
||||||
assign_2 = step.assign(2)
|
|
||||||
assign_100 = step.assign(100)
|
|
||||||
decayed_lr = learning_rate_decay_v2.exponential_decay(
|
|
||||||
.1, step, 3, 0.96, staircase=True)
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
# No change to learning rate
|
|
||||||
self.evaluate(assign_1.op)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6)
|
|
||||||
self.evaluate(assign_2.op)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), .1, 1e-6)
|
|
||||||
# Decayed learning rate
|
|
||||||
self.evaluate(assign_100.op)
|
|
||||||
expected = .1 * 0.96**(100 // 3)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testPiecewiseConstant(self):
|
|
||||||
x = resource_variable_ops.ResourceVariable(-999)
|
|
||||||
decayed_lr = learning_rate_decay_v2.piecewise_constant(
|
|
||||||
x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001])
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
|
|
||||||
self.evaluate(x.assign(100))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6)
|
|
||||||
self.evaluate(x.assign(105))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
|
|
||||||
self.evaluate(x.assign(110))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6)
|
|
||||||
self.evaluate(x.assign(120))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6)
|
|
||||||
self.evaluate(x.assign(999))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testPiecewiseConstantEdgeCases(self):
|
|
||||||
x_int = resource_variable_ops.ResourceVariable(
|
|
||||||
0, dtype=variables.dtypes.int32)
|
|
||||||
boundaries, values = [-1.0, 1.0], [1, 2, 3]
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
decayed_lr = learning_rate_decay_v2.piecewise_constant(
|
|
||||||
x_int, boundaries, values)
|
|
||||||
decayed_lr()
|
|
||||||
|
|
||||||
x = resource_variable_ops.ResourceVariable(0.0)
|
|
||||||
boundaries, values = [-1.0, 1.0], [1.0, 2, 3]
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
decayed_lr = learning_rate_decay_v2.piecewise_constant(
|
|
||||||
x, boundaries, values)()
|
|
||||||
decayed_lr()
|
|
||||||
|
|
||||||
# Test that ref types are valid.
|
|
||||||
if not context.executing_eagerly():
|
|
||||||
x = variables.Variable(0.0)
|
|
||||||
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
|
|
||||||
boundaries, values = [1.0, 2.0], [1, 2, 3]
|
|
||||||
learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values)
|
|
||||||
|
|
||||||
# Test casting boundaries from int32 to int64.
|
|
||||||
x_int64 = resource_variable_ops.ResourceVariable(
|
|
||||||
0, dtype=variables.dtypes.int64)
|
|
||||||
boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]
|
|
||||||
decayed_lr = learning_rate_decay_v2.piecewise_constant(
|
|
||||||
x_int64, boundaries, values)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
|
|
||||||
self.evaluate(x_int64.assign(1))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6)
|
|
||||||
self.evaluate(x_int64.assign(2))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6)
|
|
||||||
self.evaluate(x_int64.assign(3))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6)
|
|
||||||
self.evaluate(x_int64.assign(4))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testHalfWay(self):
|
|
||||||
step = 5
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.0
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
|
|
||||||
expected = lr * 0.5
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testEnd(self):
|
|
||||||
step = 10
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
|
|
||||||
expected = end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testHalfWayWithEnd(self):
|
|
||||||
step = 5
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
|
|
||||||
expected = (lr + end_lr) * 0.5
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testBeyondEnd(self):
|
|
||||||
step = 15
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr)
|
|
||||||
expected = end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testBeyondEndWithCycle(self):
|
|
||||||
step = 15
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, cycle=True)
|
|
||||||
expected = (lr - end_lr) * 0.25 + end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class SqrtDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testHalfWay(self):
|
|
||||||
step = 5
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.0
|
|
||||||
power = 0.5
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, power=power)
|
|
||||||
expected = lr * 0.5**power
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testEnd(self):
|
|
||||||
step = 10
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
power = 0.5
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, power=power)
|
|
||||||
expected = end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testHalfWayWithEnd(self):
|
|
||||||
step = 5
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
power = 0.5
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, power=power)
|
|
||||||
expected = (lr - end_lr) * 0.5**power + end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testBeyondEnd(self):
|
|
||||||
step = 15
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
power = 0.5
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, power=power)
|
|
||||||
expected = end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testBeyondEndWithCycle(self):
|
|
||||||
step = 15
|
|
||||||
lr = 0.05
|
|
||||||
end_lr = 0.001
|
|
||||||
power = 0.5
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, 10, end_lr, power=power, cycle=True)
|
|
||||||
expected = (lr - end_lr) * 0.25**power + end_lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class PolynomialDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testBeginWithCycle(self):
|
|
||||||
lr = 0.001
|
|
||||||
decay_steps = 10
|
|
||||||
step = 0
|
|
||||||
decayed_lr = learning_rate_decay_v2.polynomial_decay(
|
|
||||||
lr, step, decay_steps, cycle=True)
|
|
||||||
expected = lr
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class ExponentialDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDecay(self):
|
|
||||||
initial_lr = 0.1
|
|
||||||
k = 10
|
|
||||||
decay_rate = 0.96
|
|
||||||
step = resource_variable_ops.ResourceVariable(0)
|
|
||||||
decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k,
|
|
||||||
decay_rate)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
for i in range(k + 1):
|
|
||||||
expected = initial_lr * math.exp(-i / k * decay_rate)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
self.evaluate(step.assign_add(1))
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testStaircase(self):
|
|
||||||
initial_lr = 0.1
|
|
||||||
k = 10
|
|
||||||
decay_rate = 0.96
|
|
||||||
step = resource_variable_ops.ResourceVariable(0)
|
|
||||||
decayed_lr = learning_rate_decay_v2.natural_exp_decay(
|
|
||||||
initial_lr, step, k, decay_rate, staircase=True)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
for i in range(k + 1):
|
|
||||||
expected = initial_lr * math.exp(-decay_rate * (i // k))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
self.evaluate(step.assign_add(1))
|
|
||||||
|
|
||||||
|
|
||||||
class InverseDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDecay(self):
|
|
||||||
initial_lr = 0.1
|
|
||||||
k = 10
|
|
||||||
decay_rate = 0.96
|
|
||||||
step = resource_variable_ops.ResourceVariable(0)
|
|
||||||
decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k,
|
|
||||||
decay_rate)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
for i in range(k + 1):
|
|
||||||
expected = initial_lr / (1 + i / k * decay_rate)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
self.evaluate(step.assign_add(1))
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testStaircase(self):
|
|
||||||
initial_lr = 0.1
|
|
||||||
k = 10
|
|
||||||
decay_rate = 0.96
|
|
||||||
step = resource_variable_ops.ResourceVariable(0)
|
|
||||||
decayed_lr = learning_rate_decay_v2.inverse_time_decay(
|
|
||||||
initial_lr, step, k, decay_rate, staircase=True)
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
for i in range(k + 1):
|
|
||||||
expected = initial_lr / (1 + decay_rate * (i // k))
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
self.evaluate(step.assign_add(1))
|
|
||||||
|
|
||||||
|
|
||||||
class CosineDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
def np_cosine_decay(self, step, decay_steps, alpha=0.0):
|
|
||||||
step = min(step, decay_steps)
|
|
||||||
completed_fraction = step / decay_steps
|
|
||||||
decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
|
||||||
return (1.0 - alpha) * decay + alpha
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDecay(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
|
|
||||||
num_training_steps)
|
|
||||||
expected = self.np_cosine_decay(step, num_training_steps)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testAlpha(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
alpha = 0.1
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step,
|
|
||||||
num_training_steps,
|
|
||||||
alpha)
|
|
||||||
expected = self.np_cosine_decay(step, num_training_steps, alpha)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0,
|
|
||||||
alpha=0.0):
|
|
||||||
fac = 1.0
|
|
||||||
while step >= decay_steps:
|
|
||||||
step -= decay_steps
|
|
||||||
decay_steps *= t_mul
|
|
||||||
fac *= m_mul
|
|
||||||
|
|
||||||
completed_fraction = step / decay_steps
|
|
||||||
decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))
|
|
||||||
return (1.0 - alpha) * decay + alpha
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDecay(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
|
|
||||||
initial_lr, step, num_training_steps)
|
|
||||||
expected = self.np_cosine_decay_restarts(step, num_training_steps)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testAlpha(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
alpha = 0.1
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
|
|
||||||
initial_lr, step, num_training_steps, alpha=alpha)
|
|
||||||
expected = self.np_cosine_decay_restarts(
|
|
||||||
step, num_training_steps, alpha=alpha)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testMMul(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
m_mul = 0.9
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
|
|
||||||
initial_lr, step, num_training_steps, m_mul=m_mul)
|
|
||||||
expected = self.np_cosine_decay_restarts(
|
|
||||||
step, num_training_steps, m_mul=m_mul)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testTMul(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
t_mul = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.cosine_decay_restarts(
|
|
||||||
initial_lr, step, num_training_steps, t_mul=t_mul)
|
|
||||||
expected = self.np_cosine_decay_restarts(
|
|
||||||
step, num_training_steps, t_mul=t_mul)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearCosineDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
def np_linear_cosine_decay(self,
|
|
||||||
step,
|
|
||||||
decay_steps,
|
|
||||||
alpha=0.0,
|
|
||||||
beta=0.001,
|
|
||||||
num_periods=0.5):
|
|
||||||
step = min(step, decay_steps)
|
|
||||||
linear_decayed = float(decay_steps - step) / decay_steps
|
|
||||||
fraction = 2.0 * num_periods * step / float(decay_steps)
|
|
||||||
cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction))
|
|
||||||
return (alpha + linear_decayed) * cosine_decayed + beta
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDefaultDecay(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
|
|
||||||
initial_lr, step, num_training_steps)
|
|
||||||
expected = self.np_linear_cosine_decay(step, num_training_steps)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testNonDefaultDecay(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
decayed_lr = learning_rate_decay_v2.linear_cosine_decay(
|
|
||||||
initial_lr,
|
|
||||||
step,
|
|
||||||
num_training_steps,
|
|
||||||
alpha=0.1,
|
|
||||||
beta=1e-4,
|
|
||||||
num_periods=5)
|
|
||||||
expected = self.np_linear_cosine_decay(
|
|
||||||
step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5)
|
|
||||||
self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase):
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testDefaultNoisyLinearCosine(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
# No numerical check because of noise
|
|
||||||
decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
|
|
||||||
initial_lr, step, num_training_steps)
|
|
||||||
# Cannot be deterministically tested
|
|
||||||
self.evaluate(decayed_lr())
|
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testNonDefaultNoisyLinearCosine(self):
|
|
||||||
num_training_steps = 1000
|
|
||||||
initial_lr = 1.0
|
|
||||||
for step in range(0, 1500, 250):
|
|
||||||
# No numerical check because of noise
|
|
||||||
decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay(
|
|
||||||
initial_lr,
|
|
||||||
step,
|
|
||||||
num_training_steps,
|
|
||||||
initial_variance=0.5,
|
|
||||||
variance_decay=0.1,
|
|
||||||
alpha=0.1,
|
|
||||||
beta=1e-4,
|
|
||||||
num_periods=5)
|
|
||||||
# Cannot be deterministically tested
|
|
||||||
self.evaluate(decayed_lr())
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
googletest.main()
|
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.experimental.CosineDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,9 @@
|
|||||||
path: "tensorflow.keras.experimental"
|
path: "tensorflow.keras.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "CosineDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "PeepholeLSTMCell"
|
name: "PeepholeLSTMCell"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -32,6 +32,10 @@ tf_module {
|
|||||||
name: "SGD"
|
name: "SGD"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "schedules"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "deserialize"
|
name: "deserialize"
|
||||||
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.ExponentialDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.ExponentialDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.InverseTimeDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.InverseTimeDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.LearningRateSchedule"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.PiecewiseConstantDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.PolynomialDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "ExponentialDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "InverseTimeDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "LearningRateSchedule"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "PiecewiseConstantDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "PolynomialDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "deserialize"
|
||||||
|
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "serialize"
|
||||||
|
argspec: "args=[\'learning_rate_schedule\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.experimental.CosineDecayRestarts"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecayRestarts\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.experimental.CosineDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.experimental.LinearCosineDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LinearCosineDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.experimental.NoisyLinearCosineDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.NoisyLinearCosineDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,21 @@
|
|||||||
path: "tensorflow.keras.experimental"
|
path: "tensorflow.keras.experimental"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "CosineDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "CosineDecayRestarts"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "LinearCosineDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "NoisyLinearCosineDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "PeepholeLSTMCell"
|
name: "PeepholeLSTMCell"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -32,6 +32,10 @@ tf_module {
|
|||||||
name: "SGD"
|
name: "SGD"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "schedules"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "deserialize"
|
name: "deserialize"
|
||||||
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.ExponentialDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.ExponentialDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.InverseTimeDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.InverseTimeDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.LearningRateSchedule"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.PiecewiseConstantDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PiecewiseConstantDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,18 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules.PolynomialDecay"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.PolynomialDecay\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "from_config"
|
||||||
|
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_config"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,31 @@
|
|||||||
|
path: "tensorflow.keras.optimizers.schedules"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "ExponentialDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "InverseTimeDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "LearningRateSchedule"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "PiecewiseConstantDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "PolynomialDecay"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "deserialize"
|
||||||
|
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "serialize"
|
||||||
|
argspec: "args=[\'learning_rate_schedule\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -68,34 +68,14 @@ tf_module {
|
|||||||
name: "ServerDef"
|
name: "ServerDef"
|
||||||
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "cosine_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "cosine_decay_restarts"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "exponential_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_checkpoint_state"
|
name: "get_checkpoint_state"
|
||||||
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "inverse_time_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "latest_checkpoint"
|
name: "latest_checkpoint"
|
||||||
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "linear_cosine_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "list_variables"
|
name: "list_variables"
|
||||||
argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -108,22 +88,6 @@ tf_module {
|
|||||||
name: "load_variable"
|
name: "load_variable"
|
||||||
argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "natural_exp_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "noisy_linear_cosine_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "piecewise_constant_decay"
|
|
||||||
argspec: "args=[\'x\', \'boundaries\', \'values\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "polynomial_decay"
|
|
||||||
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "sdca_fprint"
|
name: "sdca_fprint"
|
||||||
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -693,8 +693,11 @@ renames = {
|
|||||||
'tf.train.batch': 'tf.compat.v1.train.batch',
|
'tf.train.batch': 'tf.compat.v1.train.batch',
|
||||||
'tf.train.batch_join': 'tf.compat.v1.train.batch_join',
|
'tf.train.batch_join': 'tf.compat.v1.train.batch_join',
|
||||||
'tf.train.checkpoint_exists': 'tf.compat.v1.train.checkpoint_exists',
|
'tf.train.checkpoint_exists': 'tf.compat.v1.train.checkpoint_exists',
|
||||||
|
'tf.train.cosine_decay': 'tf.compat.v1.train.cosine_decay',
|
||||||
|
'tf.train.cosine_decay_restarts': 'tf.compat.v1.train.cosine_decay_restarts',
|
||||||
'tf.train.create_global_step': 'tf.compat.v1.train.create_global_step',
|
'tf.train.create_global_step': 'tf.compat.v1.train.create_global_step',
|
||||||
'tf.train.do_quantize_training_on_graphdef': 'tf.compat.v1.train.do_quantize_training_on_graphdef',
|
'tf.train.do_quantize_training_on_graphdef': 'tf.compat.v1.train.do_quantize_training_on_graphdef',
|
||||||
|
'tf.train.exponential_decay': 'tf.compat.v1.train.exponential_decay',
|
||||||
'tf.train.export_meta_graph': 'tf.compat.v1.train.export_meta_graph',
|
'tf.train.export_meta_graph': 'tf.compat.v1.train.export_meta_graph',
|
||||||
'tf.train.generate_checkpoint_state_proto': 'tf.compat.v1.train.generate_checkpoint_state_proto',
|
'tf.train.generate_checkpoint_state_proto': 'tf.compat.v1.train.generate_checkpoint_state_proto',
|
||||||
'tf.train.get_checkpoint_mtimes': 'tf.compat.v1.train.get_checkpoint_mtimes',
|
'tf.train.get_checkpoint_mtimes': 'tf.compat.v1.train.get_checkpoint_mtimes',
|
||||||
@ -704,13 +707,19 @@ renames = {
|
|||||||
'tf.train.import_meta_graph': 'tf.compat.v1.train.import_meta_graph',
|
'tf.train.import_meta_graph': 'tf.compat.v1.train.import_meta_graph',
|
||||||
'tf.train.init_from_checkpoint': 'tf.compat.v1.train.init_from_checkpoint',
|
'tf.train.init_from_checkpoint': 'tf.compat.v1.train.init_from_checkpoint',
|
||||||
'tf.train.input_producer': 'tf.compat.v1.train.input_producer',
|
'tf.train.input_producer': 'tf.compat.v1.train.input_producer',
|
||||||
|
'tf.train.inverse_time_decay': 'tf.compat.v1.train.inverse_time_decay',
|
||||||
'tf.train.limit_epochs': 'tf.compat.v1.train.limit_epochs',
|
'tf.train.limit_epochs': 'tf.compat.v1.train.limit_epochs',
|
||||||
|
'tf.train.linear_cosine_decay': 'tf.compat.v1.train.linear_cosine_decay',
|
||||||
'tf.train.match_filenames_once': 'tf.io.match_filenames_once',
|
'tf.train.match_filenames_once': 'tf.io.match_filenames_once',
|
||||||
'tf.train.maybe_batch': 'tf.compat.v1.train.maybe_batch',
|
'tf.train.maybe_batch': 'tf.compat.v1.train.maybe_batch',
|
||||||
'tf.train.maybe_batch_join': 'tf.compat.v1.train.maybe_batch_join',
|
'tf.train.maybe_batch_join': 'tf.compat.v1.train.maybe_batch_join',
|
||||||
'tf.train.maybe_shuffle_batch': 'tf.compat.v1.train.maybe_shuffle_batch',
|
'tf.train.maybe_shuffle_batch': 'tf.compat.v1.train.maybe_shuffle_batch',
|
||||||
'tf.train.maybe_shuffle_batch_join': 'tf.compat.v1.train.maybe_shuffle_batch_join',
|
'tf.train.maybe_shuffle_batch_join': 'tf.compat.v1.train.maybe_shuffle_batch_join',
|
||||||
|
'tf.train.natural_exp_decay': 'tf.compat.v1.train.natural_exp_decay',
|
||||||
|
'tf.train.noisy_linear_cosine_decay': 'tf.compat.v1.train.noisy_linear_cosine_decay',
|
||||||
'tf.train.piecewise_constant': 'tf.compat.v1.train.piecewise_constant',
|
'tf.train.piecewise_constant': 'tf.compat.v1.train.piecewise_constant',
|
||||||
|
'tf.train.piecewise_constant_decay': 'tf.compat.v1.train.piecewise_constant_decay',
|
||||||
|
'tf.train.polynomial_decay': 'tf.compat.v1.train.polynomial_decay',
|
||||||
'tf.train.queue_runner.QueueRunner': 'tf.compat.v1.train.queue_runner.QueueRunner',
|
'tf.train.queue_runner.QueueRunner': 'tf.compat.v1.train.queue_runner.QueueRunner',
|
||||||
'tf.train.queue_runner.add_queue_runner': 'tf.compat.v1.train.queue_runner.add_queue_runner',
|
'tf.train.queue_runner.add_queue_runner': 'tf.compat.v1.train.queue_runner.add_queue_runner',
|
||||||
'tf.train.queue_runner.start_queue_runners': 'tf.compat.v1.train.queue_runner.start_queue_runners',
|
'tf.train.queue_runner.start_queue_runners': 'tf.compat.v1.train.queue_runner.start_queue_runners',
|
||||||
|
Loading…
Reference in New Issue
Block a user