Deprecate "_with_float32_vars" policies.
These policies will be removed in TensorFlow 2.1. I plan on removing them very shortly. PiperOrigin-RevId: 267429724
This commit is contained in:
parent
a3a9fd34e9
commit
6b9a66d543
@ -55,28 +55,21 @@ class Policy(object):
|
|||||||
computation dtype to avoid type errors.
|
computation dtype to avoid type errors.
|
||||||
|
|
||||||
Policies also have a `tf.train.experimental.LossScale` instance, which is used
|
Policies also have a `tf.train.experimental.LossScale` instance, which is used
|
||||||
by Models to performance loss scaling. Layers which are not Models ignore
|
by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
|
||||||
the loss scale.
|
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
|
||||||
|
ignore the loss scale.
|
||||||
|
|
||||||
Policies are constructed by passing a string to the constructor, e.g.
|
Policies are constructed by passing a string to the constructor, e.g.
|
||||||
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string
|
||||||
determines the compute and variable dtypes. Currently, it can be one of
|
determines the compute and variable dtypes. It can be one of the following:
|
||||||
in one of the following forms:
|
|
||||||
|
|
||||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||||
compute dtypes will be that dtype.
|
compute dtypes will be that dtype. No loss scaling is done by default.
|
||||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype
|
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||||
will be <dtype>, while the variable dtype is float32. This can be used for
|
bfloat16, while the variable dtype is float32. These policies are used for
|
||||||
mixed precision, which uses float16 or bfloat16 for most computations, and
|
mixed precision training. With 'mixed_float16', a dynamic loss scale is
|
||||||
float32 for variables, but it is recommended to use the 'mixed_float16' or
|
used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
|
||||||
'mixed_bfloat16' policies instead.
|
scaling is unnecessary with bfloat16.
|
||||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
|
||||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively.
|
|
||||||
'mixed_float16' is identical to 'float16_with_float32_vars' except the
|
|
||||||
loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical
|
|
||||||
to 'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
|
||||||
policies in the future, to further differentiate them from
|
|
||||||
[b]float16_with_float32_vars.
|
|
||||||
|
|
||||||
### How to use mixed precision in layers with Policies
|
### How to use mixed precision in layers with Policies
|
||||||
|
|
||||||
@ -118,6 +111,14 @@ class Policy(object):
|
|||||||
constructors in the `dtype` argument instead of policies, but only if the
|
constructors in the `dtype` argument instead of policies, but only if the
|
||||||
string is convertible to a dtype.
|
string is convertible to a dtype.
|
||||||
|
|
||||||
|
Note the 'mixed_float16' policy will apply loss scaling by default in
|
||||||
|
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
|
||||||
|
custom training loop is used) and 'mixed_float16' is used, the loss scale must
|
||||||
|
be manually applied. See
|
||||||
|
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
|
||||||
|
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
|
||||||
|
manually applied.
|
||||||
|
|
||||||
### The deprecated "infer" policy
|
### The deprecated "infer" policy
|
||||||
|
|
||||||
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
|
||||||
@ -129,7 +130,13 @@ class Policy(object):
|
|||||||
the dtype of the first input.
|
the dtype of the first input.
|
||||||
|
|
||||||
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
|
||||||
that infers the compute dtype, but not the variable dtype.
|
that infers the compute dtype, but not the variable dtype. Once a layer with
|
||||||
|
an "infer_with_float32_vars" policy is called for the first time, the layer's
|
||||||
|
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
|
||||||
|
of the first input. These policies force variables in float32.
|
||||||
|
|
||||||
|
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
|
||||||
|
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
|
||||||
|
|
||||||
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
|
||||||
available.
|
available.
|
||||||
@ -137,6 +144,12 @@ class Policy(object):
|
|||||||
# TODO(reedwm): Replace link in above docstring with a version that is more
|
# TODO(reedwm): Replace link in above docstring with a version that is more
|
||||||
# TensorFlow-specific, and that also mentions bfloat16.
|
# TensorFlow-specific, and that also mentions bfloat16.
|
||||||
|
|
||||||
|
# If True, warn when a policy is created whose name ends in
|
||||||
|
# "_with_float32_vars". We always want to warn when a user creates such a
|
||||||
|
# policy, but when the TensorFlow creates a policy, it suppresses the warning
|
||||||
|
# by setting this to False when creating the policy.
|
||||||
|
_warn_about_float32_vars = True
|
||||||
|
|
||||||
def __init__(self, name, loss_scale=USE_DEFAULT):
|
def __init__(self, name, loss_scale=USE_DEFAULT):
|
||||||
"""Constructs the policy.
|
"""Constructs the policy.
|
||||||
|
|
||||||
@ -148,21 +161,12 @@ class Policy(object):
|
|||||||
name: A string. Can be one of the following values:
|
name: A string. Can be one of the following values:
|
||||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||||
compute dtypes will be that dtype.
|
compute dtypes will be that dtype.
|
||||||
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute
|
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||||
dtype will be <dtype>, while the variable dtype is float32. This can
|
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||||
be used for mixed precision, which uses float16 or bfloat16 for most
|
a dynamic loss scale is used. These policies are used for mixed
|
||||||
computations, and float32 for variables, but it is recommended to use
|
precision training.
|
||||||
the 'mixed_float16' or 'mixed_bfloat16' policies instead.
|
* 'infer' (deprecated): Infer the compute and variable dtype from the
|
||||||
* 'mixed_float16' or 'mixed_bfloat16': Similar to
|
input dtype.
|
||||||
'float16_with_float32_vars' or 'bfloat16_with_float32_vars'
|
|
||||||
respectively. 'mixed_float16' is identical to
|
|
||||||
'float16_with_float32_vars' except the loss_scale is dynamic by
|
|
||||||
default. 'mixed_bfloat16' is currently identical to
|
|
||||||
'bfloat16_with_float32_vars'. More changes may be added to these mixed
|
|
||||||
policies in the future, to further differentiate them from
|
|
||||||
[b]float16_with_float32_vars.
|
|
||||||
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the
|
|
||||||
computation dtype from the input dtype.
|
|
||||||
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
|
||||||
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
one such as "dynamic". Defaults to using no loss scaling unless `name`
|
||||||
is "mixed_float16", in which case this defaults to "dynamic". Only
|
is "mixed_float16", in which case this defaults to "dynamic". Only
|
||||||
@ -185,6 +189,18 @@ class Policy(object):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
self._compute_dtype, self._variable_dtype = self._parse_name(name)
|
||||||
|
|
||||||
|
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
|
||||||
|
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
|
||||||
|
"in TensorFlow 2.1." % name)
|
||||||
|
if name == 'infer_with_float32_vars':
|
||||||
|
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
|
||||||
|
"policy instead.")
|
||||||
|
elif name == 'float16_with_float32_vars':
|
||||||
|
warning += " Please use the 'mixed_float16' policy instead."
|
||||||
|
elif name == 'bfloat16_with_float32_vars':
|
||||||
|
warning += " Please use the 'mixed_bfloat16' policy instead."
|
||||||
|
tf_logging.warn(warning)
|
||||||
|
|
||||||
if loss_scale == USE_DEFAULT:
|
if loss_scale == USE_DEFAULT:
|
||||||
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
loss_scale = 'dynamic' if name == 'mixed_float16' else None
|
||||||
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
if loss_scale and self._compute_dtype not in (None, 'float16'):
|
||||||
@ -221,10 +237,10 @@ class Policy(object):
|
|||||||
try:
|
try:
|
||||||
base_dtype = dtypes.as_dtype(base_name).name
|
base_dtype = dtypes.as_dtype(base_name).name
|
||||||
except TypeError:
|
except TypeError:
|
||||||
error = ('Cannot convert value %s to a mixed precision Policy. '
|
error = ("Cannot convert value %s to a mixed precision Policy. "
|
||||||
'Valid policies include include those in the form "<dtype>" '
|
"Valid policies include include 'mixed_float16', "
|
||||||
'and "<dtype>_with_float32_vars", where <dtype> is the name '
|
"'mixed_bfloat16', and the name of any dtype such as "
|
||||||
'of a dtype.' % (name,))
|
"'float32'." % (name,))
|
||||||
if float32_vars:
|
if float32_vars:
|
||||||
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
error += (' The value %s ends with _with_float32_vars, but %s cannot '
|
||||||
'be converted to a DType' % (name, base_name))
|
'be converted to a DType' % (name, base_name))
|
||||||
@ -337,7 +353,11 @@ def with_input_dtype(policy, dtype):
|
|||||||
# Policies without a compute dtype are either "infer" or
|
# Policies without a compute dtype are either "infer" or
|
||||||
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
# "infer_with_float32_vars", so the variable_dtype must be float32 here.
|
||||||
assert policy.variable_dtype == 'float32'
|
assert policy.variable_dtype == 'float32'
|
||||||
|
try:
|
||||||
|
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
|
||||||
return Policy(dtype + '_with_float32_vars')
|
return Policy(dtype + '_with_float32_vars')
|
||||||
|
finally:
|
||||||
|
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
# The current global policy in effect. If None, it means the current value of
|
# The current global policy in effect. If None, it means the current value of
|
||||||
|
@ -169,11 +169,49 @@ class PolicyTest(test.TestCase):
|
|||||||
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
'policies. You passed loss_scale=2.0 for policy float32. Consider '
|
||||||
'not passing any loss_scale instead.')
|
'not passing any loss_scale instead.')
|
||||||
|
|
||||||
for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16':
|
for policy_name in 'float16', 'mixed_float16':
|
||||||
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
mp_policy.Policy(policy_name, loss_scale=2.)
|
mp_policy.Policy(policy_name, loss_scale=2.)
|
||||||
mock_warn.assert_not_called()
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
|
def test_float32_vars_warning(self):
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('infer_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||||
|
"or 'mixed_bfloat16' policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('float16_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
|
||||||
|
"policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('bfloat16_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
|
||||||
|
"policy instead.")
|
||||||
|
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy('float64_with_float32_vars')
|
||||||
|
self.assertEqual(
|
||||||
|
mock_warn.call_args[0][0],
|
||||||
|
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
|
||||||
|
"will be removed in TensorFlow 2.1.")
|
||||||
|
|
||||||
|
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
|
||||||
|
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
|
||||||
|
mp_policy.Policy(policy_name)
|
||||||
|
mock_warn.assert_not_called()
|
||||||
|
|
||||||
@testing_utils.enable_v2_dtype_behavior
|
@testing_utils.enable_v2_dtype_behavior
|
||||||
def test_policy_scope(self):
|
def test_policy_scope(self):
|
||||||
if base_layer_utils.v2_dtype_behavior_enabled():
|
if base_layer_utils.v2_dtype_behavior_enabled():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user