Deprecate LossScale and modify Keras APIs to not use it.

LossScale and its subclasses are deprecated and will be removed from the TF 2 namespace in TensorFlow 2.5. It will still be accessible under the tf.compat.v1 namespace, and this change makes it non-experimental under the tf.compat.v1 namespace, exporting it as `tf.compat.v1.mixed_precision.LossScale`. LossScale cannot be removed from the tf.compat.v1 namespace since its used by the V1-only class tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer.

LossScaleOptimizer previously used a LossScale, but now it directly performs loss scaling within the class itself. Additionally a new non-experimental `tf.keras.mixed_precision.LossScaleOptimizer` has been introduced. Unlike the experimental LossScaleOptimizer, the non-experimental LossScaleOptimizer does not accept a LossScale but instead has different constructor arguments to specify the type of loss scaling to be done. The old experimental LossScaleOptimizer will be removed in TensorFlow 2.5, at which point a LossScale cannot be used with any Keras LossScaleOptimizer.

Internally, LossScaleOptimizer uses a fork of DynamicLossScale called _DynamicLossScaleState, but this is not exposed to the user. In the future, _DynamicLossScaleState will be merged into LossScaleOptimizer.

LossScaleOptimizer now exposes some attributes that DynamicLossScale previously did. "increment_period" is renamed to "dynamic_growth_steps" for consistency with `ExponentialDecay.decay_steps`. `num_good_steps` is replaced by `dynamic_counter`.

LossScaleOptimizer.loss_scale is now a tensor, not a LossScale. This means the previous way of getting the loss scale as a tensor (calling `optimizer.loss_scale()`) will raise an error instead. I don't know of any users who do this, so I do not anticipate any breakages.

Policy previously had an instance of a LossScale, and optionally took a LossScale in the constructor. By default, the "mixed_float16" policy had a DynamicLossScale, while all other policies had no loss scale. Now, Policy no longer has a loss scale or takes an instance of a loss scale. To temporarily preserve backwards compatibility with the old API, the symbol `tf.keras.mixed_precision.experimental.Policy` still takes and holds a LossScale, as it did before. A new non-experimental symbol, `tf.keras.mixed_precision.Policy`, removes the use of the LossScale. The old experimental symbol will be removed in the future.

When deserializing a layer or model with an old experimental policy, it will be restored as the new policy and the loss scale will be silently dropped. This is to preserve SavedModel compatibility with models saved in TensorFlow 2.3 and restored in future versions of TensorFlow once the old experimental Policy is removed. Luckily, dropping the loss scale is unlikely to break anyone, as a bug in the mixed precision API causes models to not save their dtype policies at all when being serialized. Similarly, when deserializing a model with the old experimental LossScaleOptimizer, it will be restored as the new LossScaleOptimizer but unlike the policy case, nothing is silently dropped.

