Improve mixed precision docstrings.
The docstrings have been reworded to make them more clear and concise. Some clarifying information is also added. I removed some examples for uncommon use cases in order to shorten the docstrings, and rewrote or shortened other examples to make them easier and faster to read. All references to experimental mixed precision APIs have been changed to use the non-experimental APIs. Examples now use the newly added attribute Layer.dtype_policy as well. The section "How to use float64 in a Keras model" has been removed. Float64 can be enabled by setting floatx to float64 so I don't think its necessary to mention it in the policy section. Also, it's fairly obvious after reading the Policy docstring how to use float64 using policies: Just set the global policy to "float64". I intend on cherrypicking this change into TF 2.4. PiperOrigin-RevId: 339780420 Change-Id: I5f6ad44f54964114c398c306d2d7a4da39bb1c54
This commit is contained in:
parent
6fd2c09dec
commit
be5b64a639
@ -138,10 +138,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
|
||||
Attributes:
|
||||
name: The name of the layer (string).
|
||||
dtype: The dtype of the layer's computations and weights. If mixed
|
||||
precision is used with a `tf.keras.mixed_precision.Policy`, this is
|
||||
instead just the dtype of the layer's weights, as the computations are
|
||||
done in a different dtype.
|
||||
dtype: The dtype of the layer's weights.
|
||||
variable_dtype: Alias of `dtype`.
|
||||
compute_dtype: The dtype of the layer's computations. Layers automatically
|
||||
cast inputs to this dtype which causes the computations and output to also
|
||||
be in this dtype. When mixed precision is used with a
|
||||
`tf.keras.mixed_precision.Policy`, this will be different than
|
||||
`variable_dtype`.
|
||||
dtype_policy: The layer's dtype policy. See the
|
||||
`tf.keras.mixed_precision.Policy` documentation for details.
|
||||
trainable_weights: List of variables to be included in backprop.
|
||||
non_trainable_weights: List of variables that should not be
|
||||
included in backprop.
|
||||
@ -517,7 +522,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
Arguments:
|
||||
name: Variable name.
|
||||
shape: Variable shape. Defaults to scalar if unspecified.
|
||||
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
|
||||
dtype: The type of the variable. Defaults to `self.dtype`.
|
||||
initializer: Initializer instance (callable).
|
||||
regularizer: Regularizer instance (callable).
|
||||
trainable: Boolean, whether the variable should be part of the layer's
|
||||
@ -2373,6 +2378,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
mixed precision is used, this is the same as `Layer.dtype`, the dtype of
|
||||
the weights.
|
||||
|
||||
Layers automatically cast their inputs to the compute dtype, which causes
|
||||
computations and the output to be in the compute dtype as well. This is done
|
||||
by the base Layer class in `Layer.__call__`, so you do not have to insert
|
||||
these casts if implementing your own layer.
|
||||
|
||||
Layers often perform certain internal computations in higher precision when
|
||||
`compute_dtype` is float16 or bfloat16 for numeric stability. The output
|
||||
will still typically be float16 or bfloat16 in such cases.
|
||||
|
@ -395,53 +395,35 @@ _DEFAULT_GROWTH_STEPS = 2000
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
@keras_export('keras.mixed_precision.LossScaleOptimizer')
|
||||
class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
"""An optimizer that applies loss scaling.
|
||||
"""An optimizer that applies loss scaling to prevent numeric underflow.
|
||||
|
||||
Loss scaling is a process that multiplies the loss by a multiplier called the
|
||||
loss scale, and divides each gradient by the same multiplier. The pseudocode
|
||||
for this process is:
|
||||
|
||||
```
|
||||
loss = ...
|
||||
loss *= loss_scale
|
||||
grads = gradients(loss, vars)
|
||||
grads /= loss_scale
|
||||
```
|
||||
|
||||
Mathematically, loss scaling has no effect, but can help avoid numerical
|
||||
underflow in intermediate gradients when float16 tensors are used. By
|
||||
multiplying the loss, each intermediate gradient will have the same multiplier
|
||||
applied.
|
||||
|
||||
The loss scale can either be a fixed constant, chosen by the user, or be
|
||||
dynamically determined. Using a dynamic loss scale is highly recommend and is
|
||||
the default behavior, as choosing a specific fixed loss scale is difficult.
|
||||
Every step, the dynamic loss scale is potentially updated to a new value.
|
||||
Dynamic loss scaling sometimes causes the loss scale to be too high and cause
|
||||
the gradients to overflow, in which case gradients are not applied to
|
||||
variables that step.
|
||||
Loss scaling is a technique to prevent numeric underflow in intermediate
|
||||
gradients when float16 is used. To prevent underflow, the loss is multiplied
|
||||
(or "scaled") by a certain factor called the "loss scale", which causes
|
||||
intermediate gradients to be scaled by the loss scale as well. The final
|
||||
gradients are divided (or "unscaled") by the loss scale to bring them back to
|
||||
their original value.
|
||||
|
||||
`LossScaleOptimizer` wraps another optimizer and applies loss scaling to it.
|
||||
Loss scaling is applied whenever gradients are computed, either through
|
||||
`minimize()` or `get_gradients()`. If dynamic, the loss scale is updated
|
||||
whenever gradients are applied, either through `minimize()` or
|
||||
`apply_gradients()`. For example:
|
||||
By default, the loss scale is dynamically updated over time so you do not have
|
||||
to choose the loss scale. The `minimize` method automatically scales the loss,
|
||||
unscales the gradients, and updates the loss scale so all you have to do is
|
||||
wrap your optimizer with a `LossScaleOptimizer` if you use `minimize`. For
|
||||
example:
|
||||
|
||||
>>> opt = tf.keras.optimizers.SGD(0.25)
|
||||
>>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
|
||||
>>> var = tf.Variable(1.)
|
||||
>>> loss_fn = lambda: var ** 2
|
||||
>>> # 'minimize' applies loss scaling to the loss and updates the loss sale.
|
||||
>>> # 'minimize' applies loss scaling and updates the loss sale.
|
||||
>>> opt.minimize(loss_fn, var_list=var)
|
||||
>>> var.numpy()
|
||||
0.5
|
||||
|
||||
If a `tf.GradientTape` is used to compute gradients instead of
|
||||
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
|
||||
and gradients must be scaled manually. This can be done by calling
|
||||
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
|
||||
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
|
||||
computing the gradients with `tf.GradientTape`. For example:
|
||||
If a `tf.GradientTape` is used to compute gradients instead of `minimize`, you
|
||||
must scale the loss and gradients manually. This can be done with the
|
||||
`LossScaleOptimizer.get_scaled_loss` and
|
||||
`LossScaleOptimizer.get_unscaled_gradients` methods. For example:
|
||||
|
||||
>>> with tf.GradientTape() as tape:
|
||||
... loss = loss_fn()
|
||||
@ -452,8 +434,18 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
>>> var.numpy()
|
||||
0.25
|
||||
|
||||
Warning: If you forget to call `get_scaled_loss` or `get_unscaled_gradients`
|
||||
(or both) when using a `tf.GradientTape`, the model will likely converge to a
|
||||
worse quality. Please make sure you call each function exactly once.
|
||||
|
||||
When mixed precision with float16 is used, there is typically no risk of
|
||||
underflow affecting model quality if loss scaling is properly used. See
|
||||
[the mixed precision guide](
|
||||
https://www.tensorflow.org/guide/keras/mixed_precision) for more information
|
||||
on how to use mixed precision.
|
||||
|
||||
Args:
|
||||
inner_optimizer: The Optimizer instance to wrap.
|
||||
inner_optimizer: The `tf.keras.optimizers.Optimizer` instance to wrap.
|
||||
dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to
|
||||
True. If True, the loss scale will be dynamically updated over time using
|
||||
an algorithm that keeps the loss scale at approximately its optimal value.
|
||||
@ -463,11 +455,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
performance overhead to dynamic loss scaling compared to fixed loss
|
||||
scaling.
|
||||
initial_scale: The initial loss scale. If `dynamic` is True, this defaults
|
||||
to 2 ** 15. If `dynamic` is False, this must be specified and acts as the
|
||||
sole loss scale, as the loss scale does not change over time. When dynamic
|
||||
loss scaling is used, is better for this to be a very high number, because
|
||||
a loss scale that is too high gets lowered far more quickly than a loss
|
||||
scale that is too low gets raised.
|
||||
to `2 ** 15`. If `dynamic` is False, this must be specified and acts as
|
||||
the sole loss scale, as the loss scale does not change over time. When
|
||||
dynamic loss scaling is used, is better for this to be a very high number,
|
||||
because a loss scale that is too high gets lowered far more quickly than a
|
||||
loss scale that is too low gets raised.
|
||||
dynamic_growth_steps: With dynamic loss scaling, every
|
||||
`dynamic_growth_steps` steps with finite gradients, the loss scale is
|
||||
doubled. Defaults to 2000. If a nonfinite gradient is encountered, the
|
||||
@ -476,27 +468,33 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
`LossScaleOptimizer.dynamic_counter`. This argument can only be specified
|
||||
if `dynamic` is True.
|
||||
|
||||
To use a fixed loss scale instead of dynamic loss scale, pass `dynamic=False`
|
||||
and pass the loss scale to `initial_scale`. For example:
|
||||
`LossScaleOptimizer` will occasionally skip applying gradients to the
|
||||
variables, in which case the trainable variables will not change that step.
|
||||
This is done because the dynamic loss scale will sometimes be raised too
|
||||
high, causing overflow in the gradients. Typically, the first 2 to 15 steps of
|
||||
the model are skipped as the initial loss scale is very high, but afterwards
|
||||
steps will only be skipped on average 0.05% of the time (the fraction of steps
|
||||
skipped is `1 / dynamic_growth_steps`).
|
||||
|
||||
>>> opt = tf.keras.mixed_precision.LossScaleOptimizer(
|
||||
... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=1024)
|
||||
>>> opt.loss_scale.numpy()
|
||||
1024.
|
||||
`LossScaleOptimizer` delegates all public `Optimizer` methods to the inner
|
||||
optimizer. Additionally, in methods `minimize` and `get_gradients, it scales
|
||||
the loss and unscales the gradients. In methods `minimize` and
|
||||
`apply_gradients`, it additionally updates the loss scale and skips applying
|
||||
gradients if any gradient has a nonfinite value.
|
||||
|
||||
### Hyperparameters
|
||||
|
||||
Hyperparameters can be accessed and set on the LossScaleOptimizer, which will
|
||||
be delegated to the wrapped optimizer.
|
||||
|
||||
>>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5)
|
||||
>>> lso = tf.keras.mixed_precision.LossScaleOptimizer(opt)
|
||||
>>> opt.beta_1
|
||||
>>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)
|
||||
>>> opt.beta_1 # Equivalent to `opt.inner_optimizer.beta_1`
|
||||
0.8
|
||||
>>> lso.beta_1 # Equivalent to `opt.beta_1`
|
||||
0.8
|
||||
>>> lso.beta_1 = 0.7 # Equivalent to `opt.beta_1 = 0.7`
|
||||
>>> opt.beta_1 = 0.7 # Equivalent to `opt.inner_optimizer.beta_1 = 0.7`
|
||||
>>> opt.beta_1
|
||||
0.7
|
||||
>>> lso.beta_1
|
||||
>>> opt.inner_optimizer.beta_1
|
||||
0.7
|
||||
|
||||
However, accessing or setting non-hyperparameters is not delegated to the
|
||||
@ -504,19 +502,19 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
`epsilon` is not, as the Adam optimizer only calls `Optimizer._set_hyper` on
|
||||
`beta_1`.
|
||||
|
||||
>>> opt.epsilon
|
||||
>>> opt.inner_optimizer.epsilon
|
||||
1e-5
|
||||
>>> lso.epsilon
|
||||
>>> opt.epsilon
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: 'LossScaleOptimizer' object has no attribute 'epsilon'
|
||||
>>> lso.epsilon = 1e-4
|
||||
>>> opt.epsilon
|
||||
>>> opt.epsilon = 1e-4 # This does NOT set epsilon on `opt.inner_optimizer`
|
||||
>>> opt.inner_optimizer.epsilon
|
||||
>>> 1e-5
|
||||
|
||||
In the above example, despite epsilon being set on the LossScaleOptimizer, the
|
||||
old epsilon value will still be used when training as epsilon was not set on
|
||||
the Adam optimizer.
|
||||
the inner optimizer.
|
||||
"""
|
||||
|
||||
_HAS_AGGREGATE_GRAD = True
|
||||
@ -562,6 +560,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
|
||||
@property
|
||||
def dynamic(self):
|
||||
"""Bool indicating whether dynamic loss scaling is used."""
|
||||
return isinstance(self._loss_scale, _DynamicLossScaleState)
|
||||
|
||||
@property
|
||||
@ -593,7 +592,8 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2):
|
||||
def initial_scale(self):
|
||||
"""The initial loss scale.
|
||||
|
||||
This is None if `LossScaleOptimizer.dynamic` is False.
|
||||
If `LossScaleOptimizer.dynamic` is False, this is the same number as
|
||||
`LossScaleOptimizer.loss_scale`, as the loss scale never changes.
|
||||
"""
|
||||
if isinstance(self._loss_scale, _DynamicLossScaleState):
|
||||
return self._loss_scale.initial_loss_scale
|
||||
@ -982,6 +982,24 @@ class LossScaleOptimizerV1(LossScaleOptimizer):
|
||||
... dynamic_growth_steps=500)
|
||||
>>> assert opt1.get_config() == opt2.get_config()
|
||||
|
||||
Make sure to also switch from this class to the non-experimental class in
|
||||
isinstance checks, if you have any. If you do not do this, your model may run
|
||||
into hard-to-debug issues, as the experimental `LossScaleOptimizer` subclasses
|
||||
the non-experimental `LossScaleOptimizer`, but not vice versa. It is safe to
|
||||
switch isinstance checks to the non-experimental `LossScaleOptimizer` even
|
||||
before using the non-experimental `LossScaleOptimizer`.
|
||||
|
||||
>>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
|
||||
... tf.keras.optimizers.SGD(), loss_scale='dynamic')
|
||||
>>> # The experimental class subclasses the non-experimental class
|
||||
>>> isinstance(opt1, tf.keras.mixed_precision.LossScaleOptimizer)
|
||||
True
|
||||
>>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer(
|
||||
... tf.keras.optimizers.SGD())
|
||||
>>> # The non-experimental class does NOT subclass the experimental class.
|
||||
>>> isinstance(opt2, tf.keras.mixed_precision.experimental.LossScaleOptimizer)
|
||||
False
|
||||
|
||||
Args:
|
||||
optimizer: The Optimizer instance to wrap.
|
||||
loss_scale: The loss scale to scale the loss and gradients. This can
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
@keras_export('keras.mixed_precision.Policy', v1=[])
|
||||
class Policy(object):
|
||||
"""A dtype policy for a Keras layer.
|
||||
@ -39,106 +40,57 @@ class Policy(object):
|
||||
A dtype policy determines a layer's computation and variable dtypes. Each
|
||||
layer has a policy. Policies can be passed to the `dtype` argument of layer
|
||||
constructors, or a global policy can be set with
|
||||
`tf.keras.mixed_precision.experimental.set_policy`. A layer will default to
|
||||
the global policy if no policy is passed to it's constructor.
|
||||
`tf.keras.mixed_precision.set_global_policy`.
|
||||
|
||||
For many models, each layer's policy will have the same compute dtype and
|
||||
variable dtype, which will typically be float32. In this case, we refer to the
|
||||
singular dtype as the layer's dtype, which can be queried by the property
|
||||
`tf.keras.layers.Layer.dtype`.
|
||||
Args:
|
||||
name: The policy name, which determines the compute and variable dtypes. Can
|
||||
be any dtype name, such as `'float32'` or `'float64'`, which causes both
|
||||
the compute and variable dtypes will be that dtype. Can also be the string
|
||||
`'mixed_float16'` or `'mixed_bfloat16'`, which causes the compute dtype to
|
||||
be float16 or bfloat16 and the variable dtype to be float32.
|
||||
|
||||
When mixed precision training is used, most layers will instead have a float16
|
||||
or bfloat16 compute dtype and a float32 variable dtype, and so the layer does
|
||||
not have a single dtype. When the variable dtype does not match the compute
|
||||
dtype, variables will be automatically casted to the compute dtype to avoid
|
||||
type errors. In this case, `tf.keras.layers.Layer.dtype` refers to the
|
||||
variable dtype, not the compute dtype. See [the mixed precision guide](
|
||||
https://www.tensorflow.org/guide/keras/mixed_precision) for more
|
||||
information on how to use mixed precision.
|
||||
Typically you only need to interact with dtype policies when using mixed
|
||||
precision, which is the use of float16 or bfloat16 for computations and
|
||||
float32 for variables. This is why the term `mixed_precision` appears in the
|
||||
API name. Mixed precision can be enabled by passing `'mixed_float16'` or
|
||||
`'mixed_bfloat16'` to `tf.keras.mixed_precision.set_global_policy`. See [the
|
||||
mixed precision guide](https://www.tensorflow.org/guide/keras/mixed_precision)
|
||||
for more information on how to use mixed precision.
|
||||
|
||||
Policies are constructed by passing a string to the constructor, e.g.
|
||||
`tf.keras.mixed_precision.Policy('float32')`. The string determines the
|
||||
compute and variable dtypes. It can be one of the following:
|
||||
>>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
|
||||
>>> layer1 = tf.keras.layers.Dense(10)
|
||||
>>> layer1.dtype_policy # `layer1` will automatically use mixed precision
|
||||
<Policy "mixed_float16">
|
||||
>>> # Can optionally override layer to use float32 instead of mixed precision.
|
||||
>>> layer2 = tf.keras.layers.Dense(10, dtype='float32')
|
||||
>>> layer2.dtype_policy
|
||||
<Policy "float32">
|
||||
>>> # Set policy back to initial float32 for future examples.
|
||||
>>> tf.keras.mixed_precision.set_global_policy('float32')
|
||||
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||
`tf.keras.Model.compile` will wrap the optimizer with a
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer`. These policies are used for
|
||||
mixed precision training.
|
||||
In the example above, passing `dtype='float32'` to the layer is equivalent to
|
||||
passing `dtype=tf.keras.mixed_precision.Policy('float32')`. In general,
|
||||
passing a dtype to a layer is equivalent to passing the corresponding policy,
|
||||
so it is never necessary to explicitly construct a `Policy` object.
|
||||
|
||||
### How to use mixed precision in a Keras model
|
||||
|
||||
To use mixed precision in a Keras model, the `'mixed_float16'` or
|
||||
`'mixed_bfloat16'` policy can be used.
|
||||
`tf.keras.mixed_precision.experimental.set_policy` can be used to set the
|
||||
default policy for layers if no policy is passed to them. For example:
|
||||
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... # Dense layers use global policy of 'mixed_float16', which does
|
||||
... # computations in float16 while keeping variables in float32.
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Dense(10),
|
||||
... # Softmax should be done in float32 for numeric stability. We pass
|
||||
... # dtype='float32' to use float32 instead of the global policy.
|
||||
... tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
... ])
|
||||
|
||||
Alternatively, the policy can be passed to individual layers instead of
|
||||
setting the global policy with `set_policy`:
|
||||
|
||||
>>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... tf.keras.layers.Dense(10, dtype=policy),
|
||||
... tf.keras.layers.Dense(10, dtype=policy),
|
||||
... # Softmax should be done in float32 for numeric stability.
|
||||
... tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
... ])
|
||||
|
||||
Note the `'mixed_float16'` policy will apply loss scaling by default in
|
||||
`Model.fit`, `Model.train_on_batch`, and other training methods. If no such
|
||||
method is used (e.g., a custom training loop is used) and `'mixed_float16'` is
|
||||
used, the loss scale must be manually applied. See
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer` for details. For
|
||||
`'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be
|
||||
manually applied.
|
||||
|
||||
See [the mixed precision guide](
|
||||
https://www.tensorflow.org/guide/keras/mixed_precision) for more
|
||||
information on using mixed precision
|
||||
|
||||
### How to use float64 in a Keras model
|
||||
|
||||
Using float64 is similar to mixed precision. Either the global policy can be
|
||||
set to float64, or `dtype='float64'` can be passed to individual layers. For
|
||||
example, to set the global policy:
|
||||
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('float64')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... # All layers use global policy of 'float64', which does computations
|
||||
... # and creates variables in float64.
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Activation('softmax')
|
||||
... ])
|
||||
>>> # Optionaly set policy back to float32 if any other models use float32
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
Note: `Model.compile` will automatically wrap an optimizer with a
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer` if you use the `'mixed_float16'`
|
||||
policy. If you use a custom training loop instead of calling `Model.compile`,
|
||||
you should explicitly use a `tf.keras.mixed_precision.LossScaleOptimizer` to
|
||||
avoid numeric underflow with float16.
|
||||
|
||||
### How a layer uses its policy's compute dtype
|
||||
|
||||
A layer will cast its inputs to its compute dtype in TensorFlow 2. For
|
||||
example:
|
||||
A layer casts its inputs to its compute dtype. This causes the layer's
|
||||
computations and output to also be in the compute dtype. For example:
|
||||
|
||||
>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
>>> # `layer`'s policy defaults to float32.
|
||||
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
>>> # `layer` casts it's inputs to its compute dtype, which is float32, and
|
||||
>>> # does computations in float32.
|
||||
>>> layer.compute_dtype # Equivalent to layer.dtype_policy.compute_dtype
|
||||
'float32'
|
||||
>>> # `layer` casts it's inputs to its compute dtype and does computations in
|
||||
>>> # that dtype.
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
@ -147,7 +99,8 @@ class Policy(object):
|
||||
subclassing your own layer, you do not have to insert any casts.
|
||||
|
||||
Currently, only tensors in the first argument to the layer's `call` method are
|
||||
casted. For example:
|
||||
casted (although this will likely be changed in a future minor release). For
|
||||
example:
|
||||
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... # Bug! `b` will not be casted.
|
||||
@ -162,45 +115,13 @@ class Policy(object):
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
|
||||
If writing your own layer, it is recommended to accept tensors only in the
|
||||
first argument. This way, all tensors are casted to the layer's compute dtype.
|
||||
`MyLayer` should therefore be written as:
|
||||
If writing your own layer with multiple inputs, you should either explicitly
|
||||
cast other tensors to `self.compute_dtype` in `call` or accept all tensors in
|
||||
the first argument as a list.
|
||||
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... # Now, all tensor inputs will be casted.
|
||||
... def call(self, inputs):
|
||||
... a, b = inputs
|
||||
... return a + 1., b + 1.
|
||||
>>> a = tf.constant(1., dtype="float32")
|
||||
>>> b = tf.constant(1., dtype="float32")
|
||||
>>> layer = MyLayer(dtype="float64")
|
||||
>>> x, y = layer((a, b))
|
||||
>>> x.dtype
|
||||
tf.float64
|
||||
>>> y.dtype
|
||||
tf.float64
|
||||
|
||||
Other arguments are not automatically casted for technical reasons, but this
|
||||
may change in a future minor release.
|
||||
|
||||
The casting only occurs in TensorFlow 2, but can be enabled if
|
||||
`tf.compat.v1.disable_v2_behavior()` has been called with
|
||||
`tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`.
|
||||
|
||||
A layer subclass can prevent its inputs from being autocasted by passing
|
||||
`autocast=False` to the layer constructor. For example:
|
||||
|
||||
>>> class NonAutoCastingLayer(tf.keras.layers.Layer):
|
||||
... def __init__(self, **kwargs):
|
||||
... kwargs['autocast'] = False
|
||||
... super(NonAutoCastingLayer, self).__init__(**kwargs)
|
||||
... def call(self, inp):
|
||||
... return inp
|
||||
>>> x = tf.ones((4, 4, 4, 4), dtype='float32')
|
||||
>>> layer = NonAutoCastingLayer(dtype='float64')
|
||||
>>> y = layer(x) # Will not cast inputs to it's compute dtype of float64
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
The casting only occurs in TensorFlow 2. If
|
||||
`tf.compat.v1.disable_v2_behavior()` has been called, you can enable the
|
||||
casting behavior with `tf.compat.v1.keras.layers.enable_v2_dtype_behavior()`.
|
||||
|
||||
### How a layer uses its policy's variable dtype
|
||||
|
||||
@ -209,30 +130,33 @@ class Policy(object):
|
||||
|
||||
If a layer's compute and variable dtypes differ, `add_weight` will wrap
|
||||
floating-point variables with a special wrapper called an `AutoCastVariable`.
|
||||
This wrapper is identical to the original variable except it casts itself to
|
||||
the layer's compute dtype when used within `Layer.call`. Outside `Layer.call`,
|
||||
the variable is not casted.
|
||||
`AutoCastVariable` is identical to the original variable except it casts
|
||||
itself to the layer's compute dtype when used within `Layer.call`. This means
|
||||
if you are writing a layer, you do not have to explicitly cast the variables
|
||||
to the layer's compute dtype. For example:
|
||||
|
||||
>>> class SimpleDense(tf.keras.layers.Layer):
|
||||
...
|
||||
... def build(self, input_shape):
|
||||
... # With mixed precision, self.kernel is a float32 AutoCastVariable
|
||||
... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
|
||||
...
|
||||
... def call(self, inputs):
|
||||
... # With mixed precision, self.kernel will be casted to float16
|
||||
... return tf.linalg.matmul(inputs, self.kernel)
|
||||
...
|
||||
>>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
>>> layer = SimpleDense(dtype=dtype_policy)
|
||||
>>> y = layer(tf.ones((10, 10)))
|
||||
>>> y.dtype
|
||||
tf.float16
|
||||
>>> layer.kernel.dtype
|
||||
tf.float32
|
||||
|
||||
A layer author can prevent a variable from being wrapped with an
|
||||
`AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`:
|
||||
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.x = self.add_weight('x')
|
||||
... self.y = self.add_weight('y', experimental_autocast=False)
|
||||
>>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
>>> layer = MyLayer(dtype=policy)
|
||||
>>> layer.build((2, 2))
|
||||
>>> layer.x
|
||||
<AutoCastVariable 'x:0' shape=() dtype=float32 dtype_to_cast_to=float32,
|
||||
numpy=...>
|
||||
>>> layer.y
|
||||
<tf.Variable 'y:0' shape=() dtype=float32, numpy=...>
|
||||
|
||||
Passing `experimental_autocast=False` is useful for layers which may
|
||||
internally do some math in the variable dtype instead of the compute dtype.
|
||||
For example, you may wish to compute variable statistics, such as mean and
|
||||
variance, in the variable dtype.
|
||||
`AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`,
|
||||
which is useful if the float32 value of the variable must be accessed within
|
||||
the layer.
|
||||
|
||||
### How to write a layer that supports mixed precision and float64.
|
||||
|
||||
@ -241,69 +165,33 @@ class Policy(object):
|
||||
automatically casts inputs, creates variables of the correct type, and in the
|
||||
case of mixed precision, wraps variables with `AutoCastVariables`.
|
||||
|
||||
For example, this simple dense layer does not require any additional work to
|
||||
support mixed precision or float64. Keras automatically casts the inputs and
|
||||
variable to the appropriate dtype.
|
||||
|
||||
>>> class MyDense(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
|
||||
... def call(self, inputs):
|
||||
... return tf.matmul(inputs, self.kernel)
|
||||
|
||||
>>> policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
>>> layer = MyDense(dtype=policy)
|
||||
>>> x = np.random.rand(10, 10)
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float16
|
||||
|
||||
The primary case where you need extra work to support mixed precision or
|
||||
float64 is when you create a new tensor, such as with `tf.ones` or
|
||||
`tf.constant`. In such cases, you must create the tensor of the correct dtype.
|
||||
For example, suppose you modify the `MyDense` layer to add a random number to
|
||||
the output using `tf.random.normal`. You must pass the input dtype to
|
||||
`tf.random.normal` to ensure the dtypes match.
|
||||
`tf.random.normal`, In such cases, you must create the tensor of the correct
|
||||
dtype. For example, if you call `tf.random.normal`, you must pass the compute
|
||||
dtype, which is the dtype the inputs have been casted to:
|
||||
|
||||
>>> class MyDense(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
|
||||
>>> class AddRandom(tf.keras.layers.Layer):
|
||||
...
|
||||
... def call(self, inputs):
|
||||
... # We must pass `dtype=inputs.dtype`, otherwise a TypeError may
|
||||
... # occur when adding `inputs` to `rand`.
|
||||
... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
|
||||
... return tf.matmul(inputs, self.kernel) + rand
|
||||
>>>
|
||||
>>> layer = MyDense(dtype=policy)
|
||||
... return inputs + rand
|
||||
|
||||
>>> dtype_policy = tf.keras.mixed_precision.Policy('mixed_float16')
|
||||
>>> layer = AddRandom(dtype=dtype_policy)
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float16
|
||||
|
||||
If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a `TypeError`
|
||||
would have occurred. This is because the dtype defaults to `"float32"`, so the
|
||||
layer would only work if the inputs were float32.
|
||||
If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a
|
||||
`TypeError` would have occurred. This is because the `tf.random.normal`'s
|
||||
dtype defaults to `"float32"`, but the input dtype is float16. You cannot add
|
||||
a float32 tensor with a float16 tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Constructs the policy.
|
||||
|
||||
The `name` argument determines the compute and variable dtype. The compute
|
||||
and variable dtypes can only be specified through `name`, and cannot be
|
||||
specified directly.
|
||||
|
||||
`name` is also used by `tf.keras.Model.compile`. If `name` is
|
||||
`"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the
|
||||
optimizer with a LossScaleOptimizer if it is not already a
|
||||
LossScaleOptimizer.
|
||||
|
||||
Args:
|
||||
name: A string. Can be one of the following values:
|
||||
* Any dtype name, such as 'float32' or 'float64'. Both the variable and
|
||||
compute dtypes will be that dtype.
|
||||
* 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or
|
||||
bfloat16, while the variable dtype is float32. With 'mixed_float16',
|
||||
`tf.keras.Model.compile` will wrap the optimizer with a
|
||||
`tf.keras.mixed_precision.LossScaleOptimizer. These policies are used
|
||||
for mixed precision training.
|
||||
"""
|
||||
if isinstance(name, dtypes.DType):
|
||||
raise TypeError("'name' must be a string, not a DType. "
|
||||
"Instead, pass DType.name. Got: %s" % (name.name,))
|
||||
@ -373,8 +261,10 @@ class Policy(object):
|
||||
`Policy.compute_dtype`, Layers will cast variables to the compute dtype to
|
||||
avoid type errors.
|
||||
|
||||
Variable regularizers are run in the variable dtype, not the compute dtype.
|
||||
|
||||
Returns:
|
||||
The variable dtype of this policy.
|
||||
The variable dtype of this policy, as a string
|
||||
"""
|
||||
return self._variable_dtype
|
||||
|
||||
@ -382,26 +272,27 @@ class Policy(object):
|
||||
def compute_dtype(self):
|
||||
"""The compute dtype of this policy.
|
||||
|
||||
This is the dtype layers will do their computations in.
|
||||
This is the dtype layers will do their computations in. Typically layers
|
||||
output tensors with the compute dtype as well.
|
||||
|
||||
Note that even if the compute dtype is float16 or bfloat16, hardware devices
|
||||
may not do individual adds, multiplies, and other fundamental operations in
|
||||
[b]float16, but instead may do some of them in float32 for numeric
|
||||
float16 or bfloat16, but instead may do some of them in float32 for numeric
|
||||
stability. The compute dtype is the dtype of the inputs and outputs of the
|
||||
TensorFlow ops that the layer executes. Internally, many TensorFlow ops will
|
||||
do certain internal calculations in float32, or some other device-internal
|
||||
intermediate format with higher precision than [b]float16, to increase
|
||||
do certain internal calculations in float32 or some other device-internal
|
||||
intermediate format with higher precision than float16/bfloat16, to increase
|
||||
numeric stability.
|
||||
|
||||
For example, a `tf.keras.layers.Dense` layer, when run on a GPU with a
|
||||
float16 compute dtype, will pass float16 inputs to tf.matmul. But, tf.matmul
|
||||
will do use float32 intermediate math. The performance benefit of float16 is
|
||||
still apparent, due to increased memory bandwidth and the fact modern GPUs
|
||||
have specialized hardware for computing matmuls on float16 while still
|
||||
keeping intermediate computations in float32.
|
||||
float16 compute dtype, will pass float16 inputs to `tf.linalg.matmul`. But,
|
||||
`tf.linalg.matmul` will do use float32 intermediate math. The performance
|
||||
benefit of float16 is still apparent, due to increased memory bandwidth and
|
||||
the fact modern GPUs have specialized hardware for computing matmuls on
|
||||
float16 inputs while still keeping intermediate computations in float32.
|
||||
|
||||
Returns:
|
||||
The compute dtype of this policy.
|
||||
The compute dtype of this policy, as a string.
|
||||
"""
|
||||
return self._compute_dtype
|
||||
|
||||
@ -529,13 +420,18 @@ _global_policy = None
|
||||
@keras_export('keras.mixed_precision.global_policy',
|
||||
'keras.mixed_precision.experimental.global_policy', v1=[])
|
||||
def global_policy():
|
||||
"""Returns the global Policy.
|
||||
"""Returns the global dtype policy.
|
||||
|
||||
The global policy is the default policy used for layers, if no policy is
|
||||
passed to the layer constructor. If no policy has been set with
|
||||
`keras.mixed_precision.experimental.set_policy`, this will return a policy
|
||||
The global policy is the default `tf.keras.mixed_precision.Policy` used for
|
||||
layers, if no policy is passed to the layer constructor. If no policy has been
|
||||
set with `keras.mixed_precision.set_global_policy`, this will return a policy
|
||||
constructed from `tf.keras.backend.floatx()` (floatx defaults to float32).
|
||||
|
||||
>>> tf.keras.mixed_precision.global_policy()
|
||||
<Policy "float32">
|
||||
>>> tf.keras.layers.Dense(10).dtype_policy # Defaults to the global policy
|
||||
<Policy "float32">
|
||||
|
||||
If TensorFlow 2 behavior has been disabled with
|
||||
`tf.compat.v1.disable_v2_behavior()`, this will instead return a special
|
||||
"_infer" policy which infers the dtype from the dtype of the first input the
|
||||
@ -574,11 +470,27 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
|
||||
@keras_export('keras.mixed_precision.set_global_policy',
|
||||
'keras.mixed_precision.experimental.set_policy', v1=[])
|
||||
def set_policy(policy):
|
||||
"""Sets the global Policy.
|
||||
"""Sets the global dtype policy.
|
||||
|
||||
The global policy is the default policy used for layers, if no policy is
|
||||
passed to the layer constructor. If no global policy is set, layers will
|
||||
instead default to a Policy constructed from `tf.keras.backend.floatx()`.
|
||||
The global policy is the default `tf.keras.mixed_precision.Policy` used for
|
||||
layers, if no policy is passed to the layer constructor.
|
||||
|
||||
>>> tf.keras.mixed_precision.set_global_policy('mixed_float16')
|
||||
>>> tf.keras.mixed_precision.global_policy()
|
||||
<Policy "mixed_float16">
|
||||
>>> tf.keras.layers.Dense(10).dtype_policy
|
||||
<Policy "mixed_float16">
|
||||
>>> # Global policy is not used if a policy is directly passed to constructor
|
||||
>>> tf.keras.layers.Dense(10, dtype='float64').dtype_policy
|
||||
<Policy "float64">
|
||||
>>> tf.keras.mixed_precision.set_global_policy('float32')
|
||||
|
||||
If no global policy is set, layers will instead default to a Policy
|
||||
constructed from `tf.keras.backend.floatx()`.
|
||||
|
||||
To use mixed precision, the global policy should be set to `'mixed_float16'`
|
||||
or `'mixed_bfloat16'`, so that every layer uses a 16-bit compute dtype and
|
||||
float32 variable dtype by default.
|
||||
|
||||
Only floating point policies can be set as the global policy, such as
|
||||
`'float32'` and `'mixed_float16'`. Non-floating point policies such as
|
||||
@ -588,7 +500,9 @@ def set_policy(policy):
|
||||
See `tf.keras.mixed_precision.Policy` for more information.
|
||||
|
||||
Args:
|
||||
policy: A Policy, or a string that will be converted to a Policy..
|
||||
policy: A Policy, or a string that will be converted to a Policy. Can also
|
||||
be None, in which case the global policy will be constructed from
|
||||
`tf.keras.backend.floatx()`
|
||||
"""
|
||||
global _global_policy
|
||||
if not base_layer_utils.v2_dtype_behavior_enabled():
|
||||
|
Loading…
Reference in New Issue
Block a user