Add disable_mixed_precision_graph_rewrite function for unit testing

PiperOrigin-RevId: 248036105
This commit is contained in:
Reed Wanderman-Milne 2019-05-13 16:38:52 -07:00 committed by TensorFlower Gardener
parent 09ca187c09
commit 20f681b89a
4 changed files with 80 additions and 9 deletions

View File

@ -169,3 +169,54 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
config.set_optimizer_experimental_options({'auto_mixed_precision': True}) config.set_optimizer_experimental_options({'auto_mixed_precision': True})
mixed_precision_global_state.mixed_precision_is_enabled = True mixed_precision_global_state.mixed_precision_is_enabled = True
return opt return opt
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
def disable_mixed_precision_graph_rewrite():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
tf.functions, and so float32 operations will no longer be converted to
float16.
This does not undo the effects of loss scaling. Any optimizers wrapped with a
LossScaleOptimizer will continue to do loss scaling, although this loss
scaling will no longer be useful, as the graph rewrite no longer converts
tf.functions to use float16.
This function is useful for unit testing. A unit test can test using the mixed
precision graph rewrite, then disable it so future unit tests continue using
float32.
"""
if not mixed_precision_global_state.mixed_precision_is_enabled:
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
'precision is already disabled.')
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
mixed_precision_global_state.mixed_precision_is_enabled = False
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
def disable_mixed_precision_graph_rewrite_v1():
"""Disables the mixed precision graph rewrite.
After this is called, the mixed precision graph rewrite will no longer run for
new Sessions, and so float32 operations will no longer be converted to float16
in such Sessions. However, any existing Sessions will continue to have the
graph rewrite enabled if they were created after
`enable_mixed_precision_graph_rewrite` was called but before
`disable_mixed_precision_graph_rewrite` was called.
This does not undo the effects of loss scaling. Any optimizers wrapped with a
LossScaleOptimizer will continue to do loss scaling, although this loss
scaling will no longer be useful if the optimizer is used in new Sessions, as
the graph rewrite no longer converts the graph to use float16.
This function is useful for unit testing. A unit tests can test using the
mixed precision graph rewrite, then disable it so future unit tests continue
using float32. If this is done, unit tests should not share a single session,
as `enable_mixed_precision_graph_rewrite` and
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
"""
# We only have a separate V1 version of this function, because the V1
# docstring mentions sessions.
disable_mixed_precision_graph_rewrite()

View File

@ -58,22 +58,15 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
# to ignore performance and always transform the graph. # to ignore performance and always transform the graph.
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR) self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
os.environ[self.IGNORE_PERF_VAR] = '1' os.environ[self.IGNORE_PERF_VAR] = '1'
# Set global variables to their original state, in case other tests modified
# them
mixed_precision_global_state.mixed_precision_is_enabled = False
mixed_precision_global_state.non_mixed_precision_session_created = False
def tearDown(self): def tearDown(self):
# Set auto_mixed_precision back to it's default value.
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
# Set the IGNORE_PERF_VAR variable back to it's original value. # Set the IGNORE_PERF_VAR variable back to it's original value.
if self._original_ignore_perf_value is not None: if self._original_ignore_perf_value is not None:
os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
else: else:
del os.environ[self.IGNORE_PERF_VAR] del os.environ[self.IGNORE_PERF_VAR]
# Revert global variables
mixed_precision_global_state.mixed_precision_is_enabled = False mixed_precision.disable_mixed_precision_graph_rewrite()
mixed_precision_global_state.non_mixed_precision_session_created = False
super(MixedPrecisionTest, self).tearDown() super(MixedPrecisionTest, self).tearDown()
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
@ -141,6 +134,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
self.assertEqual(f().numpy(), float('Inf')) self.assertEqual(f().numpy(), float('Inf'))
# Outside a def_function.function, the grappler pass will not be applied. # Outside a def_function.function, the grappler pass will not be applied.
self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20) self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20)
# Test disabling mixed precision.
mixed_precision.disable_mixed_precision_graph_rewrite()
self.assertEqual(f().numpy(), 2 ** 20)
else: else:
with session.Session() as sess: with session.Session() as sess:
out = overflow_in_float16() out = overflow_in_float16()
@ -154,8 +151,19 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
sess.run(var.initializer) sess.run(var.initializer)
self.assertEqual(sess.run(out), float('Inf')) self.assertEqual(sess.run(out), float('Inf'))
# Test disabling mixed precision.
mixed_precision.disable_mixed_precision_graph_rewrite()
with session.Session() as sess:
out = overflow_in_float16()
sess.run(var.initializer)
self.assertAlmostEqual(sess.run(out), 2 ** 20)
@test.mock.patch.object(tf_logging, 'warn') @test.mock.patch.object(tf_logging, 'warn')
def test_warn_if_session_already_exists(self, mock_warn): def test_warn_if_session_already_exists(self, mock_warn):
# Set this to False, so Sessions created in previous tests do not trigger
# the warning.
mixed_precision_global_state.non_mixed_precision_session_created = False
with session.Session(): with session.Session():
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
mock_warn.assert_any_call( mock_warn.assert_any_call(
@ -165,6 +173,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
@test.mock.patch.object(tf_logging, 'warn') @test.mock.patch.object(tf_logging, 'warn')
def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn): def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn):
# Set this to False, so Sessions created in previous tests do not trigger
# the warning.
mixed_precision_global_state.non_mixed_precision_session_created = False
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
with session.Session(): with session.Session():
# Make sure the "You already have existing Sessions" warning was not # Make sure the "You already have existing Sessions" warning was not

View File

@ -8,6 +8,10 @@ tf_module {
name: "PythonState" name: "PythonState"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member_method {
name: "disable_mixed_precision_graph_rewrite"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "enable_mixed_precision_graph_rewrite" name: "enable_mixed_precision_graph_rewrite"
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], " argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "

View File

@ -4,6 +4,10 @@ tf_module {
name: "PythonState" name: "PythonState"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member_method {
name: "disable_mixed_precision_graph_rewrite"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "enable_mixed_precision_graph_rewrite" name: "enable_mixed_precision_graph_rewrite"
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], " argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "