Deprecate "_with_float32_vars" policies.
These policies will be removed in TensorFlow 2.1. I plan on removing them very shortly. PiperOrigin-RevId: 267429724
This commit is contained in:
parent
a3a9fd34e9
commit
6b9a66d543
@ -55,28 +55,21 @@ class Policy(object):
|
||||
computation dtype to avoid type errors.
|
||||
|
||||
Policies also have a `tf.train.experimental.LossScale` instance, which is used
|
||||
by Models to performance loss scaling. Layers which are not Models ignore
|
||||
the loss scale.
|
||||
by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
|
||||
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
|
||||
ignore the loss scale.
|
||||
|
||||
Policies are constructed by passing a string to the constructor, e.g.
|
||||
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
||||
determines the compute and variable dtypes. Currently, it can be one of
|
||||
in one of the following forms:
|
||||
determines the compute and variable dtypes. It can be one of the following:
|
||||
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
|
||||
will be <dtype>, while the variable dtype is float32. This can be used for
|
||||
mixed precision, which uses float16 or bfloat16 for most computations, and
|
||||
float32 for variables, but it is recommended to use the 'mixed_float16' or
|
||||
'mixed_bfloat16' policies instead.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively.
|
||||
'mixed_float16' is identical to 'float16_with_float32_vars' except the
|
||||
loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical
|
||||
to 'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
||||
policies in the future, to further differentiate them from
|
||||
[b]float16_with_float32_vars.
|
||||
compute dtypes will be that dtype. No loss scaling is done by default.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. These policies are used for
|
||||
mixed precision training. With 'mixed_float16', a dynamic loss scale is
|
||||
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
|
||||
scaling is unnecessary with bfloat16.
|
||||
|
||||
### How to use mixed precision in layers with Policies
|
||||
|
||||
@ -118,6 +111,14 @@ class Policy(object):
|
||||
constructors in the `dtype` argument instead of policies, but only if the
|
||||
string is convertible to a dtype.
|
||||
|
||||
Note the 'mixed_float16' policy will apply loss scaling by default in
|
||||
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
|
||||
custom training loop is used) and 'mixed_float16' is used, the loss scale must
|
||||
be manually applied. See
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
|
||||
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
|
||||
manually applied.
|
||||
|
||||
### The deprecated "infer" policy
|
||||
|
||||
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
||||
@ -129,7 +130,13 @@ class Policy(object):
|
||||
the dtype of the first input.
|
||||
|
||||
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
||||
that infers the compute dtype, but not the variable dtype.
|
||||
that infers the compute dtype, but not the variable dtype. Once a layer with
|
||||
an "infer_with_float32_vars" policy is called for the first time, the layer's
|
||||
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
|
||||
of the first input. These policies force variables in float32.
|
||||
|
||||
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
|
||||
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
|
||||
|
||||
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
||||
available.
|
||||
@ -137,6 +144,12 @@ class Policy(object):
|
||||
# TODO(reedwm): Replace link in above docstring with a version that is more
|
||||
# TensorFlow-specific, and that also mentions bfloat16.
|
||||
|
||||
# If True, warn when a policy is created whose name ends in
|
||||
# "_with_float32_vars". We always want to warn when a user creates such a
|
||||
# policy, but when the TensorFlow creates a policy, it suppresses the warning
|
||||
# by setting this to False when creating the policy.
|
||||
_warn_about_float32_vars = True
|
||||
|
||||
def __init__(self, name, loss_scale=USE_DEFAULT):
|
||||
"""Constructs the policy.
|
||||
|
||||
@ -148,21 +161,12 @@ class Policy(object):
|
||||
name: A string. Can be one of the following values:
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
|
||||
dtype will be <dtype>, while the variable dtype is float32. This can
|
||||
be used for mixed precision, which uses float16 or bfloat16 for most
|
||||
computations, and float32 for variables, but it is recommended to use
|
||||
the 'mixed_float16' or 'mixed_bfloat16' policies instead.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
|
||||
respectively. 'mixed_float16' is identical to
|
||||
'float16_with_float32_vars' except the loss_scale is dynamic by
|
||||
default. 'mixed_bfloat16' is currently identical to
|
||||
'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
||||
policies in the future, to further differentiate them from
|
||||
[b]float16_with_float32_vars.
|
||||
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
|
||||
computation dtype from the input dtype.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||
a dynamic loss scale is used. These policies are used for mixed
|
||||
precision training.
|
||||
* 'infer' (deprecated): Infer the compute and variable dtype from the
|
||||
input dtype.
|
||||
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
||||
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
||||
is "mixed_float16", in which case this defaults to "dynamic". Only
|
||||
@ -185,6 +189,18 @@ class Policy(object):
|
||||
self._name = name
|
||||
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
||||
|
||||
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
|
||||
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
|
||||
"in TensorFlow 2.1." % name)
|
||||
if name == 'infer_with_float32_vars':
|
||||
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
|
||||
"policy instead.")
|
||||
elif name == 'float16_with_float32_vars':
|
||||
warning += " Please use the 'mixed_float16' policy instead."
|
||||
elif name == 'bfloat16_with_float32_vars':
|
||||
warning += " Please use the 'mixed_bfloat16' policy instead."
|
||||
tf_logging.warn(warning)
|
||||
|
||||
if loss_scale == USE_DEFAULT:
|
||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||
@ -221,10 +237,10 @@ class Policy(object):
|
||||
try:
|
||||
base_dtype = dtypes.as_dtype(base_name).name
|
||||
except TypeError:
|
||||
error = ('Cannot convert value %s to a mixed precision Policy. '
|
||||
'Valid policies include include those in the form "<dtype>" '
|
||||
'and "<dtype>_with_float32_vars", where <dtype> is the name '
|
||||
'of a dtype.' % (name,))
|
||||
error = ("Cannot convert value %s to a mixed precision Policy. "
|
||||
"Valid policies include include 'mixed_float16', "
|
||||
"'mixed_bfloat16', and the name of any dtype such as "
|
||||
"'float32'." % (name,))
|
||||
if float32_vars:
|
||||
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
||||
'be converted to a DType' % (name, base_name))
|
||||
@ -337,7 +353,11 @@ def with_input_dtype(policy, dtype):
|
||||
# Policies without a compute dtype are either "infer" or
|
||||
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
||||
assert policy.variable_dtype == 'float32'
|
||||
return Policy(dtype + '_with_float32_vars')
|
||||
try:
|
||||
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
|
||||
return Policy(dtype + '_with_float32_vars')
|
||||
finally:
|
||||
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
|
||||
|
||||
|
||||
# The current global policy in effect. If None, it means the current value of
|
||||
|
@ -169,11 +169,49 @@ class PolicyTest(test.TestCase):
|
||||
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
||||
'not passing any loss_scale instead.')
|
||||
|
||||
for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16':
|
||||
for policy_name in 'float16', 'mixed_float16':
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy(policy_name, loss_scale=2.)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_float32_vars_warning(self):
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('infer_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||
"or 'mixed_bfloat16' policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float16_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||
"policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('bfloat16_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
|
||||
"policy instead.")
|
||||
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy('float64_with_float32_vars')
|
||||
self.assertEqual(
|
||||
mock_warn.call_args[0][0],
|
||||
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
|
||||
"will be removed in TensorFlow 2.1.")
|
||||
|
||||
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
|
||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||
mp_policy.Policy(policy_name)
|
||||
mock_warn.assert_not_called()
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_policy_scope(self):
|
||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||
|
Loading…
x
Reference in New Issue
Block a user