Deprecate "_with_float32_vars" policies.

These policies will be removed in TensorFlow 2.1. I plan on removing them very shortly.

PiperOrigin-RevId: 267429724
This commit is contained in:
Reed Wanderman-Milne 2019-09-05 12:38:21 -07:00 committed by Reed
parent a3a9fd34e9
commit 6b9a66d543
2 changed files with 97 additions and 39 deletions

View File

@ -55,28 +55,21 @@ class Policy(object):
computation dtype to avoid type errors.
Policies also have a `tf.train.experimental.LossScale` instance, which is used
by Models to performance loss scaling. Layers which are not Models ignore
the loss scale.
by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
ignore the loss scale.
Policies are constructed by passing a string to the constructor, e.g.
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
determines the compute and variable dtypes. Currently, it can be one of
in one of the following forms:
determines the compute and variable dtypes. It can be one of the following:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype.
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
will be <dtype>, while the variable dtype is float32. This can be used for
mixed precision, which uses float16 or bfloat16 for most computations, and
float32 for variables, but it is recommended to use the 'mixed_float16' or
'mixed_bfloat16' policies instead.
* 'mixed_float16' or 'mixed_bfloat16': Similar to
'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively.
'mixed_float16' is identical to 'float16_with_float32_vars' except the
loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical
to 'bfloat16_with_float32_vars'. More changes may be added to these mixed
policies in the future, to further differentiate them from
[b]float16_with_float32_vars.
compute dtypes will be that dtype. No loss scaling is done by default.
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
bfloat16, while the variable dtype is float32. These policies are used for
mixed precision training. With 'mixed_float16', a dynamic loss scale is
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
scaling is unnecessary with bfloat16.
### How to use mixed precision in layers with Policies
@ -118,6 +111,14 @@ class Policy(object):
constructors in the `dtype` argument instead of policies, but only if the
string is convertible to a dtype.
Note the 'mixed_float16' policy will apply loss scaling by default in
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
custom training loop is used) and 'mixed_float16' is used, the loss scale must
be manually applied. See
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
manually applied.
### The deprecated "infer" policy
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
@ -129,7 +130,13 @@ class Policy(object):
the dtype of the first input.
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
that infers the compute dtype, but not the variable dtype.
that infers the compute dtype, but not the variable dtype. Once a layer with
an "infer_with_float32_vars" policy is called for the first time, the layer's
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
of the first input. These policies force variables in float32.
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
available.
@ -137,6 +144,12 @@ class Policy(object):
# TODO(reedwm): Replace link in above docstring with a version that is more
# TensorFlow-specific, and that also mentions bfloat16.
# If True, warn when a policy is created whose name ends in
# "_with_float32_vars". We always want to warn when a user creates such a
# policy, but when the TensorFlow creates a policy, it suppresses the warning
# by setting this to False when creating the policy.
_warn_about_float32_vars = True
def __init__(self, name, loss_scale=USE_DEFAULT):
"""Constructs the policy.
@ -148,21 +161,12 @@ class Policy(object):
name: A string. Can be one of the following values:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype.
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
dtype will be <dtype>, while the variable dtype is float32. This can
be used for mixed precision, which uses float16 or bfloat16 for most
computations, and float32 for variables, but it is recommended to use
the 'mixed_float16' or 'mixed_bfloat16' policies instead.
* 'mixed_float16' or 'mixed_bfloat16': Similar to
'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
respectively. 'mixed_float16' is identical to
'float16_with_float32_vars' except the loss_scale is dynamic by
default. 'mixed_bfloat16' is currently identical to
'bfloat16_with_float32_vars'. More changes may be added to these mixed
policies in the future, to further differentiate them from
[b]float16_with_float32_vars.
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
computation dtype from the input dtype.
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
bfloat16, while the variable dtype is float32. With 'mixed_float16',
a dynamic loss scale is used. These policies are used for mixed
precision training.
* 'infer' (deprecated): Infer the compute and variable dtype from the
input dtype.
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
one such as "dynamic". Defaults to using no loss scaling unless `name`
is "mixed_float16", in which case this defaults to "dynamic". Only
@ -185,6 +189,18 @@ class Policy(object):
self._name = name
self._compute_dtype, self._variable_dtype = self._parse_name(name)
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
"in TensorFlow 2.1." % name)
if name == 'infer_with_float32_vars':
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
"policy instead.")
elif name == 'float16_with_float32_vars':
warning += " Please use the 'mixed_float16' policy instead."
elif name == 'bfloat16_with_float32_vars':
warning += " Please use the 'mixed_bfloat16' policy instead."
tf_logging.warn(warning)
if loss_scale == USE_DEFAULT:
loss_scale = 'dynamic' if name == 'mixed_float16' else None
if loss_scale and self._compute_dtype not in (None, 'float16'):
@ -221,10 +237,10 @@ class Policy(object):
try:
base_dtype = dtypes.as_dtype(base_name).name
except TypeError:
error = ('Cannot convert value %s to a mixed precision Policy. '
'Valid policies include include those in the form "<dtype>" '
'and "<dtype>_with_float32_vars", where <dtype> is the name '
'of a dtype.' % (name,))
error = ("Cannot convert value %s to a mixed precision Policy. "
"Valid policies include include 'mixed_float16', "
"'mixed_bfloat16', and the name of any dtype such as "
"'float32'." % (name,))
if float32_vars:
error += (' The value %s ends with _with_float32_vars, but %s cannot '
'be converted to a DType' % (name, base_name))
@ -337,7 +353,11 @@ def with_input_dtype(policy, dtype):
# Policies without a compute dtype are either "infer" or
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
assert policy.variable_dtype == 'float32'
return Policy(dtype + '_with_float32_vars')
try:
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
return Policy(dtype + '_with_float32_vars')
finally:
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
# The current global policy in effect. If None, it means the current value of

View File

@ -169,11 +169,49 @@ class PolicyTest(test.TestCase):
'policies. You passed loss_scale=2.0 for policy float32. Consider '
'not passing any loss_scale instead.')
for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16':
for policy_name in 'float16', 'mixed_float16':
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name, loss_scale=2.)
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
def test_float32_vars_warning(self):
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('infer_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
"or 'mixed_bfloat16' policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float16_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
"policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('bfloat16_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
"policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float64_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1.")
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name)
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self):
if base_layer_utils.v2_dtype_behavior_enabled():