Merge pull request #32269 from reedwm/mp_cherrypicks
[r2.0-rc1 CherryPick]: Several tf.keras mixed precision API changes
This commit is contained in:
commit
8ea0a418a6
@ -2592,6 +2592,9 @@ class AddLoss(Layer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, unconditional, **kwargs):
|
def __init__(self, unconditional, **kwargs):
|
||||||
|
# Pass autocast=False, as there is no reason to cast loss to a different
|
||||||
|
# dtype.
|
||||||
|
kwargs['autocast'] = False
|
||||||
super(AddLoss, self).__init__(**kwargs)
|
super(AddLoss, self).__init__(**kwargs)
|
||||||
self.unconditional = unconditional
|
self.unconditional = unconditional
|
||||||
|
|
||||||
|
@ -230,12 +230,10 @@ class Network(base_layer.Layer):
|
|||||||
else:
|
else:
|
||||||
self._graph = ops.get_default_graph() # Used in symbolic mode only.
|
self._graph = ops.get_default_graph() # Used in symbolic mode only.
|
||||||
|
|
||||||
# Both graph and subclassed networks have a dtype policy. The policy is
|
# Both graph and subclassed networks have a dtype policy. For graph
|
||||||
# currently ignored for a graph network, as graph networks disable
|
# networks, the policy's compute and variable dtypes are ignored, but other
|
||||||
# autocasting (making the policy's compute dtype meaningless) and graph
|
# fields, like the loss scale, are used by Models. For subclassed networks,
|
||||||
# networks have no variables (making the policy's variable_dtype
|
# the compute and variable dtypes are used as like any ordinary layer.
|
||||||
# meaningless). For subclassed networks, the dtype policy acts as it does
|
|
||||||
# for any ordinary layer.
|
|
||||||
self._set_dtype_policy(kwargs.get('dtype', None))
|
self._set_dtype_policy(kwargs.get('dtype', None))
|
||||||
|
|
||||||
# All layers in order of horizontal graph traversal.
|
# All layers in order of horizontal graph traversal.
|
||||||
|
@ -52,6 +52,8 @@ from tensorflow.python.keras.engine import training_generator
|
|||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.engine import training_v2
|
from tensorflow.python.keras.engine import training_v2
|
||||||
from tensorflow.python.keras.engine import training_v2_utils
|
from tensorflow.python.keras.engine import training_v2_utils
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
from tensorflow.python.keras.saving import saving_utils
|
from tensorflow.python.keras.saving import saving_utils
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
@ -248,13 +250,9 @@ class Model(network.Network):
|
|||||||
self._experimental_run_tf_function = kwargs.pop(
|
self._experimental_run_tf_function = kwargs.pop(
|
||||||
'experimental_run_tf_function', True)
|
'experimental_run_tf_function', True)
|
||||||
|
|
||||||
if isinstance(optimizer, (list, tuple)):
|
self._set_optimizer(optimizer)
|
||||||
self.optimizer = [optimizers.get(opt) for opt in optimizer]
|
is_any_optimizer_v1 = any(isinstance(opt, optimizers.Optimizer)
|
||||||
is_any_optimizer_v1 = any(
|
for opt in nest.flatten(self.optimizer))
|
||||||
isinstance(opt, optimizers.Optimizer) for opt in self.optimizer)
|
|
||||||
else:
|
|
||||||
self.optimizer = optimizers.get(optimizer)
|
|
||||||
is_any_optimizer_v1 = isinstance(self.optimizer, optimizers.Optimizer)
|
|
||||||
|
|
||||||
if ((sample_weight_mode is not None)
|
if ((sample_weight_mode is not None)
|
||||||
or (target_tensors is not None)
|
or (target_tensors is not None)
|
||||||
@ -1443,6 +1441,47 @@ class Model(network.Network):
|
|||||||
'and the first argument in `call` as positional arguments, '
|
'and the first argument in `call` as positional arguments, '
|
||||||
'found: ' + str(extra_args) + '.')
|
'found: ' + str(extra_args) + '.')
|
||||||
|
|
||||||
|
def _set_optimizer(self, optimizer):
|
||||||
|
"""Sets self.optimizer.
|
||||||
|
|
||||||
|
Sets self.optimizer to `optimizer`, potentially wrapping it with a
|
||||||
|
LossScaleOptimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: The optimizer(s) to assign to self.optimizer.
|
||||||
|
"""
|
||||||
|
if isinstance(optimizer, (list, tuple)):
|
||||||
|
self.optimizer = [optimizers.get(opt) for opt in optimizer]
|
||||||
|
else:
|
||||||
|
self.optimizer = optimizers.get(optimizer)
|
||||||
|
|
||||||
|
if (self._dtype_policy.loss_scale is not None and
|
||||||
|
not isinstance(self.optimizer,
|
||||||
|
loss_scale_optimizer.LossScaleOptimizer)):
|
||||||
|
if isinstance(self.optimizer, list):
|
||||||
|
raise ValueError('When a dtype policy with a loss scale is used, you '
|
||||||
|
'can only pass a single optimizer. Using policy %s '
|
||||||
|
'and got optimizers: %s' %
|
||||||
|
self._dtype_policy, self.optimizer)
|
||||||
|
if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
|
||||||
|
raise ValueError('"optimizer" must be an instance of '
|
||||||
|
'tf.keras.optimizers.Optimizer when a dype policy '
|
||||||
|
'with a loss scale used, but got: %s. Using policy: '
|
||||||
|
'%s' %
|
||||||
|
(self.optimizer, self._dtype_policy))
|
||||||
|
self.optimizer = loss_scale_optimizer.LossScaleOptimizer(
|
||||||
|
self.optimizer, self._dtype_policy.loss_scale)
|
||||||
|
if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and
|
||||||
|
self._dtype_policy.loss_scale and
|
||||||
|
self.optimizer.loss_scale != self._dtype_policy.loss_scale):
|
||||||
|
logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) '
|
||||||
|
'is not the same as the dtype policy\'s loss scale (%s). '
|
||||||
|
'Because the dtype policy has a loss scale, you should '
|
||||||
|
'pass an optimizer that is not wrapped with a '
|
||||||
|
'LossScaleOptimizer,'
|
||||||
|
% (self.optimizer.loss_scale,
|
||||||
|
self._dtype_policy.loss_scale))
|
||||||
|
|
||||||
def _prepare_validation_data(self, validation_data, batch_size,
|
def _prepare_validation_data(self, validation_data, batch_size,
|
||||||
validation_steps):
|
validation_steps):
|
||||||
"""Unpack and check the validation data."""
|
"""Unpack and check the validation data."""
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras import backend
|
|||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import layers
|
from tensorflow.python.keras import layers
|
||||||
from tensorflow.python.keras import models
|
from tensorflow.python.keras import models
|
||||||
|
from tensorflow.python.keras import optimizers
|
||||||
from tensorflow.python.keras import regularizers
|
from tensorflow.python.keras import regularizers
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
from tensorflow.python.keras.engine import base_layer
|
from tensorflow.python.keras.engine import base_layer
|
||||||
@ -434,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
}, {
|
}, {
|
||||||
'testcase_name': 'infer',
|
'testcase_name': 'infer',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
'policy_name': 'infer_with_float32_vars'
|
'policy_name': 'mixed_float16'
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'norun_distributed',
|
'testcase_name': 'norun_distributed',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
@ -445,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
strategy_fn,
|
strategy_fn,
|
||||||
use_operator=False,
|
use_operator=False,
|
||||||
use_regularizer=False,
|
use_regularizer=False,
|
||||||
policy_name='float16_with_float32_vars',
|
policy_name='mixed_float16',
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
||||||
return
|
return
|
||||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||||
with strategy_fn().scope():
|
with strategy_fn().scope():
|
||||||
with policy.policy_scope(policy_name):
|
# Pass loss_scale=None, as this test will fail if the DynamicLossScale
|
||||||
|
# skips applying gradients for a step
|
||||||
|
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
|
||||||
layer_list = []
|
layer_list = []
|
||||||
if testing_utils.get_model_type() == 'subclass':
|
if testing_utils.get_model_type() == 'subclass':
|
||||||
# Subclassed models do not have an Input layer, so the model does not
|
# Subclassed models do not have an Input layer, so the model does not
|
||||||
@ -579,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
if use_loss_scaling:
|
if use_loss_scaling:
|
||||||
loss_scale = 8.
|
loss_scale = 8.
|
||||||
|
else:
|
||||||
|
loss_scale = None
|
||||||
learning_rate = 2**-14
|
learning_rate = 2**-14
|
||||||
|
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
with policy.policy_scope(policy.Policy('float16_with_float32_vars')):
|
with policy.policy_scope(policy.Policy('mixed_float16',
|
||||||
|
loss_scale=loss_scale)):
|
||||||
x = layers.Input(shape=(1,), batch_size=2)
|
x = layers.Input(shape=(1,), batch_size=2)
|
||||||
layer1 = AddLayer(
|
layer1 = AddLayer(
|
||||||
assert_type=dtypes.float16,
|
assert_type=dtypes.float16,
|
||||||
@ -618,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
return math_ops.reduce_mean(y_pred)
|
return math_ops.reduce_mean(y_pred)
|
||||||
|
|
||||||
opt = gradient_descent.SGD(learning_rate)
|
opt = gradient_descent.SGD(learning_rate)
|
||||||
if use_loss_scaling:
|
|
||||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
|
||||||
model.compile(
|
model.compile(
|
||||||
opt,
|
opt,
|
||||||
loss=loss_fn,
|
loss=loss_fn,
|
||||||
@ -646,6 +650,10 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
}, {
|
}, {
|
||||||
'testcase_name': 'distribute',
|
'testcase_name': 'distribute',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'pass_loss_scale_to_policy',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'pass_loss_scale_to_policy': True,
|
||||||
}, {
|
}, {
|
||||||
'testcase_name': 'norun_distributed',
|
'testcase_name': 'norun_distributed',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
@ -653,18 +661,27 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
})
|
})
|
||||||
def test_dynamic_loss_scaling(self,
|
def test_dynamic_loss_scaling(self,
|
||||||
strategy_fn,
|
strategy_fn,
|
||||||
|
pass_loss_scale_to_policy=False,
|
||||||
experimental_run_tf_function=True):
|
experimental_run_tf_function=True):
|
||||||
if not self._is_strategy_supported(strategy_fn):
|
if not self._is_strategy_supported(strategy_fn):
|
||||||
return
|
return
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
initial_loss_scale = 2.
|
initial_loss_scale = 2.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
|
initial_loss_scale=initial_loss_scale, increment_period=2)
|
||||||
expected_gradient = backend.variable([initial_loss_scale / batch_size],
|
expected_gradient = backend.variable([initial_loss_scale / batch_size],
|
||||||
dtype=dtypes.float16)
|
dtype=dtypes.float16)
|
||||||
# If this variable is set to True, the model below will have NaN gradients
|
# If this variable is set to True, the model below will have NaN gradients
|
||||||
have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
|
have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
|
||||||
with strategy.scope():
|
with strategy.scope():
|
||||||
with policy.policy_scope(policy.Policy('infer_float32_vars')):
|
opt = gradient_descent.SGD(1.)
|
||||||
|
if pass_loss_scale_to_policy:
|
||||||
|
p = policy.Policy('infer_float32_vars', loss_scale=loss_scale)
|
||||||
|
else:
|
||||||
|
p = policy.Policy('infer_float32_vars')
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||||
|
with policy.policy_scope(p):
|
||||||
x = layers.Input(
|
x = layers.Input(
|
||||||
shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
|
shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
|
||||||
layer = AddLayer(assert_type=dtypes.float16)
|
layer = AddLayer(assert_type=dtypes.float16)
|
||||||
@ -685,10 +702,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
del y_true
|
del y_true
|
||||||
return math_ops.reduce_mean(y_pred)
|
return math_ops.reduce_mean(y_pred)
|
||||||
|
|
||||||
opt = gradient_descent.SGD(1.)
|
|
||||||
loss_scale = loss_scale_module.DynamicLossScale(
|
|
||||||
initial_loss_scale=initial_loss_scale, increment_period=2)
|
|
||||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
|
||||||
model.compile(
|
model.compile(
|
||||||
opt,
|
opt,
|
||||||
loss=loss_fn,
|
loss=loss_fn,
|
||||||
@ -728,6 +741,40 @@ class KerasModelTest(keras_parameterized.TestCase):
|
|||||||
model.fit(dataset)
|
model.fit(dataset)
|
||||||
self.assertEqual(backend.eval(layer.v), -3)
|
self.assertEqual(backend.eval(layer.v), -3)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_loss_scale_optimizer_overrides_policy_loss_scale(self):
|
||||||
|
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
|
||||||
|
opt = gradient_descent.SGD(1.)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.)
|
||||||
|
x = layers.Input(shape=(1,))
|
||||||
|
y = AddLayer()(x)
|
||||||
|
model = models.Model(x, y)
|
||||||
|
model.compile(opt, loss='mse')
|
||||||
|
self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_pass_invalid_optimizer_with_loss_scaling(self):
|
||||||
|
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
|
||||||
|
x = layers.Input(shape=(1,))
|
||||||
|
y = AddLayer()(x)
|
||||||
|
model = models.Model(x, y)
|
||||||
|
with self.assertRaisesRegexp(ValueError,
|
||||||
|
'optimizer" must be an instance of '):
|
||||||
|
model.compile(optimizers.SGD(1.), 'mse')
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_functional_model_loss_dtype(self):
|
||||||
|
with policy.policy_scope('float16'):
|
||||||
|
x = layers.Input(shape=(1,))
|
||||||
|
y = AddLayer()(x)
|
||||||
|
model = models.Model(x, y)
|
||||||
|
model.add_loss(math_ops.cast(y, 'float32'))
|
||||||
|
# The loss should not be casted to the policy's dtype.
|
||||||
|
self.assertEqual(model.losses[0].dtype, 'float32')
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
{
|
{
|
||||||
'testcase_name': 'base',
|
'testcase_name': 'base',
|
||||||
|
@ -24,19 +24,26 @@ import six
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import base_layer_utils
|
from tensorflow.python.keras.engine import base_layer_utils
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
|
# Default value of certain arguments, indicating the default behavior for
|
||||||
|
# that argument should be used.
|
||||||
|
USE_DEFAULT = 'USE_DEFAULT'
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.mixed_precision.experimental.Policy')
|
@keras_export('keras.mixed_precision.experimental.Policy')
|
||||||
class Policy(object):
|
class Policy(object):
|
||||||
"""A dtype policy for a Keras layer.
|
"""A dtype policy for a Keras layer.
|
||||||
|
|
||||||
A dtype policy determines the computation dtype and the variable dtype of a
|
A dtype policy determines dtype-related aspects of a layer, such as its
|
||||||
Keras layer. Each layer has a policy. Policies can be passed to the 'dtype'
|
computation and variable dtypes. Each layer has a policy. Policies can be
|
||||||
argument of layer constructors, or a global policy can be set with
|
passed to the 'dtype' argument of layer constructors, or a global policy can
|
||||||
'tf.keras.mixed_precision.experimental.set_policy'. A layer will default to
|
be set with 'tf.keras.mixed_precision.experimental.set_policy'. A layer will
|
||||||
the global policy if no policy is passed to it's constructor.
|
default to the global policy if no policy is passed to it's constructor.
|
||||||
|
|
||||||
For most models, each layer will have the same computation dtype and variable
|
For most models, each layer will have the same computation dtype and variable
|
||||||
dtype, which will typically be float32. However, when mixed precision
|
dtype, which will typically be float32. However, when mixed precision
|
||||||
@ -47,55 +54,49 @@ class Policy(object):
|
|||||||
not match the computation dtype, variables will be automatically casted to the
|
not match the computation dtype, variables will be automatically casted to the
|
||||||
computation dtype to avoid type errors.
|
computation dtype to avoid type errors.
|
||||||
|
|
||||||
In the near future, policies will also determine the loss scaling algorithm
|
Policies also have a `tf.train.experimental.LossScale` instance, which is used
|
||||||
for Keras models.
|
by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
|
||||||
|
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
|
||||||
|
ignore the loss scale.
|
||||||
|
|
||||||
Policies are constructed by passing a string to the constructor, e.g.
|
Policies are constructed by passing a string to the constructor, e.g.
|
||||||
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
||||||
determines the compute and variable dtypes. Currently, it can be one of
|
determines the compute and variable dtypes. It can be one of the following:
|
||||||
in one of the following forms:
|
|
||||||
|
|
||||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||||
compute dtypes will be that dtype.
|
compute dtypes will be that dtype. No loss scaling is done by default.
|
||||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
|
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||||
will be <dtype>, while the variable dtype is float32. This is intended for
|
bfloat16, while the variable dtype is float32. These policies are used for
|
||||||
the use of mixed precision, which uses float16 or bfloat16 for most
|
mixed precision training. With 'mixed_float16', a dynamic loss scale is
|
||||||
computations, and float32 for variables. This policy is only useful if
|
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
|
||||||
<dtype> is float16 or bfloat16, although <dtype> is allowed to be any
|
scaling is unnecessary with bfloat16.
|
||||||
dtype. Note we will have a "mixed" policy in the future, which will make
|
|
||||||
it even easier to use mixed precision by enabling other features such as
|
|
||||||
loss scaling.
|
|
||||||
|
|
||||||
### How to use mixed precision in layers with Policies
|
### How to use mixed precision in layers with Policies
|
||||||
|
|
||||||
To use mixed precision in a model, the 'float16_with_float32_vars' policy can
|
To use mixed precision in a model, the 'mixed_float16' policy can
|
||||||
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
|
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
|
||||||
the default policy for layers if no policy is passed to them. Note loss
|
the default policy for layers if no policy is passed to them. For example:
|
||||||
scaling must also be done, e.g. with a
|
|
||||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`. For example
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
tf.keras.mixed_precision.experimental.set_policy(
|
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||||
'float16_with_float32_vars')
|
|
||||||
model = tf.keras.models.Sequential(
|
model = tf.keras.models.Sequential(
|
||||||
tf.keras.layers.Input((100,)),
|
tf.keras.layers.Input((100,)),
|
||||||
# Dense layers use global policy of 'float16_with_float32_vars'
|
# Dense layers use global policy of 'mixed_float16', which does
|
||||||
|
# computations in float16 while keeping variables in float32.
|
||||||
tf.keras.layers.Dense(10),
|
tf.keras.layers.Dense(10),
|
||||||
tf.keras.layers.Dense(10),
|
tf.keras.layers.Dense(10),
|
||||||
# Softmax should be done in float32 for numeric stability. We pass
|
# Softmax should be done in float32 for numeric stability. We pass
|
||||||
# dtype='float32' to use float32 instead of the global policy.
|
# dtype='float32' to use float32 instead of the global policy.
|
||||||
tf.keras.layers.Activation('Softmax', dtype='float32')
|
tf.keras.layers.Activation('Softmax', dtype='float32')
|
||||||
)
|
)
|
||||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
model.fit(...) # Train `model`
|
||||||
... # Train `model` with `opt`.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, the policy can be passed to individual layers instead of
|
Alternatively, the policy can be passed to individual layers instead of
|
||||||
setting the global policy with `set_policy`:
|
setting the global policy with `set_policy`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||||
'float16_with_float32_vars')
|
|
||||||
model = tf.keras.models.Sequential(
|
model = tf.keras.models.Sequential(
|
||||||
tf.keras.layers.Input((100,)),
|
tf.keras.layers.Input((100,)),
|
||||||
tf.keras.layers.Dense(10, dtype=policy),
|
tf.keras.layers.Dense(10, dtype=policy),
|
||||||
@ -103,14 +104,21 @@ class Policy(object):
|
|||||||
# Softmax should be done in float32 for numeric stability.
|
# Softmax should be done in float32 for numeric stability.
|
||||||
tf.keras.layers.Activation('Softmax', dtype='float32')
|
tf.keras.layers.Activation('Softmax', dtype='float32')
|
||||||
)
|
)
|
||||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
model.fit(...) # Train `model`
|
||||||
... # Train `model` with `opt`.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
As the above example shows, strings can be directly passed to layer
|
As the above example shows, strings can be directly passed to layer
|
||||||
constructors in the `dtype` argument instead of policies, but only if the
|
constructors in the `dtype` argument instead of policies, but only if the
|
||||||
string is convertible to a dtype.
|
string is convertible to a dtype.
|
||||||
|
|
||||||
|
Note the 'mixed_float16' policy will apply loss scaling by default in
|
||||||
|
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
|
||||||
|
custom training loop is used) and 'mixed_float16' is used, the loss scale must
|
||||||
|
be manually applied. See
|
||||||
|
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
|
||||||
|
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
|
||||||
|
manually applied.
|
||||||
|
|
||||||
### The deprecated "infer" policy
|
### The deprecated "infer" policy
|
||||||
|
|
||||||
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
||||||
@ -122,7 +130,13 @@ class Policy(object):
|
|||||||
the dtype of the first input.
|
the dtype of the first input.
|
||||||
|
|
||||||
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
||||||
that infers the compute dtype, but not the variable dtype.
|
that infers the compute dtype, but not the variable dtype. Once a layer with
|
||||||
|
an "infer_with_float32_vars" policy is called for the first time, the layer's
|
||||||
|
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
|
||||||
|
of the first input. These policies force variables in float32.
|
||||||
|
|
||||||
|
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
|
||||||
|
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
|
||||||
|
|
||||||
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
||||||
available.
|
available.
|
||||||
@ -130,7 +144,13 @@ class Policy(object):
|
|||||||
# TODO(reedwm): Replace link in above docstring with a version that is more
|
# TODO(reedwm): Replace link in above docstring with a version that is more
|
||||||
# TensorFlow-specific, and that also mentions bfloat16.
|
# TensorFlow-specific, and that also mentions bfloat16.
|
||||||
|
|
||||||
def __init__(self, name):
|
# If True, warn when a policy is created whose name ends in
|
||||||
|
# "_with_float32_vars". We always want to warn when a user creates such a
|
||||||
|
# policy, but when the TensorFlow creates a policy, it suppresses the warning
|
||||||
|
# by setting this to False when creating the policy.
|
||||||
|
_warn_about_float32_vars = True
|
||||||
|
|
||||||
|
def __init__(self, name, loss_scale=USE_DEFAULT):
|
||||||
"""Constructs the policy.
|
"""Constructs the policy.
|
||||||
|
|
||||||
The `name` argument determines the compute and variable dtype, and has no
|
The `name` argument determines the compute and variable dtype, and has no
|
||||||
@ -141,16 +161,17 @@ class Policy(object):
|
|||||||
name: A string. Can be one of the following values:
|
name: A string. Can be one of the following values:
|
||||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||||
compute dtypes will be that dtype.
|
compute dtypes will be that dtype.
|
||||||
* <dtype>_with_float32_vars, where <dtype> is any dtype. The compute
|
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||||
dtype will be <dtype>, while the variable dtype is float32. This is
|
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||||
intended for the use of mixed precision, which uses float16 or
|
a dynamic loss scale is used. These policies are used for mixed
|
||||||
bfloat16 for most computations, and float32 for variables. This policy
|
precision training.
|
||||||
is only useful if <dtype> is float16 or bfloat16, although <dtype> is
|
* 'infer' (deprecated): Infer the compute and variable dtype from the
|
||||||
allowed to be any dtype. Note we will have a "mixed" policy in the
|
input dtype.
|
||||||
future, which will make it even easier to use mixed precision by
|
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
||||||
enabling other features such as loss scaling.
|
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
||||||
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
|
is "mixed_float16", in which case this defaults to "dynamic". Only
|
||||||
computation dtype from the input dtype.
|
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
|
||||||
|
during `Model.fit` or `Model.train_on_batch`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if isinstance(name, dtypes.DType):
|
if isinstance(name, dtypes.DType):
|
||||||
@ -168,6 +189,27 @@ class Policy(object):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
||||||
|
|
||||||
|
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
|
||||||
|
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
|
||||||
|
"in TensorFlow 2.1." % name)
|
||||||
|
if name == 'infer_with_float32_vars':
|
||||||
|
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
|
||||||
|
"policy instead.")
|
||||||
|
elif name == 'float16_with_float32_vars':
|
||||||
|
warning += " Please use the 'mixed_float16' policy instead."
|
||||||
|
elif name == 'bfloat16_with_float32_vars':
|
||||||
|
warning += " Please use the 'mixed_bfloat16' policy instead."
|
||||||
|
tf_logging.warn(warning)
|
||||||
|
|
||||||
|
if loss_scale == USE_DEFAULT:
|
||||||
|
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||||
|
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||||
|
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
||||||
|
'float16 policies. You passed loss_scale=%r for policy '
|
||||||
|
'%s. Consider not passing any loss_scale instead.' %
|
||||||
|
(loss_scale, name))
|
||||||
|
self._loss_scale = loss_scale_module.get(loss_scale)
|
||||||
|
|
||||||
def _parse_name(self, name):
|
def _parse_name(self, name):
|
||||||
"""Parses a Policy name into a compute and variable dtype.
|
"""Parses a Policy name into a compute and variable dtype.
|
||||||
|
|
||||||
@ -177,6 +219,11 @@ class Policy(object):
|
|||||||
Returns:
|
Returns:
|
||||||
The (compute_dtype, variable_dtype) pair.
|
The (compute_dtype, variable_dtype) pair.
|
||||||
"""
|
"""
|
||||||
|
if name == 'mixed_float16':
|
||||||
|
return 'float16', 'float32'
|
||||||
|
elif name == 'mixed_bfloat16':
|
||||||
|
return 'bfloat16', 'float32'
|
||||||
|
|
||||||
if name.endswith('_with_float32_vars'):
|
if name.endswith('_with_float32_vars'):
|
||||||
base_name = name[:-len('_with_float32_vars')]
|
base_name = name[:-len('_with_float32_vars')]
|
||||||
float32_vars = True
|
float32_vars = True
|
||||||
@ -190,10 +237,10 @@ class Policy(object):
|
|||||||
try:
|
try:
|
||||||
base_dtype = dtypes.as_dtype(base_name).name
|
base_dtype = dtypes.as_dtype(base_name).name
|
||||||
except TypeError:
|
except TypeError:
|
||||||
error = ('Cannot convert value %s to a mixed precision Policy. '
|
error = ("Cannot convert value %s to a mixed precision Policy. "
|
||||||
'Valid policies include include those in the form "<dtype>" '
|
"Valid policies include include 'mixed_float16', "
|
||||||
'and "<dtype>_with_float32_vars", where <dtype> is the name '
|
"'mixed_bfloat16', and the name of any dtype such as "
|
||||||
'of a dtype.' % (name,))
|
"'float32'." % (name,))
|
||||||
if float32_vars:
|
if float32_vars:
|
||||||
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
||||||
'be converted to a DType' % (name, base_name))
|
'be converted to a DType' % (name, base_name))
|
||||||
@ -264,13 +311,22 @@ class Policy(object):
|
|||||||
"""
|
"""
|
||||||
return self.variable_dtype != self.compute_dtype
|
return self.variable_dtype != self.compute_dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
"""Returns the loss scale of this Policy.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `tf.train.experimental.LossScale`, or None.
|
||||||
|
"""
|
||||||
|
return self._loss_scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
"""Returns the name of this policy."""
|
"""Returns the name of this policy."""
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '<Policy "%s">' % self._name
|
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
|
||||||
|
|
||||||
|
|
||||||
def with_input_dtype(policy, dtype):
|
def with_input_dtype(policy, dtype):
|
||||||
@ -297,7 +353,11 @@ def with_input_dtype(policy, dtype):
|
|||||||
# Policies without a compute dtype are either "infer" or
|
# Policies without a compute dtype are either "infer" or
|
||||||
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
||||||
assert policy.variable_dtype == 'float32'
|
assert policy.variable_dtype == 'float32'
|
||||||
|
try:
|
||||||
|
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
|
||||||
return Policy(dtype + '_with_float32_vars')
|
return Policy(dtype + '_with_float32_vars')
|
||||||
|
finally:
|
||||||
|
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# The current global policy in effect. If None, it means the current value of
|
# The current global policy in effect. If None, it means the current value of
|
||||||
|
@ -26,6 +26,8 @@ from tensorflow.python.keras.engine import base_layer_utils
|
|||||||
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
|
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
|
||||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||||
from tensorflow.python.training.experimental import mixed_precision
|
from tensorflow.python.training.experimental import mixed_precision
|
||||||
|
|
||||||
|
|
||||||
@ -45,10 +47,20 @@ class PolicyTest(test.TestCase):
|
|||||||
|
|
||||||
for dtype in 'int32', 'bool', 'float16', 'float32':
|
for dtype in 'int32', 'bool', 'float16', 'float32':
|
||||||
policy = mp_policy.Policy(dtype)
|
policy = mp_policy.Policy(dtype)
|
||||||
|
self.assertEqual(policy.name, dtype)
|
||||||
self.assertEqual(policy.compute_dtype, dtype)
|
self.assertEqual(policy.compute_dtype, dtype)
|
||||||
self.assertEqual(policy.variable_dtype, dtype)
|
self.assertEqual(policy.variable_dtype, dtype)
|
||||||
|
|
||||||
policy = mp_policy.Policy(dtype + '_with_float32_vars')
|
policy = mp_policy.Policy(dtype + '_with_float32_vars')
|
||||||
|
expected_name = (
|
||||||
|
dtype if dtype == 'float32' else dtype + '_with_float32_vars')
|
||||||
|
self.assertEqual(policy.name, expected_name)
|
||||||
|
self.assertEqual(policy.compute_dtype, dtype)
|
||||||
|
self.assertEqual(policy.variable_dtype, 'float32')
|
||||||
|
|
||||||
|
for dtype in 'float16', 'bfloat16':
|
||||||
|
policy = mp_policy.Policy('mixed_' + dtype)
|
||||||
|
self.assertEqual(policy.name, 'mixed_' + dtype)
|
||||||
self.assertEqual(policy.compute_dtype, dtype)
|
self.assertEqual(policy.compute_dtype, dtype)
|
||||||
self.assertEqual(policy.variable_dtype, 'float32')
|
self.assertEqual(policy.variable_dtype, 'float32')
|
||||||
|
|
||||||
@ -57,9 +69,11 @@ class PolicyTest(test.TestCase):
|
|||||||
for policy in ('infer', 'infer_with_float32_vars', 'float32',
|
for policy in ('infer', 'infer_with_float32_vars', 'float32',
|
||||||
'float16_with_float32_vars'):
|
'float16_with_float32_vars'):
|
||||||
self.assertEqual(repr(mp_policy.Policy(policy)),
|
self.assertEqual(repr(mp_policy.Policy(policy)),
|
||||||
'<Policy "%s">' % policy)
|
'<Policy "%s", loss_scale=None>' % policy)
|
||||||
self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
|
self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
|
||||||
'<Policy "float32">')
|
'<Policy "float32", loss_scale=None>')
|
||||||
|
self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)),
|
||||||
|
'<Policy "float16", loss_scale=FixedLossScale(2.0)>')
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_policy_errors(self):
|
def test_policy_errors(self):
|
||||||
@ -99,6 +113,30 @@ class PolicyTest(test.TestCase):
|
|||||||
self.assertEqual(policy.compute_dtype, 'float32')
|
self.assertEqual(policy.compute_dtype, 'float32')
|
||||||
self.assertEqual(policy.variable_dtype, 'float32')
|
self.assertEqual(policy.variable_dtype, 'float32')
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_loss_scale(self):
|
||||||
|
policy = mp_policy.Policy('float32')
|
||||||
|
self.assertEqual(policy.loss_scale, None)
|
||||||
|
|
||||||
|
policy = mp_policy.Policy('float32', loss_scale=None)
|
||||||
|
self.assertEqual(policy.loss_scale, None)
|
||||||
|
|
||||||
|
ls = loss_scale_module.DynamicLossScale()
|
||||||
|
policy = mp_policy.Policy('float32', loss_scale=ls)
|
||||||
|
self.assertIs(policy.loss_scale, ls)
|
||||||
|
|
||||||
|
policy = mp_policy.Policy('float32', loss_scale='dynamic')
|
||||||
|
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||||
|
|
||||||
|
policy = mp_policy.Policy('mixed_float16')
|
||||||
|
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||||
|
|
||||||
|
policy = mp_policy.Policy('mixed_float16', loss_scale=None)
|
||||||
|
self.assertEqual(policy.loss_scale, None)
|
||||||
|
|
||||||
|
policy = mp_policy.Policy('mixed_bfloat16')
|
||||||
|
self.assertEqual(policy.loss_scale, None)
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_global_policy(self):
|
def test_global_policy(self):
|
||||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||||
@ -121,6 +159,59 @@ class PolicyTest(test.TestCase):
|
|||||||
finally:
|
finally:
|
||||||
mp_policy.set_policy(None)
|
mp_policy.set_policy(None)
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_loss_scale_warning(self):
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('float32', loss_scale=2.)
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
'Creating a Policy with a loss scale is only useful for float16 '
|
||||||
|
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
||||||
|
'not passing any loss_scale instead.')
|
||||||
|
|
||||||
|
for policy_name in 'float16', 'mixed_float16':
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy(policy_name, loss_scale=2.)
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_float32_vars_warning(self):
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('infer_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||||
|
"or 'mixed_bfloat16' policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('float16_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||||
|
"policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('bfloat16_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
|
||||||
|
"policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('float64_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1.")
|
||||||
|
|
||||||
|
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy(policy_name)
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_policy_scope(self):
|
def test_policy_scope(self):
|
||||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||||
|
@ -227,6 +227,9 @@ class FixedLossScale(LossScale):
|
|||||||
del grads
|
del grads
|
||||||
return control_flow_ops.no_op(), True
|
return control_flow_ops.no_op(), True
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'FixedLossScale(%s)' % self._loss_scale_value
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {'loss_scale_value': self._loss_scale_value}
|
return {'loss_scale_value': self._loss_scale_value}
|
||||||
|
|
||||||
@ -376,6 +379,17 @@ class DynamicLossScale(LossScale):
|
|||||||
should_apply_gradients = is_finite
|
should_apply_gradients = is_finite
|
||||||
return update_op, should_apply_gradients
|
return update_op, should_apply_gradients
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
|
||||||
|
'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
|
||||||
|
(self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
|
||||||
|
self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||||
|
else:
|
||||||
|
return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
|
||||||
|
'multiplier=%s)' %
|
||||||
|
(self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return {
|
return {
|
||||||
'initial_loss_scale': self.initial_loss_scale,
|
'initial_loss_scale': self.initial_loss_scale,
|
||||||
|
@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase):
|
|||||||
scalar = loss_scale_module.FixedLossScale(123)
|
scalar = loss_scale_module.FixedLossScale(123)
|
||||||
self.assertIsInstance(scalar(), ops.Tensor)
|
self.assertIsInstance(scalar(), ops.Tensor)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_repr(self):
|
||||||
|
loss_scale = loss_scale_module.FixedLossScale(123)
|
||||||
|
self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
|
||||||
|
|
||||||
|
|
||||||
def _get_example_iter(inputs):
|
def _get_example_iter(inputs):
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
|
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
|
||||||
@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
|||||||
scalar = loss_scale_module.DynamicLossScale()
|
scalar = loss_scale_module.DynamicLossScale()
|
||||||
self.assertIsInstance(scalar(), ops.Tensor)
|
self.assertIsInstance(scalar(), ops.Tensor)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_repr(self, strategy_fn):
|
||||||
|
with strategy_fn().scope():
|
||||||
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
|
initial_loss_scale=1, increment_period=2, multiplier=3)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.assertEqual(repr(loss_scale),
|
||||||
|
'DynamicLossScale(current_loss_scale=1.0, '
|
||||||
|
'num_good_steps=0, initial_loss_scale=1.0, '
|
||||||
|
'increment_period=2, multiplier=3.0)')
|
||||||
|
else:
|
||||||
|
self.assertEqual(repr(loss_scale),
|
||||||
|
'DynamicLossScale(initial_loss_scale=1.0, '
|
||||||
|
'increment_period=2, multiplier=3.0)')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
name: "compute_dtype"
|
name: "compute_dtype"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "loss_scale"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "name"
|
name: "name"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -20,6 +24,6 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,10 @@ tf_class {
|
|||||||
name: "compute_dtype"
|
name: "compute_dtype"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "loss_scale"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "name"
|
name: "name"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -20,6 +24,6 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user