Various mixed precision doc improvements.
The primary changes are: 1. Use the >>> format for code samples. But I couldn't change the code in mixed_precision.py to use the >>> format, since an error is thrown if both the Keras mixed precision API and the graph rewrite API are used in the same process. 2. In the enable_mixed_precision_graph_rewrite docstring, compare the graph rewrite API with the Keras mixed precision API. 3. In the error message that occurs if you try to use both the graph rewrite and the Keras MP API, recommend the Keras API instead of the graph rewrite. 4. Add new sections to Policy docstring: "How a layer users its policy's variable dtype" and "How to write a layer that supports mixed precision and float64". The latter section might be more appropriate for the mixed precision guide, but the guide is already very long and its hard to fit in. PiperOrigin-RevId: 298419101 Change-Id: I1936e1bcce401d02d4fdeb36b9866eb9cd40c421
This commit is contained in:
parent
27348ebd06
commit
f0756e1c19
tensorflow/python
@ -37,9 +37,8 @@ def epsilon():
|
||||
A float.
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.epsilon() >>>1e-07
|
||||
```
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-07
|
||||
"""
|
||||
return _EPSILON
|
||||
|
||||
@ -50,8 +49,14 @@ def set_epsilon(value):
|
||||
|
||||
Arguments:
|
||||
value: float. New value of epsilon.
|
||||
Example: ```python from keras import backend as K K.epsilon() >>> 1e-07
|
||||
K.set_epsilon(1e-05) K.epsilon() >>> 1e-05 ```
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-07
|
||||
>>> tf.keras.backend.set_epsilon(1e-5)
|
||||
>>> tf.keras.backend.epsilon()
|
||||
1e-05
|
||||
>>> tf.keras.backend.set_epsilon(1e-7)
|
||||
"""
|
||||
global _EPSILON
|
||||
_EPSILON = value
|
||||
@ -61,15 +66,14 @@ def set_epsilon(value):
|
||||
def floatx():
|
||||
"""Returns the default float type, as a string.
|
||||
|
||||
E.g. 'float16', 'float32', 'float64'.
|
||||
E.g. `'float16'`, `'float32'`, `'float64'`.
|
||||
|
||||
Returns:
|
||||
String, the current default float type.
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.floatx() >>> 'float32'
|
||||
```
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float32'
|
||||
"""
|
||||
return _FLOATX
|
||||
|
||||
@ -78,10 +82,23 @@ def floatx():
|
||||
def set_floatx(value):
|
||||
"""Sets the default float type.
|
||||
|
||||
Note: It is not recommended to set this to float16 for training, as this will
|
||||
likely cause numeric stability issues. Instead, mixed precision, which is
|
||||
using a mix of float16 and float32, can be used by calling
|
||||
`tf.keras.mixed_precision.experimental.set_policy('mixed_float16')`. See the
|
||||
[mixed precision
|
||||
guide](https://www.tensorflow.org/guide/keras/mixed_precision) for details.
|
||||
|
||||
Arguments:
|
||||
value: String; 'float16', 'float32', or 'float64'.
|
||||
Example: ```python from keras import backend as K K.floatx() >>> 'float32'
|
||||
K.set_floatx('float16') K.floatx() >>> 'float16' ```
|
||||
value: String; `'float16'`, `'float32'`, or `'float64'`.
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float32'
|
||||
>>> tf.keras.backend.set_floatx('float64')
|
||||
>>> tf.keras.backend.floatx()
|
||||
'float64'
|
||||
>>> tf.keras.backend.set_floatx('float32')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid value.
|
||||
@ -100,9 +117,8 @@ def image_data_format():
|
||||
A string, either `'channels_first'` or `'channels_last'`
|
||||
|
||||
Example:
|
||||
```python
|
||||
keras.backend.image_data_format() >>> 'channels_first'
|
||||
```
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_last'
|
||||
"""
|
||||
return _IMAGE_DATA_FORMAT
|
||||
|
||||
@ -113,9 +129,14 @@ def set_image_data_format(data_format):
|
||||
|
||||
Arguments:
|
||||
data_format: string. `'channels_first'` or `'channels_last'`.
|
||||
Example: ```python from keras import backend as K K.image_data_format() >>>
|
||||
'channels_first' K.set_image_data_format('channels_last')
|
||||
K.image_data_format() >>> 'channels_last' ```
|
||||
|
||||
Example:
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_last'
|
||||
>>> tf.keras.backend.set_image_data_format('channels_first')
|
||||
>>> tf.keras.backend.image_data_format()
|
||||
'channels_first'
|
||||
>>> tf.keras.backend.set_image_data_format('channels_last')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `data_format` value.
|
||||
|
24
tensorflow/python/keras/mixed_precision/__init__.py
Normal file
24
tensorflow/python/keras/mixed_precision/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Keras mixed precision API.
|
||||
|
||||
See [the mixed precision
|
||||
guide](https://www.tensorflow.org/guide/keras/mixed_precision) to learn how to
|
||||
use the API.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -28,21 +28,23 @@ from tensorflow.python.ops import variables
|
||||
class AutoCastVariable(variables.Variable):
|
||||
"""Variable that will cast itself to a different dtype in applicable contexts.
|
||||
|
||||
This class wraps a floating-point tf.Variable. It emulates the variable
|
||||
This class wraps a floating-point `tf.Variable`. It emulates the variable
|
||||
interface and delegates to the wrapped variable, but it additionally will cast
|
||||
the wrapped variable under a `Graph._enable_variable_auto_cast(dtype)` context
|
||||
manager.
|
||||
the wrapped variable under a `Graph._enable_auto_casting_variables(dtype)`
|
||||
context manager.
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
v = tf.Variable(1.0, dtype=tf.float32)
|
||||
v = AutoCastVariable(v)
|
||||
print(tf.identity(v).dtype) # tf.float32
|
||||
with ops.get_default_graph()._enable_variable_auto_cast(tf.float16):
|
||||
print(tf.identity(v).dtype) # tf.float16, as v will cast itself to float16
|
||||
print(v.dtype) # tf.float16, as v.dtype also changes under the ctx manager.
|
||||
```
|
||||
>>> v = tf.Variable(1.0, dtype=tf.float32)
|
||||
>>> v = AutoCastVariable(v)
|
||||
>>> tf.identity(v).dtype
|
||||
tf.float32
|
||||
>>> with ops.get_default_graph()._enable_auto_casting_variables(tf.float16):
|
||||
... tf.identity(v).dtype
|
||||
tf.float16
|
||||
>>> with ops.get_default_graph()._enable_auto_casting_variables(tf.float16):
|
||||
... v.dtype # v.dtype also changes under the context manager
|
||||
tf.float16
|
||||
|
||||
The purpose of this class is to allow Keras layers to create variables in
|
||||
float32, and automatically cast them to float16 or bfloat16 when the layer is
|
||||
|
@ -76,12 +76,15 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
updated via `LossScale.update()` whenever gradients are applied, either
|
||||
through `minimize()` or `apply_gradients()`. For example:
|
||||
|
||||
```python
|
||||
opt = tf.keras.optimizers.SGD(0.1)
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||
# 'minimize' applies loss scaling to the loss and updates the loss sale.
|
||||
opt.minimize(loss_fn)
|
||||
```
|
||||
>>> opt = tf.keras.optimizers.SGD(0.25)
|
||||
>>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt,
|
||||
... "dynamic")
|
||||
>>> var = tf.Variable(1.)
|
||||
>>> loss_fn = lambda: var ** 2
|
||||
>>> # 'minimize' applies loss scaling to the loss and updates the loss sale.
|
||||
>>> opt.minimize(loss_fn, var_list=var)
|
||||
>>> var.numpy()
|
||||
0.5
|
||||
|
||||
If a `tf.GradientTape` is used to compute gradients instead of
|
||||
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
|
||||
@ -90,16 +93,14 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
|
||||
computing the gradients with `tf.GradientTape`. For example:
|
||||
|
||||
```python
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
||||
vars = ...
|
||||
with tf.GradientTape() as tape:
|
||||
loss = ...
|
||||
scaled_loss = opt.get_scaled_loss(loss)
|
||||
scaled_grads = tape.gradient(scaled_loss, vars)
|
||||
grads = opt.get_unscaled_gradients(scaled_grads)
|
||||
opt.apply_gradients(zip(grads, vars)) # Loss scale will be updated here
|
||||
```
|
||||
>>> with tf.GradientTape() as tape:
|
||||
... loss = loss_fn()
|
||||
... scaled_loss = opt.get_scaled_loss(loss)
|
||||
>>> scaled_grad = tape.gradient(scaled_loss, var)
|
||||
>>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
|
||||
>>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here
|
||||
>>> var.numpy()
|
||||
0.25
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, loss_scale):
|
||||
|
@ -41,6 +41,10 @@ from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.tracking import util as trackable_utils
|
||||
|
||||
# Disable not-callable lint error, as the linter is unable to detect that
|
||||
# LossScale instances are callable.
|
||||
# pylint: disable=not-callable
|
||||
|
||||
|
||||
# If called outside any strategy.scope() calls, this will return the default
|
||||
# strategy.
|
||||
|
@ -54,13 +54,12 @@ class Policy(object):
|
||||
|
||||
When mixed precision training is used, most layers will instead have a float16
|
||||
or bfloat16 compute dtype and a float32 variable dtype, and so the layer does
|
||||
not have a single dtype. See [this
|
||||
link](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html)
|
||||
for more information on mixed precision training. When the variable dtype does
|
||||
not match the compute dtype, variables will be automatically casted to the
|
||||
compute dtype to avoid type errors. In this case,
|
||||
`tf.keras.layers.Layer.dtype` refers to the variable dtype, not the compute
|
||||
dtype.
|
||||
not have a single dtype. When the variable dtype does not match the compute
|
||||
dtype, variables will be automatically casted to the compute dtype to avoid
|
||||
type errors. In this case, `tf.keras.layers.Layer.dtype` refers to the
|
||||
variable dtype, not the compute dtype. See [the mixed precision
|
||||
guide](https://www.tensorflow.org/guide/keras/mixed_precision) for more
|
||||
information on how to use mixed precision.
|
||||
|
||||
Certain policies also have a `tf.mixed_precision.experimental.LossScale`
|
||||
instance, which is used by `tf.keras.Model`s to performance loss scaling. Loss
|
||||
@ -88,37 +87,29 @@ class Policy(object):
|
||||
`tf.keras.mixed_precision.experimental.set_policy` can be used to set the
|
||||
default policy for layers if no policy is passed to them. For example:
|
||||
|
||||
```python
|
||||
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Input((100,)),
|
||||
# Dense layers use global policy of 'mixed_float16', which does
|
||||
# computations in float16 while keeping variables in float32.
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(10),
|
||||
# Softmax should be done in float32 for numeric stability. We pass
|
||||
# dtype='float32' to use float32 instead of the global policy.
|
||||
tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
])
|
||||
model.compile(...)
|
||||
model.fit(...) # Train `model`
|
||||
```
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... # Dense layers use global policy of 'mixed_float16', which does
|
||||
... # computations in float16 while keeping variables in float32.
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Dense(10),
|
||||
... # Softmax should be done in float32 for numeric stability. We pass
|
||||
... # dtype='float32' to use float32 instead of the global policy.
|
||||
... tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
... ])
|
||||
|
||||
Alternatively, the policy can be passed to individual layers instead of
|
||||
setting the global policy with `set_policy`:
|
||||
|
||||
```python
|
||||
policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Input((100,)),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
tf.keras.layers.Dense(10, dtype=policy),
|
||||
# Softmax should be done in float32 for numeric stability.
|
||||
tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
])
|
||||
model.compile(...)
|
||||
model.fit(...) # Train `model`
|
||||
```
|
||||
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... tf.keras.layers.Dense(10, dtype=policy),
|
||||
... tf.keras.layers.Dense(10, dtype=policy),
|
||||
... # Softmax should be done in float32 for numeric stability.
|
||||
... tf.keras.layers.Activation('softmax', dtype='float32')
|
||||
... ])
|
||||
|
||||
Note the `'mixed_float16'` policy will apply loss scaling by default in
|
||||
`Model.fit`, `Model.train_on_batch`, and other training methods. If no such
|
||||
@ -128,79 +119,78 @@ class Policy(object):
|
||||
`'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be
|
||||
manually applied.
|
||||
|
||||
See [the mixed precision
|
||||
guide](https://www.tensorflow.org/guide/keras/mixed_precision) for more
|
||||
information on using mixed precision
|
||||
|
||||
### How to use float64 in a Keras model
|
||||
|
||||
Using float64 is similar to mixed precision. Either the global policy can be
|
||||
set to float64, or `dtype='float64'` can be passed to individual layers. For
|
||||
example, to set the global policy:
|
||||
|
||||
```python
|
||||
tf.keras.mixed_precision.experimental.set_policy('float64')
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Input((100,)),
|
||||
# All layers use global policy of 'float64', which does computations and
|
||||
# creates variables in float64.
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Dense(10),
|
||||
tf.keras.layers.Activation('softmax')
|
||||
])
|
||||
model.compile(...)
|
||||
model.fit(...) # Train `model`
|
||||
```
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('float64')
|
||||
>>> model = tf.keras.models.Sequential([
|
||||
... tf.keras.layers.Input((100,)),
|
||||
... # All layers use global policy of 'float64', which does computations
|
||||
... # and creates variables in float64.
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Dense(10),
|
||||
... tf.keras.layers.Activation('softmax')
|
||||
... ])
|
||||
>>> # Optionaly set policy back to float32 if any other models use float32
|
||||
>>> tf.keras.mixed_precision.experimental.set_policy('float32')
|
||||
|
||||
### How a layer uses its policy's compute dtype
|
||||
|
||||
A layer will cast its inputs to its compute dtype in TensorFlow 2. For
|
||||
example:
|
||||
|
||||
```python
|
||||
x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
# `layer`'s policy defaults to float32.
|
||||
layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
>>> x = tf.ones((4, 4, 4, 4), dtype='float64')
|
||||
>>> # `layer`'s policy defaults to float32.
|
||||
>>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2)
|
||||
>>> # `layer` casts it's inputs to its compute dtype, which is float32, and
|
||||
>>> # does computations in float32.
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
|
||||
# `layer` casts it's inputs to its compute dtype, which is float32, and does
|
||||
# computations in float32.
|
||||
y = layer(x)
|
||||
print(y.dtype) # float32
|
||||
```
|
||||
Note that the base `tf.keras.layers.Layer` class inserts the casts. If
|
||||
subclassing your own layer, you do not have to insert any casts.
|
||||
|
||||
Currently, only tensors in the first argument to the layer's `call` method are
|
||||
casted. For example:
|
||||
|
||||
```python
|
||||
class MyLayer(tf.keras.layers.Layer):
|
||||
# Bug! `b` will not be casted.
|
||||
def call(self, a, b):
|
||||
return a + 1., b + 1.
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... # Bug! `b` will not be casted.
|
||||
... def call(self, a, b):
|
||||
... return a + 1., b + 1.
|
||||
>>> a = tf.constant(1., dtype="float32")
|
||||
>>> b = tf.constant(1., dtype="float32")
|
||||
>>> layer = MyLayer(dtype="float64")
|
||||
>>> x, y = layer(a, b)
|
||||
>>> x.dtype
|
||||
tf.float64
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
|
||||
a = tf.constant(1., dtype="float32")
|
||||
b = tf.constant(1., dtype="float32")
|
||||
If writing your own layer, it is recommended to accept tensors only in the
|
||||
first argument. This way, all tensors are casted to the layer's compute dtype.
|
||||
`MyLayer` should therefore be written as:
|
||||
|
||||
layer = MyLayer(dtype="float64")
|
||||
x, y = layer(a, b)
|
||||
print(x.dtype) # float64
|
||||
print(y.dtype) # float32. Not casted since `b` was not passed to first input
|
||||
```
|
||||
|
||||
It is recommended to accept tensors only in the first argument. This way, all
|
||||
tensors are casted to the layer's compute dtype. `MyLayer` should therefore be
|
||||
written as:
|
||||
|
||||
```python
|
||||
class MyLayer(tf.keras.layers.Layer):
|
||||
# Now, all tensor inputs will be casted.
|
||||
def call(self, inputs):
|
||||
a, b = inputs
|
||||
return a + 1., b + 1.
|
||||
|
||||
a = tf.constant(1., dtype="float32")
|
||||
b = tf.constant(1., dtype="float32")
|
||||
|
||||
layer = MyLayer(dtype="float64")
|
||||
x, y = layer((a, b))
|
||||
print(x.dtype) # float64
|
||||
print(y.dtype) # float64.
|
||||
```
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... # Now, all tensor inputs will be casted.
|
||||
... def call(self, inputs):
|
||||
... a, b = inputs
|
||||
... return a + 1., b + 1.
|
||||
>>> a = tf.constant(1., dtype="float32")
|
||||
>>> b = tf.constant(1., dtype="float32")
|
||||
>>> layer = MyLayer(dtype="float64")
|
||||
>>> x, y = layer((a, b))
|
||||
>>> x.dtype
|
||||
tf.float64
|
||||
>>> y.dtype
|
||||
tf.float64
|
||||
|
||||
Other arguments are not automatically casted for technical reasons, but this
|
||||
may change in a future minor release.
|
||||
@ -208,21 +198,95 @@ class Policy(object):
|
||||
A layer subclass can prevent its inputs from being autocasted by passing
|
||||
`autocast=False` to the layer constructor. For example:
|
||||
|
||||
```python
|
||||
class NonAutoCastingLayer(tf.keras.layers.Layer):
|
||||
>>> class NonAutoCastingLayer(tf.keras.layers.Layer):
|
||||
... def __init__(self, **kwargs):
|
||||
... kwargs['autocast'] = False
|
||||
... super(NonAutoCastingLayer, self).__init__(**kwargs)
|
||||
... def call(self, inp):
|
||||
... return inp
|
||||
>>> x = tf.ones((4, 4, 4, 4), dtype='float32')
|
||||
>>> layer = NonAutoCastingLayer(dtype='float64')
|
||||
>>> y = layer(x) # Will not cast inputs to it's compute dtype of float64
|
||||
>>> y.dtype
|
||||
tf.float32
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['autocast'] = False
|
||||
super(NonAutoCastingLayer, self).__init__(**kwargs)
|
||||
### How a layer uses its policy's variable dtype
|
||||
|
||||
def call(self, inp):
|
||||
return inp
|
||||
The default dtype of variables created by `tf.keras.layers.Layer.add_weight`
|
||||
is the layer's policy's variable dtype.
|
||||
|
||||
x = tf.ones((4, 4, 4, 4), dtype='float32')
|
||||
layer = NonAutoCastingLayer(dtype='float64')
|
||||
y = layer(x) # MyLayer will not cast inputs to it's compute dtype of float32
|
||||
print(y.dtype) # float32
|
||||
```
|
||||
If a layer's compute and variable dtypes differ, `add_weight` will wrap
|
||||
floating-point variables with a special wrapper called an `AutoCastVariable`.
|
||||
This wrapper is identical to the original variable except it casts itself to
|
||||
the layer's compute dtype when used within `Layer.call`. Outside `Layer.call`,
|
||||
the variable is not casted.
|
||||
|
||||
A layer author can prevent a variable from being wrapped with an
|
||||
`AutoCastVariable` by passing `experimental_autocast=False` to `add_weight`:
|
||||
|
||||
>>> class MyLayer(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.x = self.add_weight('x')
|
||||
... self.y = self.add_weight('y', experimental_autocast=False)
|
||||
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
>>> layer = MyLayer(dtype=policy)
|
||||
>>> layer.build((2, 2))
|
||||
>>> layer.x
|
||||
<AutoCastVariable 'x:0' shape=() dtype=float32 true_dtype=float32, numpy=...>
|
||||
>>> layer.y
|
||||
<tf.Variable 'y:0' shape=() dtype=float32, numpy=...>
|
||||
|
||||
Passing `experimental_autocast=False` is useful for layers which may
|
||||
internally do some math in the variable dtype instead of the compute dtype.
|
||||
For example, you may wish to compute variable statistics, such as mean and
|
||||
variance, in the variable dtype.
|
||||
|
||||
### How to write a layer that supports mixed precision and float64.
|
||||
|
||||
For the most part, layers will automatically support mixed precision and
|
||||
float64 without any additional work, due to the fact the base layer
|
||||
automatically casts inputs, creates variables of the correct type, and in the
|
||||
case of mixed precision, wraps variables with `AutoCastVariables`.
|
||||
|
||||
For example, this simple dense layer does not require any additional work to
|
||||
support mixed precision or float64. Keras automatically casts the inputs and
|
||||
variable to the appropriate dtype.
|
||||
|
||||
>>> class MyDense(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
|
||||
... def call(self, inputs):
|
||||
... return tf.matmul(inputs, self.kernel)
|
||||
|
||||
>>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
|
||||
>>> layer = MyDense(dtype=policy)
|
||||
>>> x = np.random.rand(10, 10)
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float16
|
||||
|
||||
The primary case where you need extra work to support mixed precision or
|
||||
float64 is when you create a new tensor, such as with `tf.ones` or
|
||||
`tf.constant`. In such cases, you must create the tensor of the correct dtype.
|
||||
For example, suppose you modify the `MyDense` layer to add a random number to
|
||||
the output using `tf.random.normal`. You must pass the input dtype to
|
||||
`tf.random.normal` to ensure the dtypes match.
|
||||
|
||||
>>> class MyDense(tf.keras.layers.Layer):
|
||||
... def build(self, input_shape):
|
||||
... self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
|
||||
... def call(self, inputs):
|
||||
... rand = tf.random.normal(shape=inputs.shape, dtype=inputs.dtype)
|
||||
... return tf.matmul(inputs, self.kernel) + rand
|
||||
>>>
|
||||
>>> layer = MyDense(dtype=policy)
|
||||
>>> y = layer(x)
|
||||
>>> y.dtype
|
||||
tf.float16
|
||||
|
||||
If you did not pass `dtype=inputs.dtype` to `tf.random.normal`, a `TypeError`
|
||||
would have occurred. This is because the dtype defaults to `"float32"`, so the
|
||||
layer would only work if the inputs were float32.
|
||||
|
||||
### The deprecated "infer" policy
|
||||
|
||||
@ -234,8 +298,6 @@ class Policy(object):
|
||||
|
||||
In TensorFlow 1, only the "infer" policy is available.
|
||||
"""
|
||||
# TODO(reedwm): Replace link in above docstring with a version that is more
|
||||
# TensorFlow-specific, and that also mentions bfloat16.
|
||||
|
||||
def __init__(self, name, loss_scale=USE_DEFAULT):
|
||||
"""Constructs the policy.
|
||||
@ -463,10 +525,9 @@ def _check_if_mixed_precision_graph_rewrite_is_enabled():
|
||||
' 2. tf.keras.mixed_precision.experimental.set_policy() (You called '
|
||||
'this second)\n\n'
|
||||
'You called both functions, which is an error, because both functions '
|
||||
'enable you to use mixed precision. The first function enables mixed '
|
||||
'precision in the graph with a graph rewrite. However it is currently '
|
||||
'not very customizable, and does not support eager. The second '
|
||||
'function is for Keras layers, but is not yet fully complete.')
|
||||
'enable you to use mixed precision. If in doubt which function to use, '
|
||||
'use the second, as it supports Eager execution and is more '
|
||||
'customizable.')
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.set_policy')
|
||||
|
@ -41,7 +41,16 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
@tf_export('mixed_precision.experimental.LossScale',
|
||||
'train.experimental.LossScale')
|
||||
class LossScale(trackable.Trackable):
|
||||
"""Loss scale base class.
|
||||
"""Base class for all loss scales.
|
||||
|
||||
This is an abstract base class, so you cannot instantiate it directly.
|
||||
Instead, use one of its concrete subclasses:
|
||||
* `tf.mixed_precision.experimental.DynamicLossScale` (recommended)
|
||||
* `tf.mixed_precision.experimental.FixedLossScale`
|
||||
|
||||
It's recommended to use a loss scale with a
|
||||
`tf.keras.mixed_precision.experimental.LossScaleOptimizer`, as its easier than
|
||||
using a loss scale directly.
|
||||
|
||||
Loss scaling is a process that multiplies the loss by a multiplier called the
|
||||
loss scale, and divides each gradient by the same multiplier. The pseudocode
|
||||
@ -63,6 +72,10 @@ class LossScale(trackable.Trackable):
|
||||
class returns the loss scale as a scalar float32 tensor, while method
|
||||
`update()` updates the loss scale depending on the values of the gradients.
|
||||
Optimizers use instances of this class to scale loss and gradients.
|
||||
|
||||
In most functions that accept a LossScale, you can also pass an int (such as
|
||||
8) to create a `FixedLossScale` or the string `"dynamic"` to create a dynamic
|
||||
loss scale.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
@ -84,7 +84,7 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
operation and a loss-scale optimizer.
|
||||
|
||||
Performing arithmetic operations in float16 takes advantage of specialized
|
||||
processing units, such as NVIDIA Tensor Cores for much higher arithmetic
|
||||
processing units, such as NVIDIA Tensor Cores, for much higher arithmetic
|
||||
throughput. However, due to the smaller representable range, performing the
|
||||
entire training with float16 can result in gradient underflow, that is, small
|
||||
gradient values becoming zeroes. Instead, performing only select arithmetic
|
||||
@ -105,39 +105,27 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
|
||||
```python
|
||||
model = tf.keras.models.Sequential([
|
||||
...
|
||||
tf.keras.layers.Dense(64, activation='relu'),
|
||||
tf.keras.layers.Dense(64, activation='softmax'),
|
||||
])
|
||||
|
||||
opt = tf.keras.optimizers.SGD()
|
||||
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
|
||||
model.compile(loss="mse", optimizer=opt)
|
||||
|
||||
model.compile(loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit(x_train, y_train,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs)
|
||||
x_train = np.random.random((1024, 64))
|
||||
y_train = np.random.random((1024, 64))
|
||||
model.fit(x_train, y_train)
|
||||
```
|
||||
|
||||
For a complete example showing the speed-up on training an image
|
||||
classification task on CIFAR10, check out this
|
||||
<a href="https://colab.research.google.com/github/NVIDIA/
|
||||
DeepLearningExamples/blob/master/TensorFlow/docs/amp/notebook_v1.14/
|
||||
auto_mixed_precision_demo_cifar10.ipynb">Colab notebook</a>.
|
||||
|
||||
Calling `enable_mixed_precision_graph_rewrite(opt)` enables the graph rewrite
|
||||
operation before computing gradients. The function additionally returns an
|
||||
`Optimizer`(`opt`) wrapped with a `LossScaleOptimizer`. This prevents
|
||||
`Optimizer` (`opt`) wrapped with a `LossScaleOptimizer`. This prevents
|
||||
underflow in the float16 tensors during the backward pass. An optimizer of
|
||||
type `tf.train.Optimizer` or `tf.keras.optimizers.Optimizer` must be passed
|
||||
to this function, which will then be wrapped to use loss scaling.
|
||||
type `tf.keras.optimizers.Optimizer` or `tf.compat.v1.train.Optimizer` must be
|
||||
passed to this function, which will then be wrapped to use loss scaling.
|
||||
|
||||
<img src="
|
||||
http://developer.download.nvidia.com/compute/machine-learning/frameworks/
|
||||
TF_mixed_precision_training.png" width="500px">
|
||||
|
||||
The graph rewrite operation changes the `dtype` of certain operations in the
|
||||
The graph rewrite operation changes the dtype of certain operations in the
|
||||
graph from float32 to float16. There are several categories of operations
|
||||
that are either included or excluded by this rewrite operation. The following
|
||||
categories of Ops are defined inside corresponding functions under the class
|
||||
@ -155,17 +143,19 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
* `GrayList`: Ops that are considered numerically safe for execution in
|
||||
float16 unless downstream from a BlackList Op. E.g. `Add` and `AvgPool`.
|
||||
|
||||
When this function is used, gradients should only be computed and applied
|
||||
with the returned optimizer, either by calling `opt.minimize()` or
|
||||
`opt.compute_gradients()` followed by `opt.apply_gradients()`.
|
||||
Gradients should not be computed with `tf.gradients` or `tf.GradientTape`.
|
||||
This is because the returned optimizer will apply loss scaling, and
|
||||
`tf.gradients` or `tf.GradientTape` will not. If you do directly use
|
||||
`tf.gradients` or `tf.GradientTape`, your model may not converge due to
|
||||
float16 underflow problems.
|
||||
When this function is used, gradients should be computed and applied with the
|
||||
returned optimizer, either by calling `opt.minimize()` or
|
||||
`opt.compute_gradients()` followed by `opt.apply_gradients()`. If gradients
|
||||
are instead computed with `tf.gradients` or `tf.GradientTape`, loss scaling
|
||||
will not be applied, which will likely cause your model not to converge due to
|
||||
float16 underflow problems. To apply lossing scaling with `tf.gradients` or
|
||||
`tf.GradientTape`, `LossScaleOptimizer.get_scaled_loss` and
|
||||
`LossScaleOptimizer.get_unscaled_gradients`. See
|
||||
`keras.mixed_precision.experimental.LossScaleOptimizer` for details how to do
|
||||
this.
|
||||
|
||||
When eager execution is enabled, the mixed precision graph rewrite is only
|
||||
enabled within `tf.function`, as outside `tf.function`, there is no graph.
|
||||
enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
|
||||
|
||||
For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions
|
||||
(such as batch size, input size, output size, and channel counts)
|
||||
@ -176,16 +166,45 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
|
||||
Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with
|
||||
Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The
|
||||
parts of the graph on CPUs and TPUs are untouched by the graph rewrite. TPU
|
||||
support is coming soon. CPUs are not supported, as CPUs do not run float16
|
||||
operations faster than float32 operations.
|
||||
parts of the graph on CPUs and TPUs are untouched by the graph rewrite.
|
||||
|
||||
## Comparison with the Keras mixed precision API
|
||||
Both this function and the [Keras mixed precision
|
||||
API](https://www.tensorflow.org/guide/keras/mixed_precision) enable the use of
|
||||
mixed precision in a model. Therefore, only one of the two APIs can be used.
|
||||
We recommend using the Keras mixed precision API, as it is more customizable
|
||||
and supports Eager execution. However, it only supports models which use Keras
|
||||
layers, while the graph rewrite works in any model that uses `tf.function`s.
|
||||
|
||||
The core difference between the two APIs is that this function is a graph
|
||||
rewrite, and so it changes the graph to use mixed precision under the hood.
|
||||
You still build your graph in float32, and the graph rewrite will change
|
||||
certain ops to float16. The Keras mixed precision API directly builds the
|
||||
Keras Model using a mix of float16 and float32.
|
||||
|
||||
One core advantage of the Keras API is it supports mixed precision with Eager
|
||||
execution, i.e. mixed precision outside `tf.function`s. The graph rewrite will
|
||||
only affect ops within `tf.function`s, making it harder to debug if issues
|
||||
occur with mixed precision. The Keras API is also more customizable, as you
|
||||
can override any layer to run in float32 by passing `dtype="float32"` to the
|
||||
layer constructor. Additionally, you can query the dtype of tensors in the
|
||||
model by checking `tensor.dtype`. With the graph rewrite, all tensors appear
|
||||
to be float32 since the dtype is only changed under the hood.
|
||||
|
||||
The main advantage of the graph rewrite (this function) is that it works even
|
||||
if you do not use Keras layers or any other part of Keras. The Keras mixed
|
||||
precision API requires models which use Keras layers, as it only inserts casts
|
||||
inside Keras layers and models. Another advantage is that the graph rewrite
|
||||
never results in a TypeError, which the Keras API may introduce if you do
|
||||
certain operations outside Keras. For example, the following will result in a
|
||||
TypeError if the Keras mixed precision API is enabled, as a float16 and
|
||||
float32 tensor will be added:
|
||||
`tf.keras.layers.Dense(2)(x) + tf.keras.layers.Dense(2, dtype="float32")(x)`
|
||||
|
||||
Raises:
|
||||
`ValueError` when
|
||||
`mixed_precision_global_state.using_default_mixed_precision_policy`
|
||||
is set to `False` before
|
||||
`tf.train.experimental.enable_mixed_precision_graph_rewrite()`
|
||||
is called.
|
||||
`ValueError`, if the `tf.keras.mixed_precision` API is also used by calling
|
||||
`tf.keras.mixed_precision.experimental.set_policy`. Only one mixed precision
|
||||
API can be used.
|
||||
|
||||
Args:
|
||||
opt: An instance of a `tf.keras.optimizers.Optimizer`.
|
||||
@ -208,16 +227,16 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
Mixed precision is the use of both float32 and float16 data types when
|
||||
training a model to improve performance. This is achieved via a graph rewrite
|
||||
operation and a loss-scale optimizer.
|
||||
|
||||
|
||||
Performing arithmetic operations in float16 takes advantage of specialized
|
||||
processing units, such as NVIDIA Tensor Cores for much higher arithmetic
|
||||
processing units, such as NVIDIA Tensor Cores, for much higher arithmetic
|
||||
throughput. However, due to the smaller representable range, performing the
|
||||
entire training with float16 can result in gradient underflow, that is, small
|
||||
gradient values becoming zeroes. Instead, performing only select arithmetic
|
||||
operations in float16 results in higher throughput and decreased training
|
||||
time when using compatible hardware accelerators while also reducing memory
|
||||
usage, typically without sacrificing model accuracy.
|
||||
|
||||
|
||||
Note: While the mixed precision rewrite changes the datatype of various
|
||||
layers throughout the model, the same accuracy reached in float32 is
|
||||
expected. If a `NaN` gradient occurs with dynamic loss scaling, the model
|
||||
@ -226,43 +245,31 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
scaling value to avoid `NaN` values in subsequent iterations. This approach
|
||||
has been shown to achieve the same accuracy as float32 and, in most cases,
|
||||
better training throughput.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
```python
|
||||
model = tf.keras.models.Sequential([
|
||||
...
|
||||
tf.keras.layers.Dense(64, activation='relu'),
|
||||
tf.keras.layers.Dense(64, activation='softmax'),
|
||||
])
|
||||
|
||||
|
||||
opt = tf.keras.optimizers.SGD()
|
||||
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
|
||||
|
||||
model.compile(loss="categorical_crossentropy",
|
||||
optimizer=opt,
|
||||
metrics=["accuracy"])
|
||||
|
||||
model.fit(x_train, y_train,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs)
|
||||
model.compile(loss="mse", optimizer=opt)
|
||||
|
||||
x_train = np.random.random((1024, 64))
|
||||
y_train = np.random.random((1024, 64))
|
||||
model.fit(x_train, y_train)
|
||||
```
|
||||
|
||||
For a complete example showing the speed-up on training an image
|
||||
classification task on CIFAR10, check out this
|
||||
<a href="https://colab.research.google.com/github/NVIDIA/
|
||||
DeepLearningExamples/blob/master/TensorFlow/docs/amp/notebook_v1.14/
|
||||
auto_mixed_precision_demo_cifar10.ipynb">Colab notebook</a>.
|
||||
|
||||
|
||||
Calling `enable_mixed_precision_graph_rewrite(opt)` enables the graph rewrite
|
||||
operation before computing gradients. The function additionally returns an
|
||||
`Optimizer`(`opt`) wrapped with a `LossScaleOptimizer`. This prevents
|
||||
`Optimizer` (`opt`) wrapped with a `LossScaleOptimizer`. This prevents
|
||||
underflow in the float16 tensors during the backward pass. An optimizer of
|
||||
type `tf.train.Optimizer` or `tf.keras.optimizers.Optimizer` must be passed
|
||||
to this function, which will then be wrapped to use loss scaling.
|
||||
|
||||
<img src="
|
||||
http://developer.download.nvidia.com/compute/machine-learning/frameworks/
|
||||
TF_mixed_precision_training.png" width="500px">
|
||||
|
||||
|
||||
The graph rewrite operation changes the `dtype` of certain operations in the
|
||||
graph from float32 to float16. There are several categories of operations
|
||||
that are either included or excluded by this rewrite operation. The following
|
||||
@ -271,7 +278,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
<a href="https://github.com/tensorflow/tensorflow/blob/master/tensorflow/
|
||||
core/grappler/optimizers/auto_mixed_precision_lists.h">
|
||||
auto_mixed_precision_lists.h</a>:
|
||||
|
||||
|
||||
* `ClearList`: Ops that do not have numerically significant adverse effects.
|
||||
E.g. `ArgMax` and `Floor`.
|
||||
* `WhiteList`: Ops that are considered numerically safe for execution in
|
||||
@ -280,7 +287,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
can negatively affect downstream nodes. E.g. `Softmax`.
|
||||
* `GrayList`: Ops that are considered numerically safe for execution in
|
||||
float16 unless downstream from a BlackList Op. E.g. `Add` and `AvgPool`.
|
||||
|
||||
|
||||
When this function is used, gradients should only be computed and applied
|
||||
with the returned optimizer, either by calling `opt.minimize()` or
|
||||
`opt.compute_gradients()` followed by `opt.apply_gradients()`.
|
||||
@ -289,30 +296,26 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
`tf.gradients` or `tf.GradientTape` will not. If you do directly use
|
||||
`tf.gradients` or `tf.GradientTape`, your model may not converge due to
|
||||
float16 underflow problems.
|
||||
|
||||
|
||||
When eager execution is enabled, the mixed precision graph rewrite is only
|
||||
enabled within `tf.function`, as outside `tf.function`, there is no graph.
|
||||
|
||||
enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
|
||||
|
||||
For NVIDIA GPUs with Tensor cores, as a general performance guide, dimensions
|
||||
(such as batch size, input size, output size, and channel counts)
|
||||
should be powers of two if under 256, or otherwise divisible by 8 if above
|
||||
256. For more information, check out the
|
||||
[NVIDIA Deep Learning Performance Guide](
|
||||
https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html).
|
||||
|
||||
|
||||
Currently, mixed precision is only enabled on NVIDIA Tensor Core GPUs with
|
||||
Compute Capability 7.0 and above (Volta, Turing, or newer architectures). The
|
||||
parts of the graph on CPUs and TPUs are untouched by the graph rewrite. TPU
|
||||
support is coming soon. CPUs are not supported, as CPUs do not run float16
|
||||
operations faster than float32 operations.
|
||||
|
||||
parts of the graph on CPUs and TPUs are untouched by the graph rewrite.
|
||||
|
||||
Raises:
|
||||
`ValueError` when
|
||||
`mixed_precision_global_state.using_default_mixed_precision_policy`
|
||||
is set to `False` before
|
||||
`tf.train.experimental.enable_mixed_precision_graph_rewrite()`
|
||||
is called.
|
||||
|
||||
`ValueError`, if the `tf.keras.mixed_precision` API is also used by calling
|
||||
`tf.keras.mixed_precision.experimental.set_policy`. Only one mixed precision
|
||||
API can be used.
|
||||
|
||||
Args:
|
||||
opt: An instance of a `tf.keras.optimizers.Optimizer` or a
|
||||
`tf.train.Optimizer`.
|
||||
@ -320,7 +323,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
a `tf.mixed_precision.experimental.LossScale`. The loss scale to use. It
|
||||
is recommended to keep this as its default value of `"dynamic"`, which
|
||||
will adjust the scaling automatically to prevent `Inf` or `NaN` values.
|
||||
|
||||
|
||||
Returns:
|
||||
A version of `opt` that will use loss scaling to prevent underflow.
|
||||
"""
|
||||
@ -344,10 +347,8 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||
'(You called this second)\n\n'
|
||||
'You called both functions, which is an error, because both functions '
|
||||
'enable you to use mixed precision. If in doubt which function to use, '
|
||||
'use the second, as it is currently more complete and easy to use. The '
|
||||
'second function enables mixed precision in the graph with a graph '
|
||||
'rewrite. However it is currently not very customizable, and does not '
|
||||
'support eager.')
|
||||
'use the first, as it supports Eager execution and is more '
|
||||
'customizable.')
|
||||
|
||||
if mixed_precision_global_state.non_mixed_precision_session_created:
|
||||
# TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
|
||||
|
Loading…
Reference in New Issue
Block a user