Reinstate graph rewrite mixed precision as experimental.
PiperOrigin-RevId: 261365431
This commit is contained in:
parent
e785638489
commit
b9964d43a3
tensorflow
python/training
tools
@ -27,7 +27,7 @@ from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def _wrap_optimizer(opt, loss_scale):
|
||||
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
|
||||
"""Wraps an optimizer with a LossScaleOptimizer."""
|
||||
|
||||
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
|
||||
@ -67,12 +67,60 @@ def _wrap_optimizer(opt, loss_scale):
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
|
||||
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
|
||||
|
||||
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
|
||||
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
|
||||
if use_v1_behavior:
|
||||
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
|
||||
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
|
||||
else:
|
||||
raise ValueError('"opt" must be an instance of a '
|
||||
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
|
||||
|
||||
|
||||
@tf_export('train.experimental.enable_mixed_precision_graph_rewrite', v1=[])
|
||||
def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
"""Enable mixed precision in `tf.function`s via a graph rewrite.
|
||||
|
||||
Mixed precision is the use of both float16 and float32 when training a model,
|
||||
and is used to make the model run faster. This function will use mixed
|
||||
precision to speed up the execution time of `tf.function`s when run on a GPU.
|
||||
It does this by changing the dtype of certain operations in the function's
|
||||
graph from float32 to float16.
|
||||
|
||||
This function additionally wraps an Optimizer with a LossScaleOptimizer, which
|
||||
is required to prevent underflow in the float16 tensors during the backwards
|
||||
pass. An optimizer must be passed to this function, which will then be wrapped
|
||||
to use loss scaling.
|
||||
|
||||
When this function is used, gradients should only be computed and applied with
|
||||
the returned optimizer through `opt.minimize()`, and not with a
|
||||
`tf.GradientTape`. This is because the returned optimizer will apply loss
|
||||
scaling, and `tf.GradientTape` will not. If you do use a `tf.GradientTape`,
|
||||
your model may train to a worse quality.
|
||||
|
||||
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
|
||||
support is coming soon. CPUs are not supported, as CPUs do not run float16
|
||||
operations faster than float32 operations.
|
||||
|
||||
WARNING: This rewrite silently affects the entire model and can have
|
||||
unintended consequences. One example: If a NaN occurs during dynamic loss
|
||||
scaling, the data for the batch is silently dropped while the
|
||||
LossScaleOptimizer attempts to find the appropriate scaling value on the next
|
||||
batch.
|
||||
|
||||
Args:
|
||||
opt: An instance of a `tf.keras.optimizers.Optimizer`.
|
||||
loss_scale: Either an int/float, the string "dynamic", or an instance of a
|
||||
`tf.train.experimental.LossScale`. The loss scale to use. It is
|
||||
recommended to keep this as its default value of "dynamic".
|
||||
|
||||
Returns:
|
||||
A version of `opt` that will use loss scaling to prevent underflow.
|
||||
"""
|
||||
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||
use_v1_behavior=False)
|
||||
|
||||
|
||||
@tf_export(v1=['train.experimental.enable_mixed_precision_graph_rewrite'])
|
||||
def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
||||
"""Enable mixed precision via a graph rewrite.
|
||||
|
||||
Mixed precision is the use of both float16 and float32 when training a model,
|
||||
@ -94,11 +142,9 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
`tf.gradients`/`tf.GradientTape` will not. If you do directly use
|
||||
`tf.gradients` or `tf.GradientTape`, your model may train to a worse quality.
|
||||
|
||||
When eager execution is enabled, the mixed precision graph rewrite is only
|
||||
enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
|
||||
|
||||
When enabled, mixed precision is only used on Volta GPUs and above. The parts
|
||||
of the graph on CPUs and TPUs are untouched by the graph rewrite.
|
||||
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
|
||||
support is coming soon. CPUs are not supported, as CPUs do not run float16
|
||||
operations faster than float32 operations.
|
||||
|
||||
Args:
|
||||
opt: An instance of a `tf.keras.optimizers.Optimizer` or a
|
||||
@ -112,6 +158,13 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
"""
|
||||
# TODO(reedwm): If a ConfigProto is passed to Session, either assert that
|
||||
# auto_mixed_precision is on or turn it on for the user.
|
||||
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||
use_v1_behavior=True)
|
||||
|
||||
|
||||
def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||
use_v1_behavior):
|
||||
"""Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
|
||||
if not mixed_precision_global_state.using_default_mixed_precision_policy:
|
||||
raise ValueError(
|
||||
'The mixed precision graph rewrite cannot be enabled, because a keras '
|
||||
@ -122,10 +175,11 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
|
||||
'(You called this second)\n\n'
|
||||
'You called both functions, which is an error, because both functions '
|
||||
'enable you to use mixed precision. The second function enables mixed '
|
||||
'precision in the graph with a graph rewrite. However it is currently '
|
||||
'not very customizable, and does not support eager. The first '
|
||||
'function is for Keras layers, but is not yet fully complete.')
|
||||
'enable you to use mixed precision. If in doubt which function to use, '
|
||||
'use the second, as it is currently more complete and easy to use. The '
|
||||
'second function enables mixed precision in the graph with a graph '
|
||||
'rewrite. However it is currently not very customizable, and does not '
|
||||
'support eager.')
|
||||
|
||||
if mixed_precision_global_state.non_mixed_precision_session_created:
|
||||
# TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
|
||||
@ -133,16 +187,40 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
|
||||
tf_logging.warn('You already have existing Sessions that do not use mixed '
|
||||
'precision. enable_mixed_precision_graph_rewrite() will '
|
||||
'not affect these Sessions.')
|
||||
opt = _wrap_optimizer(opt, loss_scale)
|
||||
opt = _wrap_optimizer(opt, loss_scale, use_v1_behavior=use_v1_behavior)
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
|
||||
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = True
|
||||
return opt
|
||||
|
||||
|
||||
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
|
||||
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
|
||||
def disable_mixed_precision_graph_rewrite():
|
||||
"""Disables the mixed precision graph rewrite.
|
||||
|
||||
After this is called, the mixed precision graph rewrite will no longer run for
|
||||
tf.functions, and so float32 operations will no longer be converted to
|
||||
float16.
|
||||
|
||||
This does not undo the effects of loss scaling. Any optimizers wrapped with a
|
||||
LossScaleOptimizer will continue to do loss scaling, although this loss
|
||||
scaling will no longer be useful, as the graph rewrite no longer converts
|
||||
tf.functions to use float16.
|
||||
|
||||
This function is useful for unit testing. A unit test can test using the mixed
|
||||
precision graph rewrite, then disable it so future unit tests continue using
|
||||
float32.
|
||||
"""
|
||||
if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
|
||||
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
|
||||
'precision is already disabled.')
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
||||
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
|
||||
|
||||
|
||||
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
|
||||
def disable_mixed_precision_graph_rewrite_v1():
|
||||
"""Disables the mixed precision graph rewrite.
|
||||
|
||||
After this is called, the mixed precision graph rewrite will no longer run for
|
||||
new Sessions, and so float32 operations will no longer be converted to float16
|
||||
in such Sessions. However, any existing Sessions will continue to have the
|
||||
@ -161,8 +239,6 @@ def disable_mixed_precision_graph_rewrite():
|
||||
as `enable_mixed_precision_graph_rewrite` and
|
||||
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
|
||||
"""
|
||||
if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
|
||||
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
|
||||
'precision is already disabled.')
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
||||
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
|
||||
# We only have a separate V1 version of this function, because the V1
|
||||
# docstring mentions sessions.
|
||||
disable_mixed_precision_graph_rewrite()
|
||||
|
@ -21,6 +21,7 @@ import os
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -40,6 +41,14 @@ from tensorflow.python.training.experimental import mixed_precision
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
|
||||
|
||||
if tf2.enabled():
|
||||
enable_mixed_precision_graph_rewrite = (
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite)
|
||||
else:
|
||||
enable_mixed_precision_graph_rewrite = (
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite_v1)
|
||||
|
||||
|
||||
class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
|
||||
@ -64,13 +73,13 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_wrap_optimizer(self):
|
||||
opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
|
||||
opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
self.assertIsInstance(
|
||||
opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer)
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
|
||||
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
self.assertIsInstance(
|
||||
opt, loss_scale_optimizer_v2.LossScaleOptimizer)
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
|
||||
@ -78,10 +87,14 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_optimizer_errors(self):
|
||||
opt = 1
|
||||
expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
|
||||
'a tf.keras.optimizers.Optimizer, but got')
|
||||
if tf2.enabled():
|
||||
expected_regex = ('"opt" must be an instance of a '
|
||||
'tf.keras.optimizers.Optimizer, but got')
|
||||
else:
|
||||
expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
|
||||
'a tf.keras.optimizers.Optimizer, but got')
|
||||
with self.assertRaisesRegexp(ValueError, expected_regex):
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
|
||||
enable_mixed_precision_graph_rewrite(opt)
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
@ -91,7 +104,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'"opt" must not already be an instance of a '
|
||||
'MixedPrecisionLossScaleOptimizer.'):
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
|
||||
enable_mixed_precision_graph_rewrite(opt)
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
@ -100,7 +113,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'"opt" must not already be an instance of a '
|
||||
'LossScaleOptimizer.'):
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
|
||||
enable_mixed_precision_graph_rewrite(opt)
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
@ -108,7 +121,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_grappler_pass_enabled(self):
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
|
||||
var = variables.Variable([[1.0]])
|
||||
|
||||
@ -153,8 +166,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
with session.Session():
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
mock_warn.assert_any_call(
|
||||
'You already have existing Sessions that do not use mixed precision. '
|
||||
'enable_mixed_precision_graph_rewrite() will not affect these '
|
||||
@ -166,8 +178,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
# the warning.
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
with session.Session():
|
||||
# Make sure the "You already have existing Sessions" warning was not
|
||||
# issued, since the Session was only created after
|
||||
@ -181,11 +192,9 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'a keras mixed precision Policy has been set'):
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
# Test no error is thrown when the policy is current the default.
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.training.adam import AdamOptimizer
|
||||
from tensorflow.python.training.ftrl import FtrlOptimizer
|
||||
from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
|
||||
from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite
|
||||
from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite_v1
|
||||
from tensorflow.python.training.momentum import MomentumOptimizer
|
||||
from tensorflow.python.training.moving_averages import ExponentialMovingAverage
|
||||
from tensorflow.python.training.optimizer import Optimizer
|
||||
|
@ -16,4 +16,12 @@ tf_module {
|
||||
name: "PythonState"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
||||
}
|
||||
}
|
||||
|
@ -1459,10 +1459,6 @@ renames = {
|
||||
'tf.compat.v1.train.do_quantize_training_on_graphdef',
|
||||
'tf.train.experimental.MixedPrecisionLossScaleOptimizer':
|
||||
'tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer',
|
||||
'tf.train.experimental.disable_mixed_precision_graph_rewrite':
|
||||
'tf.compat.v1.train.experimental.disable_mixed_precision_graph_rewrite',
|
||||
'tf.train.experimental.enable_mixed_precision_graph_rewrite':
|
||||
'tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite',
|
||||
'tf.train.exponential_decay':
|
||||
'tf.compat.v1.train.exponential_decay',
|
||||
'tf.train.export_meta_graph':
|
||||
|
Loading…
Reference in New Issue
Block a user