Add disable_mixed_precision_graph_rewrite function for unit testing
PiperOrigin-RevId: 248036105
This commit is contained in:
parent
09ca187c09
commit
20f681b89a
@ -169,3 +169,54 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
|
||||
mixed_precision_global_state.mixed_precision_is_enabled = True
|
||||
return opt
|
||||
|
||||
|
||||
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
|
||||
def disable_mixed_precision_graph_rewrite():
|
||||
"""Disables the mixed precision graph rewrite.
|
||||
|
||||
After this is called, the mixed precision graph rewrite will no longer run for
|
||||
tf.functions, and so float32 operations will no longer be converted to
|
||||
float16.
|
||||
|
||||
This does not undo the effects of loss scaling. Any optimizers wrapped with a
|
||||
LossScaleOptimizer will continue to do loss scaling, although this loss
|
||||
scaling will no longer be useful, as the graph rewrite no longer converts
|
||||
tf.functions to use float16.
|
||||
|
||||
This function is useful for unit testing. A unit test can test using the mixed
|
||||
precision graph rewrite, then disable it so future unit tests continue using
|
||||
float32.
|
||||
"""
|
||||
if not mixed_precision_global_state.mixed_precision_is_enabled:
|
||||
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
|
||||
'precision is already disabled.')
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
||||
mixed_precision_global_state.mixed_precision_is_enabled = False
|
||||
|
||||
|
||||
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
|
||||
def disable_mixed_precision_graph_rewrite_v1():
|
||||
"""Disables the mixed precision graph rewrite.
|
||||
|
||||
After this is called, the mixed precision graph rewrite will no longer run for
|
||||
new Sessions, and so float32 operations will no longer be converted to float16
|
||||
in such Sessions. However, any existing Sessions will continue to have the
|
||||
graph rewrite enabled if they were created after
|
||||
`enable_mixed_precision_graph_rewrite` was called but before
|
||||
`disable_mixed_precision_graph_rewrite` was called.
|
||||
|
||||
This does not undo the effects of loss scaling. Any optimizers wrapped with a
|
||||
LossScaleOptimizer will continue to do loss scaling, although this loss
|
||||
scaling will no longer be useful if the optimizer is used in new Sessions, as
|
||||
the graph rewrite no longer converts the graph to use float16.
|
||||
|
||||
This function is useful for unit testing. A unit tests can test using the
|
||||
mixed precision graph rewrite, then disable it so future unit tests continue
|
||||
using float32. If this is done, unit tests should not share a single session,
|
||||
as `enable_mixed_precision_graph_rewrite` and
|
||||
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
|
||||
"""
|
||||
# We only have a separate V1 version of this function, because the V1
|
||||
# docstring mentions sessions.
|
||||
disable_mixed_precision_graph_rewrite()
|
||||
|
@ -58,22 +58,15 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
# to ignore performance and always transform the graph.
|
||||
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
||||
os.environ[self.IGNORE_PERF_VAR] = '1'
|
||||
# Set global variables to their original state, in case other tests modified
|
||||
# them
|
||||
mixed_precision_global_state.mixed_precision_is_enabled = False
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
def tearDown(self):
|
||||
# Set auto_mixed_precision back to it's default value.
|
||||
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
||||
# Set the IGNORE_PERF_VAR variable back to it's original value.
|
||||
if self._original_ignore_perf_value is not None:
|
||||
os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
|
||||
else:
|
||||
del os.environ[self.IGNORE_PERF_VAR]
|
||||
# Revert global variables
|
||||
mixed_precision_global_state.mixed_precision_is_enabled = False
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||
super(MixedPrecisionTest, self).tearDown()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -141,6 +134,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(f().numpy(), float('Inf'))
|
||||
# Outside a def_function.function, the grappler pass will not be applied.
|
||||
self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20)
|
||||
|
||||
# Test disabling mixed precision.
|
||||
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||
self.assertEqual(f().numpy(), 2 ** 20)
|
||||
else:
|
||||
with session.Session() as sess:
|
||||
out = overflow_in_float16()
|
||||
@ -154,8 +151,19 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
sess.run(var.initializer)
|
||||
self.assertEqual(sess.run(out), float('Inf'))
|
||||
|
||||
# Test disabling mixed precision.
|
||||
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||
with session.Session() as sess:
|
||||
out = overflow_in_float16()
|
||||
sess.run(var.initializer)
|
||||
self.assertAlmostEqual(sess.run(out), 2 ** 20)
|
||||
|
||||
@test.mock.patch.object(tf_logging, 'warn')
|
||||
def test_warn_if_session_already_exists(self, mock_warn):
|
||||
# Set this to False, so Sessions created in previous tests do not trigger
|
||||
# the warning.
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
with session.Session():
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
mock_warn.assert_any_call(
|
||||
@ -165,6 +173,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@test.mock.patch.object(tf_logging, 'warn')
|
||||
def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn):
|
||||
# Set this to False, so Sessions created in previous tests do not trigger
|
||||
# the warning.
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
with session.Session():
|
||||
# Make sure the "You already have existing Sessions" warning was not
|
||||
|
@ -8,6 +8,10 @@ tf_module {
|
||||
name: "PythonState"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
||||
|
@ -4,6 +4,10 @@ tf_module {
|
||||
name: "PythonState"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "disable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "enable_mixed_precision_graph_rewrite"
|
||||
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user