Reinstate graph rewrite mixed precision as experimental.

PiperOrigin-RevId: 261365431
This commit is contained in:
Toby Boyd 2019-08-02 12:04:19 -07:00 committed by TensorFlower Gardener
parent e785638489
commit b9964d43a3
5 changed files with 130 additions and 40 deletions
tensorflow

View File

@ -27,7 +27,7 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
def _wrap_optimizer(opt, loss_scale):
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
"""Wraps an optimizer with a LossScaleOptimizer."""
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
@ -67,12 +67,60 @@ def _wrap_optimizer(opt, loss_scale):
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
if use_v1_behavior:
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
else:
raise ValueError('"opt" must be an instance of a '
'tf.keras.optimizers.Optimizer, but got: %s' % opt)
@tf_export('train.experimental.enable_mixed_precision_graph_rewrite', v1=[])
def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
"""Enable mixed precision in `tf.function`s via a graph rewrite.
Mixed precision is the use of both float16 and float32 when training a model,
and is used to make the model run faster. This function will use mixed
precision to speed up the execution time of `tf.function`s when run on a GPU.
It does this by changing the dtype of certain operations in the function's
graph from float32 to float16.
This function additionally wraps an Optimizer with a LossScaleOptimizer, which
is required to prevent underflow in the float16 tensors during the backwards
pass. An optimizer must be passed to this function, which will then be wrapped
to use loss scaling.
When this function is used, gradients should only be computed and applied with
the returned optimizer through `opt.minimize()`, and not with a
`tf.GradientTape`. This is because the returned optimizer will apply loss
scaling, and `tf.GradientTape` will not. If you do use a `tf.GradientTape`,
your model may train to a worse quality.
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
support is coming soon. CPUs are not supported, as CPUs do not run float16
operations faster than float32 operations.
WARNING: This rewrite silently affects the entire model and can have
unintended consequences. One example: If a NaN occurs during dynamic loss
scaling, the data for the batch is silently dropped while the
LossScaleOptimizer attempts to find the appropriate scaling value on the next
batch.
Args:
opt: An instance of a `tf.keras.optimizers.Optimizer`.
loss_scale: Either an int/float, the string "dynamic", or an instance of a
`tf.train.experimental.LossScale`. The loss scale to use. It is
recommended to keep this as its default value of "dynamic".
Returns:
A version of `opt` that will use loss scaling to prevent underflow.
"""
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior=False)
@tf_export(v1=['train.experimental.enable_mixed_precision_graph_rewrite'])
def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
"""Enable mixed precision via a graph rewrite.
Mixed precision is the use of both float16 and float32 when training a model,
@ -94,11 +142,9 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
`tf.gradients`/`tf.GradientTape` will not. If you do directly use
`tf.gradients` or `tf.GradientTape`, your model may train to a worse quality.
When eager execution is enabled, the mixed precision graph rewrite is only
enabled within `tf.function`s, as outside `tf.function`s, there is no graph.
When enabled, mixed precision is only used on Volta GPUs and above. The parts
of the graph on CPUs and TPUs are untouched by the graph rewrite.
Currently, mixed precision is only enabled on Volta GPUs and above. TPU
support is coming soon. CPUs are not supported, as CPUs do not run float16
operations faster than float32 operations.
Args:
opt: An instance of a `tf.keras.optimizers.Optimizer` or a
@ -112,6 +158,13 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
"""
# TODO(reedwm): If a ConfigProto is passed to Session, either assert that
# auto_mixed_precision is on or turn it on for the user.
return _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior=True)
def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
use_v1_behavior):
"""Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
if not mixed_precision_global_state.using_default_mixed_precision_policy:
raise ValueError(
'The mixed precision graph rewrite cannot be enabled, because a keras '
@ -122,10 +175,11 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
'(You called this second)\n\n'
'You called both functions, which is an error, because both functions '
'enable you to use mixed precision. The second function enables mixed '
'precision in the graph with a graph rewrite. However it is currently '
'not very customizable, and does not support eager. The first '
'function is for Keras layers, but is not yet fully complete.')
'enable you to use mixed precision. If in doubt which function to use, '
'use the second, as it is currently more complete and easy to use. The '
'second function enables mixed precision in the graph with a graph '
'rewrite. However it is currently not very customizable, and does not '
'support eager.')
if mixed_precision_global_state.non_mixed_precision_session_created:
# TODO(reedwm): Give the stacktrace of the existing Sessions. And if the
@ -133,16 +187,40 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'):
tf_logging.warn('You already have existing Sessions that do not use mixed '
'precision. enable_mixed_precision_graph_rewrite() will '
'not affect these Sessions.')
opt = _wrap_optimizer(opt, loss_scale)
opt = _wrap_optimizer(opt, loss_scale, use_v1_behavior=use_v1_behavior)
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = True
return opt
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
def disable_mixed_precision_graph_rewrite():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
tf.functions, and so float32 operations will no longer be converted to
float16.
This does not undo the effects of loss scaling. Any optimizers wrapped with a
LossScaleOptimizer will continue to do loss scaling, although this loss
scaling will no longer be useful, as the graph rewrite no longer converts
tf.functions to use float16.
This function is useful for unit testing. A unit test can test using the mixed
precision graph rewrite, then disable it so future unit tests continue using
float32.
"""
if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
'precision is already disabled.')
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
def disable_mixed_precision_graph_rewrite_v1():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
new Sessions, and so float32 operations will no longer be converted to float16
in such Sessions. However, any existing Sessions will continue to have the
@ -161,8 +239,6 @@ def disable_mixed_precision_graph_rewrite():
as `enable_mixed_precision_graph_rewrite` and
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
"""
if not mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
'precision is already disabled.')
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled = False
# We only have a separate V1 version of this function, because the V1
# docstring mentions sessions.
disable_mixed_precision_graph_rewrite()

View File

@ -21,6 +21,7 @@ import os
from absl.testing import parameterized
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -40,6 +41,14 @@ from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.training.experimental import mixed_precision_global_state
if tf2.enabled():
enable_mixed_precision_graph_rewrite = (
mixed_precision.enable_mixed_precision_graph_rewrite)
else:
enable_mixed_precision_graph_rewrite = (
mixed_precision.enable_mixed_precision_graph_rewrite_v1)
class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
@ -64,13 +73,13 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_wrap_optimizer(self):
opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
self.assertIsInstance(
opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer)
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
opt = gradient_descent_v2.SGD(1.0)
opt = mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
self.assertIsInstance(
opt, loss_scale_optimizer_v2.LossScaleOptimizer)
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
@ -78,10 +87,14 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_optimizer_errors(self):
opt = 1
expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
'a tf.keras.optimizers.Optimizer, but got')
if tf2.enabled():
expected_regex = ('"opt" must be an instance of a '
'tf.keras.optimizers.Optimizer, but got')
else:
expected_regex = ('"opt" must be an instance of a tf.train.Optimizer or '
'a tf.keras.optimizers.Optimizer, but got')
with self.assertRaisesRegexp(ValueError, expected_regex):
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
enable_mixed_precision_graph_rewrite(opt)
self.assertFalse(config.get_optimizer_experimental_options()
.get('auto_mixed_precision', False))
@ -91,7 +104,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError,
'"opt" must not already be an instance of a '
'MixedPrecisionLossScaleOptimizer.'):
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
enable_mixed_precision_graph_rewrite(opt)
self.assertFalse(config.get_optimizer_experimental_options()
.get('auto_mixed_precision', False))
@ -100,7 +113,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError,
'"opt" must not already be an instance of a '
'LossScaleOptimizer.'):
mixed_precision.enable_mixed_precision_graph_rewrite(opt)
enable_mixed_precision_graph_rewrite(opt)
self.assertFalse(config.get_optimizer_experimental_options()
.get('auto_mixed_precision', False))
@ -108,7 +121,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_grappler_pass_enabled(self):
opt = gradient_descent_v2.SGD(1.0)
mixed_precision.enable_mixed_precision_graph_rewrite(opt, 123.)
enable_mixed_precision_graph_rewrite(opt, 123.)
var = variables.Variable([[1.0]])
@ -153,8 +166,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
mixed_precision_global_state.non_mixed_precision_session_created = False
with session.Session():
mixed_precision.enable_mixed_precision_graph_rewrite(
gradient_descent_v2.SGD(1.0))
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
mock_warn.assert_any_call(
'You already have existing Sessions that do not use mixed precision. '
'enable_mixed_precision_graph_rewrite() will not affect these '
@ -166,8 +178,7 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
# the warning.
mixed_precision_global_state.non_mixed_precision_session_created = False
mixed_precision.enable_mixed_precision_graph_rewrite(
gradient_descent_v2.SGD(1.0))
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
with session.Session():
# Make sure the "You already have existing Sessions" warning was not
# issued, since the Session was only created after
@ -181,11 +192,9 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
with policy.policy_scope('infer_float32_vars'):
with self.assertRaisesRegexp(
ValueError, 'a keras mixed precision Policy has been set'):
mixed_precision.enable_mixed_precision_graph_rewrite(
gradient_descent_v2.SGD(1.0))
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
# Test no error is thrown when the policy is current the default.
mixed_precision.enable_mixed_precision_graph_rewrite(
gradient_descent_v2.SGD(1.0))
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
if __name__ == '__main__':

View File

@ -35,6 +35,7 @@ from tensorflow.python.training.adam import AdamOptimizer
from tensorflow.python.training.ftrl import FtrlOptimizer
from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite
from tensorflow.python.training.experimental.mixed_precision import enable_mixed_precision_graph_rewrite_v1
from tensorflow.python.training.momentum import MomentumOptimizer
from tensorflow.python.training.moving_averages import ExponentialMovingAverage
from tensorflow.python.training.optimizer import Optimizer

View File

@ -16,4 +16,12 @@ tf_module {
name: "PythonState"
mtype: "<type \'type\'>"
}
member_method {
name: "disable_mixed_precision_graph_rewrite"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "enable_mixed_precision_graph_rewrite"
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
}
}

View File

@ -1459,10 +1459,6 @@ renames = {
'tf.compat.v1.train.do_quantize_training_on_graphdef',
'tf.train.experimental.MixedPrecisionLossScaleOptimizer':
'tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer',
'tf.train.experimental.disable_mixed_precision_graph_rewrite':
'tf.compat.v1.train.experimental.disable_mixed_precision_graph_rewrite',
'tf.train.experimental.enable_mixed_precision_graph_rewrite':
'tf.compat.v1.train.experimental.enable_mixed_precision_graph_rewrite',
'tf.train.exponential_decay':
'tf.compat.v1.train.exponential_decay',
'tf.train.export_meta_graph':