This change is different than what is described in the mixed precision RFC (https://github.com/tensorflow/community/pull/293) but I think this API is a lot clearer and simpler than the API in the RFC. The RFC forked the LossScale classes into Keras, but I now think its better to simply not use them and make LossScale exposed under tf.compat.v1 only. This new API was designed based on feedback from @fchollet and @omalleyt12. I will retroactively update the RFC to reflect this API.

PiperOrigin-RevId: 337938270
Change-Id: Id7bb3bb89eb2143e5fadabeb2f57d1f8267379b3
This commit is contained in:
Reed Wanderman-Milne 2020-10-19 14:46:43 -07:00 committed by TensorFlower Gardener
parent f0bfa71e14
commit 51fbc48cef
28 changed files with 1769 additions and 395 deletions

View File

@ -137,9 +137,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
Attributes: Attributes:
name: The name of the layer (string). name: The name of the layer (string).
dtype: The dtype of the layer's computations and weights. If mixed dtype: The dtype of the layer's computations and weights. If mixed
precision is used with a `tf.keras.mixed_precision.experimental.Policy`, precision is used with a `tf.keras.mixed_precision.Policy`, this is
this is instead just the dtype of the layer's weights, as the computations instead just the dtype of the layer's weights, as the computations are
are done in a different dtype. done in a different dtype.
trainable_weights: List of variables to be included in backprop. trainable_weights: List of variables to be included in backprop.
non_trainable_weights: List of variables that should not be non_trainable_weights: List of variables that should not be
included in backprop. included in backprop.
@ -279,8 +279,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
precision is used, layers may have different computation and variable dtypes. precision is used, layers may have different computation and variable dtypes.
See `tf.keras.mixed_precision.experimental.Policy` for details on layer See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
dtypes.
""" """
# See tf.Module for the usage of this property. # See tf.Module for the usage of this property.
@ -388,9 +387,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._metrics_lock = threading.Lock() self._metrics_lock = threading.Lock()
# Both graph and subclassed networks have a dtype policy. For graph # Both graph and subclassed networks have a dtype policy. For graph
# networks, the policy's compute and variable dtypes are ignored, but other # networks, the policy's compute and variable dtypes are ignored. Such
# fields, like the loss scale, are used by Models. For subclassed networks, # networks only use the policy if it is a PolicyV1, in which case it uses
# the compute and variable dtypes are used as like any ordinary layer. # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
# subclassed networks, the compute and variable dtypes are used as like any
# ordinary layer.
self._set_dtype_policy(dtype) self._set_dtype_policy(dtype)
# Boolean indicating whether the layer automatically casts its inputs to the # Boolean indicating whether the layer automatically casts its inputs to the
# layer's compute_dtype. # layer's compute_dtype.

View File

@ -741,9 +741,9 @@ def enable_v2_dtype_behavior():
autocasting part of the V2 behavior for that layer, but not the defaulting to autocasting part of the V2 behavior for that layer, but not the defaulting to
floatx part of the V2 behavior. floatx part of the V2 behavior.
When a global `tf.keras.mixed_precision.experimental.Policy` is set, a Keras When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype
layer's dtype will default to the global policy instead of floatx. Layers will default to the global policy instead of floatx. Layers will automatically
will automatically cast inputs to the policy's compute_dtype. cast inputs to the policy's compute_dtype.
""" """
global V2_DTYPE_BEHAVIOR global V2_DTYPE_BEHAVIOR
V2_DTYPE_BEHAVIOR = True V2_DTYPE_BEHAVIOR = True

View File

@ -112,9 +112,9 @@ class Layer(base_layer.Layer):
Attributes: Attributes:
name: The name of the layer (string). name: The name of the layer (string).
dtype: The dtype of the layer's computations and weights. If mixed dtype: The dtype of the layer's computations and weights. If mixed
precision is used with a `tf.keras.mixed_precision.experimental.Policy`, precision is used with a `tf.keras.mixed_precision.Policy`, this is
this is instead just the dtype of the layer's weights, as the computations instead just the dtype of the layer's weights, as the computations are
are done in a different dtype. done in a different dtype.
updates: List of update ops of this layer. updates: List of update ops of this layer.
losses: List of losses added by this layer. losses: List of losses added by this layer.
trainable_weights: List of variables to be included in backprop. trainable_weights: List of variables to be included in backprop.
@ -133,8 +133,7 @@ class Layer(base_layer.Layer):
if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
precision is used, layers may have different computation and variable dtypes. precision is used, layers may have different computation and variable dtypes.
See `tf.keras.mixed_precision.experimental.Policy` for details on layer See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
dtypes.
""" """
# See tf.Module for the usage of this property. # See tf.Module for the usage of this property.
@ -199,9 +198,11 @@ class Layer(base_layer.Layer):
self._metrics = [] self._metrics = []
# Both graph and subclassed networks have a dtype policy. For graph # Both graph and subclassed networks have a dtype policy. For graph
# networks, the policy's compute and variable dtypes are ignored, but other # networks, the policy's compute and variable dtypes are ignored. Such
# fields, like the loss scale, are used by Models. For subclassed networks, # networks only use the policy if it is a PolicyV1, in which case it uses
# the compute and variable dtypes are used as like any ordinary layer. # the PolicyV1's loss_scale (Policy does not have a loss_scale). For
# subclassed networks, the compute and variable dtypes are used as like any
# ordinary layer.
self._set_dtype_policy(dtype) self._set_dtype_policy(dtype)
# Boolean indicating whether the layer automatically casts its inputs to the # Boolean indicating whether the layer automatically casts its inputs to the
# layer's compute_dtype. # layer's compute_dtype.

View File

@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import compile_utils
from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import hdf5_format
from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving import save
from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import json_utils
@ -549,12 +550,25 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
def _get_optimizer(self, optimizer): def _get_optimizer(self, optimizer):
"""Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" """Wraps `optimizer` in `LossScaleOptimizer` if necessary."""
# The deprecated PolicyV1 has a loss_scale, which we use for backwards
# compatibility to match TF 2.3 behavior. The new Policy does not have a
# loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is
# used.
if isinstance(self._dtype_policy, policy.PolicyV1):
loss_scale = self._dtype_policy.loss_scale
elif self._dtype_policy.name == 'mixed_float16':
loss_scale = 'dynamic'
else:
loss_scale = None
def _get_single_optimizer(opt): def _get_single_optimizer(opt):
opt = optimizers.get(opt) opt = optimizers.get(opt)
if (self._dtype_policy.loss_scale is not None and if (loss_scale is not None and
not isinstance(opt, lso.LossScaleOptimizer)): not isinstance(opt, lso.LossScaleOptimizer)):
opt = lso.LossScaleOptimizer(opt, self._dtype_policy.loss_scale) if loss_scale == 'dynamic':
opt = lso.LossScaleOptimizer(opt)
else:
opt = lso.LossScaleOptimizerV1(opt, loss_scale)
return opt return opt
return nest.map_structure(_get_single_optimizer, optimizer) return nest.map_structure(_get_single_optimizer, optimizer)

View File

@ -53,6 +53,7 @@ from tensorflow.python.keras.engine import training_generator_v1
from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.engine import training_utils_v1
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.saving.saved_model import model_serialization
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
@ -1342,7 +1343,14 @@ class Model(training_lib.Model):
else: else:
self.optimizer = optimizers.get(optimizer) self.optimizer = optimizers.get(optimizer)
if (self._dtype_policy.loss_scale is not None and if isinstance(self._dtype_policy, policy.PolicyV1):
loss_scale = self._dtype_policy.loss_scale
elif self._dtype_policy.name == 'mixed_float16':
loss_scale = 'dynamic'
else:
loss_scale = None
if (loss_scale is not None and
not isinstance(self.optimizer, not isinstance(self.optimizer,
loss_scale_optimizer.LossScaleOptimizer)): loss_scale_optimizer.LossScaleOptimizer)):
if isinstance(self.optimizer, list): if isinstance(self.optimizer, list):
@ -1356,18 +1364,11 @@ class Model(training_lib.Model):
'with a loss scale used, but got: %s. Using policy: ' 'with a loss scale used, but got: %s. Using policy: '
'%s' % '%s' %
(self.optimizer, self._dtype_policy)) (self.optimizer, self._dtype_policy))
self.optimizer = loss_scale_optimizer.LossScaleOptimizer( if loss_scale == 'dynamic':
self.optimizer, self._dtype_policy.loss_scale) self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer)
if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and else:
self._dtype_policy.loss_scale and self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1(
self.optimizer.loss_scale != self._dtype_policy.loss_scale): self.optimizer, loss_scale)
logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) '
'is not the same as the dtype policy\'s loss scale (%s). '
'Because the dtype policy has a loss scale, you should '
'pass an optimizer that is not wrapped with a '
'LossScaleOptimizer,'
% (self.optimizer.loss_scale,
self._dtype_policy.loss_scale))
def _prepare_validation_data(self, validation_data, batch_size, def _prepare_validation_data(self, validation_data, batch_size,
validation_steps): validation_steps):

View File

@ -344,32 +344,10 @@ class KerasLayerTest(keras_parameterized.TestCase):
self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32')
layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16',
loss_scale=None))
config = layer.get_config() config = layer.get_config()
self.assertEqual(config['dtype'], self.assertEqual(config['dtype'],
{'class_name': 'Policy', {'class_name': 'Policy',
'config': {'name': 'mixed_float16', 'config': {'name': 'mixed_float16'}})
'loss_scale': None}})
layer = mp_test_util.MultiplyLayer.from_config(config)
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float32')
layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('float64',
loss_scale=2.))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'float64',
'loss_scale': {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.0}}}})
layer = mp_test_util.MultiplyLayer.from_config(config)
self.assertEqual(layer.dtype, 'float64')
self.assertEqual(layer(x).dtype, 'float64')
self.assertEqual(layer.v.dtype, 'float64')
layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer'))
config = layer.get_config() config = layer.get_config()
@ -383,11 +361,53 @@ class KerasLayerTest(keras_parameterized.TestCase):
self.assertEqual(layer(x).dtype, 'float32') self.assertEqual(layer(x).dtype, 'float32')
self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32')
layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer', @parameterized.named_parameters(*TESTCASES)
def test_config_policy_v1(self, strategy_fn):
x = constant_op.constant([1.], dtype=dtypes.float16)
with strategy_fn().scope():
layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16',
loss_scale=None))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'PolicyV1',
'config': {'name': 'mixed_float16',
'loss_scale': None}})
layer = mp_test_util.MultiplyLayer.from_config(config)
self.assertEqual(layer.dtype, 'float32')
self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float32')
# Restoring a PolicyV1 silently converts it to a Policy and drops the loss
# scale.
self.assertEqual(type(layer._dtype_policy), policy.Policy)
config = layer.get_config()
# The loss_scale is silently dropped
self.assertEqual(config['dtype'],
{'class_name': 'Policy',
'config': {'name': 'mixed_float16'}})
layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64',
loss_scale=2.)) loss_scale=2.))
config = layer.get_config() config = layer.get_config()
self.assertEqual(config['dtype'], self.assertEqual(config['dtype'],
{'class_name': 'Policy', {'class_name': 'PolicyV1',
'config': {'name': 'float64',
'loss_scale': {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.0}}}})
layer = mp_test_util.MultiplyLayer.from_config(config)
self.assertEqual(layer.dtype, 'float64')
self.assertEqual(layer(x).dtype, 'float64')
self.assertEqual(layer.v.dtype, 'float64')
self.assertEqual(type(layer._dtype_policy), policy.Policy)
config = layer.get_config()
self.assertEqual(config['dtype'], 'float64')
layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer',
loss_scale=2.))
config = layer.get_config()
self.assertEqual(config['dtype'],
{'class_name': 'PolicyV1',
'config': {'name': '_infer', 'config': {'name': '_infer',
'loss_scale': { 'loss_scale': {
'class_name': 'FixedLossScale', 'class_name': 'FixedLossScale',
@ -396,6 +416,9 @@ class KerasLayerTest(keras_parameterized.TestCase):
self.assertEqual(layer.dtype, None) self.assertEqual(layer.dtype, None)
self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer(x).dtype, 'float16')
self.assertEqual(layer.v.dtype, 'float16') self.assertEqual(layer.v.dtype, 'float16')
self.assertEqual(type(layer._dtype_policy), policy.Policy)
config = layer.get_config()
self.assertEqual(config['dtype'], 'float16')
def test_delete_variable(self): def test_delete_variable(self):
layer = base_layer.Layer(dtype=policy.Policy('mixed_float16')) layer = base_layer.Layer(dtype=policy.Policy('mixed_float16'))
@ -501,6 +524,11 @@ class KerasModelTest(keras_parameterized.TestCase):
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
'save_format': 'h5', 'save_format': 'h5',
'use_regularizer': True, 'use_regularizer': True,
}, {
'testcase_name': 'saved_model_v1_policy',
'strategy_fn': create_mirrored_strategy,
'use_v1_policy': True,
'save_format': 'tf',
}) })
def test_model(self, def test_model(self,
strategy_fn, strategy_fn,
@ -509,15 +537,15 @@ class KerasModelTest(keras_parameterized.TestCase):
policy_name='mixed_float16', policy_name='mixed_float16',
get_config=False, get_config=False,
save_format=None, save_format=None,
use_input_spec=False): use_input_spec=False,
use_v1_policy=False):
self._skip_if_strategy_unsupported(strategy_fn) self._skip_if_strategy_unsupported(strategy_fn)
self._skip_if_save_format_unsupported(save_format) self._skip_if_save_format_unsupported(save_format)
regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer
else None) else None)
with strategy_fn().scope(): with strategy_fn().scope():
# Pass loss_scale=None, as this test will fail if the DynamicLossScale cls = policy.PolicyV1 if use_v1_policy else policy.Policy
# skips applying gradients for a step with policy.policy_scope(cls(policy_name)):
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
layer = mp_test_util.MultiplyLayer( layer = mp_test_util.MultiplyLayer(
assert_type=dtypes.float16, assert_type=dtypes.float16,
use_operator=use_operator, use_operator=use_operator,
@ -543,6 +571,10 @@ class KerasModelTest(keras_parameterized.TestCase):
# the variable will not change. So this tests the learning rate not # the variable will not change. So this tests the learning rate not
# applied to a float16 value, but instead the float32 variable. # applied to a float16 value, but instead the float32 variable.
opt = gradient_descent.SGD(2**-14) opt = gradient_descent.SGD(2**-14)
# Use a fixed loss scale, as this test will fail if gradients are
# skipped for a step due to dynamic loss scaling.
opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=8)
model.compile( model.compile(
opt, opt,
loss=loss_fn, loss=loss_fn,
@ -597,6 +629,13 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32')
self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16') self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16')
# Loading a model always loads with a v2 Policy, even if saved with a
# PolicyV1.
self.assertEqual(type(model._dtype_policy), policy.Policy)
self.assertEqual(layer.get_config()['dtype'],
{'class_name': 'Policy', 'config': {
'name': 'mixed_float16'}})
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
@ -630,7 +669,8 @@ class KerasModelTest(keras_parameterized.TestCase):
return math_ops.reduce_mean(y_pred) return math_ops.reduce_mean(y_pred)
opt = gradient_descent.SGD(1.) opt = gradient_descent.SGD(1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=loss_scale)
model.compile( model.compile(
opt, opt,
loss=loss_fn, loss=loss_fn,
@ -669,13 +709,10 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy = strategy_fn() strategy = strategy_fn()
if use_loss_scaling: if use_loss_scaling:
loss_scale = 8. loss_scale = 8.
else:
loss_scale = None
learning_rate = 2**-14 learning_rate = 2**-14
with strategy.scope(): with strategy.scope():
with policy.policy_scope(policy.Policy('mixed_float16', with policy.policy_scope(policy.Policy('mixed_float16')):
loss_scale=loss_scale)):
x = layers.Input(shape=(1,), batch_size=2) x = layers.Input(shape=(1,), batch_size=2)
layer1 = mp_test_util.MultiplyLayer( layer1 = mp_test_util.MultiplyLayer(
assert_type=dtypes.float16, assert_type=dtypes.float16,
@ -710,6 +747,9 @@ class KerasModelTest(keras_parameterized.TestCase):
return math_ops.reduce_mean(y_pred) return math_ops.reduce_mean(y_pred)
opt = gradient_descent.SGD(learning_rate) opt = gradient_descent.SGD(learning_rate)
if use_loss_scaling:
opt = loss_scale_optimizer.LossScaleOptimizer(
opt, dynamic=False, initial_scale=loss_scale)
model.compile( model.compile(
opt, opt,
loss=loss_fn, loss=loss_fn,
@ -743,6 +783,11 @@ class KerasModelTest(keras_parameterized.TestCase):
'testcase_name': 'get_config', 'testcase_name': 'get_config',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
'get_config': True, 'get_config': True,
}, {
'testcase_name': 'get_config_v1_lso',
'strategy_fn': create_mirrored_strategy,
'get_config': True,
'use_v1_loss_scale_optimizer': True,
}, { }, {
'testcase_name': 'get_config_and_pass_loss_scale_to_policy', 'testcase_name': 'get_config_and_pass_loss_scale_to_policy',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
@ -752,12 +797,11 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_dynamic_loss_scaling(self, def test_dynamic_loss_scaling(self,
strategy_fn, strategy_fn,
pass_loss_scale_to_policy=False, pass_loss_scale_to_policy=False,
get_config=False): get_config=False,
use_v1_loss_scale_optimizer=False):
strategy = strategy_fn() strategy = strategy_fn()
initial_loss_scale = 2. initial_loss_scale = 2.
batch_size = 4 batch_size = 4
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale, increment_period=2)
expected_gradient = backend.variable([initial_loss_scale / batch_size], expected_gradient = backend.variable([initial_loss_scale / batch_size],
dtype=dtypes.float16) dtype=dtypes.float16)
# If this variable is set to True, the model below will have NaN gradients # If this variable is set to True, the model below will have NaN gradients
@ -765,10 +809,19 @@ class KerasModelTest(keras_parameterized.TestCase):
with strategy.scope(): with strategy.scope():
opt = gradient_descent.SGD(1.) opt = gradient_descent.SGD(1.)
if pass_loss_scale_to_policy: if pass_loss_scale_to_policy:
p = policy.Policy('mixed_float16', loss_scale=loss_scale) loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale, increment_period=2)
p = policy.PolicyV1('mixed_float16', loss_scale=loss_scale)
elif use_v1_loss_scale_optimizer:
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale, increment_period=2)
p = policy.Policy('mixed_float16')
opt = loss_scale_optimizer.LossScaleOptimizerV1(
opt, loss_scale)
else: else:
p = policy.Policy('mixed_float16', loss_scale=None) p = policy.Policy('mixed_float16')
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizer(
opt, initial_scale=initial_loss_scale, dynamic_growth_steps=2)
with policy.policy_scope(p): with policy.policy_scope(p):
x = layers.Input( x = layers.Input(
shape=(1,), batch_size=batch_size, dtype=dtypes.float16) shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
@ -835,19 +888,32 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertEqual(backend.eval(layer.v), -3) self.assertEqual(backend.eval(layer.v), -3)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_loss_scale_optimizer_overrides_policy_loss_scale(self): def test_loss_scale_optimizer_overrides_policy_v1_loss_scale(self):
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
opt = gradient_descent.SGD(1.) opt = gradient_descent.SGD(1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=5.)
x = layers.Input(shape=(1,)) x = layers.Input(shape=(1,))
y = mp_test_util.MultiplyLayer()(x) y = mp_test_util.MultiplyLayer()(x)
model = models.Model(x, y) model = models.Model(x, y)
model.compile(opt, loss='mse') model.compile(opt, loss='mse')
self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.) self.assertEqual(self.evaluate(model.optimizer.loss_scale), 5.)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_policy_v1_without_loss_scale(self):
with policy.policy_scope(policy.PolicyV1('mixed_float16',
loss_scale=None)):
opt = gradient_descent.SGD(1.)
x = layers.Input(shape=(1,))
y = mp_test_util.MultiplyLayer()(x)
model = models.Model(x, y)
model.compile(opt, loss='mse')
self.assertNotIsInstance(model.optimizer,
loss_scale_optimizer.LossScaleOptimizer)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_pass_invalid_optimizer_with_loss_scaling(self): def test_pass_invalid_optimizer_with_loss_scaling(self):
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)):
x = layers.Input(shape=(1,)) x = layers.Input(shape=(1,))
y = mp_test_util.MultiplyLayer()(x) y = mp_test_util.MultiplyLayer()(x)
model = models.Model(x, y) model = models.Model(x, y)
@ -926,7 +992,7 @@ class KerasModelTest(keras_parameterized.TestCase):
def test_save_slot_variables_with_autocast_vars(self, def test_save_slot_variables_with_autocast_vars(self,
strategy_fn, strategy_fn,
var_name='v'): var_name='v'):
p = policy.Policy('mixed_float16', loss_scale=None) p = policy.Policy('mixed_float16')
with strategy_fn().scope(), policy.policy_scope(p): with strategy_fn().scope(), policy.policy_scope(p):
x = layers.Input(shape=(2,), batch_size=2) x = layers.Input(shape=(2,), batch_size=2)
# Having a var_name other than 'v' tests that a fixed bug (b/134713714) # Having a var_name other than 'v' tests that a fixed bug (b/134713714)
@ -938,6 +1004,8 @@ class KerasModelTest(keras_parameterized.TestCase):
y = layer(x) y = layer(x)
model = models.Model(inputs=x, outputs=y) model = models.Model(inputs=x, outputs=y)
opt = gradient_descent.SGD(1., 1.) opt = gradient_descent.SGD(1., 1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=1)
model.compile( model.compile(
optimizer=opt, optimizer=opt,
loss='mse', loss='mse',
@ -971,18 +1039,17 @@ class KerasModelTest(keras_parameterized.TestCase):
y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x) y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x)
model = models.Model(inputs=x, outputs=y) model = models.Model(inputs=x, outputs=y)
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2., multiplier=2.)
opt = gradient_descent.SGD(1.) opt = gradient_descent.SGD(1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizer(
opt, initial_scale=1., dynamic_growth_steps=2.)
model.compile( model.compile(
optimizer=opt, optimizer=opt,
loss='mse', loss='mse',
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
# Run for 3 steps (6 examples with a batch size of 2) # Run for 3 steps (6 examples with a batch size of 2)
model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2) model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2)
self.assertEqual(backend.get_value(loss_scale()), 2) self.assertEqual(backend.get_value(opt.loss_scale), 2)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
# Save model weights. # Save model weights.
save_prefix = os.path.join(self.get_temp_dir(), 'ckpt') save_prefix = os.path.join(self.get_temp_dir(), 'ckpt')
@ -990,20 +1057,20 @@ class KerasModelTest(keras_parameterized.TestCase):
# Run model again for 1 step (2 examples with a batch size of 2) # Run model again for 1 step (2 examples with a batch size of 2)
model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2) model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2)
self.assertEqual(backend.get_value(loss_scale()), 4) self.assertEqual(backend.get_value(opt.loss_scale), 4)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) self.assertEqual(backend.get_value(opt.dynamic_counter), 0)
# Load model weights and ensure loss scale weights are restored. # Load model weights and ensure loss scale weights are restored.
model.load_weights(save_prefix) model.load_weights(save_prefix)
self.assertEqual(backend.get_value(loss_scale()), 2) self.assertEqual(backend.get_value(opt.loss_scale), 2)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_restore_old_loss_scale_checkpoint(self): def test_restore_old_loss_scale_checkpoint(self):
# Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format
# of LossScaleOptimizer changed, but old checkpoints can still be loaded # of LossScaleOptimizer changed, but old checkpoints can still be loaded
opt = gradient_descent.SGD(0.1, momentum=0.1) opt = gradient_descent.SGD(0.1, momentum=0.1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') opt = loss_scale_optimizer.LossScaleOptimizer(opt)
model = sequential.Sequential([core.Dense(2,)]) model = sequential.Sequential([core.Dense(2,)])
# The checkpoint and expected values were obtained from the program in # The checkpoint and expected values were obtained from the program in
@ -1024,8 +1091,8 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertAllClose( self.assertAllClose(
self.evaluate(opt.get_slot(model.weights[0], 'momentum')), self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
expected_slot) expected_slot)
self.assertEqual(self.evaluate(opt.loss_scale()), 32768) self.assertEqual(self.evaluate(opt.loss_scale), 32768)
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
# Check restoring works even after the model is compiled and the weights # Check restoring works even after the model is compiled and the weights
# have been created. # have been created.
@ -1039,8 +1106,8 @@ class KerasModelTest(keras_parameterized.TestCase):
self.assertAllClose( self.assertAllClose(
self.evaluate(opt.get_slot(model.weights[0], 'momentum')), self.evaluate(opt.get_slot(model.weights[0], 'momentum')),
expected_slot) expected_slot)
self.assertEqual(self.evaluate(opt.loss_scale()), 32768) self.assertEqual(self.evaluate(opt.loss_scale), 32768)
self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
def test_restore_old_saved_model(self): def test_restore_old_saved_model(self):
saved_model_dir = os.path.join( saved_model_dir = os.path.join(
@ -1053,7 +1120,7 @@ class KerasModelTest(keras_parameterized.TestCase):
model = save.load_model(saved_model_dir) model = save.load_model(saved_model_dir)
expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]]) expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]])
self.assertAllClose(backend.eval(model.weights[0]), expected_kernel) self.assertAllClose(backend.eval(model.weights[0]), expected_kernel)
self.assertIsInstance(model.optimizer, self.assertEqual(type(model.optimizer),
loss_scale_optimizer.LossScaleOptimizer) loss_scale_optimizer.LossScaleOptimizer)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
@ -1064,6 +1131,10 @@ class KerasModelTest(keras_parameterized.TestCase):
}, { }, {
'testcase_name': 'distribute', 'testcase_name': 'distribute',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
}, {
'testcase_name': 'use_v1_lso',
'strategy_fn': create_mirrored_strategy,
'use_v1_loss_scale_optimizer': True
}, { }, {
'testcase_name': 'base_h5', 'testcase_name': 'base_h5',
'strategy_fn': default_strategy_fn, 'strategy_fn': default_strategy_fn,
@ -1073,7 +1144,8 @@ class KerasModelTest(keras_parameterized.TestCase):
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
'h5': True, 'h5': True,
}) })
def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False): def test_save_model_with_dynamic_loss_scaling(
self, strategy_fn, h5=False, use_v1_loss_scale_optimizer=False):
# TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy
# as well. # as well.
strategy = strategy_fn() strategy = strategy_fn()
@ -1088,18 +1160,22 @@ class KerasModelTest(keras_parameterized.TestCase):
y = mp_test_util.MultiplyLayer()(x) y = mp_test_util.MultiplyLayer()(x)
model = models.Model(inputs=x, outputs=y) model = models.Model(inputs=x, outputs=y)
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2., multiplier=2.)
opt = gradient_descent.SGD(1.) opt = gradient_descent.SGD(1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) if use_v1_loss_scale_optimizer:
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2.)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
else:
opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1.,
dynamic_growth_steps=2.)
model.compile( model.compile(
optimizer=opt, optimizer=opt,
loss='mse', loss='mse',
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
# Run for 3 steps (6 examples with a batch size of 2) # Run for 3 steps (6 examples with a batch size of 2)
model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2) model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2)
self.assertEqual(backend.get_value(loss_scale()), 2) self.assertEqual(backend.get_value(opt.loss_scale), 2)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) self.assertEqual(backend.get_value(opt.dynamic_counter), 1)
(weight,) = model.trainable_weights (weight,) = model.trainable_weights
orig_weight = backend.get_value(weight) orig_weight = backend.get_value(weight)
@ -1111,13 +1187,12 @@ class KerasModelTest(keras_parameterized.TestCase):
model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2) model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2)
new_weight = backend.get_value(weight) new_weight = backend.get_value(weight)
self.assertNotEqual(new_weight, orig_weight) self.assertNotEqual(new_weight, orig_weight)
self.assertEqual(backend.get_value(loss_scale()), 4) self.assertEqual(backend.get_value(opt.loss_scale), 4)
self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) self.assertEqual(backend.get_value(opt.dynamic_counter), 0)
# Load model weights and ensure loss scale weights are restored. # Load model weights and ensure loss scale weights are restored.
model = save.load_model( model = save.load_model(
save_path, custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer}) save_path, custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer})
loss_scale = model.optimizer.loss_scale
(weight,) = model.trainable_weights (weight,) = model.trainable_weights
loaded_weight = backend.get_value(weight) loaded_weight = backend.get_value(weight)
self.assertEqual(loaded_weight, orig_weight) self.assertEqual(loaded_weight, orig_weight)
@ -1125,8 +1200,14 @@ class KerasModelTest(keras_parameterized.TestCase):
# Model.save(). So we assert the loss scale either has the value when it was # Model.save(). So we assert the loss scale either has the value when it was
# saved, or the value it was initialized with. # saved, or the value it was initialized with.
# TODO(reedwm): Always save/restore the loss scale with Model.save(). # TODO(reedwm): Always save/restore the loss scale with Model.save().
self.assertIn(backend.get_value(loss_scale()), (1, 2)) self.assertIn(backend.get_value(model.optimizer.loss_scale), (1, 2))
self.assertIn(backend.get_value(loss_scale._num_good_steps), (0, 1)) self.assertIn(backend.get_value(model.optimizer.dynamic_counter), (0, 1))
# Test optimizer attributes and type
self.assertEqual(model.optimizer.initial_scale, 1.)
self.assertEqual(model.optimizer.dynamic_growth_steps, 2.)
self.assertEqual(type(model.optimizer),
loss_scale_optimizer.LossScaleOptimizer)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -21,8 +21,10 @@ from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_strategy from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond from tensorflow.python.framework import smart_cond
@ -32,8 +34,13 @@ from tensorflow.python.keras.mixed_precision.experimental import loss_scale as k
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -173,7 +180,220 @@ class _DelegatingTrackableMixin(object):
# pylint: enable=protected-access # pylint: enable=protected-access
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None
]
return math_ops.reduce_all(is_finite_per_grad)
def _op_in_graph_mode(tensor):
"""Returns the tensor's op in graph mode, or the tensor in eager mode.
This is useful because sometimes an op is needed in graph mode instead of a
tensor. In eager mode, there are no ops.
Args:
tensor: A tensor.
Returns:
The tensor's op in graph mode. The tensor in eager mode.
"""
if context.executing_eagerly():
return tensor
return tensor.op
def _assign_if_finite(var, value):
"""Assigns a value to a variable if the value is finite."""
return control_flow_ops.cond(
math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
control_flow_ops.no_op)
class _DynamicLossScaleState(trackable.Trackable):
"""The state of a dynamic loss scale."""
def __init__(self,
initial_loss_scale,
growth_steps,
multiplier):
"""Creates the dynamic loss scale."""
super(_DynamicLossScaleState, self).__init__()
self._initial_loss_scale = float(initial_loss_scale)
self._growth_steps = int(growth_steps)
self._multiplier = float(multiplier)
self._weights = {}
self._current_loss_scale = self._add_weight(
name='current_loss_scale',
dtype=dtypes.float32,
initial_value=self._initial_loss_scale)
# The number of consecutive steps with finite gradients since the last
# nonfinite gradient or change in loss scale. The name is 'good_steps' for
# backwards compatibility with older checkpoints.
self._counter = self._add_weight(
name='good_steps', dtype=dtypes.int64, initial_value=0)
def _add_weight(self, name, initial_value, dtype=None):
"""Adds a weight to this loss scale.
Args:
name: Variable name.
initial_value: The variable's initial value.
dtype: The type of the variable.
Returns:
A variable.
Raises:
RuntimeError: If a weight with `name` has already been added.
"""
variable = variable_scope.variable(
initial_value=initial_value,
name=name,
dtype=dtype,
trainable=False,
use_resource=True,
synchronization=variables.VariableSynchronization.AUTO,
# Set aggregation to NONE, as loss scaling variables should never be
# aggregated.
aggregation=variables.VariableAggregation.NONE)
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
key = (name, graph_key)
self._weights[key] = variable
self._handle_deferred_dependencies(name=name, trackable=variable)
backend.track_variable(variable)
return variable
@property
def _checkpoint_dependencies(self):
"""From Trackable. Gather graph-specific weights to save."""
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
weights = []
for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
if g == graph_key:
weights.append(trackable.TrackableReference(name=name, ref=v))
return (super(_DynamicLossScaleState, self)._checkpoint_dependencies +
weights)
def _lookup_dependency(self, name):
"""From Trackable. Find a weight in the current graph."""
unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name)
if unconditional is not None:
return unconditional
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
return self._weights.get((name, graph_key), None)
@property
def initial_loss_scale(self):
return self._initial_loss_scale
@property
def growth_steps(self):
return self._growth_steps
@property
def multiplier(self):
return self._multiplier
@property
def current_loss_scale(self):
"""Returns the current loss scale as a float32 `tf.Variable`."""
return self._current_loss_scale
@property
def counter(self):
"""Returns the counter as a float32 `tf.Variable`."""
return self._counter
def __call__(self):
"""Returns the current loss scale as a scalar `float32` tensor."""
return ops.convert_to_tensor(self._current_loss_scale)
def update(self, grads):
"""Updates the value of the loss scale.
Args:
grads: A nested structure of unscaled gradients, each which is the
gradient of the loss with respect to a weight.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
scale.
should_apply_gradients: Either a bool or a scalar boolean tensor. If
False, the caller should skip applying `grads` to the variables this
step.
"""
grads = nest.flatten(grads)
if distribution_strategy_context.has_strategy():
distribution = distribution_strategy_context.get_strategy()
def get_is_finite(grads):
is_finite = _is_all_finite(grads)
# We cast to float, because we cannot reduce booleans with
# DistributionStrategy.
return math_ops.cast(is_finite, dtypes.float32)
is_finite_float = distribution.extended.call_for_each_replica(
get_is_finite, args=(grads,))
reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
is_finite_float, axis=None)
is_finite = math_ops.equal(reduced_is_finite_float,
distribution.num_replicas_in_sync)
else:
is_finite = _is_all_finite(grads)
def update_if_finite_grads():
"""Update assuming the gradients are finite."""
def incr_loss_scale():
new_loss_scale = self.current_loss_scale * self.multiplier
return control_flow_ops.group(
_assign_if_finite(self.current_loss_scale, new_loss_scale),
self.counter.assign(0))
return control_flow_ops.cond(
self.counter + 1 >= self.growth_steps,
incr_loss_scale,
lambda: _op_in_graph_mode(self.counter.assign_add(1)))
def update_if_not_finite_grads():
"""Update assuming the gradients are nonfinite."""
new_loss_scale = math_ops.maximum(
self.current_loss_scale / self.multiplier, 1)
return control_flow_ops.group(
self.counter.assign(0),
self.current_loss_scale.assign(new_loss_scale))
update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
update_if_not_finite_grads)
should_apply_gradients = is_finite
return update_op, should_apply_gradients
# See LossScaleOptimizer docstring for why this is so big
_DEFAULT_INITIAL_SCALE = 2 ** 15
_DEFAULT_GROWTH_STEPS = 2000
# pylint: disable=g-classes-have-attributes
@keras_export('keras.mixed_precision.LossScaleOptimizer')
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling. """An optimizer that applies loss scaling.
@ -194,19 +414,21 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
applied. applied.
The loss scale can either be a fixed constant, chosen by the user, or be The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient dynamically determined. Using a dynamic loss scale is highly recommend and is
as a loss scale does not have to be explicitly chosen. However it reduces the default behavior, as choosing a specific fixed loss scale is difficult.
performance. Every step, the dynamic loss scale is potentially updated to a new value.
Dynamic loss scaling sometimes causes the loss scale to be too high and cause
the gradients to overflow, in which case gradients are not applied to
variables that step.
This optimizer wraps another optimizer and applies loss scaling to it via a `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
`LossScale`. Loss scaling is applied whenever gradients are Loss scaling is applied whenever gradients are computed, either through
computed, either through `minimize()` or `get_gradients()`. The loss scale is `minimize()` or `get_gradients()`. If dynamic, the loss scale is updated
updated via `LossScale.update()` whenever gradients are applied, either whenever gradients are applied, either through `minimize()` or
through `minimize()` or `apply_gradients()`. For example: `apply_gradients()`. For example:
>>> opt = tf.keras.optimizers.SGD(0.25) >>> opt = tf.keras.optimizers.SGD(0.25)
>>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
... "dynamic")
>>> var = tf.Variable(1.) >>> var = tf.Variable(1.)
>>> loss_fn = lambda: var ** 2 >>> loss_fn = lambda: var ** 2
>>> # 'minimize' applies loss scaling to the loss and updates the loss sale. >>> # 'minimize' applies loss scaling to the loss and updates the loss sale.
@ -230,12 +452,43 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
>>> var.numpy() >>> var.numpy()
0.25 0.25
Args:
optimizer: The Optimizer instance to wrap.
dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
True. If True, the loss scale will be dynamically updated over time using
an algorithm that keeps the loss scale at approximately its optimal value.
If False, a single fixed loss scale is used and `initial_scale` must be
specified, which is used as the loss scale. Recommended to keep as True,
as choosing a fixed loss scale can be tricky. Currently, there is a small
performance overhead to dynamic loss scaling compared to fixed loss
scaling.
initial_scale: The initial loss scale. If `dynamic` is True, this defaults
to 2 ** 15. If `dynamic` is False, this must be specified and acts as the
sole loss scale, as the loss scale does not change over time. When dynamic
loss scaling is used, is better for this to be a very high number, because
a loss scale that is too high gets lowered far more quickly than a loss
scale that is too low gets raised.
dynamic_growth_steps: With dynamic loss scaling, every
`dynamic_growth_steps` steps with finite gradients, the loss scale is
doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
count is reset back to zero, gradients are skipped that step, and the loss
scale is halved. The count can be queried with
`LossScaleOptimizer.dynamic_counter`. This argument can only be specified
if `dynamic` is True.
To use a fixed loss scale instead of dynamic loss scale, pass `dynamic=False`
and pass the loss scale to `initial_scale`. For example:
>>> opt = tf.keras.mixed_precision.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=1024)
>>> opt.loss_scale.numpy()
1024.
Hyperparameters can be accessed and set on the LossScaleOptimizer, which will Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
be delegated to the wrapped optimizer. be delegated to the wrapped optimizer.
>>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
>>> lso = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, >>> lso = tf.keras.mixed_precision.LossScaleOptimizer(opt)
... "dynamic")
>>> opt.beta_1 >>> opt.beta_1
0.8 0.8
>>> lso.beta_1 # Equivalent to `opt.beta_1` >>> lso.beta_1 # Equivalent to `opt.beta_1`
@ -268,50 +521,99 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
_HAS_AGGREGATE_GRAD = True _HAS_AGGREGATE_GRAD = True
def __init__(self, optimizer, loss_scale): def __init__(self, optimizer, dynamic=True, initial_scale=None,
"""Initializes this loss scale optimizer. dynamic_growth_steps=None):
Args:
optimizer: The Optimizer instance to wrap.
loss_scale: The loss scale to scale the loss and gradients. This can
either be an int/float to use a fixed loss scale, the string "dynamic"
to use dynamic loss scaling, or an instance of a LossScale. The string
"dynamic" equivalent to passing `DynamicLossScale()`, and passing an
int/float is equivalent to passing a FixedLossScale with the given loss
scale.
"""
if not isinstance(optimizer, optimizer_v2.OptimizerV2): if not isinstance(optimizer, optimizer_v2.OptimizerV2):
raise ValueError('"optimizer" must be an instance of OptimizerV2, but ' raise TypeError('"optimizer" must be an instance of OptimizerV2, but '
'got: %s' % optimizer) 'got: %s' % optimizer)
if not isinstance(dynamic, bool):
# Catch errors if a user incorrectly passes a string or float to the
# second argument argument, as this is commonly done for
# LossScaleOptimizerV1.
raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must '
'be a bool, but got: %r' % (dynamic,))
self._raise_if_strategy_unsupported() self._raise_if_strategy_unsupported()
self._optimizer = optimizer self._optimizer = optimizer
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None.')
# We don't call super().__init__, since we do not want to call OptimizerV2's # We don't call super().__init__, since we do not want to call OptimizerV2's
# constructor. # constructor.
_DelegatingTrackableMixin.__init__(self, self._optimizer) _DelegatingTrackableMixin.__init__(self, self._optimizer)
for weight in self._loss_scale._weights.values(): # pylint: disable=protected-access if dynamic:
# We cannot call `track_variable` in the LossScale class itself, because a if initial_scale is None:
# file outside of Keras cannot depend on a Keras file. Calling it here initial_scale = _DEFAULT_INITIAL_SCALE
# instead is OK, because a variable only needs to be tracked if used with if dynamic_growth_steps is None:
# a Keras class, and the only way to use LossScale with a Keras class is dynamic_growth_steps = _DEFAULT_GROWTH_STEPS
# through the LossScaleOptimizer. self._loss_scale = _DynamicLossScaleState(
backend.track_variable(weight) initial_scale, dynamic_growth_steps, multiplier=2)
self._track_trackable(self._loss_scale, 'loss_scale') self._track_trackable(self._loss_scale, 'loss_scale')
else:
if initial_scale is None:
raise ValueError('"initial_scale" must be specified if "dynamic" is '
'False')
self._loss_scale = float(initial_scale)
if dynamic_growth_steps is not None:
raise ValueError('"dynamic_growth_steps" must be None if "dynamic" '
'is False, but got: %s' % (dynamic_growth_steps,))
# To support restoring TensorFlow 2.2 checkpoints. # To support restoring TensorFlow 2.2 checkpoints.
self._track_trackable(FakeOptimizerForRestoration(self._optimizer), self._track_trackable(FakeOptimizerForRestoration(self._optimizer),
'base_optimizer') 'base_optimizer')
@property
def dynamic(self):
return isinstance(self._loss_scale, _DynamicLossScaleState)
@property @property
def loss_scale(self): def loss_scale(self):
"""The `LossScale` instance associated with this optimizer.""" """The current loss scale as a float32 scalar tensor."""
if isinstance(self._loss_scale, _DynamicLossScaleState):
return ops.convert_to_tensor(self._loss_scale.current_loss_scale)
else:
return ops.convert_to_tensor(self._loss_scale)
@property
def dynamic_counter(self):
"""The number of steps since the loss scale was last increased or decreased.
This is None if `LossScaleOptimizer.dynamic` is False.
The counter is incremented every step. Once it reaches
`LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled
and the counter will be reset back to zero. If nonfinite gradients are
encountered, the loss scale will be halved and the counter will be reset
back to zero.
"""
if isinstance(self._loss_scale, _DynamicLossScaleState):
return self._loss_scale.counter
else:
return None
@property
def initial_scale(self):
"""The initial loss scale.
This is None if `LossScaleOptimizer.dynamic` is False.
"""
if isinstance(self._loss_scale, _DynamicLossScaleState):
return self._loss_scale.initial_loss_scale
else:
return self._loss_scale return self._loss_scale
@property
def dynamic_growth_steps(self):
"""The number of steps it takes to increase the loss scale.
This is None if `LossScaleOptimizer.dynamic` is False.
Every `dynamic_growth_steps` consecutive steps with finite gradients, the
loss scale is increased.
"""
if isinstance(self._loss_scale, _DynamicLossScaleState):
return self._loss_scale.growth_steps
else:
return None
def get_scaled_loss(self, loss): def get_scaled_loss(self, loss):
"""Scales the loss by the loss scale. """Scales the loss by the loss scale.
@ -322,7 +624,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
scaling is automatically applied and this method is unneeded. scaling is automatically applied and this method is unneeded.
If this method is called, `get_unscaled_gradients` should also be called. If this method is called, `get_unscaled_gradients` should also be called.
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for
an example. an example.
Args: Args:
@ -330,16 +632,15 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
a tensor or a callable returning a tensor. a tensor or a callable returning a tensor.
Returns: Returns:
`loss` multiplied by `LossScaleOptimizer.loss_scale()`. `loss` multiplied by `LossScaleOptimizer.loss_scale`.
""" """
loss_scale = self._loss_scale()
if callable(loss): if callable(loss):
def new_loss(): def new_loss():
loss_val = loss() loss_val = loss()
return loss_val * math_ops.cast(loss_scale, loss_val.dtype) return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype)
return new_loss return new_loss
else: else:
return loss * math_ops.cast(loss_scale, loss.dtype) return loss * math_ops.cast(self.loss_scale, loss.dtype)
def get_unscaled_gradients(self, grads): def get_unscaled_gradients(self, grads):
"""Unscales the gradients by the loss scale. """Unscales the gradients by the loss scale.
@ -351,7 +652,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
scaling is automatically applied and this method is unneeded. scaling is automatically applied and this method is unneeded.
If this method is called, `get_scaled_loss` should also be called. See If this method is called, `get_scaled_loss` should also be called. See
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an
example. example.
Args: Args:
@ -360,10 +661,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
Returns: Returns:
A new list the same size as `grads`, where every non-None value in `grads` A new list the same size as `grads`, where every non-None value in `grads`
is divided by `LossScaleOptimizer.loss_scale()`. is divided by `LossScaleOptimizer.loss_scale`.
""" """
loss_scale = self._loss_scale() loss_scale_reciprocal = 1. / self.loss_scale
loss_scale_reciprocal = 1. / loss_scale
return [ return [
_multiply_gradient(g, loss_scale_reciprocal) if g is not None else None _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None
for g in grads for g in grads
@ -379,9 +679,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
grad_loss, grad_loss,
tape=tape) tape=tape)
grads = [g for g, _ in grads_and_vars] grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars] weights = [v for _, v in grads_and_vars]
unscaled_grads = self.get_unscaled_gradients(grads) unscaled_grads = self.get_unscaled_gradients(grads)
return list(zip(unscaled_grads, variables)) return list(zip(unscaled_grads, weights))
def get_gradients(self, loss, params): def get_gradients(self, loss, params):
loss = self.get_scaled_loss(loss) loss = self.get_scaled_loss(loss)
@ -409,7 +709,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name, def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name,
experimental_aggregate_gradients): experimental_aggregate_gradients):
grads = [g for g, _ in grads_and_vars] grads = [g for g, _ in grads_and_vars]
if isinstance(self._loss_scale, _DynamicLossScaleState):
loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)
else:
loss_scale_update_op = control_flow_ops.no_op()
should_apply_grads = True
def apply_fn(): def apply_fn():
# We do not want DistributionStrategy to unwrap any MirroredVariables in # We do not want DistributionStrategy to unwrap any MirroredVariables in
@ -447,10 +751,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
def get_config(self): def get_config(self):
serialized_optimizer = optimizers.serialize(self._optimizer) serialized_optimizer = optimizers.serialize(self._optimizer)
serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale)
return { return {
'optimizer': serialized_optimizer, 'optimizer': serialized_optimizer,
'loss_scale': serialized_loss_scale, 'dynamic': self.dynamic,
'initial_scale': self.initial_scale,
'dynamic_growth_steps': self.dynamic_growth_steps,
} }
@classmethod @classmethod
@ -458,8 +763,28 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
config = config.copy() # Make a copy, since we mutate config config = config.copy() # Make a copy, since we mutate config
config['optimizer'] = optimizers.deserialize( config['optimizer'] = optimizers.deserialize(
config['optimizer'], custom_objects=custom_objects) config['optimizer'], custom_objects=custom_objects)
config['loss_scale'] = keras_loss_scale_module.deserialize( if 'loss_scale' in config:
config['loss_scale'], custom_objects=custom_objects) # If loss_scale is in config, we assume we are deserializing a
# LossScaleOptimizer from TF 2.3 or below. We convert the config so it
# can be deserialized in the current LossScaleOptimizer.
loss_scale = keras_loss_scale_module.deserialize(
config.pop('loss_scale'))
if isinstance(loss_scale, loss_scale_module.FixedLossScale):
config['dynamic'] = False
config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access
elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
config['dynamic'] = True
config['initial_scale'] = loss_scale.initial_loss_scale
config['dynamic_growth_steps'] = loss_scale.increment_period
if loss_scale.multiplier != 2:
raise ValueError('Cannot deserialize LossScaleOptimizer with a '
'DynamicLossScale whose multiplier is not 2. Got '
'DynamicLossScale: %s' % (loss_scale,))
else:
raise ValueError(
'Serialized LossScaleOptimizers with a LossScale that is neither a '
'FixedLossScale nor a DynamicLossScale can no longer be '
'deserialized')
return cls(**config) return cls(**config)
def _raise_if_strategy_unsupported(self): def _raise_if_strategy_unsupported(self):
@ -601,15 +926,162 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
# both self._compute_gradients() and self.apply_gradients(), and both need # both self._compute_gradients() and self.apply_gradients(), and both need
# to have the LossScaleOptimizer version called. # to have the LossScaleOptimizer version called.
# TODO(reedwm): Maybe merge this class's functionality into OptimizerV2.
# TODO(reedwm): Maybe throw an error if mixed precision is used without this # TODO(reedwm): Maybe throw an error if mixed precision is used without this
# optimizer being used. # optimizer being used.
# Trackable delegations: Delegate all Trackable methods to the wrapped
# optimizer. This is so the checkpoint format for a LossScaleOptimizer is @keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
# identical to the checkpoint format for a normal optimizer, except the loss class LossScaleOptimizerV1(LossScaleOptimizer):
# scale is stored in the checkpoint. """An deprecated optimizer that applies loss scaling.
Warning: This class is deprecated and will be removed in TensorFlow 2.5.
Please use the non-experimental class
`tf.keras.mixed_precision.LossScaleOptimizer` instead.
This class is identical to the non-experimental
`keras.mixed_precision.LossScaleOptimizer` except its constructor takes
different arguments. For this class (the experimental version), the
constructor takes a `loss_scale` argument. For the non-experimental class,
the constructor encodes the loss scaling information in multiple arguments.
Note that unlike this class, the non-experimental class does not accept a
`tf.compat.v1.mixed_precision.LossScale`, which is deprecated.
If you currently use this class, you should switch to the non-experimental
`tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several
examples of converting the use of the experimental class to the equivalent
non-experimental class.
>>> # In all of the the examples below, `opt1` and `opt2` are identical
>>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), loss_scale='dynamic')
>>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
... tf.keras.optimizers.SGD())
>>> assert opt1.get_config() == opt2.get_config()
>>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), loss_scale=123)
>>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123
>>> # refers to the initial loss scale, which is the single fixed loss scale
>>> # when dynamic=False.
>>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123)
>>> assert opt1.get_config() == opt2.get_config()
>>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale(
... initial_loss_scale=2048, increment_period=500)
>>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), loss_scale=loss_scale)
>>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
... tf.keras.optimizers.SGD(), initial_scale=2048,
... dynamic_growth_steps=500)
>>> assert opt1.get_config() == opt2.get_config()
Args:
optimizer: The Optimizer instance to wrap.
loss_scale: The loss scale to scale the loss and gradients. This can
either be an int/float to use a fixed loss scale, the string "dynamic"
to use dynamic loss scaling, or an instance of a LossScale. The string
"dynamic" equivalent to passing `DynamicLossScale()`, and passing an
int/float is equivalent to passing a FixedLossScale with the given loss
scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must
be 2 (the default).
"""
def __init__(self, optimizer, loss_scale):
warn_msg_prefix = (
'tf.keras.mixed_precision.experimental.LossScaleOptimizer is '
'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer '
'instead. ')
if isinstance(loss_scale, dict):
loss_scale = keras_loss_scale_module.deserialize(loss_scale)
if isinstance(loss_scale, (int, float)):
tf_logging.warn(
warn_msg_prefix + 'For example\n'
' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
'opt, dynamic=False, initial_scale={})'.format(loss_scale))
super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
initial_scale=loss_scale)
elif isinstance(loss_scale, loss_scale_module.FixedLossScale):
ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access
tf_logging.warn(
warn_msg_prefix + 'For example\n'
' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
'opt, dynamic=False, initial_scale={})'.format(ls_val))
super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False,
initial_scale=ls_val)
elif loss_scale == 'dynamic':
tf_logging.warn(
warn_msg_prefix + 'For example\n'
' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
'opt)')
super(LossScaleOptimizerV1, self).__init__(optimizer)
elif isinstance(loss_scale, loss_scale_module.DynamicLossScale):
kwargs = {}
extra_arguments = ''
if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE:
kwargs['initial_scale'] = loss_scale.initial_loss_scale
extra_arguments += (', initial_scale=%s' %
loss_scale.initial_loss_scale)
if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS:
kwargs['dynamic_growth_steps'] = loss_scale.increment_period
extra_arguments += (', dynamic_growth_steps=%s' %
loss_scale.increment_period)
if loss_scale.multiplier != 2:
raise ValueError('When passing a DynamicLossScale to "loss_scale", '
'DynamicLossScale.multiplier must be 2. Got: %s'
% (loss_scale,))
tf_logging.warn(
warn_msg_prefix +
'Note that the non-experimental LossScaleOptimizer does not take a '
'DynamicLossScale but instead takes the dynamic configuration '
'directly in the constructor. For example:\n'
' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer('
'opt{})\n'.format(extra_arguments))
super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs)
elif isinstance(loss_scale, loss_scale_module.LossScale):
raise TypeError('Passing a LossScale that is not a FixedLossScale or a '
'DynamicLossScale is no longer supported. Got: {}'
.format(loss_scale))
else:
raise ValueError('Invalid value passed to loss_scale. loss_scale '
'must be the string "dynamic" (recommended), an int, '
'a float, a FixedLossScale, or a DynamicLossScale. Got '
'value: {}'.format(loss_scale))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy() # Make a copy, since we mutate config
config['optimizer'] = optimizers.deserialize(
config['optimizer'], custom_objects=custom_objects)
# If loss_scale is in config, we assume we are deserializing a
# LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are
# deserializing a LossScaleOptimizer from TF 2.4 or above.
if 'loss_scale' in config:
config['loss_scale'] = keras_loss_scale_module.deserialize(
config['loss_scale'])
if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale)
and config['loss_scale'].multiplier != 2):
raise ValueError('Cannot deserialize LossScaleOptimizer with a '
'DynamicLossScale whose multiplier is not 2. Got '
'DynamicLossScale: %s' % (config['loss_scale'],))
return cls(**config)
# We convert the config, as generated by LossScaleOptimizer.get_config, to a
# version that can be passed to LossScaleOptimizerV1.__init__
if config['dynamic']:
config['loss_scale'] = loss_scale_module.DynamicLossScale(
config['initial_scale'], config['dynamic_growth_steps'], multiplier=2)
else:
config['loss_scale'] = loss_scale_module.FixedLossScale(
config['initial_scale'])
del config['dynamic']
del config['initial_scale']
del config['dynamic_growth_steps']
return cls(**config)
class FakeOptimizerForRestoration(trackable.Trackable): class FakeOptimizerForRestoration(trackable.Trackable):
@ -654,7 +1126,7 @@ class FakeOptimizerForRestoration(trackable.Trackable):
# pylint: disable=protected-access # pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
LossScaleOptimizer) LossScaleOptimizerV1)
def _multiply_gradient(gradient, scale): def _multiply_gradient(gradient, scale):

View File

@ -40,7 +40,7 @@ from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import loss_scale as tf_loss_scale_module
from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.training.tracking import util as trackable_utils
# Disable not-callable lint error, as the linter is unable to detect that # Disable not-callable lint error, as the linter is unable to detect that
@ -93,7 +93,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
var = variables.Variable([5.0]) var = variables.Variable([5.0])
opt = gradient_descent.SGD(2.0) opt = gradient_descent.SGD(2.0)
loss_scale = 10. loss_scale = 10.
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=loss_scale)
self.assertEqual(self.evaluate(opt.loss_scale), loss_scale)
self.assertIsInstance(opt.loss_scale, ops.Tensor)
# We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale
# / strategy.num_replicas_in_sync will not be exact, which could lead to # / strategy.num_replicas_in_sync will not be exact, which could lead to
# assertion failures due to rounding issues. # assertion failures due to rounding issues.
@ -112,7 +115,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
var = variables.Variable([2.0]) var = variables.Variable([2.0])
opt = gradient_descent.SGD(1.0) opt = gradient_descent.SGD(1.0)
loss_scale = 10. loss_scale = 10.
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=loss_scale)
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
loss_scale) loss_scale)
loss = grad_check_fn(var) loss = grad_check_fn(var)
@ -122,9 +126,18 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# mp_test_util.create_identity_with_grad_check_fn added an assertion op. # mp_test_util.create_identity_with_grad_check_fn added an assertion op.
self.evaluate(run_op) self.evaluate(run_op)
def testDynamicAttrsWithFixedLossScale(self):
opt = gradient_descent.SGD()
opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=2.)
self.assertFalse(opt.dynamic)
self.assertIsNone(opt.dynamic_counter)
self.assertIsNone(opt.dynamic_growth_steps)
def testGetScaledLoss(self): def testGetScaledLoss(self):
opt = gradient_descent.SGD(2.0) opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=2.)
loss = ops.convert_to_tensor_v2_with_dispatch(5.) loss = ops.convert_to_tensor_v2_with_dispatch(5.)
self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss))) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss)))
self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)())) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)()))
@ -134,7 +147,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
def testGetUnscaledGradients(self): def testGetUnscaledGradients(self):
opt = gradient_descent.SGD(2.0) opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=2)
scaled_grads = [ scaled_grads = [
ops.convert_to_tensor_v2_with_dispatch(3.), None, ops.convert_to_tensor_v2_with_dispatch(3.), None,
ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16') ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16')
@ -145,7 +159,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
def testGetUnscaledSparseGradients(self): def testGetUnscaledSparseGradients(self):
opt = gradient_descent.SGD(2.0) opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=2)
sparse_scaled_grad = ops.IndexedSlices( sparse_scaled_grad = ops.IndexedSlices(
ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]), ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]),
ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'), ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'),
@ -165,12 +180,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with strategy.scope(): with strategy.scope():
var = variables.Variable([5.0]) var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate) opt = gradient_descent.SGD(learning_rate)
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2,
initial_loss_scale=2, increment_period=1, multiplier=2) dynamic_growth_steps=1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) self.assertEqual(opt.initial_scale, 2.)
self.assertEqual( self.assertIsInstance(opt.initial_scale, float)
loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) self.assertEqual(opt.dynamic_growth_steps, 1)
self.assertIsInstance(opt.dynamic_growth_steps, int)
self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0)
run_fn = self._run_fn_with_grad_check(strategy, var, opt, run_fn = self._run_fn_with_grad_check(strategy, var, opt,
expected_gradient) expected_gradient)
run_op = strategy.experimental_run(run_fn) run_op = strategy.experimental_run(run_fn)
@ -189,6 +206,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# 1. # 1.
self.assertAllClose([1.], self.evaluate(var)) self.assertAllClose([1.], self.evaluate(var))
def testDynamicLossScaleDefaultValues(self):
opt = gradient_descent.SGD()
opt = loss_scale_optimizer.LossScaleOptimizer(opt)
self.assertEqual(opt.initial_scale, 2 ** 15)
self.assertEqual(opt.dynamic_growth_steps, 2000)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15)
# pylint: disable=cell-var-from-loop # pylint: disable=cell-var-from-loop
@parameterized.named_parameters(*TESTCASES) @parameterized.named_parameters(*TESTCASES)
def testClipping(self, strategy_fn): def testClipping(self, strategy_fn):
@ -198,12 +223,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with strategy.scope(), self.subTest(clip_type=clip_type): with strategy.scope(), self.subTest(clip_type=clip_type):
var = variables.Variable([5.0]) var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0}) opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0})
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2,
initial_loss_scale=2, increment_period=1, multiplier=2) dynamic_growth_steps=1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
self.assertEqual(getattr(opt, clip_type), 2.0) self.assertEqual(getattr(opt, clip_type), 2.0)
self.assertEqual( self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0)
loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0)
loss = lambda: var * 4 / strategy.num_replicas_in_sync loss = lambda: var * 4 / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, var_list=[var]) run_fn = lambda: opt.minimize(loss, var_list=[var])
@ -215,7 +238,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# The gradient is 4 but is clipped to 2, so the variable will be # The gradient is 4 but is clipped to 2, so the variable will be
# init_val - clipped_grad * lr == 5 - 2 * 2 == 1 # init_val - clipped_grad * lr == 5 - 2 * 2 == 1
self.assertAllClose([1.], self.evaluate(var)) self.assertAllClose([1.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), 4) self.assertEqual(self.evaluate(opt.loss_scale), 4)
# Test changing the clip amount and running again # Test changing the clip amount and running again
setattr(opt, clip_type, 3.0) setattr(opt, clip_type, 3.0)
@ -224,7 +247,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# The gradient is 4 but is clipped to 3, so the variable will be # The gradient is 4 but is clipped to 3, so the variable will be
# prev_var - clipped_grad * lr == 1 - 3 * 2 == -5 # prev_var - clipped_grad * lr == 1 - 3 * 2 == -5
self.assertAllClose([-5.], self.evaluate(var)) self.assertAllClose([-5.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), 8) self.assertEqual(self.evaluate(opt.loss_scale), 8)
# Test Inf gradients are still skipped instead of being clipped # Test Inf gradients are still skipped instead of being clipped
loss = lambda: var * float('Inf') loss = lambda: var * float('Inf')
@ -232,7 +255,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
run_op = strategy.experimental_run(run_fn) run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op) self._run_if_in_graph_mode(run_op)
self.assertAllClose([-5.], self.evaluate(var)) # Var does not change self.assertAllClose([-5.], self.evaluate(var)) # Var does not change
self.assertEqual(self.evaluate(opt.loss_scale()), 4) self.assertEqual(self.evaluate(opt.loss_scale), 4)
# pylint: enable=cell-var-from-loop # pylint: enable=cell-var-from-loop
@parameterized.named_parameters(*TESTCASES) @parameterized.named_parameters(*TESTCASES)
@ -240,9 +263,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with strategy_fn().scope() as strategy: with strategy_fn().scope() as strategy:
var = variables.Variable([1.0, 2.0]) var = variables.Variable([1.0, 2.0])
opt = gradient_descent.SGD(1.0) opt = gradient_descent.SGD(1.0)
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2,
initial_loss_scale=2, increment_period=1, multiplier=2) dynamic_growth_steps=1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
# Test optimizer with finite gradients # Test optimizer with finite gradients
loss = lambda: var * 2.0 / strategy.num_replicas_in_sync loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
@ -253,7 +275,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# Gradient is 2, so variable will have 2 subtracted from it # Gradient is 2, so variable will have 2 subtracted from it
self.assertAllClose([-1.0, 0.0], self.evaluate(var)) self.assertAllClose([-1.0, 0.0], self.evaluate(var))
# Loss scale has doubled from 2 to 4 # Loss scale has doubled from 2 to 4
self.assertEqual(4., self.evaluate(opt.loss_scale())) self.assertEqual(4., self.evaluate(opt.loss_scale))
# Test optimizer with NaN gradients # Test optimizer with NaN gradients
loss = lambda: var * float('NaN') loss = lambda: var * float('NaN')
@ -263,7 +285,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# Variable should not change from before, due to NaN gradients. # Variable should not change from before, due to NaN gradients.
self.assertAllClose(self.evaluate(var), [-1.0, 0.0]) self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
# Loss scale should half due to NaN gradients. # Loss scale should half due to NaN gradients.
self.assertEqual(2., self.evaluate(opt.loss_scale())) self.assertEqual(2., self.evaluate(opt.loss_scale))
@parameterized.named_parameters(*TESTCASES) @parameterized.named_parameters(*TESTCASES)
def testDynamicLossScaleWithFloat16Loss(self, strategy_fn): def testDynamicLossScaleWithFloat16Loss(self, strategy_fn):
@ -272,9 +294,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with strategy.scope(): with strategy.scope():
var = variables.Variable([5.0]) var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate) opt = gradient_descent.SGD(learning_rate)
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2,
initial_loss_scale=2, increment_period=1, multiplier=2) dynamic_growth_steps=1)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
def loss(): def loss():
return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16') return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16')
@ -297,11 +318,9 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
var = variables.Variable([1.0, 2.0]) var = variables.Variable([1.0, 2.0])
# An SGD optimizer with momentum has slot variables. # An SGD optimizer with momentum has slot variables.
opt = gradient_descent.SGD(1.0, momentum=1.) opt = gradient_descent.SGD(1.0, momentum=1.)
initial_loss_scale = 2. initial_scale = 2.
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(
initial_loss_scale=initial_loss_scale, increment_period=1, opt, initial_scale=initial_scale, dynamic_growth_steps=1)
multiplier=4)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
loss = lambda: var / strategy.num_replicas_in_sync loss = lambda: var / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, var_list=[var]) run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn) run_op = strategy.experimental_run(run_fn)
@ -312,7 +331,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# variable is subtracted by the accumulator, so the variable is subtracted # variable is subtracted by the accumulator, so the variable is subtracted
# by 1. # by 1.
self.assertAllClose([0.0, 1.0], self.evaluate(var)) self.assertAllClose([0.0, 1.0], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4) self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 2)
run_op = strategy.experimental_run(run_fn) run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op) self._run_if_in_graph_mode(run_op)
@ -321,14 +340,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# variable is subtracted by the accumulator, so the variable is subtracted # variable is subtracted by the accumulator, so the variable is subtracted
# by 2. # by 2.
self.assertAllClose([-2., -1.], self.evaluate(var)) self.assertAllClose([-2., -1.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale()), self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 4)
initial_loss_scale * 16)
self.assertEqual(opt.get_slot_names(), ['momentum']) self.assertEqual(opt.get_slot_names(), ['momentum'])
def testIterations(self): def testIterations(self):
opt = gradient_descent.SGD(2.0) opt = gradient_descent.SGD(2.0)
lso = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=10.) lso = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=10.)
lso.iterations = 7 lso.iterations = 7
self.assertEqual(lso.iterations, 7) self.assertEqual(lso.iterations, 7)
self.assertEqual(opt.iterations, 7) self.assertEqual(opt.iterations, 7)
@ -338,7 +357,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with strategy_fn().scope() as strategy: with strategy_fn().scope() as strategy:
# Test iterations is incremented in opt.minimize. # Test iterations is incremented in opt.minimize.
opt = gradient_descent.SGD(1.0) opt = gradient_descent.SGD(1.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale='dynamic') opt = loss_scale_optimizer.LossScaleOptimizer(opt)
var = variables.Variable([5.0]) var = variables.Variable([5.0])
loss = lambda: var * 2.0 / strategy.num_replicas_in_sync loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, [var]) run_fn = lambda: opt.minimize(loss, [var])
@ -361,11 +380,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with self.test_session(): with self.test_session():
var = variables.Variable([1.0]) var = variables.Variable([1.0])
opt = gradient_descent.SGD(1.0) opt = gradient_descent.SGD(1.0)
initial_loss_scale = 2. opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2.,
loss_scale = loss_scale_module.DynamicLossScale( dynamic_growth_steps=1)
initial_loss_scale=initial_loss_scale, increment_period=1,
multiplier=4)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
run_op = opt.minimize(lambda: var * 2, [var]) run_op = opt.minimize(lambda: var * 2, [var])
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op) self._run_if_in_graph_mode(run_op)
@ -377,15 +393,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
opt.set_weights([np.array(2.)]) opt.set_weights([np.array(2.)])
self.assertEqual(self.evaluate(opt.variables()[0]), 2) self.assertEqual(self.evaluate(opt.variables()[0]), 2)
def testPassingNoneToLossScale(self):
opt = gradient_descent.SGD()
with self.assertRaisesRegex(ValueError, r'loss_scale cannot be None'):
loss_scale_optimizer.LossScaleOptimizer(opt, None)
def testHyperParametersExposed(self): def testHyperParametersExposed(self):
with self.cached_session(): with self.cached_session():
opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9) opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9)
lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') lso = loss_scale_optimizer.LossScaleOptimizer(opt)
# Force hyperparameters to be created # Force hyperparameters to be created
opt.lr # pylint: disable=pointless-statement opt.lr # pylint: disable=pointless-statement
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
@ -420,13 +431,13 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self._set_hyper('loss_scale', 123.) self._set_hyper('loss_scale', 123.)
opt = MyOpt() opt = MyOpt()
lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') lso = loss_scale_optimizer.LossScaleOptimizer(opt)
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
lso.loss_scale = loss_scale_module.FixedLossScale(2.) lso.loss_scale = 2.
def testArbitraryAttributesNotExposed(self): def testArbitraryAttributesNotExposed(self):
opt = gradient_descent.SGD() opt = gradient_descent.SGD()
lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') lso = loss_scale_optimizer.LossScaleOptimizer(opt)
self.assertFalse(opt.nesterov) self.assertFalse(opt.nesterov)
with self.assertRaisesRegex( with self.assertRaisesRegex(
AttributeError, AttributeError,
@ -438,8 +449,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertFalse(opt.nesterov) self.assertFalse(opt.nesterov)
def testDir(self): def testDir(self):
lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
'dynamic')
dir_result = dir(lso) dir_result = dir(lso)
self.assertIn('learning_rate', dir_result) # Hyperparameter self.assertIn('learning_rate', dir_result) # Hyperparameter
self.assertIn('lr', dir_result) # Hyperparameter self.assertIn('lr', dir_result) # Hyperparameter
@ -471,11 +481,125 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with create_mirrored_strategy().scope() as strategy: with create_mirrored_strategy().scope() as strategy:
var = variables.Variable([5.0]) var = variables.Variable([5.0])
opt = MyOptimizer(learning_rate=1.0) opt = MyOptimizer(learning_rate=1.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=1) opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False,
initial_scale=1)
loss = lambda: var * 2.0 loss = lambda: var * 2.0
run_fn = lambda: opt.minimize(loss, [var]) run_fn = lambda: opt.minimize(loss, [var])
strategy.experimental_run(run_fn) strategy.experimental_run(run_fn)
@parameterized.named_parameters(*TESTCASES)
def testV1Optimizer(self, strategy_fn):
strategy = strategy_fn()
learning_rate = 2.
with strategy.scope():
# Test FixedLossScale
var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale=2)
self.assertIsInstance(opt.loss_scale, ops.Tensor)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(opt.loss_scale), 2)
self.assertEqual(opt.initial_scale, 2)
self.assertIsNone(opt.dynamic_growth_steps)
run_fn = self._run_fn_with_grad_check(
strategy, var, opt, 2 / strategy.num_replicas_in_sync)
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The loss is the identity of the variable. Therefore the gradient is 1,
# and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
self.assertAllClose([3.], self.evaluate(var))
# Test DynamicLossScale
var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 'dynamic')
self.assertEqual(opt.initial_scale, 2 ** 15)
self.assertEqual(opt.dynamic_growth_steps, 2000)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15)
for s in strategy.experimental_local_results(opt.dynamic_counter):
self.assertEqual(self.evaluate(s), 0)
loss = lambda: var * float('NaN')
run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
self.assertAllClose([5.], self.evaluate(var))
self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 14)
for s in strategy.experimental_local_results(opt.dynamic_counter):
self.assertEqual(self.evaluate(s), 0)
@parameterized.named_parameters(*TESTCASES)
def testPassingV1LossScale(self, strategy_fn):
strategy = strategy_fn()
learning_rate = 2.
with strategy.scope():
# Test FixedLossScale
var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate)
loss_scale = tf_loss_scale_module.FixedLossScale(2.)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
self.assertIsInstance(opt.loss_scale, ops.Tensor)
self.evaluate(variables.global_variables_initializer())
self.assertEqual(self.evaluate(opt.loss_scale), 2)
run_fn = self._run_fn_with_grad_check(
strategy, var, opt, 2 / strategy.num_replicas_in_sync)
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The loss is the identity of the variable. Therefore the gradient is 1,
# and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
self.assertAllClose([3.], self.evaluate(var))
# Test DynamicLossScale
var = variables.Variable([5.0])
opt = gradient_descent.SGD(learning_rate)
loss_scale = tf_loss_scale_module.DynamicLossScale(
initial_loss_scale=4, increment_period=1, multiplier=2)
loss_scale._current_loss_scale.assign(2)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
self.assertEqual(opt.initial_scale, 4)
self.assertEqual(opt.dynamic_growth_steps, 1)
self.evaluate(variables.global_variables_initializer())
# Current loss scale is not copied so loss scale is reinitialized to 4
self.assertEqual(self.evaluate(opt.loss_scale), 4)
for s in strategy.experimental_local_results(opt.dynamic_counter):
self.assertEqual(self.evaluate(s), 0)
run_fn = self._run_fn_with_grad_check(
strategy, var, opt, 4 / strategy.num_replicas_in_sync)
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
self.assertAllClose([3.], self.evaluate(var))
def testPassingV1LossScaleErrors(self):
opt = gradient_descent.SGD()
loss_scale = tf_loss_scale_module.DynamicLossScale(multiplier=4)
with self.assertRaisesRegex(
ValueError, 'When passing a DynamicLossScale to "loss_scale", '
'DynamicLossScale.multiplier must be 2. Got: '
'DynamicLossScale'):
loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
class MyLossScale(tf_loss_scale_module.LossScale):
def __call__(self):
return 1.
def update(self, grads):
return None, True
def get_config(self):
return {}
with self.assertRaisesRegex(
TypeError, 'Passing a LossScale that is not a FixedLossScale or a '
'DynamicLossScale is no longer supported. Got:'):
loss_scale_optimizer.LossScaleOptimizerV1(opt, MyLossScale())
@parameterized.named_parameters({ @parameterized.named_parameters({
'testcase_name': 'SaveAndRestoreBase', 'testcase_name': 'SaveAndRestoreBase',
'strategy_fn': default_strategy_fn, 'strategy_fn': default_strategy_fn,
@ -529,10 +653,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
var = variables.Variable([2.0]) var = variables.Variable([2.0])
opt = inner_opt = MySGD(1., momentum=1.) opt = inner_opt = MySGD(1., momentum=1.)
if save_with_ls: if save_with_ls:
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1.,
initial_loss_scale=1., increment_period=2., dynamic_growth_steps=2.)
multiplier=2.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var]) run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var])
opt_op = strategy.experimental_run(run_fn) opt_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
@ -541,8 +663,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# Assert values. # Assert values.
self.assertEqual(self.evaluate(var), 1.) self.assertEqual(self.evaluate(var), 1.)
if save_with_ls: if save_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(opt.loss_scale), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
slot_var = opt.get_slot(var, 'momentum') slot_var = opt.get_slot(var, 'momentum')
self.assertEqual(self.evaluate(slot_var).item(), -1) self.assertEqual(self.evaluate(slot_var).item(), -1)
self.assertEqual(self.evaluate(opt.iterations), 1) self.assertEqual(self.evaluate(opt.iterations), 1)
@ -560,10 +682,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
var = variables.Variable([2.0]) var = variables.Variable([2.0])
opt = inner_opt = MySGD(1., momentum=1.) opt = inner_opt = MySGD(1., momentum=1.)
if restore_with_ls: if restore_with_ls:
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1.,
initial_loss_scale=1., increment_period=2., dynamic_growth_steps=2.)
multiplier=2.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
# Restore new model. # Restore new model.
checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var) checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var)
@ -578,11 +698,11 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly(): if context.executing_eagerly():
self.assertEqual(self.evaluate(var), 1.) self.assertEqual(self.evaluate(var), 1.)
if save_with_ls and restore_with_ls: if save_with_ls and restore_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(opt.loss_scale), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
elif restore_with_ls: elif restore_with_ls:
self.assertEqual(self.evaluate(loss_scale()), 1.) self.assertEqual(self.evaluate(opt.loss_scale), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) self.assertEqual(self.evaluate(opt.dynamic_counter), 0)
self.assertEqual(self.evaluate(opt.iterations), 1) self.assertEqual(self.evaluate(opt.iterations), 1)
# Run the model again. # Run the model again.
@ -611,30 +731,180 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(var), 1) self.assertEqual(self.evaluate(var), 1)
self.assertEqual(self.evaluate(slot_var).item(), -1) self.assertEqual(self.evaluate(slot_var).item(), -1)
def testGetConfig(self): @combinations.generate(combinations.combine(
get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2']))
def testGetConfigFixed(self, get_config, from_config):
# Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the
# LossScaleOptimizer from TF 2.3. Then restore the config into a
# LossScaleOptimizerV1 or LossScaleOptimizer
opt = gradient_descent.SGD(2., momentum=0.5) opt = gradient_descent.SGD(2., momentum=0.5)
loss_scale = loss_scale_module.DynamicLossScale( if get_config == 'v1':
initial_loss_scale=2., increment_period=3., opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 2)
multiplier=4.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
config = opt.get_config() config = opt.get_config()
elif get_config == 'v2':
opt = loss_scale_optimizer.LossScaleOptimizer(
opt, dynamic=False, initial_scale=2)
config = opt.get_config()
else:
self.assertEqual(get_config, 'tf2_3')
config = {
'optimizer': {
'class_name': 'SGD',
'config': {
'learning_rate': 2.0,
'momentum': 0.5,
'decay': 0.0,
'nesterov': False,
'name': 'SGD',
}
},
'loss_scale': {
'class_name': 'FixedLossScale',
'config': {'loss_scale_value': 2.0}
},
}
if from_config == 'v1':
opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)
else:
self.assertEqual(from_config, 'v2')
opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config)
# Force hyperparameters to be created # Force hyperparameters to be created
opt.lr # pylint: disable=pointless-statement opt.lr # pylint: disable=pointless-statement
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
# Test attributes on the optimizer
self.assertEqual(self.evaluate(opt.lr), 2.) self.assertEqual(self.evaluate(opt.lr), 2.)
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) self.assertEqual(self.evaluate(opt._optimizer.lr), 2.)
self.assertEqual(self.evaluate(opt.loss_scale()), 2.) self.assertEqual(self.evaluate(opt.momentum), 0.5)
self.assertEqual(opt.loss_scale.increment_period, 3.) self.assertEqual(self.evaluate(opt.loss_scale), 2.)
self.assertEqual(opt.loss_scale.multiplier, 4.) self.assertEqual(opt.initial_scale, 2.)
self.assertIsNone(opt.dynamic_growth_steps)
self.assertIsNone(opt.dynamic_counter)
self.assertFalse(opt.dynamic)
def testSerializationWithBuiltInOptimizer(self): # Ensure the optimizer can be used
var = variables.Variable([5.0])
run_op = self._run_fn_with_grad_check(
distribution_strategy_context.get_strategy(), var, opt, 2)()
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
self.assertEqual(self.evaluate(var), [3.])
@combinations.generate(combinations.combine(
get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2']))
def testGetConfigDynamic(self, get_config, from_config):
# Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the
# LossScaleOptimizer from TF 2.3. Then restore the config into a
# LossScaleOptimizerV1 or LossScaleOptimizer
opt = gradient_descent.SGD(2., momentum=0.5) opt = gradient_descent.SGD(2., momentum=0.5)
loss_scale = loss_scale_module.DynamicLossScale( if get_config == 'v1':
initial_loss_scale=2., increment_period=3., loss_scale = tf_loss_scale_module.DynamicLossScale(
multiplier=4.) initial_loss_scale=2, increment_period=3)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
config = opt.get_config()
elif get_config == 'v2':
opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2,
dynamic_growth_steps=3)
config = opt.get_config()
else:
self.assertEqual(get_config, 'tf2_3')
config = {
'optimizer': {
'class_name': 'SGD',
'config': {
'learning_rate': 2.0,
'momentum': 0.5,
'decay': 0.0,
'nesterov': False,
'name': 'SGD',
}
},
'loss_scale': {
'class_name': 'DynamicLossScale',
'config': {
'initial_loss_scale': 2.0,
'increment_period': 3,
'multiplier': 2.0,
}
},
}
if from_config == 'v1':
opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)
else:
self.assertEqual(from_config, 'v2')
opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config)
# Force hyperparameters to be created
opt.lr # pylint: disable=pointless-statement
self.evaluate(variables.global_variables_initializer())
# Test attributes on the optimizer
self.assertEqual(self.evaluate(opt.lr), 2.)
self.assertEqual(self.evaluate(opt._optimizer.lr), 2.)
self.assertEqual(self.evaluate(opt.momentum), 0.5)
self.assertEqual(self.evaluate(opt.loss_scale), 2.)
self.assertEqual(opt.initial_scale, 2.)
self.assertEqual(opt.dynamic_growth_steps, 3.)
self.assertTrue(opt.dynamic)
# Ensure the optimizer can be used
var = variables.Variable([5.0])
run_op = self._run_fn_with_grad_check(
distribution_strategy_context.get_strategy(), var, opt, 2)()
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
self.assertEqual(self.evaluate(var), [3.])
self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
def test_from_config_with_invalid_multiplier(self):
config = {
'optimizer': {
'class_name': 'SGD',
'config': {
'learning_rate': 2.0,
'momentum': 0.5,
'decay': 0.0,
'nesterov': False,
'name': 'SGD',
}
},
'loss_scale': {
'class_name': 'DynamicLossScale',
'config': {
'initial_loss_scale': 2.0,
'increment_period': 3,
'multiplier': 4.0,
}
},
}
expected_error = ('Cannot deserialize LossScaleOptimizer with a '
'DynamicLossScale whose multiplier is not 2. Got '
'DynamicLossScale: DynamicLossScale\\(')
with self.assertRaisesRegex(ValueError, expected_error):
loss_scale_optimizer.LossScaleOptimizer.from_config(config)
with self.assertRaisesRegex(ValueError, expected_error):
loss_scale_optimizer.LossScaleOptimizerV1.from_config(config)
@parameterized.named_parameters({
'testcase_name': 'V2',
'use_v1': False,
}, {
'testcase_name': 'V1',
'use_v1': True,
},)
def testSerializationWithBuiltInOptimizer(self, use_v1):
opt = gradient_descent.SGD(2., momentum=0.5)
if use_v1:
loss_scale = tf_loss_scale_module.DynamicLossScale(
initial_loss_scale=2., increment_period=3.)
opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale)
else:
opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2.,
dynamic_growth_steps=3.)
config = optimizers.serialize(opt) config = optimizers.serialize(opt)
opt = optimizers.deserialize(config) opt = optimizers.deserialize(config)
# Force hyperparameters to be created # Force hyperparameters to be created
@ -643,9 +913,21 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(opt.lr), 2.) self.assertEqual(self.evaluate(opt.lr), 2.)
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5)
self.assertEqual(self.evaluate(opt.loss_scale()), 2.) self.assertEqual(self.evaluate(opt.loss_scale), 2.)
self.assertEqual(opt.loss_scale.increment_period, 3.) self.assertEqual(opt.dynamic_growth_steps, 3.)
self.assertEqual(opt.loss_scale.multiplier, 4.) self.assertTrue(opt.dynamic, 4.)
# Deserializing a LossScaleOptimizer always always results in a V2
# LossScaleOptimizer, even if serialized with a LossScaleOptimizerV1.
self.assertAllEqual(type(opt), loss_scale_optimizer.LossScaleOptimizer)
# Ensure the optimizer can be used
var = variables.Variable([5.0])
run_op = self._run_fn_with_grad_check(
distribution_strategy_context.get_strategy(), var, opt, 2)()
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
self.assertEqual(self.evaluate(var), [3.])
self.assertEqual(self.evaluate(opt.dynamic_counter), 1)
def testSerializationWithCustomOptimizer(self): def testSerializationWithCustomOptimizer(self):
class MySGD(gradient_descent.SGD): class MySGD(gradient_descent.SGD):
@ -655,10 +937,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self.my_attribute = 123 self.my_attribute = 123
opt = MySGD(2., momentum=0.5) opt = MySGD(2., momentum=0.5)
loss_scale = loss_scale_module.DynamicLossScale( opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2.,
initial_loss_scale=2., increment_period=3., dynamic_growth_steps=3.)
multiplier=4.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
config = optimizers.serialize(opt) config = optimizers.serialize(opt)
custom_objects = {'MySGD': MySGD} custom_objects = {'MySGD': MySGD}
opt = optimizers.deserialize(config, custom_objects=custom_objects) opt = optimizers.deserialize(config, custom_objects=custom_objects)
@ -668,9 +948,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertEqual(self.evaluate(opt.lr), 2.) self.assertEqual(self.evaluate(opt.lr), 2.)
self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5)
self.assertEqual(self.evaluate(opt.loss_scale()), 2.) self.assertEqual(self.evaluate(opt.loss_scale), 2.)
self.assertEqual(opt.loss_scale.increment_period, 3.) self.assertEqual(opt.dynamic_growth_steps, 3.)
self.assertEqual(opt.loss_scale.multiplier, 4.)
self.assertEqual(opt._optimizer.my_attribute, 123) self.assertEqual(opt._optimizer.my_attribute, 123)
def testUnsupportedStrategy(self): def testUnsupportedStrategy(self):
@ -680,8 +959,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
'CentralStorageStrategy. Try using a different Strategy, e.g. a ' 'CentralStorageStrategy. Try using a different Strategy, e.g. a '
'MirroredStrategy') 'MirroredStrategy')
with strategy.scope(), self.assertRaisesRegex(ValueError, expected_error): with strategy.scope(), self.assertRaisesRegex(ValueError, expected_error):
loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD())
with strategy.scope(): with strategy.scope():
var = variables.Variable(1.0) var = variables.Variable(1.0)
loss = lambda: var * 2.0 loss = lambda: var * 2.0
@ -689,6 +968,24 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegex(ValueError, expected_error): with self.assertRaisesRegex(ValueError, expected_error):
strategy.experimental_run(run_fn) strategy.experimental_run(run_fn)
def testInvalidArgsWithFixedLossScale(self):
opt = gradient_descent.SGD()
with self.assertRaisesRegex(
ValueError, '"initial_scale" must be specified if "dynamic" is False'):
loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False)
with self.assertRaisesRegex(
ValueError, '"dynamic_growth_steps" must be None if "dynamic" is '
'False, but got: 2'):
loss_scale_optimizer.LossScaleOptimizer(
opt, dynamic=False, initial_scale=1, dynamic_growth_steps=2)
def testDynamicMustBeBool(self):
opt = gradient_descent.SGD()
with self.assertRaisesRegex(
TypeError, '"dynamic" argument to LossScaleOptimizer.__init__ must be '
"a bool, but got: 'dynamic'"):
loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic')
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -65,13 +65,13 @@ class MixedPrecisionTest(keras_parameterized.TestCase):
opt = gradient_descent_v2.SGD(1.0) opt = gradient_descent_v2.SGD(1.0)
opt = enable_mixed_precision_graph_rewrite(opt, 123.) opt = enable_mixed_precision_graph_rewrite(opt, 123.)
self.assertIsInstance( self.assertIsInstance(
opt, loss_scale_optimizer_v2.LossScaleOptimizer) opt, loss_scale_optimizer_v2.LossScaleOptimizerV1)
self.assertEqual(self.evaluate(opt._loss_scale()), 123.) self.assertEqual(self.evaluate(opt.loss_scale), 123.)
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph', 'eager']))
def test_optimizer_errors(self): def test_optimizer_errors(self):
opt = gradient_descent_v2.SGD(1.0) opt = gradient_descent_v2.SGD(1.0)
opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic') opt = loss_scale_optimizer_v2.LossScaleOptimizerV1(opt, 'dynamic')
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, '"opt" must not already be an instance of a ' ValueError, '"opt" must not already be an instance of a '
'LossScaleOptimizer.'): 'LossScaleOptimizer.'):

View File

@ -32,15 +32,15 @@ from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) @keras_export('keras.mixed_precision.Policy', v1=[])
class Policy(object): class Policy(object):
"""A dtype policy for a Keras layer. """A dtype policy for a Keras layer.
A dtype policy determines dtype-related aspects of a layer, such as its A dtype policy determines a layer's computation and variable dtypes. Each
computation and variable dtypes. Each layer has a policy. Policies can be layer has a policy. Policies can be passed to the `dtype` argument of layer
passed to the `dtype` argument of layer constructors, or a global policy can constructors, or a global policy can be set with
be set with `tf.keras.mixed_precision.experimental.set_policy`. A layer will `tf.keras.mixed_precision.experimental.set_policy`. A layer will default to
default to the global policy if no policy is passed to it's constructor. the global policy if no policy is passed to it's constructor.
For many models, each layer's policy will have the same compute dtype and For many models, each layer's policy will have the same compute dtype and
variable dtype, which will typically be float32. In this case, we refer to the variable dtype, which will typically be float32. In this case, we refer to the
@ -56,24 +56,17 @@ class Policy(object):
https://www.tensorflow.org/guide/keras/mixed_precision) for more https://www.tensorflow.org/guide/keras/mixed_precision) for more
information on how to use mixed precision. information on how to use mixed precision.
Certain policies also have a `tf.mixed_precision.experimental.LossScale`
instance, which is used by `tf.keras.Model`s to performance loss scaling. Loss
scaling is a technique used with mixed precision to avoid numerical underflow
in float16 gradients. Loss scaling is only done by Models in `Model.fit`,
`Model.train_on_batch`, and similar methods. Layers which are not Models
ignore the loss scale.
Policies are constructed by passing a string to the constructor, e.g. Policies are constructed by passing a string to the constructor, e.g.
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string `tf.keras.mixed_precision.Policy('float32')`. The string determines the
determines the compute and variable dtypes. It can be one of the following: compute and variable dtypes. It can be one of the following:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and * Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype. No loss scaling is done by default. compute dtypes will be that dtype.
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
bfloat16, while the variable dtype is float32. These policies are used for bfloat16, while the variable dtype is float32. With 'mixed_float16',
mixed precision training. With 'mixed_float16', a dynamic loss scale is `tf.keras.Model.compile` will wrap the optimizer with a
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss `tf.keras.mixed_precision.LossScaleOptimizer`. These policies are used for
scaling is unnecessary with bfloat16. mixed precision training.
### How to use mixed precision in a Keras model ### How to use mixed precision in a Keras model
@ -97,7 +90,7 @@ class Policy(object):
Alternatively, the policy can be passed to individual layers instead of Alternatively, the policy can be passed to individual layers instead of
setting the global policy with `set_policy`: setting the global policy with `set_policy`:
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') >>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
>>> model = tf.keras.models.Sequential([ >>> model = tf.keras.models.Sequential([
... tf.keras.layers.Input((100,)), ... tf.keras.layers.Input((100,)),
... tf.keras.layers.Dense(10, dtype=policy), ... tf.keras.layers.Dense(10, dtype=policy),
@ -110,7 +103,7 @@ class Policy(object):
`Model.fit`, `Model.train_on_batch`, and other training methods. If no such `Model.fit`, `Model.train_on_batch`, and other training methods. If no such
method is used (e.g., a custom training loop is used) and `'mixed_float16'` is method is used (e.g., a custom training loop is used) and `'mixed_float16'` is
used, the loss scale must be manually applied. See used, the loss scale must be manually applied. See
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For `tf.keras.mixed_precision.LossScaleOptimizer` for details. For
`'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be `'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be
manually applied. manually applied.
@ -227,7 +220,7 @@ class Policy(object):
... def build(self, input_shape): ... def build(self, input_shape):
... self.x = self.add_weight('x') ... self.x = self.add_weight('x')
... self.y = self.add_weight('y', experimental_autocast=False) ... self.y = self.add_weight('y', experimental_autocast=False)
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') >>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
>>> layer = MyLayer(dtype=policy) >>> layer = MyLayer(dtype=policy)
>>> layer.build((2, 2)) >>> layer.build((2, 2))
>>> layer.x >>> layer.x
@ -258,7 +251,7 @@ class Policy(object):
... def call(self, inputs): ... def call(self, inputs):
... return tf.matmul(inputs, self.kernel) ... return tf.matmul(inputs, self.kernel)
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') >>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
>>> layer = MyDense(dtype=policy) >>> layer = MyDense(dtype=policy)
>>> x = np.random.rand(10, 10) >>> x = np.random.rand(10, 10)
>>> y = layer(x) >>> y = layer(x)
@ -289,29 +282,27 @@ class Policy(object):
layer would only work if the inputs were float32. layer would only work if the inputs were float32.
""" """
def __init__(self, name, loss_scale='auto'): def __init__(self, name):
"""Constructs the policy. """Constructs the policy.
The `name` argument determines the compute and variable dtype, the default The `name` argument determines the compute and variable dtype. The compute
loss scale, and has no additional effect on the Policy. The compute and and variable dtypes can only be specified through `name`, and cannot be
variable dtypes can only be specified through `name`, and cannot be
specified directly. specified directly.
`name` is also used by `tf.keras.Model.compile`. If `name` is
`"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the
optimizer with a LossScaleOptimizer if it is not already a
LossScaleOptimizer.
Args: Args:
name: A string. Can be one of the following values: name: A string. Can be one of the following values:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and * Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype. compute dtypes will be that dtype.
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
bfloat16, while the variable dtype is float32. With 'mixed_float16', bfloat16, while the variable dtype is float32. With 'mixed_float16',
a dynamic loss scale is used. These policies are used for mixed `tf.keras.Model.compile` will wrap the optimizer with a
precision training. `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used
loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which for mixed precision training.
uses a `FixedLossScale`), the string "dynamic" (which uses a
`DynamicLossScale`), or None (which uses no loss scale). Defaults to
`"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then
use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
during `Model.fit`, `Model.train_on_batch`, and other similar methods.
""" """
if isinstance(name, dtypes.DType): if isinstance(name, dtypes.DType):
raise TypeError("'name' must be a string, not a DType. " raise TypeError("'name' must be a string, not a DType. "
@ -320,19 +311,6 @@ class Policy(object):
raise TypeError("'name' must be a string, but got: %s" % (name,)) raise TypeError("'name' must be a string, but got: %s" % (name,))
self._name = name self._name = name
self._compute_dtype, self._variable_dtype = self._parse_name(name) self._compute_dtype, self._variable_dtype = self._parse_name(name)
if loss_scale == 'auto':
loss_scale = 'dynamic' if name == 'mixed_float16' else None
self._using_default_loss_scale = True
else:
self._using_default_loss_scale = False
if loss_scale and self._compute_dtype not in (None, 'float16'):
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
'float16 policies. You passed loss_scale=%r for policy '
'%s. Consider not passing any loss_scale instead.' %
(loss_scale, name))
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if name in ('mixed_float16', 'mixed_bloat16'): if name in ('mixed_float16', 'mixed_bloat16'):
device_compatibility_check.log_device_compatibility_check(name) device_compatibility_check.log_device_compatibility_check(name)
@ -438,22 +416,99 @@ class Policy(object):
""" """
return self.variable_dtype != self.compute_dtype return self.variable_dtype != self.compute_dtype
@property
def loss_scale(self):
"""Returns the loss scale of this Policy.
Returns:
A `tf.mixed_precision.experimental.LossScale`, or None.
"""
return self._loss_scale
@property @property
def name(self): def name(self):
"""Returns the name of this policy.""" """Returns the name of this policy."""
return self._name return self._name
def __repr__(self): def __repr__(self):
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale) return '<Policy "%s">' % self._name
def get_config(self):
return {'name': self.name}
@classmethod
def from_config(cls, config, custom_objects=None):
del custom_objects
if 'loss_scale' in config:
config = config.copy()
# Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We
# silently drop it.
del config['loss_scale']
return cls(**config)
@keras_export('keras.mixed_precision.experimental.Policy', v1=[])
class PolicyV1(Policy):
"""A deprecated dtype policy for a Keras layer.
Warning: This class is now deprecated and will be removed soon. Please use the
non-experimental class `tf.keras.mixed_precision.Policy` instead.
The difference between this class and the non-experimental class is that this
class has a `loss_scale` field and the non-experimental class does not. The
loss scale is only used by `tf.keras.Model.compile`, which automatically wraps
the optimizer with a `LossScaleOptimizer` if the optimzier is not already a
`LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile`
instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is
"mixed_float16".
When deserializing objects with an experimental policy using functions like
`tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as
the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale
will silently be dropped. This is so that SavedModels that are generated
with an expeirmental policy can be restored after the experimental policy is
removed.
"""
def __init__(self, name, loss_scale='auto'):
"""Constructs the policy.
The `name` argument determines the compute and variable dtype, the default
loss scale, and has no additional effect on the Policy. The compute and
variable dtypes can only be specified through `name`, and cannot be
specified directly.
Args:
name: A string. Can be one of the following values:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype.
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
bfloat16, while the variable dtype is float32. With 'mixed_float16',
a dynamic loss scale is used. These policies are used for mixed
precision training.
loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which
uses a `FixedLossScale`), the string "dynamic" (which uses a
`DynamicLossScale`), or None (which uses no loss scale). Defaults to
`"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then
use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
during `Model.fit`, `Model.train_on_batch`, and other similar methods.
"""
super(PolicyV1, self).__init__(name)
if loss_scale == 'auto':
loss_scale = 'dynamic' if name == 'mixed_float16' else None
self._using_default_loss_scale = True
else:
self._using_default_loss_scale = False
if loss_scale and self._compute_dtype not in (None, 'float16'):
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
'float16 policies. You passed loss_scale=%r for policy '
'%s. Consider not passing any loss_scale instead.' %
(loss_scale, name))
self._loss_scale = keras_loss_scale_module.get(loss_scale)
@property
def loss_scale(self):
"""Returns the loss scale of this Policy.
Returns:
A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None.
"""
return self._loss_scale
def __repr__(self):
return '<PolicyV1 "%s", loss_scale=%s>' % (self._name, self.loss_scale)
def get_config(self): def get_config(self):
config = { config = {
@ -497,8 +552,7 @@ def global_policy():
first time the layer is called. This behavior matches the behavior that first time the layer is called. This behavior matches the behavior that
existed in TensorFlow 1. existed in TensorFlow 1.
See `tf.keras.mixed_precision.experimental.Policy` for more information on See `tf.keras.mixed_precision.Policy` for more information on policies.
policies.
Returns: Returns:
The global Policy. The global Policy.
@ -540,7 +594,7 @@ def set_policy(policy):
`'int32'` and `'complex64'` cannot be set as the global policy because most `'int32'` and `'complex64'` cannot be set as the global policy because most
layers do not support such policies. layers do not support such policies.
See `tf.keras.mixed_precision.experimental.Policy` for more information. See `tf.keras.mixed_precision.Policy` for more information.
Args: Args:
policy: A Policy, or a string that will be converted to a Policy.. policy: A Policy, or a string that will be converted to a Policy..
@ -597,8 +651,8 @@ def _policy_equivalent_to_dtype(policy):
"""Returns True if the Policy is equivalent to a single dtype. """Returns True if the Policy is equivalent to a single dtype.
A policy is equivalent to a single dtype if the policy's compute and variable A policy is equivalent to a single dtype if the policy's compute and variable
dtypes are the same and the policy does not cause the layer/model to have dtypes are the same and the policy's type is Policy and not a subclass of
additional behavior, such as loss scaling. Policy (such as PolicyV1).
The "_infer" policy is considered equivalent to a single dtype. The "_infer" policy is considered equivalent to a single dtype.
@ -629,7 +683,7 @@ def deserialize(config, custom_objects=None):
return Policy(config) return Policy(config)
if config is None: if config is None:
return Policy('_infer') return Policy('_infer')
module_objects = {'Policy': Policy} module_objects = {'Policy': Policy, 'PolicyV1': Policy}
return generic_utils.deserialize_keras_object( return generic_utils.deserialize_keras_object(
config, config,
module_objects=module_objects, module_objects=module_objects,

View File

@ -60,14 +60,21 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_repr(self): def test_repr(self):
for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): # Test Policy repr
for policy in ('float32', 'int8', 'mixed_float16', 'mixed_bfloat16',
'_infer'):
self.assertEqual(repr(mp_policy.Policy(policy)), self.assertEqual(repr(mp_policy.Policy(policy)),
'<Policy "%s", loss_scale=None>' % policy) '<Policy "%s">' % policy)
self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)),
'<Policy "float16", loss_scale=FixedLossScale(2.0)>') # Test PolicyV1 repr
for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'):
self.assertEqual(repr(mp_policy.PolicyV1(policy)),
'<PolicyV1 "%s", loss_scale=None>' % policy)
self.assertEqual(repr(mp_policy.PolicyV1('float16', loss_scale=2)),
'<PolicyV1 "float16", loss_scale=FixedLossScale(2.0)>')
self.assertStartsWith( self.assertStartsWith(
repr(mp_policy.Policy('mixed_float16')), repr(mp_policy.PolicyV1('mixed_float16')),
'<Policy "mixed_float16", loss_scale=DynamicLossScale(') '<PolicyV1 "mixed_float16", loss_scale=DynamicLossScale(')
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_policy_errors(self): def test_policy_errors(self):
@ -113,26 +120,26 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_loss_scale(self): def test_loss_scale(self):
policy = mp_policy.Policy('float32') policy = mp_policy.PolicyV1('float32')
self.assertEqual(policy.loss_scale, None) self.assertEqual(policy.loss_scale, None)
policy = mp_policy.Policy('float32', loss_scale=None) policy = mp_policy.PolicyV1('float32', loss_scale=None)
self.assertEqual(policy.loss_scale, None) self.assertEqual(policy.loss_scale, None)
ls = loss_scale_module.DynamicLossScale() ls = loss_scale_module.DynamicLossScale()
policy = mp_policy.Policy('float32', loss_scale=ls) policy = mp_policy.PolicyV1('float32', loss_scale=ls)
self.assertIs(policy.loss_scale, ls) self.assertIs(policy.loss_scale, ls)
policy = mp_policy.Policy('float32', loss_scale='dynamic') policy = mp_policy.PolicyV1('float32', loss_scale='dynamic')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale) self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16') policy = mp_policy.PolicyV1('mixed_float16')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale) self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16', loss_scale=None) policy = mp_policy.PolicyV1('mixed_float16', loss_scale=None)
self.assertEqual(policy.loss_scale, None) self.assertEqual(policy.loss_scale, None)
policy = mp_policy.Policy('mixed_bfloat16') policy = mp_policy.PolicyV1('mixed_bfloat16')
self.assertEqual(policy.loss_scale, None) self.assertEqual(policy.loss_scale, None)
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
@ -173,7 +180,7 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_loss_scale_warning(self): def test_loss_scale_warning(self):
with test.mock.patch.object(tf_logging, 'warn') as mock_warn: with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float32', loss_scale=2.) mp_policy.PolicyV1('float32', loss_scale=2.)
self.assertEqual( self.assertEqual(
mock_warn.call_args[0][0], mock_warn.call_args[0][0],
'Creating a Policy with a loss scale is only useful for float16 ' 'Creating a Policy with a loss scale is only useful for float16 '
@ -182,9 +189,9 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
for policy_name in 'float16', 'mixed_float16': for policy_name in 'float16', 'mixed_float16':
# Trigger any other warnings that occur only once # Trigger any other warnings that occur only once
mp_policy.Policy(policy_name, loss_scale=2.) mp_policy.PolicyV1(policy_name, loss_scale=2.)
with test.mock.patch.object(tf_logging, 'warn') as mock_warn: with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name, loss_scale=2.) mp_policy.PolicyV1(policy_name, loss_scale=2.)
mock_warn.assert_not_called() mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
@ -230,12 +237,6 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_float16'),
mp_policy.Policy('mixed_bfloat16'), mp_policy.Policy('mixed_bfloat16'),
mp_policy.Policy('_infer'), mp_policy.Policy('_infer'),
mp_policy.Policy('float32', loss_scale=2.),
mp_policy.Policy('float32', loss_scale=None),
mp_policy.Policy('mixed_float16', loss_scale=2.),
mp_policy.Policy('mixed_float16', loss_scale=None),
mp_policy.Policy('mixed_bfloat16', loss_scale=2.),
mp_policy.Policy('mixed_bfloat16', loss_scale=None),
): ):
config = policy.get_config() config = policy.get_config()
new_policy = mp_policy.Policy.from_config(config) new_policy = mp_policy.Policy.from_config(config)
@ -263,7 +264,7 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
class MyPolicy(mp_policy.Policy): class MyPolicy(mp_policy.Policy):
pass pass
# Test policies that do not override the loss scale # Test policies that are not equivalent to a single dtype
for policy in ( for policy in (
mp_policy.Policy('mixed_float16'), mp_policy.Policy('mixed_float16'),
mp_policy.Policy('mixed_bfloat16'), mp_policy.Policy('mixed_bfloat16'),
@ -276,14 +277,14 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
custom_objects={'MyPolicy': MyPolicy}) custom_objects={'MyPolicy': MyPolicy})
self.assertEqual(str(policy), str(new_policy)) self.assertEqual(str(policy), str(new_policy))
# Test policies that override the loss scale # Test V1 policies that override the loss scale
for policy in ( for policy in (
mp_policy.Policy('float32', loss_scale=2.), mp_policy.PolicyV1('float32', loss_scale=2.),
mp_policy.Policy('float32', loss_scale=None), mp_policy.PolicyV1('float32', loss_scale=None),
mp_policy.Policy('mixed_float16', loss_scale=2.), mp_policy.PolicyV1('mixed_float16', loss_scale=2.),
mp_policy.Policy('mixed_float16', loss_scale=None), mp_policy.PolicyV1('mixed_float16', loss_scale=None),
mp_policy.Policy('mixed_bfloat16', loss_scale=2.), mp_policy.PolicyV1('mixed_bfloat16', loss_scale=2.),
mp_policy.Policy('mixed_bfloat16', loss_scale=None), mp_policy.PolicyV1('mixed_bfloat16', loss_scale=None),
): ):
config = mp_policy.serialize(policy) config = mp_policy.serialize(policy)
expected_loss_scale_config = None expected_loss_scale_config = None
@ -300,9 +301,6 @@ class PolicyTest(test.TestCase, parameterized.TestCase):
'loss_scale': expected_loss_scale_config 'loss_scale': expected_loss_scale_config
} }
}) })
new_policy = mp_policy.deserialize(
config, custom_objects={'MyPolicy': MyPolicy})
self.assertEqual(str(policy), str(new_policy))
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_error_if_graph_rewrite_enabled(self): def test_error_if_graph_rewrite_enabled(self):

View File

@ -71,6 +71,10 @@ def deserialize(config, custom_objects=None):
'sgd': gradient_descent_v2.SGD, 'sgd': gradient_descent_v2.SGD,
'ftrl': ftrl.Ftrl, 'ftrl': ftrl.Ftrl,
'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer, 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer,
# LossScaleOptimizerV1 deserializes into LossScaleOptimizer, as
# LossScaleOptimizerV1 will be removed soon but deserializing it will
# still be supported.
'lossscaleoptimizerv1': loss_scale_optimizer.LossScaleOptimizer,
} }
# Make deserialization case-insensitive for built-in optimizers. # Make deserialization case-insensitive for built-in optimizers.

View File

@ -176,9 +176,9 @@ class Layer(object):
Attributes: Attributes:
name: The name of the layer (string). name: The name of the layer (string).
dtype: The dtype of the layer's computations and weights. If mixed dtype: The dtype of the layer's computations and weights. If mixed
precision is used with a `tf.keras.mixed_precision.experimental.Policy`, precision is used with a `tf.keras.mixed_precision.Policy`, this is
this is instead just the dtype of the layer's weights, as the computations instead just the dtype of the layer's weights, as the computations are
are done in a different dtype. done in a different dtype.
updates: List of update ops of this layer. updates: List of update ops of this layer.
losses: List of losses added by this layer. losses: List of losses added by this layer.
trainable_weights: List of variables to be included in backprop. trainable_weights: List of variables to be included in backprop.
@ -197,7 +197,6 @@ class Layer(object):
if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, if no dtype is passed. `floatx()` itself defaults to "float32". Additionally,
layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed
precision is used, layers may have different computation and variable dtypes. precision is used, layers may have different computation and variable dtypes.
See `tf.keras.mixed_precision.experimental.Policy` for details on layer See `tf.keras.mixed_precision.Policy` for details on layer dtypes.
dtypes.
""" """
pass pass

View File

@ -37,20 +37,31 @@ from tensorflow.python.util.tf_export import tf_export
@six.add_metaclass(abc.ABCMeta) @six.add_metaclass(abc.ABCMeta)
@deprecation.deprecated_endpoints('train.experimental.LossScale') @deprecation.deprecated_endpoints('mixed_precision.experimental.LossScale',
@tf_export('mixed_precision.experimental.LossScale',
'train.experimental.LossScale') 'train.experimental.LossScale')
@tf_export(
'mixed_precision.experimental.LossScale',
'train.experimental.LossScale',
v1=[
'mixed_precision.LossScale',
'mixed_precision.experimental.LossScale',
'train.experimental.LossScale'
])
class LossScale(trackable.Trackable): class LossScale(trackable.Trackable):
"""Base class for all loss scales. """Base class for all TF1 loss scales.
WARNING: This class is deprecated and will be unexposed from the TF 2
namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
be accessible as `tf.compat.v1.mixed_precision.LossScale`. Additionally in
2.5, you will no longer be able to pass a `LossScale` to a
`tf.keras.mixed_precision.Policy`. All the functionality in this class has
been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
is no longer needed.
This is an abstract base class, so you cannot instantiate it directly. This is an abstract base class, so you cannot instantiate it directly.
Instead, use one of its concrete subclasses: Instead, use one of its concrete subclasses:
* `tf.mixed_precision.experimental.DynamicLossScale` (recommended) * `tf.compat.v1.mixed_precision.DynamicLossScale`
* `tf.mixed_precision.experimental.FixedLossScale` * `tf.compat.v1.mixed_precision.FixedLossScale`
It's recommended to use a loss scale with a
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, as its easier than
using a loss scale directly.
Loss scaling is a process that multiplies the loss by a multiplier called the Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode loss scale, and divides each gradient by the same multiplier. The pseudocode
@ -198,16 +209,35 @@ class LossScale(trackable.Trackable):
return cls(**config) return cls(**config)
@deprecation.deprecated_endpoints('train.experimental.FixedLossScale') @deprecation.deprecated_endpoints('mixed_precision.experimental.FixedLossScale',
@tf_export('mixed_precision.experimental.FixedLossScale',
'train.experimental.FixedLossScale') 'train.experimental.FixedLossScale')
@tf_export(
'mixed_precision.experimental.FixedLossScale',
'train.experimental.FixedLossScale',
v1=[
'mixed_precision.FixedLossScale',
'mixed_precision.experimental.FixedLossScale',
'train.experimental.FixedLossScale'
])
class FixedLossScale(LossScale): class FixedLossScale(LossScale):
"""Loss scale with a fixed value. """Loss scale with a fixed value.
WARNING: This class is deprecated and will be unexposed from the TF 2
namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
be accessible as `tf.compat.v1.mixed_precision.FixedLossScale`. Additionally
in 2.5, you will no longer be able to pass a `FixedLossScale` to a
`tf.keras.mixed_precision.Policy`. All the functionality in this class has
been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
is no longer needed.
The loss scale is not updated for the lifetime of instances of this class. The loss scale is not updated for the lifetime of instances of this class.
A given instance of this class always returns the same number when called. A given instance of this class always returns the same number when called.
""" """
@deprecation.deprecated(
None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
'LossScaleOptimizer now has all the functionality of '
'FixedLossScale')
def __init__(self, loss_scale_value): def __init__(self, loss_scale_value):
"""Creates the fixed loss scale. """Creates the fixed loss scale.
@ -280,12 +310,28 @@ def _assign_if_finite(var, value):
control_flow_ops.no_op) control_flow_ops.no_op)
@deprecation.deprecated_endpoints('train.experimental.DynamicLossScale') @deprecation.deprecated_endpoints(
@tf_export('mixed_precision.experimental.DynamicLossScale', 'mixed_precision.experimental.DynamicLossScale',
'train.experimental.DynamicLossScale') 'train.experimental.DynamicLossScale')
@tf_export(
'mixed_precision.experimental.DynamicLossScale',
'train.experimental.DynamicLossScale',
v1=[
'mixed_precision.DynamicLossScale',
'mixed_precision.experimental.DynamicLossScale',
'train.experimental.DynamicLossScale'
])
class DynamicLossScale(LossScale): class DynamicLossScale(LossScale):
"""Loss scale that dynamically adjusts itself. """Loss scale that dynamically adjusts itself.
WARNING: This class is deprecated and will be unexposed from the TF 2
namespace starting in TensorFlow 2.5. In TensorFlow 2.5, this class will only
be accessible as `tf.compat.v1.mixed_precision.DynamicLossScale`. Additionally
in 2.5, you will no longer be able to pass a `DynamicLossScale` to a
`tf.keras.mixed_precision.Policy`. All the functionality in this class has
been merged into `tf.keras.mixed_precision.LossScaleOptimizer`, so this class
is no longer needed.
Dynamic loss scaling works by adjusting the loss scale as training progresses. Dynamic loss scaling works by adjusting the loss scale as training progresses.
The goal is to keep the loss scale as high as possible without overflowing the The goal is to keep the loss scale as high as possible without overflowing the
gradients. As long as the gradients do not overflow, raising the loss scale gradients. As long as the gradients do not overflow, raising the loss scale
@ -299,6 +345,10 @@ class DynamicLossScale(LossScale):
overflowing. overflowing.
""" """
@deprecation.deprecated(
None, 'Use tf.keras.mixed_precision.LossScaleOptimizer instead. '
'LossScaleOptimizer now has all the functionality of '
'DynamicLossScale')
def __init__(self, def __init__(self,
initial_loss_scale=2 ** 15, # See docstring for why this is big. initial_loss_scale=2 ** 15, # See docstring for why this is big.
increment_period=2000, increment_period=2000,

View File

@ -33,7 +33,7 @@ mixed_precision_graph_rewrite_is_enabled = False
# Session has already been created. # Session has already been created.
non_mixed_precision_session_created = False non_mixed_precision_session_created = False
# Whether the global tf.keras.mixed_precision.experimental.Policy uses mixed # Whether the global tf.keras.mixed_precision.Policy uses mixed precision. Used
# precision. Used to raise an error message if both a mixed Policy and the graph # to raise an error message if both a mixed Policy and the graph rewrite are
# rewrite are used at the same time. # used at the same time.
using_mixed_precision_policy = False using_mixed_precision_policy = False

View File

@ -0,0 +1,112 @@
path: "tensorflow.keras.mixed_precision.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "clipnorm"
mtype: "<type \'property\'>"
}
member {
name: "clipvalue"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_counter"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_growth_steps"
mtype: "<type \'property\'>"
}
member {
name: "global_clipnorm"
mtype: "<type \'property\'>"
}
member {
name: "initial_scale"
mtype: "<type \'property\'>"
}
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_scaled_loss"
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_unscaled_gradients"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,6 @@
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizerV1\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>" is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>" is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>" is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
@ -13,10 +14,26 @@ tf_class {
name: "clipvalue" name: "clipvalue"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_counter"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_growth_steps"
mtype: "<type \'property\'>"
}
member { member {
name: "global_clipnorm" name: "global_clipnorm"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "initial_scale"
mtype: "<type \'property\'>"
}
member { member {
name: "iterations" name: "iterations"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.mixed_precision" path: "tensorflow.keras.mixed_precision"
tf_module { tf_module {
member {
name: "LossScaleOptimizer"
mtype: "<type \'type\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -0,0 +1,35 @@
path: "tensorflow.mixed_precision.DynamicLossScale"
tf_class {
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale.DynamicLossScale\'>"
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale.LossScale\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "increment_period"
mtype: "<type \'property\'>"
}
member {
name: "initial_loss_scale"
mtype: "<type \'property\'>"
}
member {
name: "multiplier"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initial_loss_scale\', \'increment_period\', \'multiplier\'], varargs=None, keywords=None, defaults=[\'32768\', \'2000\', \'2.0\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,23 @@
path: "tensorflow.mixed_precision.FixedLossScale"
tf_class {
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale.FixedLossScale\'>"
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale.LossScale\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'loss_scale_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,22 @@
path: "tensorflow.mixed_precision.LossScale"
tf_class {
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale.LossScale\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,17 @@
path: "tensorflow.mixed_precision" path: "tensorflow.mixed_precision"
tf_module { tf_module {
member {
name: "DynamicLossScale"
mtype: "<type \'type\'>"
}
member {
name: "FixedLossScale"
mtype: "<type \'type\'>"
}
member {
name: "LossScale"
mtype: "<type \'type\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -0,0 +1,112 @@
path: "tensorflow.keras.mixed_precision.LossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "clipnorm"
mtype: "<type \'property\'>"
}
member {
name: "clipvalue"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_counter"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_growth_steps"
mtype: "<type \'property\'>"
}
member {
name: "global_clipnorm"
mtype: "<type \'property\'>"
}
member {
name: "initial_scale"
mtype: "<type \'property\'>"
}
member {
name: "iterations"
mtype: "<type \'property\'>"
}
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "add_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], "
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_scaled_loss"
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_unscaled_gradients"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,33 @@
path: "tensorflow.keras.mixed_precision.Policy"
tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
is_instance: "<type \'object\'>"
member {
name: "compute_dtype"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "should_cast_variables"
mtype: "<type \'property\'>"
}
member {
name: "variable_dtype"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,6 @@
path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizerV1\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>" is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer.LossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>" is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.loss_scale_optimizer._DelegatingTrackableMixin\'>"
is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>" is_instance: "<class \'tensorflow.python.keras.optimizer_v2.optimizer_v2.OptimizerV2\'>"
@ -13,10 +14,26 @@ tf_class {
name: "clipvalue" name: "clipvalue"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_counter"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_growth_steps"
mtype: "<type \'property\'>"
}
member { member {
name: "global_clipnorm" name: "global_clipnorm"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "initial_scale"
mtype: "<type \'property\'>"
}
member { member {
name: "iterations" name: "iterations"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"

View File

@ -1,5 +1,6 @@
path: "tensorflow.keras.mixed_precision.experimental.Policy" path: "tensorflow.keras.mixed_precision.experimental.Policy"
tf_class { tf_class {
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.PolicyV1\'>"
is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>" is_instance: "<class \'tensorflow.python.keras.mixed_precision.experimental.policy.Policy\'>"
is_instance: "<type \'object\'>" is_instance: "<type \'object\'>"
member { member {

View File

@ -1,5 +1,13 @@
path: "tensorflow.keras.mixed_precision" path: "tensorflow.keras.mixed_precision"
tf_module { tf_module {
member {
name: "LossScaleOptimizer"
mtype: "<type \'type\'>"
}
member {
name: "Policy"
mtype: "<type \'type\'>"
}
member { member {
name: "experimental" name: "experimental"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -770,6 +770,12 @@ renames = {
'tf.linalg.matrix_transpose', 'tf.linalg.matrix_transpose',
'tf.matrix_triangular_solve': 'tf.matrix_triangular_solve':
'tf.linalg.triangular_solve', 'tf.linalg.triangular_solve',
'tf.mixed_precision.DynamicLossScale':
'tf.compat.v1.mixed_precision.DynamicLossScale',
'tf.mixed_precision.FixedLossScale':
'tf.compat.v1.mixed_precision.FixedLossScale',
'tf.mixed_precision.LossScale':
'tf.compat.v1.mixed_precision.LossScale',
'tf.metrics.accuracy': 'tf.metrics.accuracy':
'tf.compat.v1.metrics.accuracy', 'tf.compat.v1.metrics.accuracy',
'tf.metrics.auc': 'tf.metrics.auc':