Change LossScaleOptimizer checkpoint format.
Now the format is identical to as if a LossScaleOptimzier is not used, except that the loss scale is saved with a LossScaleOptimizer. This allows saving checkpoints with a LossScaleOptimizer and restoring without a LossScaleOptimizer, and vice versa. Checkpoints with LossScaleOptimizers created in older versions of TensorFlow can still be loaded. New checkpoints saved will use the new format. PiperOrigin-RevId: 306511555 Change-Id: Ie316ab8c4fbfec7babd6f7803d337799d0ff10a5
This commit is contained in:
parent
4e8098235c
commit
f9e99f4dca
@ -224,9 +224,16 @@ cuda_py_test(
|
|||||||
name = "keras_test",
|
name = "keras_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["keras_test.py"],
|
srcs = ["keras_test.py"],
|
||||||
|
data = [
|
||||||
|
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2",
|
||||||
|
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2",
|
||||||
|
],
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
tags = ["no_windows"], # b/139083295: bfloat16 tests fail on Windows
|
tags = [
|
||||||
|
"no_pip",
|
||||||
|
"no_windows", # b/139083295: bfloat16 tests fail on Windows
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":test_util",
|
":test_util",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
|
@ -41,6 +41,7 @@ from tensorflow.python.keras import testing_utils
|
|||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
from tensorflow.python.keras.engine import input_spec
|
from tensorflow.python.keras.engine import input_spec
|
||||||
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.layers import core
|
from tensorflow.python.keras.layers import core
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
@ -993,6 +994,56 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(backend.get_value(loss_scale()), 2)
|
self.assertEqual(backend.get_value(loss_scale()), 2)
|
||||||
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
|
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
|
def test_restore_old_loss_scale_checkpoint(self):
|
||||||
|
# Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
|
||||||
|
# of LossScaleOptimizer changed, but old checkpoints can still be loaded
|
||||||
|
opt = gradient_descent.SGD(0.1, momentum=0.1)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
|
||||||
|
model = sequential.Sequential([core.Dense(2,)])
|
||||||
|
|
||||||
|
# The checkpoint and expected values were obtained from the program in
|
||||||
|
# testdata/BUILD.
|
||||||
|
ckpt_dir = test.test_src_dir_path(
|
||||||
|
'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2')
|
||||||
|
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
|
||||||
|
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||||
|
model(np.zeros((2, 2))) # Create model weights
|
||||||
|
opt._create_all_weights(model.weights)
|
||||||
|
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
|
||||||
|
expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
|
||||||
|
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||||
|
expected_slot)
|
||||||
|
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
|
||||||
|
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
|
||||||
|
|
||||||
|
# Check restoring works even after the model is compiled and the weights
|
||||||
|
# have been created.
|
||||||
|
model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
|
||||||
|
self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||||
|
self.assertNotAllClose(
|
||||||
|
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||||
|
expected_slot)
|
||||||
|
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
|
||||||
|
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||||
|
self.assertAllClose(
|
||||||
|
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||||
|
expected_slot)
|
||||||
|
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
|
||||||
|
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
|
||||||
|
|
||||||
|
def test_restore_old_saved_model(self):
|
||||||
|
saved_model_dir = test.test_src_dir_path(
|
||||||
|
'python/keras/mixed_precision/experimental/testdata/'
|
||||||
|
'lso_savedmodel_tf2.2')
|
||||||
|
model = save.load_model(saved_model_dir)
|
||||||
|
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
|
||||||
|
self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
|
||||||
|
self.assertIsInstance(model.optimizer,
|
||||||
|
loss_scale_optimizer.LossScaleOptimizer)
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
|
@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
|||||||
from tensorflow.python.distribute import mirrored_strategy
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
from tensorflow.python.distribute import one_device_strategy
|
from tensorflow.python.distribute import one_device_strategy
|
||||||
from tensorflow.python.distribute import tpu_strategy
|
from tensorflow.python.distribute import tpu_strategy
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import smart_cond
|
from tensorflow.python.framework import smart_cond
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
@ -32,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||||
from tensorflow.python.training.experimental import mixed_precision
|
from tensorflow.python.training.experimental import mixed_precision
|
||||||
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
@ -51,8 +53,126 @@ class _UnwrapPreventer(object):
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
class _DelegatingTrackableMixin(object):
|
||||||
|
"""A mixin that delegates all Trackable methods to another trackable object.
|
||||||
|
|
||||||
|
This class must be used with multiple inheritance. A class that subclasses
|
||||||
|
Trackable can also subclass this class, which causes all Trackable methods to
|
||||||
|
be delegated to the trackable object passed in the constructor.
|
||||||
|
|
||||||
|
A subclass can use this mixin to appear as if it were the trackable passed to
|
||||||
|
the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
|
||||||
|
mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
|
||||||
|
the checkpoint format for a normal optimizer. This allows a model to be saved
|
||||||
|
with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
|
||||||
|
The only difference in checkpoint format is that the loss scale is also saved
|
||||||
|
with a LossScaleOptimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trackable_obj):
|
||||||
|
self._trackable = trackable_obj
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
@property
|
||||||
|
def _setattr_tracking(self):
|
||||||
|
return self._trackable._setattr_tracking
|
||||||
|
|
||||||
|
@_setattr_tracking.setter
|
||||||
|
def _setattr_tracking(self, value):
|
||||||
|
self._trackable._setattr_tracking = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _update_uid(self):
|
||||||
|
return self._trackable._update_uid
|
||||||
|
|
||||||
|
@_update_uid.setter
|
||||||
|
def _update_uid(self, value):
|
||||||
|
self._trackable._update_uid = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _unconditional_checkpoint_dependencies(self):
|
||||||
|
return self._trackable._unconditional_checkpoint_dependencies
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _unconditional_dependency_names(self):
|
||||||
|
return self._trackable._unconditional_dependency_names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _name_based_restores(self):
|
||||||
|
return self._trackable._name_based_restores
|
||||||
|
|
||||||
|
def _maybe_initialize_trackable(self):
|
||||||
|
return self._trackable._maybe_initialize_trackable()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _object_identifier(self):
|
||||||
|
return self._trackable._object_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _tracking_metadata(self):
|
||||||
|
return self._trackable._tracking_metadata
|
||||||
|
|
||||||
|
def _no_dependency(self, value):
|
||||||
|
return self._trackable._no_dependency(value)
|
||||||
|
|
||||||
|
def _name_based_attribute_restore(self, checkpoint):
|
||||||
|
return self._trackable._name_based_attribute_restore(checkpoint)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _checkpoint_dependencies(self):
|
||||||
|
return self._trackable._checkpoint_dependencies
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _deferred_dependencies(self):
|
||||||
|
return self._trackable._deferred_dependencies
|
||||||
|
|
||||||
|
def _lookup_dependency(self, name):
|
||||||
|
self._trackable._lookup_dependency(name)
|
||||||
|
|
||||||
|
def _add_variable_with_custom_getter(self,
|
||||||
|
name,
|
||||||
|
shape=None,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
initializer=None,
|
||||||
|
getter=None,
|
||||||
|
overwrite=False,
|
||||||
|
**kwargs_for_getter):
|
||||||
|
return self._trackable._add_variable_with_custom_getter(
|
||||||
|
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
|
||||||
|
|
||||||
|
def _preload_simple_restoration(self, name, shape):
|
||||||
|
return self._trackable._preload_simple_restoration(name, shape)
|
||||||
|
|
||||||
|
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
|
||||||
|
return self._trackable._track_trackable(trackable, name, overwrite)
|
||||||
|
|
||||||
|
def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name
|
||||||
|
return self._trackable._handle_deferred_dependencies(name, trackable)
|
||||||
|
|
||||||
|
def _restore_from_checkpoint_position(self, checkpoint_position):
|
||||||
|
return self._trackable._restore_from_checkpoint_position(
|
||||||
|
checkpoint_position)
|
||||||
|
|
||||||
|
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
|
||||||
|
visit_queue):
|
||||||
|
return self._trackable._single_restoration_from_checkpoint_position(
|
||||||
|
checkpoint_position, visit_queue)
|
||||||
|
|
||||||
|
def _gather_saveables_for_checkpoint(self):
|
||||||
|
return self._trackable._gather_saveables_for_checkpoint()
|
||||||
|
|
||||||
|
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||||
|
return self._trackable._list_extra_dependencies_for_serialization(
|
||||||
|
serialization_cache)
|
||||||
|
|
||||||
|
def _list_functions_for_serialization(self, serialization_cache):
|
||||||
|
return self._trackable._list_functions_for_serialization(
|
||||||
|
serialization_cache)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||||
"""An optimizer that applies loss scaling.
|
"""An optimizer that applies loss scaling.
|
||||||
|
|
||||||
Loss scaling is a process that multiplies the loss by a multiplier called the
|
Loss scaling is a process that multiplies the loss by a multiplier called the
|
||||||
@ -144,6 +264,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||||
if self._loss_scale is None:
|
if self._loss_scale is None:
|
||||||
raise ValueError('loss_scale cannot be None.')
|
raise ValueError('loss_scale cannot be None.')
|
||||||
|
|
||||||
|
# We don't call super().__init__, since we do not want to call OptimizerV2's
|
||||||
|
# constructor.
|
||||||
|
_DelegatingTrackableMixin.__init__(self, self._optimizer)
|
||||||
|
|
||||||
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
|
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
|
||||||
# We cannot call `track_variable` in the LossScale class itself, because a
|
# We cannot call `track_variable` in the LossScale class itself, because a
|
||||||
# file outside of Keras cannot depend on a Keras file. Calling it here
|
# file outside of Keras cannot depend on a Keras file. Calling it here
|
||||||
@ -151,12 +276,15 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
# a Keras class, and the only way to use LossScale with a Keras class is
|
# a Keras class, and the only way to use LossScale with a Keras class is
|
||||||
# through the LossScaleOptimizer.
|
# through the LossScaleOptimizer.
|
||||||
backend.track_variable(weight)
|
backend.track_variable(weight)
|
||||||
self._track_trackable(self._optimizer, 'base_optimizer')
|
|
||||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||||
|
|
||||||
# Needed because the superclass's __getattribute__ checks this.
|
# Needed because the superclass's __getattribute__ checks this.
|
||||||
self._hyper = {}
|
self._hyper = {}
|
||||||
|
|
||||||
|
# To support restoring TensorFlow 2.2 checkpoints.
|
||||||
|
self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
|
||||||
|
'base_optimizer')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_scale(self):
|
def loss_scale(self):
|
||||||
"""The `LossScale` instance associated with this optimizer."""
|
"""The `LossScale` instance associated with this optimizer."""
|
||||||
@ -348,6 +476,21 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
def _aggregate_gradients(self, grads_and_vars):
|
def _aggregate_gradients(self, grads_and_vars):
|
||||||
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
|
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||||||
|
return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
|
||||||
|
slot_variable)
|
||||||
|
|
||||||
|
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
|
||||||
|
variable):
|
||||||
|
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||||
|
slot_variable_position, slot_name, variable)
|
||||||
|
|
||||||
|
def get_slot(self, var, slot_name):
|
||||||
|
return self._optimizer.get_slot(var, slot_name)
|
||||||
|
|
||||||
|
def add_slot(self, var, slot_name, initializer='zeros'):
|
||||||
|
return self._optimizer.add_slot(var, slot_name, initializer)
|
||||||
|
|
||||||
# For the most part, we only expose methods in the base OptimizerV2, not
|
# For the most part, we only expose methods in the base OptimizerV2, not
|
||||||
# individual subclasses like Adam. However, although "learning_rate" and "lr"
|
# individual subclasses like Adam. However, although "learning_rate" and "lr"
|
||||||
# properties are not part of the base OptimizerV2 class, they are part of most
|
# properties are not part of the base OptimizerV2 class, they are part of most
|
||||||
@ -369,23 +512,6 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
def lr(self, lr):
|
def lr(self, lr):
|
||||||
self._optimizer.lr = lr
|
self._optimizer.lr = lr
|
||||||
|
|
||||||
def get_slot(self, var, slot_name):
|
|
||||||
# We cannot implement get_slot for the following reason: When saving a
|
|
||||||
# checkpoint, two optimizers cannot share slot variables. Since both the
|
|
||||||
# LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
|
|
||||||
# respectively) are checkpointed, we cannot expose the wrapped optimizer's
|
|
||||||
# slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
|
|
||||||
# both optimizers share slot variables.
|
|
||||||
raise AttributeError(
|
|
||||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
|
||||||
'will be removed in the future.')
|
|
||||||
|
|
||||||
def add_slot(self, var, slot_name, initializer='zeros'):
|
|
||||||
# We disallow adding a slot for consistency with `get_slot`.
|
|
||||||
raise AttributeError(
|
|
||||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
|
||||||
'will be removed in the future.')
|
|
||||||
|
|
||||||
# We do not override some OptimizerV2 methods. For each, we describe why we do
|
# We do not override some OptimizerV2 methods. For each, we describe why we do
|
||||||
# not delegate them to self._optimizer:
|
# not delegate them to self._optimizer:
|
||||||
# * get_updates: get_updates() calls get_gradients(). Since we override
|
# * get_updates: get_updates() calls get_gradients(). Since we override
|
||||||
@ -402,6 +528,51 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
|
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
|
||||||
# optimizer being used.
|
# optimizer being used.
|
||||||
|
|
||||||
|
# Trackable delegations: Delegate all Trackable methods to the wrapped
|
||||||
|
# optimizer. This is so the checkpoint format for a LossScaleOptimizer is
|
||||||
|
# identical to the checkpoint format for a normal optimizer, except the loss
|
||||||
|
# scale is stored in the checkpoint.
|
||||||
|
|
||||||
|
|
||||||
|
class FakeOptimizerForRestoration(trackable.Trackable):
|
||||||
|
"""A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
|
||||||
|
|
||||||
|
The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
|
||||||
|
exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
|
||||||
|
|
||||||
|
In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
|
||||||
|
following in LossScaleOptimizer.__init__
|
||||||
|
|
||||||
|
```
|
||||||
|
self._track_trackable(self._optimizer, 'base_optimizer')
|
||||||
|
```
|
||||||
|
|
||||||
|
This means a dependency from the LossScaleOptimizer to the wrapped optimizer
|
||||||
|
would be stored in the checkpoint. However now, the checkpoint format with a
|
||||||
|
LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
|
||||||
|
except the loss scale is also stored. This means there is no dependency from
|
||||||
|
the LossScaleOptimizer to the wrapped optimizer. Instead, the
|
||||||
|
LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
|
||||||
|
perspective, by overriding all Trackable methods and delegating them to the
|
||||||
|
wrapped optimizer.
|
||||||
|
|
||||||
|
To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
|
||||||
|
on this class instead of the inner optimizer. When restored, this class will
|
||||||
|
instead restore the slot variables of the inner optimizer. Since this class
|
||||||
|
has no variables, it does not affect the checkpoint when saved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer):
|
||||||
|
self._optimizer = optimizer
|
||||||
|
|
||||||
|
def get_slot_names(self):
|
||||||
|
return self._optimizer.get_slot_names()
|
||||||
|
|
||||||
|
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
|
||||||
|
variable):
|
||||||
|
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||||
|
slot_variable_position, slot_name, variable)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
|
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
|
||||||
|
@ -305,20 +305,6 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
opt.set_weights([np.array(2.)])
|
opt.set_weights([np.array(2.)])
|
||||||
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
|
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
|
||||||
|
|
||||||
def testSlotMethodErrors(self):
|
|
||||||
opt = gradient_descent.SGD(1.0, momentum=1.0)
|
|
||||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
AttributeError,
|
|
||||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
|
||||||
'will be removed in the future.'):
|
|
||||||
opt.get_slot(None, None)
|
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
AttributeError,
|
|
||||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
|
||||||
'will be removed in the future.'):
|
|
||||||
opt.add_slot(None, None)
|
|
||||||
|
|
||||||
def testPassingNoneToLossScale(self):
|
def testPassingNoneToLossScale(self):
|
||||||
opt = gradient_descent.SGD()
|
opt = gradient_descent.SGD()
|
||||||
with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
|
with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
|
||||||
@ -394,9 +380,49 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
run_fn = lambda: opt.minimize(loss, [var])
|
run_fn = lambda: opt.minimize(loss, [var])
|
||||||
strategy.experimental_run(run_fn)
|
strategy.experimental_run(run_fn)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@parameterized.named_parameters({
|
||||||
def testCheckpoint(self, strategy_fn):
|
'testcase_name': 'SaveAndRestoreBase',
|
||||||
|
'strategy_fn': default_strategy_fn,
|
||||||
|
'save_with_ls': True,
|
||||||
|
'restore_with_ls': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'SaveAndRestoreDistribute',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'save_with_ls': True,
|
||||||
|
'restore_with_ls': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'SaveBase',
|
||||||
|
'strategy_fn': default_strategy_fn,
|
||||||
|
'save_with_ls': True,
|
||||||
|
'restore_with_ls': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'SaveDistribute',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'save_with_ls': True,
|
||||||
|
'restore_with_ls': False,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'RestoreBase',
|
||||||
|
'strategy_fn': default_strategy_fn,
|
||||||
|
'save_with_ls': False,
|
||||||
|
'restore_with_ls': True,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'RestoreDistribute',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'save_with_ls': False,
|
||||||
|
'restore_with_ls': True,
|
||||||
|
})
|
||||||
|
def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
|
||||||
|
|
||||||
|
class MySGD(gradient_descent.SGD):
|
||||||
|
"""A custom optimizer that tracks an extra variable."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super(MySGD, self).__init__(*args, **kwargs)
|
||||||
|
self.my_var = variables.Variable(0.)
|
||||||
|
self._track_trackable(self.my_var, 'my_var')
|
||||||
|
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
|
replicas = strategy.num_replicas_in_sync
|
||||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
# TODO(b/121381184): Enable running the test in this case.
|
# TODO(b/121381184): Enable running the test in this case.
|
||||||
@ -405,38 +431,89 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.test_session(), strategy.scope():
|
with self.test_session(), strategy.scope():
|
||||||
# Build and run a simple model.
|
# Build and run a simple model.
|
||||||
var = variables.Variable([2.0])
|
var = variables.Variable([2.0])
|
||||||
loss_scale = loss_scale_module.DynamicLossScale(
|
opt = inner_opt = MySGD(1., momentum=1.)
|
||||||
initial_loss_scale=1., increment_period=2.,
|
if save_with_ls:
|
||||||
multiplier=2.)
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
opt = gradient_descent.SGD(1., momentum=1.)
|
initial_loss_scale=1., increment_period=2.,
|
||||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
multiplier=2.)
|
||||||
run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||||
|
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
|
||||||
opt_op = strategy.experimental_run(run_fn)
|
opt_op = strategy.experimental_run(run_fn)
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
self.evaluate(opt_op)
|
self.evaluate(strategy.experimental_local_results(opt_op))
|
||||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
|
||||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
# Assert values.
|
||||||
slot_var = opt._optimizer.get_slot(var, 'momentum')
|
self.assertEqual(self.evaluate(var), 1.)
|
||||||
slot_value = self.evaluate(slot_var).item()
|
if save_with_ls:
|
||||||
|
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||||
|
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||||
|
slot_var = opt.get_slot(var, 'momentum')
|
||||||
|
self.assertEqual(self.evaluate(slot_var).item(), -1)
|
||||||
|
self.assertEqual(self.evaluate(opt.iterations), 1)
|
||||||
|
|
||||||
|
# Set optimizer variable to check arbitrary optimizer attributes can be
|
||||||
|
# saved/restored
|
||||||
|
self.evaluate(inner_opt.my_var.assign(1.))
|
||||||
|
|
||||||
# Save a checkpoint.
|
# Save a checkpoint.
|
||||||
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
|
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
|
||||||
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
|
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||||
save_path = checkpoint.save(prefix)
|
save_path = checkpoint.save(prefix)
|
||||||
|
|
||||||
# Run model again.
|
# Create new model
|
||||||
self.evaluate(strategy.experimental_run(run_fn))
|
var = variables.Variable([2.0])
|
||||||
self.assertEqual(self.evaluate(loss_scale()), 2.)
|
opt = inner_opt = MySGD(1., momentum=1.)
|
||||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
|
if restore_with_ls:
|
||||||
self.assertNotAlmostEqual(self.evaluate(slot_var).item(), slot_value)
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
|
initial_loss_scale=1., increment_period=2.,
|
||||||
|
multiplier=2.)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||||
|
|
||||||
# Load checkpoint and ensure loss scale is back to it's original value.
|
# Restore new model.
|
||||||
|
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
|
||||||
status = checkpoint.restore(save_path)
|
status = checkpoint.restore(save_path)
|
||||||
status.assert_consumed()
|
if save_with_ls:
|
||||||
|
status.assert_existing_objects_matched()
|
||||||
|
else:
|
||||||
|
status.assert_nontrivial_match()
|
||||||
|
|
||||||
|
# Assert restored values. We can only assert in eager mode since the
|
||||||
|
# variables are uninitialized in graph mode
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.assertEqual(self.evaluate(var), 1.)
|
||||||
|
if save_with_ls and restore_with_ls:
|
||||||
|
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||||
|
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||||
|
elif restore_with_ls:
|
||||||
|
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||||
|
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
|
||||||
|
self.assertEqual(self.evaluate(opt.iterations), 1)
|
||||||
|
|
||||||
|
# Run the model again.
|
||||||
|
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
|
||||||
|
opt_op = strategy.experimental_run(run_fn)
|
||||||
|
|
||||||
|
# Assert new values.
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
status.run_restore_ops()
|
status.run_restore_ops()
|
||||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
self.evaluate(strategy.experimental_local_results(opt_op))
|
||||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
self.assertEqual(self.evaluate(var), -1)
|
||||||
self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
|
slot_var = opt.get_slot(var, 'momentum')
|
||||||
|
self.assertEqual(self.evaluate(slot_var).item(), -2)
|
||||||
|
self.assertEqual(self.evaluate(opt.iterations), 2)
|
||||||
|
self.assertEqual(self.evaluate(inner_opt.my_var), 1)
|
||||||
|
|
||||||
|
# Restore model again to test restoring after slots are created
|
||||||
|
status = checkpoint.restore(save_path)
|
||||||
|
if save_with_ls and restore_with_ls:
|
||||||
|
status.assert_consumed()
|
||||||
|
elif save_with_ls:
|
||||||
|
status.assert_existing_objects_matched()
|
||||||
|
elif restore_with_ls:
|
||||||
|
status.assert_nontrivial_match()
|
||||||
|
status.run_restore_ops()
|
||||||
|
self.assertEqual(self.evaluate(var), 1)
|
||||||
|
self.assertEqual(self.evaluate(slot_var).item(), -1)
|
||||||
|
|
||||||
def testGetConfig(self):
|
def testGetConfig(self):
|
||||||
opt = gradient_descent.SGD(2., momentum=0.5)
|
opt = gradient_descent.SGD(2., momentum=0.5)
|
||||||
|
48
tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
vendored
Normal file
48
tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
vendored
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Description:
|
||||||
|
# Contains checkpoints and SavedModels for testing purposes.
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow/python/keras:__subpackages__",
|
||||||
|
"//tensorflow/tools/pip_package:__pkg__",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
# These files were generated by running the following program with TensorFlow
|
||||||
|
# 2.2rc2. The final release of TF 2.2 was not out when this change was created.:
|
||||||
|
|
||||||
|
# import os
|
||||||
|
# import numpy as np
|
||||||
|
# import tensorflow as tf
|
||||||
|
#
|
||||||
|
# tf.random.set_seed(1)
|
||||||
|
# opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
|
||||||
|
# opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||||
|
# model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
|
||||||
|
# model.compile(opt, 'mse')
|
||||||
|
#
|
||||||
|
# x = np.ones((10, 2))
|
||||||
|
# y = x * 100
|
||||||
|
# model.fit(x, y)
|
||||||
|
# weight_dir = os.environ['TF_LSO_WEIGHT_DIR']
|
||||||
|
# model_dir = os.environ['TF_LSO_MODEL_DIR']
|
||||||
|
# model.save_weights(weight_dir)
|
||||||
|
# model.save(model_dir)
|
||||||
|
# print(model.get_weights()[0])
|
||||||
|
# print(opt._optimizer.get_slot(model.weights[0], 'momentum'))
|
||||||
|
# print(opt.loss_scale)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "lso_ckpt_tf2.2",
|
||||||
|
srcs = glob(["lso_ckpt_tf2.2/**"]),
|
||||||
|
tags = ["no_pip"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "lso_savedmodel_tf2.2",
|
||||||
|
srcs = glob(["lso_savedmodel_tf2.2/**"]),
|
||||||
|
tags = ["no_pip"],
|
||||||
|
)
|
2
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
vendored
Normal file
2
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
model_checkpoint_path: "ckpt"
|
||||||
|
all_model_checkpoint_paths: "ckpt"
|
@ -0,0 +1,45 @@
|
|||||||
|
Ê
|
||||||
|
Ò
Êf
|
||||||
|
6
|
||||||
|
layer_with_weights-0
|
||||||
|
layer-0
|
||||||
|
optimizer
|
||||||
|
|
||||||
|
|
||||||
|
kernel
|
||||||
|
bias
|
||||||
|
$
|
||||||
|
base_optimizer
|
||||||
|
|
||||||
|
loss_scale
|
||||||
|
ca
|
||||||
|
VARIABLE_VALUEsequential/dense/kernel6layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
_]
|
||||||
|
VARIABLE_VALUEsequential/dense/bias4layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
V
|
||||||
|
iter
|
||||||
|
decay
|
||||||
|
learning_rate
|
||||||
|
|
||||||
|
momentummomentum
momentum
|
||||||
|
(
|
||||||
|
current_loss_scale
|
||||||
|
|
||||||
|
good_steps
|
||||||
|
VT
|
||||||
|
VARIABLE_VALUESGD/iter8optimizer/base_optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
XV
|
||||||
|
VARIABLE_VALUE SGD/decay9optimizer/base_optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
hf
|
||||||
|
VARIABLE_VALUESGD/learning_rateAoptimizer/base_optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
^\
|
||||||
|
VARIABLE_VALUESGD/momentum<optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
jh
|
||||||
|
VARIABLE_VALUEcurrent_loss_scaleBoptimizer/loss_scale/current_loss_scale/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
ZX
|
||||||
|
VARIABLE_VALUE
|
||||||
|
good_steps:optimizer/loss_scale/good_steps/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
£
|
||||||
|
VARIABLE_VALUE$SGD/sequential/dense/kernel/momentumhlayer_with_weights-0/kernel/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
||||||
|
Ÿœ
|
||||||
|
VARIABLE_VALUE"SGD/sequential/dense/bias/momentumflayer_with_weights-0/bias/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
Binary file not shown.
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
vendored
Normal file
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
vendored
Normal file
Binary file not shown.
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
vendored
Normal file
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
|
||||||
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
||||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user