Add mixed_float16 and mixed_bfloat16 dtype policies.

These policies will be the recommended way of using mixed precision in tf.keras. So far, the only difference from [b]float16_with_float32_vars is that mixed_float16 enables dynamic loss scaling by default (and mixed_bfloat16 has no difference). In the future, the *_with_float32_vars policies will be removed.

PiperOrigin-RevId: 263206151
This commit is contained in:
Reed Wanderman-Milne 2019-08-13 13:45:02 -07:00 committed by Reed
parent 0fe04a9a52
commit 75a9d99941
3 changed files with 88 additions and 33 deletions

View File

@ -435,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase):
}, {
'testcase_name': 'infer',
'strategy_fn': create_mirrored_strategy,
'policy_name': 'infer_with_float32_vars'
'policy_name': 'mixed_float16'
}, {
'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy,
@ -446,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy_fn,
use_operator=False,
use_regularizer=False,
policy_name='float16_with_float32_vars',
policy_name='mixed_float16',
experimental_run_tf_function=True):
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
return
regularizer = IdentityRegularizer() if use_regularizer else None
with strategy_fn().scope():
with policy.policy_scope(policy_name):
# Pass loss_scale=None, as this test will fail if the DynamicLossScale
# skips applying gradients for a step
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
layer_list = []
if testing_utils.get_model_type() == 'subclass':
# Subclassed models do not have an Input layer, so the model does not
@ -580,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy = strategy_fn()
if use_loss_scaling:
loss_scale = 8.
else:
loss_scale = None
learning_rate = 2**-14
with strategy.scope():
with policy.policy_scope(policy.Policy('float16_with_float32_vars')):
with policy.policy_scope(policy.Policy('mixed_float16',
loss_scale=loss_scale)):
x = layers.Input(shape=(1,), batch_size=2)
layer1 = AddLayer(
assert_type=dtypes.float16,
@ -619,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase):
return math_ops.reduce_mean(y_pred)
opt = gradient_descent.SGD(learning_rate)
if use_loss_scaling:
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
model.compile(
opt,
loss=loss_fn,

View File

@ -24,6 +24,7 @@ import six
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util.tf_export import keras_export
@ -65,28 +66,30 @@ class Policy(object):
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype.
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
will be <dtype>, while the variable dtype is float32. This is intended for
the use of mixed precision, which uses float16 or bfloat16 for most
computations, and float32 for variables. This policy is only useful if
<dtype> is float16 or bfloat16, although <dtype> is allowed to be any
dtype. Note we will have a "mixed" policy in the future, which will make
it even easier to use mixed precision by enabling other features, such as
using loss scaling by default.
will be <dtype>, while the variable dtype is float32. This can be used for
mixed precision, which uses float16 or bfloat16 for most computations, and
float32 for variables, but it is recommended to use the 'mixed_float16' or
'mixed_bfloat16' policies instead.
* 'mixed_float16' or 'mixed_bfloat16': Similar to
'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively.
'mixed_float16' is identical to 'float16_with_float32_vars' except the
loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical
to 'bfloat16_with_float32_vars'. More changes may be added to these mixed
policies in the future, to further differentiate them from
[b]float16_with_float32_vars.
### How to use mixed precision in layers with Policies
To use mixed precision in a model, the 'float16_with_float32_vars' policy can
To use mixed precision in a model, the 'mixed_float16' policy can
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
the default policy for layers if no policy is passed to them. The loss scale
should be set to "dynamic" to perform loss scaling and dynamically determine
the optimal loss scale. For example:
the default policy for layers if no policy is passed to them. For example:
```python
tf.keras.mixed_precision.experimental.set_policy(
'float16_with_float32_vars', loss_scale='dynamic')
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
model = tf.keras.models.Sequential(
tf.keras.layers.Input((100,)),
# Dense layers use global policy of 'float16_with_float32_vars'
# Dense layers use global policy of 'mixed_float16', which does
# computations in float16 while keeping variables in float32.
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(10),
# Softmax should be done in float32 for numeric stability. We pass
@ -100,8 +103,7 @@ class Policy(object):
setting the global policy with `set_policy`:
```python
policy = tf.keras.mixed_precision.experimental.Policy(
'float16_with_float32_vars', loss_scale='dynamic')
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
model = tf.keras.models.Sequential(
tf.keras.layers.Input((100,)),
tf.keras.layers.Dense(10, dtype=policy),
@ -146,18 +148,24 @@ class Policy(object):
name: A string. Can be one of the following values:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype.
* <dtype>_with_float32_vars, where <dtype> is any dtype. The compute
dtype will be <dtype>, while the variable dtype is float32. This is
intended for the use of mixed precision, which uses float16 or
bfloat16 for most computations, and float32 for variables. This policy
is only useful if <dtype> is float16 or bfloat16, although <dtype> is
allowed to be any dtype. Note we will have a "mixed" policy in the
future, which will make it even easier to use mixed precision by
enabling other features such as loss scaling.
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
dtype will be <dtype>, while the variable dtype is float32. This can
be used for mixed precision, which uses float16 or bfloat16 for most
computations, and float32 for variables, but it is recommended to use
the 'mixed_float16' or 'mixed_bfloat16' policies instead.
* 'mixed_float16' or 'mixed_bfloat16': Similar to
'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
respectively. 'mixed_float16' is identical to
'float16_with_float32_vars' except the loss_scale is dynamic by
default. 'mixed_bfloat16' is currently identical to
'bfloat16_with_float32_vars'. More changes may be added to these mixed
policies in the future, to further differentiate them from
[b]float16_with_float32_vars.
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
computation dtype from the input dtype.
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
one such as "dynamic". Defaults to using no loss scaling. Only
one such as "dynamic". Defaults to using no loss scaling unless `name`
is "mixed_float16", in which case this defaults to "dynamic". Only
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
during `Model.fit` or `Model.train_on_batch`.
@ -178,8 +186,12 @@ class Policy(object):
self._compute_dtype, self._variable_dtype = self._parse_name(name)
if loss_scale == USE_DEFAULT:
loss_scale = None
loss_scale = 'dynamic' if name == 'mixed_float16' else None
if loss_scale and self._compute_dtype not in (None, 'float16'):
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
'float16 policies. You passed loss_scale=%r for policy '
'%s. Consider not passing any loss_scale instead.' %
(loss_scale, name))
self._loss_scale = loss_scale_module.get(loss_scale)
def _parse_name(self, name):
@ -191,6 +203,11 @@ class Policy(object):
Returns:
The (compute_dtype, variable_dtype) pair.
"""
if name == 'mixed_float16':
return 'float16', 'float32'
elif name == 'mixed_bfloat16':
return 'bfloat16', 'float32'
if name.endswith('_with_float32_vars'):
base_name = name[:-len('_with_float32_vars')]
float32_vars = True

View File

@ -26,6 +26,7 @@ from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
@ -46,10 +47,20 @@ class PolicyTest(test.TestCase):
for dtype in 'int32', 'bool', 'float16', 'float32':
policy = mp_policy.Policy(dtype)
self.assertEqual(policy.name, dtype)
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, dtype)
policy = mp_policy.Policy(dtype + '_with_float32_vars')
expected_name = (
dtype if dtype == 'float32' else dtype + '_with_float32_vars')
self.assertEqual(policy.name, expected_name)
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, 'float32')
for dtype in 'float16', 'bfloat16':
policy = mp_policy.Policy('mixed_' + dtype)
self.assertEqual(policy.name, 'mixed_' + dtype)
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, 'float32')
@ -117,6 +128,15 @@ class PolicyTest(test.TestCase):
policy = mp_policy.Policy('float32', loss_scale='dynamic')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16', loss_scale=None)
self.assertEqual(policy.loss_scale, None)
policy = mp_policy.Policy('mixed_bfloat16')
self.assertEqual(policy.loss_scale, None)
@testing_utils.enable_v2_dtype_behavior
def test_global_policy(self):
if base_layer_utils.v2_dtype_behavior_enabled():
@ -139,6 +159,21 @@ class PolicyTest(test.TestCase):
finally:
mp_policy.set_policy(None)
@testing_utils.enable_v2_dtype_behavior
def test_loss_scale_warning(self):
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float32', loss_scale=2.)
self.assertEqual(
mock_warn.call_args[0][0],
'Creating a Policy with a loss scale is only useful for float16 '
'policies. You passed loss_scale=2.0 for policy float32. Consider '
'not passing any loss_scale instead.')
for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16':
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name, loss_scale=2.)
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self):
if base_layer_utils.v2_dtype_behavior_enabled():