Merge pull request #32269 from reedwm/mp_cherrypicks

[r2.0-rc1 CherryPick]: Several tf.keras mixed precision API changes
This commit is contained in:
Goldie Gadde 2019-09-06 14:15:19 -07:00 committed by GitHub
commit 8ea0a418a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 359 additions and 77 deletions

View File

@ -2592,6 +2592,9 @@ class AddLoss(Layer):
""" """
def __init__(self, unconditional, **kwargs): def __init__(self, unconditional, **kwargs):
# Pass autocast=False, as there is no reason to cast loss to a different
# dtype.
kwargs['autocast'] = False
super(AddLoss, self).__init__(**kwargs) super(AddLoss, self).__init__(**kwargs)
self.unconditional = unconditional self.unconditional = unconditional

View File

@ -230,12 +230,10 @@ class Network(base_layer.Layer):
else: else:
self._graph = ops.get_default_graph() # Used in symbolic mode only. self._graph = ops.get_default_graph() # Used in symbolic mode only.
# Both graph and subclassed networks have a dtype policy. The policy is # Both graph and subclassed networks have a dtype policy. For graph
# currently ignored for a graph network, as graph networks disable # networks, the policy's compute and variable dtypes are ignored, but other
# autocasting (making the policy's compute dtype meaningless) and graph # fields, like the loss scale, are used by Models. For subclassed networks,
# networks have no variables (making the policy's variable_dtype # the compute and variable dtypes are used as like any ordinary layer.
# meaningless). For subclassed networks, the dtype policy acts as it does
# for any ordinary layer.
self._set_dtype_policy(kwargs.get('dtype', None)) self._set_dtype_policy(kwargs.get('dtype', None))
# All layers in order of horizontal graph traversal. # All layers in order of horizontal graph traversal.

View File

@ -52,6 +52,8 @@ from tensorflow.python.keras.engine import training_generator
from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.engine import training_v2 from tensorflow.python.keras.engine import training_v2
from tensorflow.python.keras.engine import training_v2_utils from tensorflow.python.keras.engine import training_v2_utils
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import losses_utils
@ -248,13 +250,9 @@ class Model(network.Network):
self._experimental_run_tf_function = kwargs.pop( self._experimental_run_tf_function = kwargs.pop(
'experimental_run_tf_function', True) 'experimental_run_tf_function', True)
if isinstance(optimizer, (list, tuple)): self._set_optimizer(optimizer)
self.optimizer = [optimizers.get(opt) for opt in optimizer] is_any_optimizer_v1 = any(isinstance(opt, optimizers.Optimizer)
is_any_optimizer_v1 = any( for opt in nest.flatten(self.optimizer))
isinstance(opt, optimizers.Optimizer) for opt in self.optimizer)
else:
self.optimizer = optimizers.get(optimizer)
is_any_optimizer_v1 = isinstance(self.optimizer, optimizers.Optimizer)
if ((sample_weight_mode is not None) if ((sample_weight_mode is not None)
or (target_tensors is not None) or (target_tensors is not None)
@ -1443,6 +1441,47 @@ class Model(network.Network):
'and the first argument in `call` as positional arguments, ' 'and the first argument in `call` as positional arguments, '
'found: ' + str(extra_args) + '.') 'found: ' + str(extra_args) + '.')
def _set_optimizer(self, optimizer):
"""Sets self.optimizer.
Sets self.optimizer to `optimizer`, potentially wrapping it with a
LossScaleOptimizer.
Args:
optimizer: The optimizer(s) to assign to self.optimizer.
"""
if isinstance(optimizer, (list, tuple)):
self.optimizer = [optimizers.get(opt) for opt in optimizer]
else:
self.optimizer = optimizers.get(optimizer)
if (self._dtype_policy.loss_scale is not None and
not isinstance(self.optimizer,
loss_scale_optimizer.LossScaleOptimizer)):
if isinstance(self.optimizer, list):
raise ValueError('When a dtype policy with a loss scale is used, you '
'can only pass a single optimizer. Using policy %s '
'and got optimizers: %s' %
self._dtype_policy, self.optimizer)
if not isinstance(self.optimizer, optimizer_v2.OptimizerV2):
raise ValueError('"optimizer" must be an instance of '
'tf.keras.optimizers.Optimizer when a dype policy '
'with a loss scale used, but got: %s. Using policy: '
'%s' %
(self.optimizer, self._dtype_policy))
self.optimizer = loss_scale_optimizer.LossScaleOptimizer(
self.optimizer, self._dtype_policy.loss_scale)
if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and
self._dtype_policy.loss_scale and
self.optimizer.loss_scale != self._dtype_policy.loss_scale):
logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) '
'is not the same as the dtype policy\'s loss scale (%s). '
'Because the dtype policy has a loss scale, you should '
'pass an optimizer that is not wrapped with a '
'LossScaleOptimizer,'
% (self.optimizer.loss_scale,
self._dtype_policy.loss_scale))
def _prepare_validation_data(self, validation_data, batch_size, def _prepare_validation_data(self, validation_data, batch_size,
validation_steps): validation_steps):
"""Unpack and check the validation data.""" """Unpack and check the validation data."""

