Merge pull request #32482 from reedwm/cherrypick_lso_saving
[r2.0-CherryPick]: Have LossScaleOptimizer better emulate the OptimizerV2 interface.
This commit is contained in:
commit
25d8dda8bb
@ -97,11 +97,22 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_scale",
|
||||
srcs = ["loss_scale.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:loss_scale",
|
||||
"//tensorflow/python/keras:generic_utils",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_scale_optimizer",
|
||||
srcs = ["loss_scale_optimizer.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":loss_scale",
|
||||
"//tensorflow/python:loss_scale",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import saving
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
@ -59,7 +60,8 @@ class AssertTypeLayer(base_layer.Layer):
|
||||
"""A layer which asserts it's inputs are a certain type."""
|
||||
|
||||
def __init__(self, assert_type=None, **kwargs):
|
||||
self._assert_type = assert_type
|
||||
self._assert_type = (dtypes.as_dtype(assert_type).name if assert_type
|
||||
else None)
|
||||
super(AssertTypeLayer, self).__init__(**kwargs)
|
||||
|
||||
def assert_input_types(self, inputs):
|
||||
@ -112,6 +114,15 @@ class AddLayer(AssertTypeLayer):
|
||||
else:
|
||||
return math_ops.add(x, y)
|
||||
|
||||
def get_config(self):
|
||||
config = super(AddLayer, self).get_config()
|
||||
assert self._regularizer is None, (
|
||||
'regularizer must be None to get config for AddLayer')
|
||||
config['use_operator'] = self._use_operator
|
||||
config['var_name'] = self._var_name
|
||||
config['assert_type'] = self._assert_type
|
||||
return config
|
||||
|
||||
|
||||
class AddLayerWithoutAutoCast(AddLayer):
|
||||
"""Same as AddLayer, but does not use AutoCastVariables."""
|
||||
@ -911,6 +922,78 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(backend.get_value(loss_scale()), 2)
|
||||
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': default_strategy_fn,
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
}, {
|
||||
'testcase_name': 'base_h5',
|
||||
'strategy_fn': default_strategy_fn,
|
||||
'h5': True,
|
||||
}, {
|
||||
'testcase_name': 'distribute_h5',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'h5': True,
|
||||
})
|
||||
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False):
|
||||
if not self._is_strategy_supported(strategy_fn):
|
||||
return
|
||||
strategy = strategy_fn()
|
||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||
not context.executing_eagerly()):
|
||||
# TODO(b/121381184): Enable running the test in this case.
|
||||
return
|
||||
|
||||
# Create and run model.
|
||||
with strategy.scope():
|
||||
x = layers.Input(shape=(2,), batch_size=2, dtype=dtypes.float32)
|
||||
y = AddLayer()(x)
|
||||
model = models.Model(inputs=x, outputs=y)
|
||||
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1., increment_period=2., multiplier=2.)
|
||||
opt = gradient_descent.SGD(1.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(
|
||||
optimizer=opt,
|
||||
loss='mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly(),
|
||||
experimental_run_tf_function=testing_utils.should_run_tf_function())
|
||||
# Run for 3 steps (6 examples with a batch size of 2)
|
||||
model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2)
|
||||
self.assertEqual(backend.get_value(loss_scale()), 2)
|
||||
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
|
||||
(weight,) = model.trainable_weights
|
||||
orig_weight = backend.get_value(weight)
|
||||
|
||||
# Save model weights.
|
||||
save_path = os.path.join(self.get_temp_dir(), 'model')
|
||||
model.save(save_path, save_format='h5' if h5 else 'tf')
|
||||
|
||||
# Run model again for 1 step (2 examples with a batch size of 2)
|
||||
model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2)
|
||||
new_weight = backend.get_value(weight)
|
||||
self.assertNotEqual(new_weight, orig_weight)
|
||||
self.assertEqual(backend.get_value(loss_scale()), 4)
|
||||
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0)
|
||||
|
||||
# Load model weights and ensure loss scale weights are restored.
|
||||
model = saving.load_model(save_path, custom_objects={'AddLayer': AddLayer})
|
||||
loss_scale = model.optimizer.loss_scale
|
||||
(weight,) = model.trainable_weights
|
||||
loaded_weight = backend.get_value(weight)
|
||||
self.assertEqual(loaded_weight, orig_weight)
|
||||
# Currently the loss scale isn't always saved when the model is saved with
|
||||
# Model.save(). So we assert the loss scale either has the value when it was
|
||||
# saved, or the value it was initialized with.
|
||||
# TODO(reedwm): Always save/restore the loss scale with Model.save().
|
||||
self.assertIn(backend.get_value(loss_scale()), (1, 2))
|
||||
self.assertIn(backend.get_value(loss_scale._num_good_steps), (0, 1))
|
||||
|
||||
|
||||
class RnnTest(keras_parameterized.TestCase):
|
||||
"""Test mixed precision with RNNs."""
|
||||
|
@ -0,0 +1,49 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains keras-specific LossScale functionality.
|
||||
|
||||
This functions cannot be in the non-keras loss_scale.py file since they depend
|
||||
on keras, and files outside of keras should not depend on files inside keras.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
|
||||
|
||||
def serialize(loss_scale):
|
||||
return generic_utils.serialize_keras_object(loss_scale)
|
||||
|
||||
|
||||
def deserialize(config, custom_objects=None):
|
||||
loss_scale_module_objects = {
|
||||
'FixedLossScale': loss_scale_module.FixedLossScale,
|
||||
'DynamicLossScale': loss_scale_module.DynamicLossScale,
|
||||
}
|
||||
|
||||
return generic_utils.deserialize_keras_object(
|
||||
config,
|
||||
module_objects=loss_scale_module_objects,
|
||||
custom_objects=custom_objects,
|
||||
printable_module_name='loss scale'
|
||||
)
|
||||
|
||||
|
||||
def get(identifier):
|
||||
if isinstance(identifier, dict):
|
||||
return deserialize(identifier)
|
||||
return loss_scale_module.get(identifier)
|
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
@ -99,11 +101,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, opt, loss_scale):
|
||||
def __init__(self, optimizer, loss_scale):
|
||||
"""Initializes this loss scale optimizer.
|
||||
|
||||
Args:
|
||||
opt: The Optimizer instance to wrap.
|
||||
optimizer: The Optimizer instance to wrap.
|
||||
loss_scale: The loss scale to scale the loss and gradients. This can
|
||||
either be an int/float to use a fixed loss scale, the string "dynamic"
|
||||
to use dynamic loss scaling, or an instance of a LossScale. The string
|
||||
@ -111,21 +113,21 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
int/float is equivalent to passing a FixedLossScale with the given loss
|
||||
scale.
|
||||
"""
|
||||
if not isinstance(opt, optimizer_v2.OptimizerV2):
|
||||
raise ValueError('"opt" must be an instance of OptimizerV2, but got: %s'
|
||||
% opt)
|
||||
if hasattr(opt, 'clipnorm'):
|
||||
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
|
||||
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
|
||||
'got: %s' % optimizer)
|
||||
if hasattr(optimizer, 'clipnorm'):
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipnorm. Optimizer %s has clipnorm '
|
||||
'%s' % (opt, opt.clipnorm))
|
||||
'%s' % (optimizer, optimizer.clipnorm))
|
||||
|
||||
if hasattr(opt, 'clipvalue'):
|
||||
if hasattr(optimizer, 'clipvalue'):
|
||||
raise ValueError('LossScaleOptimizer does not support wrapping '
|
||||
'optimizers with a clipvalue. Optimizer %s has '
|
||||
'clipvalue %s' % (opt, opt.clipvalue))
|
||||
'clipvalue %s' % (optimizer, optimizer.clipvalue))
|
||||
|
||||
self._optimizer = opt
|
||||
self._loss_scale = loss_scale_module.get(loss_scale)
|
||||
self._optimizer = optimizer
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
|
||||
# We cannot call `track_variable` in the LossScale class itself, because a
|
||||
# file outside of Keras cannot depend on a Keras file. Calling it here
|
||||
@ -243,6 +245,26 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
return self._optimizer.apply_gradients(list(zip(grads, wrapped_vars.value)),
|
||||
name)
|
||||
|
||||
def get_config(self):
|
||||
serialized_optimizer = optimizers.serialize(self._optimizer)
|
||||
serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale)
|
||||
return {
|
||||
'optimizer': serialized_optimizer,
|
||||
'loss_scale': serialized_loss_scale,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
config = config.copy() # Make a copy, since we mutate config
|
||||
config['optimizer'] = optimizers.deserialize(
|
||||
config['optimizer'], custom_objects=custom_objects)
|
||||
config['loss_scale'] = keras_loss_scale_module.deserialize(
|
||||
config['loss_scale'], custom_objects=custom_objects)
|
||||
return cls(**config)
|
||||
|
||||
# Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
|
||||
# below.
|
||||
|
||||
@property
|
||||
def iterations(self):
|
||||
return self._optimizer.iterations
|
||||
@ -251,6 +273,22 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
def iterations(self, variable):
|
||||
self._optimizer.iterations = variable
|
||||
|
||||
def get_slot_names(self):
|
||||
return self._optimizer.get_slot_names()
|
||||
|
||||
def variables(self):
|
||||
return self._optimizer.variables()
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self._optimizer.weights
|
||||
|
||||
def get_weights(self):
|
||||
return self._optimizer.get_weights()
|
||||
|
||||
def set_weights(self, weights):
|
||||
return self._optimizer.set_weights(weights)
|
||||
|
||||
# For the most part, we only expose methods in the base OptimizerV2, not
|
||||
# individual subclasses like Adam. However, although "learning_rate" and "lr"
|
||||
# properties are not part of the base OptimizerV2 class, they are part of most
|
||||
@ -272,22 +310,35 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
def lr(self, lr):
|
||||
self._optimizer.lr = lr
|
||||
|
||||
def get_slot_names(self):
|
||||
"""A list of names for this optimizer's slots."""
|
||||
return self._optimizer.get_slot_names()
|
||||
def get_slot(self, var, slot_name):
|
||||
# We cannot implement get_slot for the following reason: When saving a
|
||||
# checkpoint, two optimizers cannot share slot variables. Since both the
|
||||
# LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
|
||||
# respectively) are checkpointed, we cannot expose the wrapped optimizer's
|
||||
# slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
|
||||
# both optimizers share slot variables.
|
||||
raise AttributeError(
|
||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.')
|
||||
|
||||
def add_slot(self, var, slot_name, initializer='zeros'):
|
||||
# We disallow adding a slot for consistency with `get_slot`.
|
||||
raise AttributeError(
|
||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.')
|
||||
|
||||
# We do not override some OptimizerV2 methods. For each, we describe why we do
|
||||
# not delegate them to self._optimizer:
|
||||
# * get_updates: get_updates() calls get_gradients(). Since we override
|
||||
# get_gradients(), we cannot delegate get_updates() to self._optimizer,
|
||||
# otherwise the overridden get_gradients() method would not be called.
|
||||
# Luckily, get_updates() does not access any OptimizerV2 fields, so
|
||||
# inheriting the OptimizerV2 version works fine.
|
||||
# * minimize: We don't delegate for a similar as get_updates(): it calls
|
||||
# both self._compute_gradients() and self.apply_gradients(), and both need
|
||||
# to have the LossScaleOptimizer version called.
|
||||
|
||||
# TODO(reedwm): Maybe merge this class's functionality into OptimizerV2.
|
||||
|
||||
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
|
||||
# optimizer being used.
|
||||
|
||||
# TODO(reedwm): Implement get_config and from_config. This will first require
|
||||
# implementing deserialization support for OptimizerV2.
|
||||
def get_config(self):
|
||||
raise NotImplementedError('get_config() is not yet implemented for '
|
||||
'LossScaleOptimizers')
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
raise NotImplementedError('from_config() is not yet implemented for '
|
||||
'LossScaleOptimizers')
|
||||
|
@ -21,12 +21,14 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
@ -230,6 +232,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()),
|
||||
initial_loss_scale * 16)
|
||||
|
||||
self.assertEqual(opt.get_slot_names(), ['momentum'])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testIterations(self):
|
||||
opt = gradient_descent.SGD(2.0)
|
||||
@ -238,6 +242,41 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(lso.iterations, 7)
|
||||
self.assertEqual(opt.iterations, 7)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testWeightMethods(self):
|
||||
var = variables.Variable([1.0])
|
||||
opt = gradient_descent.SGD(1.0)
|
||||
initial_loss_scale = 2.
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=initial_loss_scale, increment_period=1,
|
||||
multiplier=4)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
run_op = opt.minimize(lambda: var * 2, [var])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
|
||||
self.assertLen(opt.weights, 1) # The 'iterations' weight
|
||||
self.assertEqual(self.evaluate(opt.weights[0]), 1)
|
||||
self.assertEqual(opt.get_weights()[0], 1)
|
||||
self.assertEqual(self.evaluate(opt.variables()[0]), 1)
|
||||
opt.set_weights([np.array(2.)])
|
||||
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSlotMethodErrors(self):
|
||||
opt = gradient_descent.SGD(1.0, momentum=1.0)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
|
||||
with self.assertRaisesRegexp(
|
||||
AttributeError,
|
||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.'):
|
||||
opt.get_slot(None, None)
|
||||
with self.assertRaisesRegexp(
|
||||
AttributeError,
|
||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.'):
|
||||
opt.add_slot(None, None)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGettingAndSettingLearningRate(self, strategy_fn):
|
||||
@ -351,6 +390,71 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||
self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGetConfig(self):
|
||||
opt = gradient_descent.SGD(2., momentum=0.5)
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=2., increment_period=3.,
|
||||
multiplier=4.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
config = opt.get_config()
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config)
|
||||
# Force hyperparameters to be created
|
||||
opt.lr # pylint: disable=pointless-statement
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
self.assertEqual(self.evaluate(opt.lr), 2.)
|
||||
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 2.)
|
||||
self.assertEqual(opt.loss_scale.increment_period, 3.)
|
||||
self.assertEqual(opt.loss_scale.multiplier, 4.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSerializationWithBuiltInOptimizer(self):
|
||||
opt = gradient_descent.SGD(2., momentum=0.5)
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=2., increment_period=3.,
|
||||
multiplier=4.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
config = optimizers.serialize(opt)
|
||||
opt = optimizers.deserialize(config)
|
||||
# Force hyperparameters to be created
|
||||
opt.lr # pylint: disable=pointless-statement
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
self.assertEqual(self.evaluate(opt.lr), 2.)
|
||||
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 2.)
|
||||
self.assertEqual(opt.loss_scale.increment_period, 3.)
|
||||
self.assertEqual(opt.loss_scale.multiplier, 4.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testSerializationWithCustomOptimizer(self):
|
||||
class MySGD(gradient_descent.SGD):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MySGD, self).__init__(*args, **kwargs)
|
||||
self.my_attribute = 123
|
||||
|
||||
opt = MySGD(2., momentum=0.5)
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=2., increment_period=3.,
|
||||
multiplier=4.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
config = optimizers.serialize(opt)
|
||||
custom_objects = {'MySGD': MySGD}
|
||||
opt = optimizers.deserialize(config, custom_objects=custom_objects)
|
||||
# Force hyperparameters to be created
|
||||
opt.lr # pylint: disable=pointless-statement
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
self.assertEqual(self.evaluate(opt.lr), 2.)
|
||||
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 2.)
|
||||
self.assertEqual(opt.loss_scale.increment_period, 3.)
|
||||
self.assertEqual(opt.loss_scale.multiplier, 4.)
|
||||
self.assertEqual(opt._optimizer.my_attribute, 123)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -794,6 +794,7 @@ def deserialize(config, custom_objects=None):
|
||||
Returns:
|
||||
A Keras Optimizer instance.
|
||||
"""
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer # pylint: disable=g-import-not-at-top
|
||||
all_classes = {
|
||||
'adadelta': adadelta_v2.Adadelta,
|
||||
'adagrad': adagrad_v2.Adagrad,
|
||||
@ -802,7 +803,8 @@ def deserialize(config, custom_objects=None):
|
||||
'nadam': nadam_v2.Nadam,
|
||||
'rmsprop': rmsprop_v2.RMSprop,
|
||||
'sgd': gradient_descent_v2.SGD,
|
||||
'ftrl': ftrl.Ftrl
|
||||
'ftrl': ftrl.Ftrl,
|
||||
'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
|
||||
}
|
||||
|
||||
# Make deserialization case-insensitive for built-in optimizers.
|
||||
|
@ -26,7 +26,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'optimizer\', \'loss_scale\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_slot"
|
||||
|
@ -26,7 +26,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'optimizer\', \'loss_scale\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_slot"
|
||||
|
Loading…
Reference in New Issue
Block a user