From 20f681b89ab2aef33126826f625a31d8e547270f Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 13 May 2019 16:38:52 -0700 Subject: [PATCH] Add disable_mixed_precision_graph_rewrite function for unit testing PiperOrigin-RevId: 248036105 --- .../training/experimental/mixed_precision.py | 51 +++++++++++++++++++ .../experimental/mixed_precision_test.py | 30 +++++++---- .../v1/tensorflow.train.experimental.pbtxt | 4 ++ .../v2/tensorflow.train.experimental.pbtxt | 4 ++ 4 files changed, 80 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py index 9d4fa4b605e..8bc20f0c4c2 100644 --- a/tensorflow/python/training/experimental/mixed_precision.py +++ b/tensorflow/python/training/experimental/mixed_precision.py @@ -169,3 +169,54 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale, config.set_optimizer_experimental_options({'auto_mixed_precision': True}) mixed_precision_global_state.mixed_precision_is_enabled = True return opt + + +@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[]) +def disable_mixed_precision_graph_rewrite(): + """Disables the mixed precision graph rewrite. + + After this is called, the mixed precision graph rewrite will no longer run for + tf.functions, and so float32 operations will no longer be converted to + float16. + + This does not undo the effects of loss scaling. Any optimizers wrapped with a + LossScaleOptimizer will continue to do loss scaling, although this loss + scaling will no longer be useful, as the graph rewrite no longer converts + tf.functions to use float16. + + This function is useful for unit testing. A unit test can test using the mixed + precision graph rewrite, then disable it so future unit tests continue using + float32. + """ + if not mixed_precision_global_state.mixed_precision_is_enabled: + tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed ' + 'precision is already disabled.') + config.set_optimizer_experimental_options({'auto_mixed_precision': False}) + mixed_precision_global_state.mixed_precision_is_enabled = False + + +@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite']) +def disable_mixed_precision_graph_rewrite_v1(): + """Disables the mixed precision graph rewrite. + + After this is called, the mixed precision graph rewrite will no longer run for + new Sessions, and so float32 operations will no longer be converted to float16 + in such Sessions. However, any existing Sessions will continue to have the + graph rewrite enabled if they were created after + `enable_mixed_precision_graph_rewrite` was called but before + `disable_mixed_precision_graph_rewrite` was called. + + This does not undo the effects of loss scaling. Any optimizers wrapped with a + LossScaleOptimizer will continue to do loss scaling, although this loss + scaling will no longer be useful if the optimizer is used in new Sessions, as + the graph rewrite no longer converts the graph to use float16. + + This function is useful for unit testing. A unit tests can test using the + mixed precision graph rewrite, then disable it so future unit tests continue + using float32. If this is done, unit tests should not share a single session, + as `enable_mixed_precision_graph_rewrite` and + `disable_mixed_precision_graph_rewrite` have no effect on existing sessions. + """ + # We only have a separate V1 version of this function, because the V1 + # docstring mentions sessions. + disable_mixed_precision_graph_rewrite() diff --git a/tensorflow/python/training/experimental/mixed_precision_test.py b/tensorflow/python/training/experimental/mixed_precision_test.py index a52e8471e03..1ee3fe93d57 100644 --- a/tensorflow/python/training/experimental/mixed_precision_test.py +++ b/tensorflow/python/training/experimental/mixed_precision_test.py @@ -58,22 +58,15 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): # to ignore performance and always transform the graph. self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR) os.environ[self.IGNORE_PERF_VAR] = '1' - # Set global variables to their original state, in case other tests modified - # them - mixed_precision_global_state.mixed_precision_is_enabled = False - mixed_precision_global_state.non_mixed_precision_session_created = False def tearDown(self): - # Set auto_mixed_precision back to it's default value. - config.set_optimizer_experimental_options({'auto_mixed_precision': False}) # Set the IGNORE_PERF_VAR variable back to it's original value. if self._original_ignore_perf_value is not None: os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value else: del os.environ[self.IGNORE_PERF_VAR] - # Revert global variables - mixed_precision_global_state.mixed_precision_is_enabled = False - mixed_precision_global_state.non_mixed_precision_session_created = False + + mixed_precision.disable_mixed_precision_graph_rewrite() super(MixedPrecisionTest, self).tearDown() @test_util.run_in_graph_and_eager_modes @@ -141,6 +134,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): self.assertEqual(f().numpy(), float('Inf')) # Outside a def_function.function, the grappler pass will not be applied. self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20) + + # Test disabling mixed precision. + mixed_precision.disable_mixed_precision_graph_rewrite() + self.assertEqual(f().numpy(), 2 ** 20) else: with session.Session() as sess: out = overflow_in_float16() @@ -154,8 +151,19 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): sess.run(var.initializer) self.assertEqual(sess.run(out), float('Inf')) + # Test disabling mixed precision. + mixed_precision.disable_mixed_precision_graph_rewrite() + with session.Session() as sess: + out = overflow_in_float16() + sess.run(var.initializer) + self.assertAlmostEqual(sess.run(out), 2 ** 20) + @test.mock.patch.object(tf_logging, 'warn') def test_warn_if_session_already_exists(self, mock_warn): + # Set this to False, so Sessions created in previous tests do not trigger + # the warning. + mixed_precision_global_state.non_mixed_precision_session_created = False + with session.Session(): enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) mock_warn.assert_any_call( @@ -165,6 +173,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): @test.mock.patch.object(tf_logging, 'warn') def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn): + # Set this to False, so Sessions created in previous tests do not trigger + # the warning. + mixed_precision_global_state.non_mixed_precision_session_created = False + enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) with session.Session(): # Make sure the "You already have existing Sessions" warning was not diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt index ea5403d44d1..6d41033370e 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt @@ -8,6 +8,10 @@ tf_module { name: "PythonState" mtype: "" } + member_method { + name: "disable_mixed_precision_graph_rewrite" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "enable_mixed_precision_graph_rewrite" argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt index 998a01e378b..aab0b8d7ecb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.experimental.pbtxt @@ -4,6 +4,10 @@ tf_module { name: "PythonState" mtype: "" } + member_method { + name: "disable_mixed_precision_graph_rewrite" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } member_method { name: "enable_mixed_precision_graph_rewrite" argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "