Add mixed_float16 and mixed_bfloat16 dtype policies.
These policies will be the recommended way of using mixed precision in tf.keras. So far, the only difference from [b]float16_with_float32_vars is that mixed_float16 enables dynamic loss scaling by default (and mixed_bfloat16 has no difference). In the future, the *_with_float32_vars policies will be removed. PiperOrigin-RevId: 263206151
This commit is contained in:
parent
0fe04a9a52
commit
75a9d99941
@ -435,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
}, {
|
||||
'testcase_name': 'infer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'policy_name': 'infer_with_float32_vars'
|
||||
'policy_name': 'mixed_float16'
|
||||
}, {
|
||||
'testcase_name': 'norun_distributed',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
@ -446,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
strategy_fn,
|
||||
use_operator=False,
|
||||
use_regularizer=False,
|
||||
policy_name='float16_with_float32_vars',
|
||||
policy_name='mixed_float16',
|
||||
experimental_run_tf_function=True):
|
||||
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
|
||||
return
|
||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope(policy_name):
|
||||
# Pass loss_scale=None, as this test will fail if the DynamicLossScale
|
||||
# skips applying gradients for a step
|
||||
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
|
||||
layer_list = []
|
||||
if testing_utils.get_model_type() == 'subclass':
|
||||
# Subclassed models do not have an Input layer, so the model does not
|
||||
@ -580,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
strategy = strategy_fn()
|
||||
if use_loss_scaling:
|
||||
loss_scale = 8.
|
||||
else:
|
||||
loss_scale = None
|
||||
learning_rate = 2**-14
|
||||
|
||||
with strategy.scope():
|
||||
with policy.policy_scope(policy.Policy('float16_with_float32_vars')):
|
||||
with policy.policy_scope(policy.Policy('mixed_float16',
|
||||
loss_scale=loss_scale)):
|
||||
x = layers.Input(shape=(1,), batch_size=2)
|
||||
layer1 = AddLayer(
|
||||
assert_type=dtypes.float16,
|
||||
@ -619,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase):
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
opt = gradient_descent.SGD(learning_rate)
|
||||
if use_loss_scaling:
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(
|
||||
opt,
|
||||
loss=loss_fn,
|
||||
|
||||
@ -24,6 +24,7 @@ import six
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -65,28 +66,30 @@ class Policy(object):
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
|
||||
will be <dtype>, while the variable dtype is float32. This is intended for
|
||||
the use of mixed precision, which uses float16 or bfloat16 for most
|
||||
computations, and float32 for variables. This policy is only useful if
|
||||
<dtype> is float16 or bfloat16, although <dtype> is allowed to be any
|
||||
dtype. Note we will have a "mixed" policy in the future, which will make
|
||||
it even easier to use mixed precision by enabling other features, such as
|
||||
using loss scaling by default.
|
||||
will be <dtype>, while the variable dtype is float32. This can be used for
|
||||
mixed precision, which uses float16 or bfloat16 for most computations, and
|
||||
float32 for variables, but it is recommended to use the 'mixed_float16' or
|
||||
'mixed_bfloat16' policies instead.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively.
|
||||
'mixed_float16' is identical to 'float16_with_float32_vars' except the
|
||||
loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical
|
||||
to 'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
||||
policies in the future, to further differentiate them from
|
||||
[b]float16_with_float32_vars.
|
||||
|
||||
### How to use mixed precision in layers with Policies
|
||||
|
||||
To use mixed precision in a model, the 'float16_with_float32_vars' policy can
|
||||
To use mixed precision in a model, the 'mixed_float16' policy can
|
||||
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
|
||||
the default policy for layers if no policy is passed to them. The loss scale
|
||||
should be set to "dynamic" to perform loss scaling and dynamically determine
|
||||
the optimal loss scale. For example:
|
||||
the default policy for layers if no policy is passed to them. For example:
|
||||
|
||||
```python
|
||||
tf.keras.mixed_precision.experimental.set_policy(
|
||||
'float16_with_float32_vars', loss_scale='dynamic')
|
||||
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,)),
|
||||
# Dense layers use global policy of 'float16_with_float32_vars'
|
||||
# Dense layers use global policy of 'mixed_float16', which does
|
||||
# computations in float16 while keeping variables in float32.
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(10),
|
||||
# Softmax should be done in float32 for numeric stability. We pass
|
||||
@ -100,8 +103,7 @@ class Policy(object):
|
||||
setting the global policy with `set_policy`:
|
||||
|
||||
```python
|
||||
policy = tf.keras.mixed_precision.experimental.Policy(
|
||||
'float16_with_float32_vars', loss_scale='dynamic')
|
||||
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential(
|
||||
tf.keras.layers.Input((100,)),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
@ -146,18 +148,24 @@ class Policy(object):
|
||||
name: A string. Can be one of the following values:
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* <dtype>_with_float32_vars, where <dtype> is any dtype. The compute
|
||||
dtype will be <dtype>, while the variable dtype is float32. This is
|
||||
intended for the use of mixed precision, which uses float16 or
|
||||
bfloat16 for most computations, and float32 for variables. This policy
|
||||
is only useful if <dtype> is float16 or bfloat16, although <dtype> is
|
||||
allowed to be any dtype. Note we will have a "mixed" policy in the
|
||||
future, which will make it even easier to use mixed precision by
|
||||
enabling other features such as loss scaling.
|
||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
|
||||
dtype will be <dtype>, while the variable dtype is float32. This can
|
||||
be used for mixed precision, which uses float16 or bfloat16 for most
|
||||
computations, and float32 for variables, but it is recommended to use
|
||||
the 'mixed_float16' or 'mixed_bfloat16' policies instead.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
|
||||
respectively. 'mixed_float16' is identical to
|
||||
'float16_with_float32_vars' except the loss_scale is dynamic by
|
||||
default. 'mixed_bfloat16' is currently identical to
|
||||
'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
||||
policies in the future, to further differentiate them from
|
||||
[b]float16_with_float32_vars.
|
||||
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
|
||||
computation dtype from the input dtype.
|
||||
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
||||
one such as "dynamic". Defaults to using no loss scaling. Only
|
||||
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
||||
is "mixed_float16", in which case this defaults to "dynamic". Only
|
||||
`tf.keras.Model`s, not layers, use the loss scale, and it is only used
|
||||
during `Model.fit` or `Model.train_on_batch`.
|
||||
|
||||
@ -178,8 +186,12 @@ class Policy(object):
|
||||
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
||||
|
||||
if loss_scale == USE_DEFAULT:
|
||||
loss_scale = None
|
||||
|
||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
|
||||
'float16 policies. You passed loss_scale=%r for policy '
|
||||
'%s. Consider not passing any loss_scale instead.' %
|
||||
(loss_scale, name))
|
||||
self._loss_scale = loss_scale_module.get(loss_scale)
|
||||
|
||||
def _parse_name(self, name):
|
||||
@ -191,6 +203,11 @@ class Policy(object):
|
||||
Returns:
|
||||
The (compute_dtype, variable_dtype) pair.
|
||||
"""
|
||||
if name == 'mixed_float16':
|
||||
return 'float16', 'float32'
|
||||
elif name == 'mixed_bfloat16':
|
||||
return 'bfloat16', 'float32'
|
||||
|
||||
if name.endswith('_with_float32_vars'):
|
||||
base_name = name[:-len('_with_float32_vars')]
|
||||
float32_vars = True
|
||||
|
||||
@ -26,6 +26,7 @@ from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision
|
||||
|
||||
@ -46,10 +47,20 @@ class PolicyTest(test.TestCase):
|
||||
|
||||
for dtype in 'int32', 'bool', 'float16', 'float32':
|
||||
policy = mp_policy.Policy(dtype)
|
||||
self.assertEqual(policy.name, dtype)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, dtype)
|
||||
|
||||
policy = mp_policy.Policy(dtype + '_with_float32_vars')
|
||||
expected_name = (
|
||||
dtype if dtype == 'float32' else dtype + '_with_float32_vars')
|
||||
self.assertEqual(policy.name, expected_name)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, 'float32')
|
||||
|
||||
for dtype in 'float16', 'bfloat16':
|
||||
policy = mp_policy.Policy('mixed_' + dtype)
|
||||
self.assertEqual(policy.name, 'mixed_' + dtype)
|
||||
self.assertEqual(policy.compute_dtype, dtype)
|
||||
self.assertEqual(policy.variable_dtype, 'float32')
|
||||
|
||||
@ -117,6 +128,15 @@ class PolicyTest(test.TestCase):
|
||||
policy = mp_policy.Policy('float32', loss_scale='dynamic')
|
||||
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||
|
||||
policy = mp_policy.Policy('mixed_float16')
|
||||
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
|
||||
|
||||
policy = mp_policy.Policy('mixed_float16', loss_scale=None)
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
policy = mp_policy.Policy('mixed_bfloat16')
|
||||
self.assertEqual(policy.loss_scale, None)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_global_policy(self):
|
||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||
@ -139,6 +159,21 @@ class PolicyTest(test.TestCase):
|
||||
finally:
|
||||
mp_policy.set_policy(None)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_loss_scale_warning(self):
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float32', loss_scale=2.)
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
'Creating a Policy with a loss scale is only useful for float16 '
|
||||
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
||||
'not passing any loss_scale instead.')
|
||||
|
||||
for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16':
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy(policy_name, loss_scale=2.)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_policy_scope(self):
|
||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user