Change LossScaleOptimizer checkpoint format.
Now the format is identical to as if a LossScaleOptimzier is not used, except that the loss scale is saved with a LossScaleOptimizer. This allows saving checkpoints with a LossScaleOptimizer and restoring without a LossScaleOptimizer, and vice versa. Checkpoints with LossScaleOptimizers created in older versions of TensorFlow can still be loaded. New checkpoints saved will use the new format. PiperOrigin-RevId: 306511555 Change-Id: Ie316ab8c4fbfec7babd6f7803d337799d0ff10a5
This commit is contained in:
parent
4e8098235c
commit
f9e99f4dca
@ -224,9 +224,16 @@ cuda_py_test(
|
||||
name = "keras_test",
|
||||
size = "medium",
|
||||
srcs = ["keras_test.py"],
|
||||
data = [
|
||||
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_ckpt_tf2.2",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental/testdata:lso_savedmodel_tf2.2",
|
||||
],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
tags = ["no_windows"], # b/139083295: bfloat16 tests fail on Windows
|
||||
tags = [
|
||||
"no_pip",
|
||||
"no_windows", # b/139083295: bfloat16 tests fail on Windows
|
||||
],
|
||||
deps = [
|
||||
":test_util",
|
||||
"//tensorflow/python:client_testlib",
|
||||
|
@ -41,6 +41,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.mixed_precision.experimental import get_layer_policy
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
@ -993,6 +994,56 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(backend.get_value(loss_scale()), 2)
|
||||
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_restore_old_loss_scale_checkpoint(self):
|
||||
# Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
|
||||
# of LossScaleOptimizer changed, but old checkpoints can still be loaded
|
||||
opt = gradient_descent.SGD(0.1, momentum=0.1)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
|
||||
model = sequential.Sequential([core.Dense(2,)])
|
||||
|
||||
# The checkpoint and expected values were obtained from the program in
|
||||
# testdata/BUILD.
|
||||
ckpt_dir = test.test_src_dir_path(
|
||||
'python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2')
|
||||
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
|
||||
model.compile(opt, 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
model(np.zeros((2, 2))) # Create model weights
|
||||
opt._create_all_weights(model.weights)
|
||||
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
|
||||
expected_slot = np.array([[10.049943, 9.917691], [10.049943, 9.917691]])
|
||||
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||
self.assertAllClose(
|
||||
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||
expected_slot)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
|
||||
|
||||
# Check restoring works even after the model is compiled and the weights
|
||||
# have been created.
|
||||
model.fit(np.random.normal(size=(2, 2)), np.random.normal(size=(2, 2)))
|
||||
self.assertNotAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||
self.assertNotAllClose(
|
||||
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||
expected_slot)
|
||||
model.load_weights(os.path.join(ckpt_dir, 'ckpt'))
|
||||
self.assertAllClose(self.evaluate(model.weights[0]), expected_kernel)
|
||||
self.assertAllClose(
|
||||
self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
|
||||
expected_slot)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), 32768)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1)
|
||||
|
||||
def test_restore_old_saved_model(self):
|
||||
saved_model_dir = test.test_src_dir_path(
|
||||
'python/keras/mixed_precision/experimental/testdata/'
|
||||
'lso_savedmodel_tf2.2')
|
||||
model = save.load_model(saved_model_dir)
|
||||
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
|
||||
self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
|
||||
self.assertIsInstance(model.optimizer,
|
||||
loss_scale_optimizer.LossScaleOptimizer)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import mirrored_strategy
|
||||
from tensorflow.python.distribute import one_device_strategy
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.keras import backend
|
||||
@ -32,6 +33,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -51,8 +53,126 @@ class _UnwrapPreventer(object):
|
||||
self.value = value
|
||||
|
||||
|
||||
class _DelegatingTrackableMixin(object):
|
||||
"""A mixin that delegates all Trackable methods to another trackable object.
|
||||
|
||||
This class must be used with multiple inheritance. A class that subclasses
|
||||
Trackable can also subclass this class, which causes all Trackable methods to
|
||||
be delegated to the trackable object passed in the constructor.
|
||||
|
||||
A subclass can use this mixin to appear as if it were the trackable passed to
|
||||
the constructor, from a Checkpoint's perspective. LossScaleOptimizer uses this
|
||||
mixin, so that the checkpoint format for a LossScaleOptimizer is identical to
|
||||
the checkpoint format for a normal optimizer. This allows a model to be saved
|
||||
with a normal Optimizer and restored with a LossScaleOptimizer, or vice versa.
|
||||
The only difference in checkpoint format is that the loss scale is also saved
|
||||
with a LossScaleOptimizer.
|
||||
"""
|
||||
|
||||
def __init__(self, trackable_obj):
|
||||
self._trackable = trackable_obj
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@property
|
||||
def _setattr_tracking(self):
|
||||
return self._trackable._setattr_tracking
|
||||
|
||||
@_setattr_tracking.setter
|
||||
def _setattr_tracking(self, value):
|
||||
self._trackable._setattr_tracking = value
|
||||
|
||||
@property
|
||||
def _update_uid(self):
|
||||
return self._trackable._update_uid
|
||||
|
||||
@_update_uid.setter
|
||||
def _update_uid(self, value):
|
||||
self._trackable._update_uid = value
|
||||
|
||||
@property
|
||||
def _unconditional_checkpoint_dependencies(self):
|
||||
return self._trackable._unconditional_checkpoint_dependencies
|
||||
|
||||
@property
|
||||
def _unconditional_dependency_names(self):
|
||||
return self._trackable._unconditional_dependency_names
|
||||
|
||||
@property
|
||||
def _name_based_restores(self):
|
||||
return self._trackable._name_based_restores
|
||||
|
||||
def _maybe_initialize_trackable(self):
|
||||
return self._trackable._maybe_initialize_trackable()
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
return self._trackable._object_identifier
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
return self._trackable._tracking_metadata
|
||||
|
||||
def _no_dependency(self, value):
|
||||
return self._trackable._no_dependency(value)
|
||||
|
||||
def _name_based_attribute_restore(self, checkpoint):
|
||||
return self._trackable._name_based_attribute_restore(checkpoint)
|
||||
|
||||
@property
|
||||
def _checkpoint_dependencies(self):
|
||||
return self._trackable._checkpoint_dependencies
|
||||
|
||||
@property
|
||||
def _deferred_dependencies(self):
|
||||
return self._trackable._deferred_dependencies
|
||||
|
||||
def _lookup_dependency(self, name):
|
||||
self._trackable._lookup_dependency(name)
|
||||
|
||||
def _add_variable_with_custom_getter(self,
|
||||
name,
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
initializer=None,
|
||||
getter=None,
|
||||
overwrite=False,
|
||||
**kwargs_for_getter):
|
||||
return self._trackable._add_variable_with_custom_getter(
|
||||
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
|
||||
|
||||
def _preload_simple_restoration(self, name, shape):
|
||||
return self._trackable._preload_simple_restoration(name, shape)
|
||||
|
||||
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
|
||||
return self._trackable._track_trackable(trackable, name, overwrite)
|
||||
|
||||
def _handle_deferred_dependencies(self, name, trackable): # pylint: disable=redefined-outer-name
|
||||
return self._trackable._handle_deferred_dependencies(name, trackable)
|
||||
|
||||
def _restore_from_checkpoint_position(self, checkpoint_position):
|
||||
return self._trackable._restore_from_checkpoint_position(
|
||||
checkpoint_position)
|
||||
|
||||
def _single_restoration_from_checkpoint_position(self, checkpoint_position,
|
||||
visit_queue):
|
||||
return self._trackable._single_restoration_from_checkpoint_position(
|
||||
checkpoint_position, visit_queue)
|
||||
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
return self._trackable._gather_saveables_for_checkpoint()
|
||||
|
||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
return self._trackable._list_extra_dependencies_for_serialization(
|
||||
serialization_cache)
|
||||
|
||||
def _list_functions_for_serialization(self, serialization_cache):
|
||||
return self._trackable._list_functions_for_serialization(
|
||||
serialization_cache)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
"""An optimizer that applies loss scaling.
|
||||
|
||||
Loss scaling is a process that multiplies the loss by a multiplier called the
|
||||
@ -144,6 +264,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
self._loss_scale = keras_loss_scale_module.get(loss_scale)
|
||||
if self._loss_scale is None:
|
||||
raise ValueError('loss_scale cannot be None.')
|
||||
|
||||
# We don't call super().__init__, since we do not want to call OptimizerV2's
|
||||
# constructor.
|
||||
_DelegatingTrackableMixin.__init__(self, self._optimizer)
|
||||
|
||||
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
|
||||
# We cannot call `track_variable` in the LossScale class itself, because a
|
||||
# file outside of Keras cannot depend on a Keras file. Calling it here
|
||||
@ -151,12 +276,15 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
# a Keras class, and the only way to use LossScale with a Keras class is
|
||||
# through the LossScaleOptimizer.
|
||||
backend.track_variable(weight)
|
||||
self._track_trackable(self._optimizer, 'base_optimizer')
|
||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||
|
||||
# Needed because the superclass's __getattribute__ checks this.
|
||||
self._hyper = {}
|
||||
|
||||
# To support restoring TensorFlow 2.2 checkpoints.
|
||||
self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
|
||||
'base_optimizer')
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
"""The `LossScale` instance associated with this optimizer."""
|
||||
@ -348,6 +476,21 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
def _aggregate_gradients(self, grads_and_vars):
|
||||
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
|
||||
|
||||
def _restore_slot_variable(self, slot_name, variable, slot_variable):
|
||||
return self._optimizer._restore_slot_variable(slot_name, variable, # pylint: disable=protected-access
|
||||
slot_variable)
|
||||
|
||||
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
|
||||
variable):
|
||||
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||
slot_variable_position, slot_name, variable)
|
||||
|
||||
def get_slot(self, var, slot_name):
|
||||
return self._optimizer.get_slot(var, slot_name)
|
||||
|
||||
def add_slot(self, var, slot_name, initializer='zeros'):
|
||||
return self._optimizer.add_slot(var, slot_name, initializer)
|
||||
|
||||
# For the most part, we only expose methods in the base OptimizerV2, not
|
||||
# individual subclasses like Adam. However, although "learning_rate" and "lr"
|
||||
# properties are not part of the base OptimizerV2 class, they are part of most
|
||||
@ -369,23 +512,6 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
def lr(self, lr):
|
||||
self._optimizer.lr = lr
|
||||
|
||||
def get_slot(self, var, slot_name):
|
||||
# We cannot implement get_slot for the following reason: When saving a
|
||||
# checkpoint, two optimizers cannot share slot variables. Since both the
|
||||
# LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
|
||||
# respectively) are checkpointed, we cannot expose the wrapped optimizer's
|
||||
# slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
|
||||
# both optimizers share slot variables.
|
||||
raise AttributeError(
|
||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.')
|
||||
|
||||
def add_slot(self, var, slot_name, initializer='zeros'):
|
||||
# We disallow adding a slot for consistency with `get_slot`.
|
||||
raise AttributeError(
|
||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.')
|
||||
|
||||
# We do not override some OptimizerV2 methods. For each, we describe why we do
|
||||
# not delegate them to self._optimizer:
|
||||
# * get_updates: get_updates() calls get_gradients(). Since we override
|
||||
@ -402,6 +528,51 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
|
||||
# optimizer being used.
|
||||
|
||||
# Trackable delegations: Delegate all Trackable methods to the wrapped
|
||||
# optimizer. This is so the checkpoint format for a LossScaleOptimizer is
|
||||
# identical to the checkpoint format for a normal optimizer, except the loss
|
||||
# scale is stored in the checkpoint.
|
||||
|
||||
|
||||
class FakeOptimizerForRestoration(trackable.Trackable):
|
||||
"""A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.
|
||||
|
||||
The checkpoint format for LossScaleOptimizers changed after TF 2.2. This class
|
||||
exists to support restoring TF 2.2 checkpoints in newer version of TensorFlow.
|
||||
|
||||
In TF 2.2, LossScaleOptimizer would track the wrapped optimizer by calling the
|
||||
following in LossScaleOptimizer.__init__
|
||||
|
||||
```
|
||||
self._track_trackable(self._optimizer, 'base_optimizer')
|
||||
```
|
||||
|
||||
This means a dependency from the LossScaleOptimizer to the wrapped optimizer
|
||||
would be stored in the checkpoint. However now, the checkpoint format with a
|
||||
LossScaleOptimizer is the same as the format without a LossScaleOptimizer,
|
||||
except the loss scale is also stored. This means there is no dependency from
|
||||
the LossScaleOptimizer to the wrapped optimizer. Instead, the
|
||||
LossScaleOptimizer acts as if it is the wrapped optimizer, from a checkpoint's
|
||||
perspective, by overriding all Trackable methods and delegating them to the
|
||||
wrapped optimizer.
|
||||
|
||||
To allow restoring TF 2.2. checkpoints, LossScaleOptimizer adds a dependency
|
||||
on this class instead of the inner optimizer. When restored, this class will
|
||||
instead restore the slot variables of the inner optimizer. Since this class
|
||||
has no variables, it does not affect the checkpoint when saved.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer):
|
||||
self._optimizer = optimizer
|
||||
|
||||
def get_slot_names(self):
|
||||
return self._optimizer.get_slot_names()
|
||||
|
||||
def _create_or_restore_slot_variable(self, slot_variable_position, slot_name,
|
||||
variable):
|
||||
return self._optimizer._create_or_restore_slot_variable( # pylint: disable=protected-access
|
||||
slot_variable_position, slot_name, variable)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
|
||||
|
@ -305,20 +305,6 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
opt.set_weights([np.array(2.)])
|
||||
self.assertEqual(self.evaluate(opt.variables()[0]), 2)
|
||||
|
||||
def testSlotMethodErrors(self):
|
||||
opt = gradient_descent.SGD(1.0, momentum=1.0)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
|
||||
with self.assertRaisesRegexp(
|
||||
AttributeError,
|
||||
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.'):
|
||||
opt.get_slot(None, None)
|
||||
with self.assertRaisesRegexp(
|
||||
AttributeError,
|
||||
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
|
||||
'will be removed in the future.'):
|
||||
opt.add_slot(None, None)
|
||||
|
||||
def testPassingNoneToLossScale(self):
|
||||
opt = gradient_descent.SGD()
|
||||
with self.assertRaisesRegexp(ValueError, r'loss_scale cannot be None'):
|
||||
@ -394,9 +380,49 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
run_fn = lambda: opt.minimize(loss, [var])
|
||||
strategy.experimental_run(run_fn)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
def testCheckpoint(self, strategy_fn):
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'SaveAndRestoreBase',
|
||||
'strategy_fn': default_strategy_fn,
|
||||
'save_with_ls': True,
|
||||
'restore_with_ls': True,
|
||||
}, {
|
||||
'testcase_name': 'SaveAndRestoreDistribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'save_with_ls': True,
|
||||
'restore_with_ls': True,
|
||||
}, {
|
||||
'testcase_name': 'SaveBase',
|
||||
'strategy_fn': default_strategy_fn,
|
||||
'save_with_ls': True,
|
||||
'restore_with_ls': False,
|
||||
}, {
|
||||
'testcase_name': 'SaveDistribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'save_with_ls': True,
|
||||
'restore_with_ls': False,
|
||||
}, {
|
||||
'testcase_name': 'RestoreBase',
|
||||
'strategy_fn': default_strategy_fn,
|
||||
'save_with_ls': False,
|
||||
'restore_with_ls': True,
|
||||
}, {
|
||||
'testcase_name': 'RestoreDistribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'save_with_ls': False,
|
||||
'restore_with_ls': True,
|
||||
})
|
||||
def testCheckpoint(self, strategy_fn, save_with_ls, restore_with_ls):
|
||||
|
||||
class MySGD(gradient_descent.SGD):
|
||||
"""A custom optimizer that tracks an extra variable."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MySGD, self).__init__(*args, **kwargs)
|
||||
self.my_var = variables.Variable(0.)
|
||||
self._track_trackable(self.my_var, 'my_var')
|
||||
|
||||
strategy = strategy_fn()
|
||||
replicas = strategy.num_replicas_in_sync
|
||||
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
|
||||
not context.executing_eagerly()):
|
||||
# TODO(b/121381184): Enable running the test in this case.
|
||||
@ -405,38 +431,89 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
with self.test_session(), strategy.scope():
|
||||
# Build and run a simple model.
|
||||
var = variables.Variable([2.0])
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1., increment_period=2.,
|
||||
multiplier=2.)
|
||||
opt = gradient_descent.SGD(1., momentum=1.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
|
||||
opt = inner_opt = MySGD(1., momentum=1.)
|
||||
if save_with_ls:
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1., increment_period=2.,
|
||||
multiplier=2.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
|
||||
opt_op = strategy.experimental_run(run_fn)
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(opt_op)
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||
slot_var = opt._optimizer.get_slot(var, 'momentum')
|
||||
slot_value = self.evaluate(slot_var).item()
|
||||
self.evaluate(strategy.experimental_local_results(opt_op))
|
||||
|
||||
# Assert values.
|
||||
self.assertEqual(self.evaluate(var), 1.)
|
||||
if save_with_ls:
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||
slot_var = opt.get_slot(var, 'momentum')
|
||||
self.assertEqual(self.evaluate(slot_var).item(), -1)
|
||||
self.assertEqual(self.evaluate(opt.iterations), 1)
|
||||
|
||||
# Set optimizer variable to check arbitrary optimizer attributes can be
|
||||
# saved/restored
|
||||
self.evaluate(inner_opt.my_var.assign(1.))
|
||||
|
||||
# Save a checkpoint.
|
||||
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
|
||||
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
|
||||
save_path = checkpoint.save(prefix)
|
||||
|
||||
# Run model again.
|
||||
self.evaluate(strategy.experimental_run(run_fn))
|
||||
self.assertEqual(self.evaluate(loss_scale()), 2.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
|
||||
self.assertNotAlmostEqual(self.evaluate(slot_var).item(), slot_value)
|
||||
# Create new model
|
||||
var = variables.Variable([2.0])
|
||||
opt = inner_opt = MySGD(1., momentum=1.)
|
||||
if restore_with_ls:
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1., increment_period=2.,
|
||||
multiplier=2.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
|
||||
# Load checkpoint and ensure loss scale is back to it's original value.
|
||||
# Restore new model.
|
||||
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
|
||||
status = checkpoint.restore(save_path)
|
||||
status.assert_consumed()
|
||||
if save_with_ls:
|
||||
status.assert_existing_objects_matched()
|
||||
else:
|
||||
status.assert_nontrivial_match()
|
||||
|
||||
# Assert restored values. We can only assert in eager mode since the
|
||||
# variables are uninitialized in graph mode
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(self.evaluate(var), 1.)
|
||||
if save_with_ls and restore_with_ls:
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||
elif restore_with_ls:
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
|
||||
self.assertEqual(self.evaluate(opt.iterations), 1)
|
||||
|
||||
# Run the model again.
|
||||
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
|
||||
opt_op = strategy.experimental_run(run_fn)
|
||||
|
||||
# Assert new values.
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.)
|
||||
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
|
||||
self.assertAlmostEqual(self.evaluate(slot_var).item(), slot_value)
|
||||
self.evaluate(strategy.experimental_local_results(opt_op))
|
||||
self.assertEqual(self.evaluate(var), -1)
|
||||
slot_var = opt.get_slot(var, 'momentum')
|
||||
self.assertEqual(self.evaluate(slot_var).item(), -2)
|
||||
self.assertEqual(self.evaluate(opt.iterations), 2)
|
||||
self.assertEqual(self.evaluate(inner_opt.my_var), 1)
|
||||
|
||||
# Restore model again to test restoring after slots are created
|
||||
status = checkpoint.restore(save_path)
|
||||
if save_with_ls and restore_with_ls:
|
||||
status.assert_consumed()
|
||||
elif save_with_ls:
|
||||
status.assert_existing_objects_matched()
|
||||
elif restore_with_ls:
|
||||
status.assert_nontrivial_match()
|
||||
status.run_restore_ops()
|
||||
self.assertEqual(self.evaluate(var), 1)
|
||||
self.assertEqual(self.evaluate(slot_var).item(), -1)
|
||||
|
||||
def testGetConfig(self):
|
||||
opt = gradient_descent.SGD(2., momentum=0.5)
|
||||
|
48
tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
vendored
Normal file
48
tensorflow/python/keras/mixed_precision/experimental/testdata/BUILD
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
# Description:
|
||||
# Contains checkpoints and SavedModels for testing purposes.
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow/python/keras:__subpackages__",
|
||||
"//tensorflow/tools/pip_package:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
# These files were generated by running the following program with TensorFlow
|
||||
# 2.2rc2. The final release of TF 2.2 was not out when this change was created.:
|
||||
|
||||
# import os
|
||||
# import numpy as np
|
||||
# import tensorflow as tf
|
||||
#
|
||||
# tf.random.set_seed(1)
|
||||
# opt = tf.keras.optimizers.SGD(0.1, momentum=0.1)
|
||||
# opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, 'dynamic')
|
||||
# model = tf.keras.Sequential([tf.keras.layers.Dense(2)])
|
||||
# model.compile(opt, 'mse')
|
||||
#
|
||||
# x = np.ones((10, 2))
|
||||
# y = x * 100
|
||||
# model.fit(x, y)
|
||||
# weight_dir = os.environ['TF_LSO_WEIGHT_DIR']
|
||||
# model_dir = os.environ['TF_LSO_MODEL_DIR']
|
||||
# model.save_weights(weight_dir)
|
||||
# model.save(model_dir)
|
||||
# print(model.get_weights()[0])
|
||||
# print(opt._optimizer.get_slot(model.weights[0], 'momentum'))
|
||||
# print(opt.loss_scale)
|
||||
|
||||
filegroup(
|
||||
name = "lso_ckpt_tf2.2",
|
||||
srcs = glob(["lso_ckpt_tf2.2/**"]),
|
||||
tags = ["no_pip"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "lso_savedmodel_tf2.2",
|
||||
srcs = glob(["lso_savedmodel_tf2.2/**"]),
|
||||
tags = ["no_pip"],
|
||||
)
|
2
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
vendored
Normal file
2
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/checkpoint
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
model_checkpoint_path: "ckpt"
|
||||
all_model_checkpoint_paths: "ckpt"
|
@ -0,0 +1,45 @@
|
||||
Ê
|
||||
Ò
Êf
|
||||
6
|
||||
layer_with_weights-0
|
||||
layer-0
|
||||
optimizer
|
||||
|
||||
|
||||
kernel
|
||||
bias
|
||||
$
|
||||
base_optimizer
|
||||
|
||||
loss_scale
|
||||
ca
|
||||
VARIABLE_VALUEsequential/dense/kernel6layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE
|
||||
_]
|
||||
VARIABLE_VALUEsequential/dense/bias4layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE
|
||||
V
|
||||
iter
|
||||
decay
|
||||
learning_rate
|
||||
|
||||
momentummomentum
momentum
|
||||
(
|
||||
current_loss_scale
|
||||
|
||||
good_steps
|
||||
VT
|
||||
VARIABLE_VALUESGD/iter8optimizer/base_optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE
|
||||
XV
|
||||
VARIABLE_VALUE SGD/decay9optimizer/base_optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE
|
||||
hf
|
||||
VARIABLE_VALUESGD/learning_rateAoptimizer/base_optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE
|
||||
^\
|
||||
VARIABLE_VALUESGD/momentum<optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
||||
jh
|
||||
VARIABLE_VALUEcurrent_loss_scaleBoptimizer/loss_scale/current_loss_scale/.ATTRIBUTES/VARIABLE_VALUE
|
||||
ZX
|
||||
VARIABLE_VALUE
|
||||
good_steps:optimizer/loss_scale/good_steps/.ATTRIBUTES/VARIABLE_VALUE
|
||||
£
|
||||
VARIABLE_VALUE$SGD/sequential/dense/kernel/momentumhlayer_with_weights-0/kernel/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
||||
Ÿœ
|
||||
VARIABLE_VALUE"SGD/sequential/dense/bias/momentumflayer_with_weights-0/bias/.OPTIMIZER_SLOT/optimizer/base_optimizer/momentum/.ATTRIBUTES/VARIABLE_VALUE
|
Binary file not shown.
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
vendored
Normal file
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_ckpt_tf2.2/ckpt.index
vendored
Normal file
Binary file not shown.
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
vendored
Normal file
BIN
tensorflow/python/keras/mixed_precision/experimental/testdata/lso_savedmodel_tf2.2/saved_model.pb
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user