Merge pull request #32269 from reedwm/mp_cherrypicks
[r2.0-rc1 CherryPick]: Several tf.keras mixed precision API changes
This commit is contained in:
commit
8ea0a418a6
@ -2592,6 +2592,9 @@ class AddLoss(Layer):
|
||||
"""
|
||||
|
||||
def __init__(self, unconditional, **kwargs):
|
||||
# Pass autocast=False, as there is no reason to cast loss to a different
|
||||
# dtype.
|
||||
kwargs['autocast'] = False
|
||||
super(AddLoss, self).__init__(**kwargs)
|
||||
self.unconditional = unconditional
|
||||
|
||||
|
@ -230,12 +230,10 @@ class Network(base_layer.Layer):
|
||||
else:
|
||||
self._graph = ops.get_default_graph() # Used in symbolic mode only.
|
||||
|
||||
# Both graph and subclassed networks have a dtype policy. The policy is
|
||||
# currently ignored for a graph network, as graph networks disable
|
||||
# autocasting (making the policy's compute dtype meaningless) and graph
|
||||
# networks have no variables (making the policy's variable_dtype
|
||||
# meaningless). For subclassed networks, the dtype policy acts as it does
|
||||
# for any ordinary layer.
|
||||
# Both graph and subclassed networks have a dtype policy. For graph
|
||||
# networks, the policy's compute and variable dtypes are ignored, but other
|
||||
# fields, like the loss scale, are used by Models. For subclassed networks,
|
||||
# the compute and variable dtypes are used as like any ordinary layer.
|
||||
self._set_dtype_policy(kwargs.get('dtype', None))
|
||||
|
||||
# All layers in order of horizontal graph traversal.
|
||||
|
@ -52,6 +52,8 @@ from tensorflow.python.keras.engine import training_generator
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.engine import training_v2
|
||||
from tensorflow.python.keras.engine import training_v2_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
@ -248,13 +250,9 @@ class Model(network.Network):
|
||||
self._experimental_run_tf_function = kwargs.pop(
|
||||
'experimental_run_tf_function', True)
|
||||
|
||||
if isinstance(optimizer, (list, tuple)):
|
||||
self.optimizer = [optimizers.get(opt) for opt in optimizer]
|
||||
is_any_optimizer_v1 = any(
|
||||
isinstance(opt, optimizers.Optimizer) for opt in self.optimizer)
|
||||
else:
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
is_any_optimizer_v1 = isinstance(self.optimizer, optimizers.Optimizer)
|
||||
self._set_optimizer(optimizer)
|
||||
is_any_optimizer_v1 = any(isinstance(opt, optimizers.Optimizer)
|
||||
for opt in nest.flatten(self.optimizer))
|
||||
|
||||
if ((sample_weight_mode is not None)
|
||||
or (target_tensors is not None)
|
||||
@ -1443,6 +1441,47 @@ class Model(network.Network):
|
||||
'and the first argument in `call` as positional arguments, '
|
||||
'found: ' + str(extra_args) + '.')
|
||||
|
||||
def _set_optimizer(self, optimizer):
|
||||
"""Sets self.optimizer.
|
||||
|
||||
Sets self.optimizer to `optimizer`, potentially wrapping it with a
|
||||
LossScaleOptimizer.
|
||||
|
||||
Args:
|
||||
optimizer: The optimizer(s) to assign to self.optimizer.
|
||||
"""
|
||||
if isinstance(optimizer, (list, tuple)):
|
||||
self.optimizer = [optimizers.get(opt) for opt in optimizer]
|
||||
else:
|
||||
self.optimizer = optimizers.get(optimizer)
|
||||
|
||||
if (self._dtype_policy.loss_scale is not None and
|
||||
not isinstance(self.optimizer,
|
||||
loss_scale_optimizer.LossScaleOptimizer)):
|
||||
if isinstance(self.optimizer, list):
|
||||
raise ValueError('When a dtype policy with a loss scale is used, you '
|
||||
'can only pass a single optimizer. Using policy %s '
|
||||
'and got optimizers: %s' %
|
||||
self._dtype_policy, self.optimizer)
|
||||
if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
|
||||
raise ValueError('"optimizer" must be an instance of '
|
||||
'tf.keras.optimizers.Optimizer when a dype policy '
|
||||
'with a loss scale used, but got: %s. Using policy: '
|
||||
'%s' %
|
||||
(self.optimizer, self._dtype_policy))
|
||||
self.optimizer = loss_scale_optimizer.LossScaleOptimizer(
|
||||
self.optimizer, self._dtype_policy.loss_scale)
|
||||
if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and
|
||||
self._dtype_policy.loss_scale and
|
||||
self.optimizer.loss_scale != self._dtype_policy.loss_scale):
|
||||
logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) '
|
||||
'is not the same as the dtype policy\'s loss scale (%s). '
|
||||
'Because the dtype policy has a loss scale, you should '
|
||||
'pass an optimizer that is not wrapped with a '
|
||||
'LossScaleOptimizer,'
|
||||
% (self.optimizer.loss_scale,
|
||||
self._dtype_policy.loss_scale))
|
||||
|
||||
def _prepare_validation_data(self, validation_data, batch_size,
|
||||
validation_steps):
|
||||
"""Unpack and check the validation data."""
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import models
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
@ -434,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
}, {
|
||||
'testcase_name': 'infer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'policy_name': 'infer_with_float32_vars'
|
||||
'policy_name': 'mixed_float16'
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -445,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
strategy_fn,
|
||||
use_operator=False,
|
||||
use_regularizer=False,
|
||||
policy_name='float16_with_float32_vars',
|
||||
policy_name='mixed_float16',
|
||||
experimental_run_tf_function=True):
|
||||
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
||||
return
|
||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope(policy_name):
|
||||
# Pass loss_scale=None, as this test will fail if the DynamicLossScale
|
||||
# skips applying gradients for a step
|
||||
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
|
||||
layer_list = []
|
||||
if testing_utils.get_model_type() == 'subclass':
|
||||
# Subclassed models do not have an Input layer, so the model does not
|
||||
@ -579,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
strategy = strategy_fn()
|
||||
if use_loss_scaling:
|
||||
loss_scale = 8.
|
||||
else:
|
||||
loss_scale = None
|
||||
learning_rate = 2**-14
|
||||
|
||||
with strategy.scope():
|
||||
with policy.policy_scope(policy.Policy('float16_with_float32_vars')):
|
||||
with policy.policy_scope(policy.Policy('mixed_float16',
|
||||
loss_scale=loss_scale)):
|
||||
x = layers.Input(shape=(1,), batch_size=2)
|
||||
layer1 = AddLayer(
|
||||
assert_type=dtypes.float16,
|
||||
@ -618,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
opt = gradient_descent.SGD(learning_rate)
|
||||
if use_loss_scaling:
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(
|
||||
opt,
|
||||
loss=loss_fn,
|
||||
@ -646,6 +650,10 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
}, {
|
||||
'testcase_name': 'pass_loss_scale_to_policy',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'pass_loss_scale_to_policy': True,
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -653,18 +661,27 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
})
|
||||
def test_dynamic_loss_scaling(self,
|
||||
strategy_fn,
|
||||
pass_loss_scale_to_policy=False,
|
||||
experimental_run_tf_function=True):
|
||||
if not self._is_strategy_supported(strategy_fn):
|
||||
return
|
||||
strategy = strategy_fn()
|
||||
initial_loss_scale = 2.
|
||||
batch_size = 4
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=initial_loss_scale, increment_period=2)
|
||||
expected_gradient = backend.variable([initial_loss_scale / batch_size],
|
||||
dtype=dtypes.float16)
|
||||
# If this variable is set to True, the model below will have NaN gradients
|
||||
have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
|
||||
with strategy.scope():
|
||||
with policy.policy_scope(policy.Policy('infer_float32_vars')):
|
||||
opt = gradient_descent.SGD(1.)
|
||||
if pass_loss_scale_to_policy:
|
||||
p = policy.Policy('infer_float32_vars', loss_scale=loss_scale)
|
||||
else:
|
||||
p = policy.Policy('infer_float32_vars')
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
with policy.policy_scope(p):
|
||||
x = layers.Input(
|
||||
shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
|
||||
layer = AddLayer(assert_type=dtypes.float16)
|
||||
@ -685,10 +702,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
del y_true
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
opt = gradient_descent.SGD(1.)
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=initial_loss_scale, increment_period=2)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(
|
||||
opt,
|
||||
loss=loss_fn,
|
||||
@ -728,6 +741,40 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
model.fit(dataset)
|
||||
self.assertEqual(backend.eval(layer.v), -3)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_loss_scale_optimizer_overrides_policy_loss_scale(self):
|
||||
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
|
||||
opt = gradient_descent.SGD(1.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.)
|
||||
x = layers.Input(shape=(1,))
|
||||
y = AddLayer()(x)
|
||||
model = models.Model(x, y)
|
||||
model.compile(opt, loss='mse')
|
||||
self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_pass_invalid_optimizer_with_loss_scaling(self):
|
||||
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
|
||||
x = layers.Input(shape=(1,))
|
||||
y = AddLayer()(x)
|
||||
model = models.Model(x, y)
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'optimizer" must be an instance of '):
|
||||
model.compile(optimizers.SGD(1.), 'mse')
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_functional_model_loss_dtype(self):
|
||||
with policy.policy_scope('float16'):
|
||||
x = layers.Input(shape=(1,))
|
||||
y = AddLayer()(x)
|
||||
model = models.Model(x, y)
|
||||
model.add_loss(math_ops.cast(y, 'float32'))
|
||||
# The loss should not be casted to the policy's dtype.
|
||||
self.assertEqual(model.losses[0].dtype, 'float32')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
'testcase_name': 'base',
|
||||
|
@ -24,19 +24,26 @@ import six
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
# Default value of certain arguments, indicating the default behavior for
|
||||
# that argument should be used.
|
||||
USE_DEFAULT = 'USE_DEFAULT'
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.Policy')
|
||||
class Policy(object):
|
||||
"""A dtype policy for a Keras layer.
|
||||
|
||||
A dtype policy determines the computation dtype and the variable dtype of a
|
||||
Keras layer. Each layer has a policy. Policies can be passed to the 'dtype'
|
||||
argument of layer constructors, or a global policy can be set with
|
||||
'tf.keras.mixed_precision.experimental.set_policy'. A layer will default to
|
||||
the global policy if no policy is passed to it's constructor.
|
||||
A dtype policy determines dtype-related aspects of a layer, such as its
|
||||
computation and variable dtypes. Each layer has a policy. Policies can be
|
||||
passed to the 'dtype' argument of layer constructors, or a global policy can
|
||||
be set with 'tf.keras.mixed_precision.experimental.set_policy'. A layer will
|
||||
default to the global policy if no policy is passed to it's constructor.
|
||||
|
||||
For most models, each layer will have the same computation dtype and variable
|
||||
dtype, which will typically be float32. However, when mixed precision
|
||||
@ -47,55 +54,49 @@ class Policy(object):
|
||||
not match the computation dtype, variables will be automatically casted to the
|
||||
computation dtype to avoid type errors.
|
||||
|
||||
In the near future, policies will also determine the loss scaling algorithm
|
||||
for Keras models.
|
||||
Policies also have a `tf.train.experimental.LossScale` instance, which is used
|
||||
by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
|
||||
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
|
||||
ignore the loss scale.
|
||||
|
||||
Policies are constructed by passing a string to the constructor, e.g.
|
||||
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
||||
determines the compute and variable dtypes. Currently, it can be one of
|
||||
in one of the following forms:
|
||||
determines the compute and variable dtypes. It can be one of the following:
|
||||
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
|
||||
will be <dtype>, while the variable dtype is float32. This is intended for
|
||||
the use of mixed precision, which uses float16 or bfloat16 for most
|
||||
computations, and float32 for variables. This policy is only useful if
|
||||
<dtype> is float16 or bfloat16, although <dtype> is allowed to be any
|
||||
dtype. Note we will have a "mixed" policy in the future, which will make
|
||||
it even easier to use mixed precision by enabling other features such as
|
||||
loss scaling.
|
||||
compute dtypes will be that dtype. No loss scaling is done by default.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. These policies are used for
|
||||
mixed precision training. With 'mixed_float16', a dynamic loss scale is
|
||||
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
|
||||
scaling is unnecessary with bfloat16.
|
||||
|
||||
### How to use mixed precision in layers with Policies
|
||||
|
||||
To use mixed precision in a model, the 'float16_with_float32_vars' policy can
|
||||
To use mixed precision in a model, the 'mixed_float16' policy can
|
||||
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
|
||||
the default policy for layers if no policy is passed to them. Note loss
|
||||
scaling must also be done, e.g. with a
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`. For example
|
||||
the default policy for layers if no policy is passed to them. For example:
|
||||
|
||||
```python
|
||||
tf.keras.mixed_precision.experimental.set_policy(
|
||||
'float16_with_float32_vars')
|
||||
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,)),
|
||||
# Dense layers use global policy of 'float16_with_float32_vars'
|
||||
# Dense layers use global policy of 'mixed_float16', which does
|
||||
# computations in float16 while keeping variables in float32.
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(10),
|
||||
# Softmax should be done in float32 for numeric stability. We pass
|
||||
# dtype='float32' to use float32 instead of the global policy.
|
||||
tf.keras.layers.Activation('Softmax', dtype='float32')
|
||||
)
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
||||
... # Train `model` with `opt`.
|
||||
model.fit(...) # Train `model`
|
||||
```
|
||||
|
||||
Alternatively, the policy can be passed to individual layers instead of
|
||||
setting the global policy with `set_policy`:
|
||||
|
||||
```python
|
||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
||||
'float16_with_float32_vars')
|
||||
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,)),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
@ -103,14 +104,21 @@ class Policy(object):
|
||||
# Softmax should be done in float32 for numeric stability.
|
||||
tf.keras.layers.Activation('Softmax', dtype='float32')
|
||||
)
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
||||
... # Train `model` with `opt`.
|
||||
model.fit(...) # Train `model`
|
||||
```
|
||||
|
||||
As the above example shows, strings can be directly passed to layer
|
||||
constructors in the `dtype` argument instead of policies, but only if the
|
||||
string is convertible to a dtype.
|
||||
|
||||
Note the 'mixed_float16' policy will apply loss scaling by default in
|
||||
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
|
||||
custom training loop is used) and 'mixed_float16' is used, the loss scale must
|
||||
be manually applied. See
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
|
||||
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
|
||||
manually applied.
|
||||
|
||||
### The deprecated "infer" policy
|
||||
|
||||
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
||||
@ -122,7 +130,13 @@ class Policy(object):
|
||||
the dtype of the first input.
|
||||
|
||||
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
||||
that infers the compute dtype, but not the variable dtype.
|
||||
that infers the compute dtype, but not the variable dtype. Once a layer with
|
||||
an "infer_with_float32_vars" policy is called for the first time, the layer's
|
||||
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
|
||||
of the first input. These policies force variables in float32.
|
||||
|
||||
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
|
||||
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
|
||||
|
||||
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
||||
available.
|
||||
@ -130,7 +144,13 @@ class Policy(object):
|
||||
# TODO(reedwm): Replace link in above docstring with a version that is more
|
||||
# TensorFlow-specific, and that also mentions bfloat16.
|
||||
|
||||
def __init__(self, name):
|
||||
# If True, warn when a policy is created whose name ends in
|
||||
# "_with_float32_vars". We always want to warn when a user creates such a
|
||||
# policy, but when the TensorFlow creates a policy, it suppresses the warning
|
||||
# by setting this to False when creating the policy.
|
||||
_warn_about_float32_vars = True
|
||||
|
||||
def __init__(self, name, loss_scale=USE_DEFAULT):
|
||||
"""Constructs the policy.
|
||||
|
||||
The `name` argument determines the compute and variable dtype, and has no
|
||||
@ -141,16 +161,17 @@ class Policy(object):
|
||||
name: A string. Can be one of the following values:
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* <dtype>_with_float32_vars, where <dtype> is any dtype. The compute
|
||||
dtype will be <dtype>, while the variable dtype is float32. This is
|
||||
intended for the use of mixed precision, which uses float16 or
|
||||
bfloat16 for most computations, and float32 for variables. This policy
|
||||
is only useful if <dtype> is float16 or bfloat16, although <dtype> is
|
||||
allowed to be any dtype. Note we will have a "mixed" policy in the
|
||||
future, which will make it even easier to use mixed precision by
|
||||
enabling other features such as loss scaling.
|
||||
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
|
||||
computation dtype from the input dtype.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||
a dynamic loss scale is used. These policies are used for mixed
|
||||
precision training.
|
||||
* 'infer' (deprecated): Infer the compute and variable dtype from the
|
||||
input dtype.
|
||||
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
||||
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
||||
is "mixed_float16", in which case this defaults to "dynamic". Only
|
||||
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
|
||||
during `Model.fit` or `Model.train_on_batch`.
|
||||
|
||||
"""
|
||||
if isinstance(name, dtypes.DType):
|
||||
@ -168,6 +189,27 @@ class Policy(object):
|
||||
self._name = name
|
||||
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
||||
|
||||
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
|
||||
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
|
||||
"in TensorFlow 2.1." % name)
|
||||
if name == 'infer_with_float32_vars':
|
||||
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
|
||||
"policy instead.")
|
||||
elif name == 'float16_with_float32_vars':
|
||||
warning += " Please use the 'mixed_float16' policy instead."
|
||||
elif name == 'bfloat16_with_float32_vars':
|
||||
warning += " Please use the 'mixed_bfloat16' policy instead."
|
||||
tf_logging.warn(warning)
|
||||
|
||||
if loss_scale == USE_DEFAULT:
|
||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
||||
'float16 policies. You passed loss_scale=%r for policy '
|
||||
'%s. Consider not passing any loss_scale instead.' %
|
||||
(loss_scale, name))
|
||||
self._loss_scale = loss_scale_module.get(loss_scale)
|
||||
|
||||
def _parse_name(self, name):
|
||||
"""Parses a Policy name into a compute and variable dtype.
|
||||
|
||||
@ -177,6 +219,11 @@ class Policy(object):
|
||||
Returns:
|
||||
The (compute_dtype, variable_dtype) pair.
|
||||
"""
|
||||
if name == 'mixed_float16':
|
||||
return 'float16', 'float32'
|
||||
elif name == 'mixed_bfloat16':
|
||||
return 'bfloat16', 'float32'
|
||||
|
||||
if name.endswith('_with_float32_vars'):
|
||||
base_name = name[:-len('_with_float32_vars')]
|
||||
float32_vars = True
|
||||
@ -190,10 +237,10 @@ class Policy(object):
|
||||
try:
|
||||
base_dtype = dtypes.as_dtype(base_name).name
|
||||
except TypeError:
|
||||
error = ('Cannot convert value %s to a mixed precision Policy. '
|
||||
'Valid policies include include those in the form "<dtype>" '
|
||||
'and "<dtype>_with_float32_vars", where <dtype> is the name '
|
||||
'of a dtype.' % (name,))
|
||||
error = ("Cannot convert value %s to a mixed precision Policy. "
|
||||
"Valid policies include include 'mixed_float16', "
|
||||
"'mixed_bfloat16', and the name of any dtype such as "
|
||||
"'float32'." % (name,))
|
||||
if float32_vars:
|
||||
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
||||
'be converted to a DType' % (name, base_name))
|
||||
@ -264,13 +311,22 @@ class Policy(object):
|
||||
"""
|
||||
return self.variable_dtype != self.compute_dtype
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
"""Returns the loss scale of this Policy.
|
||||
|
||||
Returns:
|
||||
A `tf.train.experimental.LossScale`, or None.
|
||||
"""
|
||||
return self._loss_scale
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Returns the name of this policy."""
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
return '<Policy "%s">' % self._name
|
||||
return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
|
||||
|
||||
|
||||
def with_input_dtype(policy, dtype):
|
||||
@ -297,7 +353,11 @@ def with_input_dtype(policy, dtype):
|
||||
# Policies without a compute dtype are either "infer" or
|
||||
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
||||
assert policy.variable_dtype == 'float32'
|
||||
return Policy(dtype + '_with_float32_vars')
|
||||
try:
|
||||
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
|
||||
return Policy(dtype + '_with_float32_vars')
|
||||
finally:
|
||||
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
|
||||
|
||||
|
||||
# The current global policy in effect. If None, it means the current value of
|
||||
|
@ -26,6 +26,8 @@ from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision
|
||||
|
||||
|
||||
@ -45,10 +47,20 @@ class PolicyTest(test.TestCase):
|
||||
|
||||
for dtype in 'int32', 'bool', 'float16', 'float32':
|
||||
policy = mp_policy.Policy(dtype)
|
||||
self.assertEqual(policy.name, dtype)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, dtype)
|
||||
|
||||
policy = mp_policy.Policy(dtype + '_with_float32_vars')
|
||||
expected_name = (
|
||||
dtype if dtype == 'float32' else dtype + '_with_float32_vars')
|
||||
self.assertEqual(policy.name, expected_name)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, 'float32')
|
||||
|
||||
for dtype in 'float16', 'bfloat16':
|
||||
policy = mp_policy.Policy('mixed_' + dtype)
|
||||
self.assertEqual(policy.name, 'mixed_' + dtype)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, 'float32')
|
||||
|
||||
@ -57,9 +69,11 @@ class PolicyTest(test.TestCase):
|
||||
for policy in ('infer', 'infer_with_float32_vars', 'float32',
|
||||
'float16_with_float32_vars'):
|
||||
self.assertEqual(repr(mp_policy.Policy(policy)),
|
||||
'<Policy "%s">' % policy)
|
||||
'<Policy "%s", loss_scale=None>' % policy)
|
||||
self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
|
||||
'<Policy "float32">')
|
||||
'<Policy "float32", loss_scale=None>')
|
||||
self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)),
|
||||
'<Policy "float16", loss_scale=FixedLossScale(2.0)>')
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_policy_errors(self):
|
||||
@ -99,6 +113,30 @@ class PolicyTest(test.TestCase):
|
||||
self.assertEqual(policy.compute_dtype, 'float32')
|
||||
self.assertEqual(policy.variable_dtype, 'float32')
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_loss_scale(self):
|
||||
policy = mp_policy.Policy('float32')
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
policy = mp_policy.Policy('float32', loss_scale=None)
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
ls = loss_scale_module.DynamicLossScale()
|
||||
policy = mp_policy.Policy('float32', loss_scale=ls)
|
||||
self.assertIs(policy.loss_scale, ls)
|
||||
|
||||
policy = mp_policy.Policy('float32', loss_scale='dynamic')
|
||||
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||
|
||||
policy = mp_policy.Policy('mixed_float16')
|
||||
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||
|
||||
policy = mp_policy.Policy('mixed_float16', loss_scale=None)
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
policy = mp_policy.Policy('mixed_bfloat16')
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_global_policy(self):
|
||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||
@ -121,6 +159,59 @@ class PolicyTest(test.TestCase):
|
||||
finally:
|
||||
mp_policy.set_policy(None)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_loss_scale_warning(self):
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float32', loss_scale=2.)
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
'Creating a Policy with a loss scale is only useful for float16 '
|
||||
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
||||
'not passing any loss_scale instead.')
|
||||
|
||||
for policy_name in 'float16', 'mixed_float16':
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy(policy_name, loss_scale=2.)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_float32_vars_warning(self):
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('infer_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||
"or 'mixed_bfloat16' policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float16_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||
"policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('bfloat16_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
|
||||
"policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float64_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1.")
|
||||
|
||||
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy(policy_name)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_policy_scope(self):
|
||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||
|
@ -227,6 +227,9 @@ class FixedLossScale(LossScale):
|
||||
del grads
|
||||
return control_flow_ops.no_op(), True
|
||||
|
||||
def __repr__(self):
|
||||
return 'FixedLossScale(%s)' % self._loss_scale_value
|
||||
|
||||
def get_config(self):
|
||||
return {'loss_scale_value': self._loss_scale_value}
|
||||
|
||||
@ -376,6 +379,17 @@ class DynamicLossScale(LossScale):
|
||||
should_apply_gradients = is_finite
|
||||
return update_op, should_apply_gradients
|
||||
|
||||
def __repr__(self):
|
||||
if context.executing_eagerly():
|
||||
return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
|
||||
'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
|
||||
(self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
|
||||
self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||
else:
|
||||
return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
|
||||
'multiplier=%s)' %
|
||||
(self.initial_loss_scale, self.increment_period, self.multiplier))
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'initial_loss_scale': self.initial_loss_scale,
|
||||
|
@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase):
|
||||
scalar = loss_scale_module.FixedLossScale(123)
|
||||
self.assertIsInstance(scalar(), ops.Tensor)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_repr(self):
|
||||
loss_scale = loss_scale_module.FixedLossScale(123)
|
||||
self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
|
||||
|
||||
|
||||
def _get_example_iter(inputs):
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
|
||||
@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
|
||||
scalar = loss_scale_module.DynamicLossScale()
|
||||
self.assertIsInstance(scalar(), ops.Tensor)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_repr(self, strategy_fn):
|
||||
with strategy_fn().scope():
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=1, increment_period=2, multiplier=3)
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(repr(loss_scale),
|
||||
'DynamicLossScale(current_loss_scale=1.0, '
|
||||
'num_good_steps=0, initial_loss_scale=1.0, '
|
||||
'increment_period=2, multiplier=3.0)')
|
||||
else:
|
||||
self.assertEqual(repr(loss_scale),
|
||||
'DynamicLossScale(initial_loss_scale=1.0, '
|
||||
'increment_period=2, multiplier=3.0)')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
name: "compute_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "loss_scale"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -20,6 +24,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,10 @@ tf_class {
|
||||
name: "compute_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "loss_scale"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -20,6 +24,6 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user