View File

@ -35,6 +35,7 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers from tensorflow.python.keras import layers
from tensorflow.python.keras import models from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import regularizers from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import base_layer
@ -434,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase):
}, { }, {
'testcase_name': 'infer', 'testcase_name': 'infer',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
'policy_name': 'infer_with_float32_vars' 'policy_name': 'mixed_float16'
}, { }, {
'testcase_name': 'norun_distributed', 'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
@ -445,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy_fn, strategy_fn,
use_operator=False, use_operator=False,
use_regularizer=False, use_regularizer=False,
policy_name='float16_with_float32_vars', policy_name='mixed_float16',
experimental_run_tf_function=True): experimental_run_tf_function=True):
if not self._is_strategy_supported(strategy_fn, check_model_type=True): if not self._is_strategy_supported(strategy_fn, check_model_type=True):
return return
regularizer = IdentityRegularizer() if use_regularizer else None regularizer = IdentityRegularizer() if use_regularizer else None
with strategy_fn().scope(): with strategy_fn().scope():
with policy.policy_scope(policy_name): # Pass loss_scale=None, as this test will fail if the DynamicLossScale
# skips applying gradients for a step
with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)):
layer_list = [] layer_list = []
if testing_utils.get_model_type() == 'subclass': if testing_utils.get_model_type() == 'subclass':
# Subclassed models do not have an Input layer, so the model does not # Subclassed models do not have an Input layer, so the model does not
@ -579,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase):
strategy = strategy_fn() strategy = strategy_fn()
if use_loss_scaling: if use_loss_scaling:
loss_scale = 8. loss_scale = 8.
else:
loss_scale = None
learning_rate = 2**-14 learning_rate = 2**-14
with strategy.scope(): with strategy.scope():
with policy.policy_scope(policy.Policy('float16_with_float32_vars')): with policy.policy_scope(policy.Policy('mixed_float16',
loss_scale=loss_scale)):
x = layers.Input(shape=(1,), batch_size=2) x = layers.Input(shape=(1,), batch_size=2)
layer1 = AddLayer( layer1 = AddLayer(
assert_type=dtypes.float16, assert_type=dtypes.float16,
@ -618,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase):
return math_ops.reduce_mean(y_pred) return math_ops.reduce_mean(y_pred)
opt = gradient_descent.SGD(learning_rate) opt = gradient_descent.SGD(learning_rate)
if use_loss_scaling:
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
model.compile( model.compile(
opt, opt,
loss=loss_fn, loss=loss_fn,
@ -646,6 +650,10 @@ class KerasModelTest(keras_parameterized.TestCase):
}, { }, {
'testcase_name': 'distribute', 'testcase_name': 'distribute',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
}, {
'testcase_name': 'pass_loss_scale_to_policy',
'strategy_fn': create_mirrored_strategy,
'pass_loss_scale_to_policy': True,
}, { }, {
'testcase_name': 'norun_distributed', 'testcase_name': 'norun_distributed',
'strategy_fn': create_mirrored_strategy, 'strategy_fn': create_mirrored_strategy,
@ -653,18 +661,27 @@ class KerasModelTest(keras_parameterized.TestCase):
}) })
def test_dynamic_loss_scaling(self, def test_dynamic_loss_scaling(self,
strategy_fn, strategy_fn,
pass_loss_scale_to_policy=False,
experimental_run_tf_function=True): experimental_run_tf_function=True):
if not self._is_strategy_supported(strategy_fn): if not self._is_strategy_supported(strategy_fn):
return return
strategy = strategy_fn() strategy = strategy_fn()
initial_loss_scale = 2. initial_loss_scale = 2.
batch_size = 4 batch_size = 4
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale, increment_period=2)
expected_gradient = backend.variable([initial_loss_scale / batch_size], expected_gradient = backend.variable([initial_loss_scale / batch_size],
dtype=dtypes.float16) dtype=dtypes.float16)
# If this variable is set to True, the model below will have NaN gradients # If this variable is set to True, the model below will have NaN gradients
have_nan_gradients = backend.variable(False, dtype=dtypes.bool) have_nan_gradients = backend.variable(False, dtype=dtypes.bool)
with strategy.scope(): with strategy.scope():
with policy.policy_scope(policy.Policy('infer_float32_vars')): opt = gradient_descent.SGD(1.)
if pass_loss_scale_to_policy:
p = policy.Policy('infer_float32_vars', loss_scale=loss_scale)
else:
p = policy.Policy('infer_float32_vars')
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
with policy.policy_scope(p):
x = layers.Input( x = layers.Input(
shape=(1,), batch_size=batch_size, dtype=dtypes.float16) shape=(1,), batch_size=batch_size, dtype=dtypes.float16)
layer = AddLayer(assert_type=dtypes.float16) layer = AddLayer(assert_type=dtypes.float16)
@ -685,10 +702,6 @@ class KerasModelTest(keras_parameterized.TestCase):
del y_true del y_true
return math_ops.reduce_mean(y_pred) return math_ops.reduce_mean(y_pred)
opt = gradient_descent.SGD(1.)
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale, increment_period=2)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
model.compile( model.compile(
opt, opt,
loss=loss_fn, loss=loss_fn,
@ -728,6 +741,40 @@ class KerasModelTest(keras_parameterized.TestCase):
model.fit(dataset) model.fit(dataset)
self.assertEqual(backend.eval(layer.v), -3) self.assertEqual(backend.eval(layer.v), -3)
@test_util.run_in_graph_and_eager_modes
@testing_utils.enable_v2_dtype_behavior
def test_loss_scale_optimizer_overrides_policy_loss_scale(self):
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
opt = gradient_descent.SGD(1.)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.)
x = layers.Input(shape=(1,))
y = AddLayer()(x)
model = models.Model(x, y)
model.compile(opt, loss='mse')
self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.)
@test_util.run_in_graph_and_eager_modes
@testing_utils.enable_v2_dtype_behavior
def test_pass_invalid_optimizer_with_loss_scaling(self):
with policy.policy_scope(policy.Policy('float32', loss_scale=10.)):
x = layers.Input(shape=(1,))
y = AddLayer()(x)
model = models.Model(x, y)
with self.assertRaisesRegexp(ValueError,
'optimizer" must be an instance of '):
model.compile(optimizers.SGD(1.), 'mse')
@test_util.run_in_graph_and_eager_modes
@testing_utils.enable_v2_dtype_behavior
def test_functional_model_loss_dtype(self):
with policy.policy_scope('float16'):
x = layers.Input(shape=(1,))
y = AddLayer()(x)
model = models.Model(x, y)
model.add_loss(math_ops.cast(y, 'float32'))
# The loss should not be casted to the policy's dtype.
self.assertEqual(model.losses[0].dtype, 'float32')
@parameterized.named_parameters( @parameterized.named_parameters(
{ {
'testcase_name': 'base', 'testcase_name': 'base',

View File

@ -24,19 +24,26 @@ import six
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
# Default value of certain arguments, indicating the default behavior for
# that argument should be used.
USE_DEFAULT = 'USE_DEFAULT'
@keras_export('keras.mixed_precision.experimental.Policy') @keras_export('keras.mixed_precision.experimental.Policy')
class Policy(object): class Policy(object):
"""A dtype policy for a Keras layer. """A dtype policy for a Keras layer.
A dtype policy determines the computation dtype and the variable dtype of a A dtype policy determines dtype-related aspects of a layer, such as its
Keras layer. Each layer has a policy. Policies can be passed to the 'dtype' computation and variable dtypes. Each layer has a policy. Policies can be
argument of layer constructors, or a global policy can be set with passed to the 'dtype' argument of layer constructors, or a global policy can
'tf.keras.mixed_precision.experimental.set_policy'. A layer will default to be set with 'tf.keras.mixed_precision.experimental.set_policy'. A layer will
the global policy if no policy is passed to it's constructor. default to the global policy if no policy is passed to it's constructor.
For most models, each layer will have the same computation dtype and variable For most models, each layer will have the same computation dtype and variable
dtype, which will typically be float32. However, when mixed precision dtype, which will typically be float32. However, when mixed precision
@ -47,55 +54,49 @@ class Policy(object):
not match the computation dtype, variables will be automatically casted to the not match the computation dtype, variables will be automatically casted to the
computation dtype to avoid type errors. computation dtype to avoid type errors.
In the near future, policies will also determine the loss scaling algorithm Policies also have a `tf.train.experimental.LossScale` instance, which is used
for Keras models. by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by
Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models
ignore the loss scale.
Policies are constructed by passing a string to the constructor, e.g. Policies are constructed by passing a string to the constructor, e.g.
`tf.keras.mixed_precision.experimental.Policy('float32')`. The string `tf.keras.mixed_precision.experimental.Policy('float32')`. The string
determines the compute and variable dtypes. Currently, it can be one of determines the compute and variable dtypes. It can be one of the following:
in one of the following forms:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and * Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype. compute dtypes will be that dtype. No loss scaling is done by default.
* '<dtype>_with_float32_vars', where <dtype> is any dtype. The compute dtype * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
will be <dtype>, while the variable dtype is float32. This is intended for bfloat16, while the variable dtype is float32. These policies are used for
the use of mixed precision, which uses float16 or bfloat16 for most mixed precision training. With 'mixed_float16', a dynamic loss scale is
computations, and float32 for variables. This policy is only useful if used by default. 'mixed_bfloat16' does no loss scaling by default, as loss
<dtype> is float16 or bfloat16, although <dtype> is allowed to be any scaling is unnecessary with bfloat16.
dtype. Note we will have a "mixed" policy in the future, which will make
it even easier to use mixed precision by enabling other features such as
loss scaling.
### How to use mixed precision in layers with Policies ### How to use mixed precision in layers with Policies
To use mixed precision in a model, the 'float16_with_float32_vars' policy can To use mixed precision in a model, the 'mixed_float16' policy can
be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set
the default policy for layers if no policy is passed to them. Note loss the default policy for layers if no policy is passed to them. For example:
scaling must also be done, e.g. with a
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`. For example
```python ```python
tf.keras.mixed_precision.experimental.set_policy( tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
'float16_with_float32_vars')
model = tf.keras.models.Sequential( model = tf.keras.models.Sequential(
tf.keras.layers.Input((100,)), tf.keras.layers.Input((100,)),
# Dense layers use global policy of 'float16_with_float32_vars' # Dense layers use global policy of 'mixed_float16', which does
# computations in float16 while keeping variables in float32.
tf.keras.layers.Dense(10), tf.keras.layers.Dense(10),
tf.keras.layers.Dense(10), tf.keras.layers.Dense(10),
# Softmax should be done in float32 for numeric stability. We pass # Softmax should be done in float32 for numeric stability. We pass
# dtype='float32' to use float32 instead of the global policy. # dtype='float32' to use float32 instead of the global policy.
tf.keras.layers.Activation('Softmax', dtype='float32') tf.keras.layers.Activation('Softmax', dtype='float32')
) )
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...) model.fit(...) # Train `model`
... # Train `model` with `opt`.
``` ```
Alternatively, the policy can be passed to individual layers instead of Alternatively, the policy can be passed to individual layers instead of
setting the global policy with `set_policy`: setting the global policy with `set_policy`:
```python ```python
policy = tf.keras.mixed_precision.experimental.Policy( policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
'float16_with_float32_vars')
model = tf.keras.models.Sequential( model = tf.keras.models.Sequential(
tf.keras.layers.Input((100,)), tf.keras.layers.Input((100,)),
tf.keras.layers.Dense(10, dtype=policy), tf.keras.layers.Dense(10, dtype=policy),
@ -103,14 +104,21 @@ class Policy(object):
# Softmax should be done in float32 for numeric stability. # Softmax should be done in float32 for numeric stability.
tf.keras.layers.Activation('Softmax', dtype='float32') tf.keras.layers.Activation('Softmax', dtype='float32')
) )
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...) model.fit(...) # Train `model`
... # Train `model` with `opt`.
``` ```
As the above example shows, strings can be directly passed to layer As the above example shows, strings can be directly passed to layer
constructors in the `dtype` argument instead of policies, but only if the constructors in the `dtype` argument instead of policies, but only if the
string is convertible to a dtype. string is convertible to a dtype.
Note the 'mixed_float16' policy will apply loss scaling by default in
`Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a
custom training loop is used) and 'mixed_float16' is used, the loss scale must
be manually applied. See
`tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For
'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be
manually applied.
### The deprecated "infer" policy ### The deprecated "infer" policy
In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be In addition to a dtype or "<dtype>_with_float32_vars", a policy can also be
@ -122,7 +130,13 @@ class Policy(object):
the dtype of the first input. the dtype of the first input.
Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy
that infers the compute dtype, but not the variable dtype. that infers the compute dtype, but not the variable dtype. Once a layer with
an "infer_with_float32_vars" policy is called for the first time, the layer's
policy will change to "<dtype>_with_float32_vars", where <dtype> is the dtype
of the first input. These policies force variables in float32.
Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow
2.1. Please use "mixed_float16" or "mixed_bfloat16" instead.
In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are
available. available.
@ -130,7 +144,13 @@ class Policy(object):
# TODO(reedwm): Replace link in above docstring with a version that is more # TODO(reedwm): Replace link in above docstring with a version that is more
# TensorFlow-specific, and that also mentions bfloat16. # TensorFlow-specific, and that also mentions bfloat16.
def __init__(self, name): # If True, warn when a policy is created whose name ends in
# "_with_float32_vars". We always want to warn when a user creates such a
# policy, but when the TensorFlow creates a policy, it suppresses the warning
# by setting this to False when creating the policy.
_warn_about_float32_vars = True
def __init__(self, name, loss_scale=USE_DEFAULT):
"""Constructs the policy. """Constructs the policy.
The `name` argument determines the compute and variable dtype, and has no The `name` argument determines the compute and variable dtype, and has no
@ -141,16 +161,17 @@ class Policy(object):
name: A string. Can be one of the following values: name: A string. Can be one of the following values:
* Any dtype name, such as 'float32' or 'float64'. Both the variable and * Any dtype name, such as 'float32' or 'float64'. Both the variable and
compute dtypes will be that dtype. compute dtypes will be that dtype.
* <dtype>_with_float32_vars, where <dtype> is any dtype. The compute * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
dtype will be <dtype>, while the variable dtype is float32. This is bfloat16, while the variable dtype is float32. With 'mixed_float16',
intended for the use of mixed precision, which uses float16 or a dynamic loss scale is used. These policies are used for mixed
bfloat16 for most computations, and float32 for variables. This policy precision training.
is only useful if <dtype> is float16 or bfloat16, although <dtype> is * 'infer' (deprecated): Infer the compute and variable dtype from the
allowed to be any dtype. Note we will have a "mixed" policy in the input dtype.
future, which will make it even easier to use mixed precision by loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to
enabling other features such as loss scaling. one such as "dynamic". Defaults to using no loss scaling unless `name`
* 'infer' or 'infer_with_float32_vars' (deprecated): Infer the is "mixed_float16", in which case this defaults to "dynamic". Only
computation dtype from the input dtype. `tf.keras.Model`s, not layers, use the loss scale, and it is only used
during `Model.fit` or `Model.train_on_batch`.
""" """
if isinstance(name, dtypes.DType): if isinstance(name, dtypes.DType):
@ -168,6 +189,27 @@ class Policy(object):
self._name = name self._name = name
self._compute_dtype, self._variable_dtype = self._parse_name(name) self._compute_dtype, self._variable_dtype = self._parse_name(name)
if name.endswith('_with_float32_vars') and self._warn_about_float32_vars:
warning = ("WARNING: The '%s' policy is deprecated and will be removed "
"in TensorFlow 2.1." % name)
if name == 'infer_with_float32_vars':
warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' "
"policy instead.")
elif name == 'float16_with_float32_vars':
warning += " Please use the 'mixed_float16' policy instead."
elif name == 'bfloat16_with_float32_vars':
warning += " Please use the 'mixed_bfloat16' policy instead."
tf_logging.warn(warning)
if loss_scale == USE_DEFAULT:
loss_scale = 'dynamic' if name == 'mixed_float16' else None
if loss_scale and self._compute_dtype not in (None, 'float16'):
tf_logging.warn('Creating a Policy with a loss scale is only useful for '
'float16 policies. You passed loss_scale=%r for policy '
'%s. Consider not passing any loss_scale instead.' %
(loss_scale, name))
self._loss_scale = loss_scale_module.get(loss_scale)
def _parse_name(self, name): def _parse_name(self, name):
"""Parses a Policy name into a compute and variable dtype. """Parses a Policy name into a compute and variable dtype.
@ -177,6 +219,11 @@ class Policy(object):
Returns: Returns:
The (compute_dtype, variable_dtype) pair. The (compute_dtype, variable_dtype) pair.
""" """
if name == 'mixed_float16':
return 'float16', 'float32'
elif name == 'mixed_bfloat16':
return 'bfloat16', 'float32'
if name.endswith('_with_float32_vars'): if name.endswith('_with_float32_vars'):
base_name = name[:-len('_with_float32_vars')] base_name = name[:-len('_with_float32_vars')]
float32_vars = True float32_vars = True
@ -190,10 +237,10 @@ class Policy(object):
try: try:
base_dtype = dtypes.as_dtype(base_name).name base_dtype = dtypes.as_dtype(base_name).name
except TypeError: except TypeError:
error = ('Cannot convert value %s to a mixed precision Policy. ' error = ("Cannot convert value %s to a mixed precision Policy. "
'Valid policies include include those in the form "<dtype>" ' "Valid policies include include 'mixed_float16', "
'and "<dtype>_with_float32_vars", where <dtype> is the name ' "'mixed_bfloat16', and the name of any dtype such as "
'of a dtype.' % (name,)) "'float32'." % (name,))
if float32_vars: if float32_vars:
error += (' The value %s ends with _with_float32_vars, but %s cannot ' error += (' The value %s ends with _with_float32_vars, but %s cannot '
'be converted to a DType' % (name, base_name)) 'be converted to a DType' % (name, base_name))
@ -264,13 +311,22 @@ class Policy(object):
""" """
return self.variable_dtype != self.compute_dtype return self.variable_dtype != self.compute_dtype
@property
def loss_scale(self):
"""Returns the loss scale of this Policy.
Returns:
A `tf.train.experimental.LossScale`, or None.
"""
return self._loss_scale
@property @property
def name(self): def name(self):
"""Returns the name of this policy.""" """Returns the name of this policy."""
return self._name return self._name
def __repr__(self): def __repr__(self):
return '<Policy "%s">' % self._name return '<Policy "%s", loss_scale=%s>' % (self._name, self.loss_scale)
def with_input_dtype(policy, dtype): def with_input_dtype(policy, dtype):
@ -297,7 +353,11 @@ def with_input_dtype(policy, dtype):
# Policies without a compute dtype are either "infer" or # Policies without a compute dtype are either "infer" or
# "infer_with_float32_vars", so the variable_dtype must be float32 here. # "infer_with_float32_vars", so the variable_dtype must be float32 here.
assert policy.variable_dtype == 'float32' assert policy.variable_dtype == 'float32'
try:
Policy._warn_about_float32_vars = False # pylint: disable=protected-access
return Policy(dtype + '_with_float32_vars') return Policy(dtype + '_with_float32_vars')
finally:
Policy._warn_about_float32_vars = True # pylint: disable=protected-access
# The current global policy in effect. If None, it means the current value of # The current global policy in effect. If None, it means the current value of

View File

@ -26,6 +26,8 @@ from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy
from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.training.experimental import mixed_precision
@ -45,10 +47,20 @@ class PolicyTest(test.TestCase):
for dtype in 'int32', 'bool', 'float16', 'float32': for dtype in 'int32', 'bool', 'float16', 'float32':
policy = mp_policy.Policy(dtype) policy = mp_policy.Policy(dtype)
self.assertEqual(policy.name, dtype)
self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, dtype) self.assertEqual(policy.variable_dtype, dtype)
policy = mp_policy.Policy(dtype + '_with_float32_vars') policy = mp_policy.Policy(dtype + '_with_float32_vars')
expected_name = (
dtype if dtype == 'float32' else dtype + '_with_float32_vars')
self.assertEqual(policy.name, expected_name)
self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, 'float32')
for dtype in 'float16', 'bfloat16':
policy = mp_policy.Policy('mixed_' + dtype)
self.assertEqual(policy.name, 'mixed_' + dtype)
self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.compute_dtype, dtype)
self.assertEqual(policy.variable_dtype, 'float32') self.assertEqual(policy.variable_dtype, 'float32')
@ -57,9 +69,11 @@ class PolicyTest(test.TestCase):
for policy in ('infer', 'infer_with_float32_vars', 'float32', for policy in ('infer', 'infer_with_float32_vars', 'float32',
'float16_with_float32_vars'): 'float16_with_float32_vars'):
self.assertEqual(repr(mp_policy.Policy(policy)), self.assertEqual(repr(mp_policy.Policy(policy)),
'<Policy "%s">' % policy) '<Policy "%s", loss_scale=None>' % policy)
self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')), self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')),
'<Policy "float32">') '<Policy "float32", loss_scale=None>')
self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)),
'<Policy "float16", loss_scale=FixedLossScale(2.0)>')
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_policy_errors(self): def test_policy_errors(self):
@ -99,6 +113,30 @@ class PolicyTest(test.TestCase):
self.assertEqual(policy.compute_dtype, 'float32') self.assertEqual(policy.compute_dtype, 'float32')
self.assertEqual(policy.variable_dtype, 'float32') self.assertEqual(policy.variable_dtype, 'float32')
@testing_utils.enable_v2_dtype_behavior
def test_loss_scale(self):
policy = mp_policy.Policy('float32')
self.assertEqual(policy.loss_scale, None)
policy = mp_policy.Policy('float32', loss_scale=None)
self.assertEqual(policy.loss_scale, None)
ls = loss_scale_module.DynamicLossScale()
policy = mp_policy.Policy('float32', loss_scale=ls)
self.assertIs(policy.loss_scale, ls)
policy = mp_policy.Policy('float32', loss_scale='dynamic')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16')
self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale)
policy = mp_policy.Policy('mixed_float16', loss_scale=None)
self.assertEqual(policy.loss_scale, None)
policy = mp_policy.Policy('mixed_bfloat16')
self.assertEqual(policy.loss_scale, None)
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_global_policy(self): def test_global_policy(self):
if base_layer_utils.v2_dtype_behavior_enabled(): if base_layer_utils.v2_dtype_behavior_enabled():
@ -121,6 +159,59 @@ class PolicyTest(test.TestCase):
finally: finally:
mp_policy.set_policy(None) mp_policy.set_policy(None)
@testing_utils.enable_v2_dtype_behavior
def test_loss_scale_warning(self):
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float32', loss_scale=2.)
self.assertEqual(
mock_warn.call_args[0][0],
'Creating a Policy with a loss scale is only useful for float16 '
'policies. You passed loss_scale=2.0 for policy float32. Consider '
'not passing any loss_scale instead.')
for policy_name in 'float16', 'mixed_float16':
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name, loss_scale=2.)
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior
def test_float32_vars_warning(self):
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('infer_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'infer_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
"or 'mixed_bfloat16' policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float16_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'float16_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_float16' "
"policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('bfloat16_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' "
"policy instead.")
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy('float64_with_float32_vars')
self.assertEqual(
mock_warn.call_args[0][0],
"WARNING: The 'float64_with_float32_vars' policy is deprecated and "
"will be removed in TensorFlow 2.1.")
for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16':
with test.mock.patch.object(tf_logging, 'warn') as mock_warn:
mp_policy.Policy(policy_name)
mock_warn.assert_not_called()
@testing_utils.enable_v2_dtype_behavior @testing_utils.enable_v2_dtype_behavior
def test_policy_scope(self): def test_policy_scope(self):
if base_layer_utils.v2_dtype_behavior_enabled(): if base_layer_utils.v2_dtype_behavior_enabled():

View File

@ -227,6 +227,9 @@ class FixedLossScale(LossScale):
del grads del grads
return control_flow_ops.no_op(), True return control_flow_ops.no_op(), True
def __repr__(self):
return 'FixedLossScale(%s)' % self._loss_scale_value
def get_config(self): def get_config(self):
return {'loss_scale_value': self._loss_scale_value} return {'loss_scale_value': self._loss_scale_value}
@ -376,6 +379,17 @@ class DynamicLossScale(LossScale):
should_apply_gradients = is_finite should_apply_gradients = is_finite
return update_op, should_apply_gradients return update_op, should_apply_gradients
def __repr__(self):
if context.executing_eagerly():
return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, '
'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' %
(self._current_loss_scale.numpy(), self._num_good_steps.numpy(),
self.initial_loss_scale, self.increment_period, self.multiplier))
else:
return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, '
'multiplier=%s)' %
(self.initial_loss_scale, self.increment_period, self.multiplier))
def get_config(self): def get_config(self):
return { return {
'initial_loss_scale': self.initial_loss_scale, 'initial_loss_scale': self.initial_loss_scale,

View File

@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase):
scalar = loss_scale_module.FixedLossScale(123) scalar = loss_scale_module.FixedLossScale(123)
self.assertIsInstance(scalar(), ops.Tensor) self.assertIsInstance(scalar(), ops.Tensor)
@test_util.run_in_graph_and_eager_modes
def test_repr(self):
loss_scale = loss_scale_module.FixedLossScale(123)
self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)')
def _get_example_iter(inputs): def _get_example_iter(inputs):
dataset = dataset_ops.Dataset.from_tensor_slices(inputs) dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
scalar = loss_scale_module.DynamicLossScale() scalar = loss_scale_module.DynamicLossScale()
self.assertIsInstance(scalar(), ops.Tensor) self.assertIsInstance(scalar(), ops.Tensor)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_repr(self, strategy_fn):
with strategy_fn().scope():
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1, increment_period=2, multiplier=3)
if context.executing_eagerly():
self.assertEqual(repr(loss_scale),
'DynamicLossScale(current_loss_scale=1.0, '
'num_good_steps=0, initial_loss_scale=1.0, '
'increment_period=2, multiplier=3.0)')
else:
self.assertEqual(repr(loss_scale),
'DynamicLossScale(initial_loss_scale=1.0, '
'increment_period=2, multiplier=3.0)')
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -6,6 +6,10 @@ tf_class {
name: "compute_dtype" name: "compute_dtype"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member { member {
name: "name" name: "name"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -20,6 +24,6 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
} }
} }

View File

@ -6,6 +6,10 @@ tf_class {
name: "compute_dtype" name: "compute_dtype"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
} }
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member { member {
name: "name" name: "name"
mtype: "<type \'property\'>" mtype: "<type \'property\'>"
@ -20,6 +24,6 @@ tf_class {
} }
member_method { member_method {
name: "__init__" name: "__init__"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], "
} }
} }