From 710cc7e0e5335e052d7e43542d6fe57232174d90 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Fri, 12 Jun 2020 14:57:08 -0700 Subject: [PATCH] Fork the keras related MP test to keras folder. PiperOrigin-RevId: 316183309 Change-Id: I89aca251b05817811a382e098cf0e882714dbd45 --- .../keras/mixed_precision/experimental/BUILD | 19 ++++ .../mixed_precision_graph_rewrite_test.py | 97 +++++++++++++++++++ .../experimental/mixed_precision_test.py | 39 +------- 3 files changed, 121 insertions(+), 34 deletions(-) create mode 100644 tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py diff --git a/tensorflow/python/keras/mixed_precision/experimental/BUILD b/tensorflow/python/keras/mixed_precision/experimental/BUILD index ec89fa0c987..024b093c469 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/BUILD +++ b/tensorflow/python/keras/mixed_precision/experimental/BUILD @@ -211,6 +211,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "mixed_precision_graph_rewrite_test", + size = "small", + srcs = ["mixed_precision_graph_rewrite_test.py"], + python_version = "PY3", + deps = [ + ":loss_scale_optimizer", + ":policy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:config", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:mixed_precision", + "//tensorflow/python:tf2", + "//tensorflow/python/keras:testing_utils", + "//tensorflow/python/keras/optimizer_v2", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "test_util", srcs = ["test_util.py"], diff --git a/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py b/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py new file mode 100644 index 00000000000..d7454a89bad --- /dev/null +++ b/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests Keras integration with enable_mixed_precision_graph_rewrite().""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized + +from tensorflow.python import tf2 +from tensorflow.python.framework import config +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 +from tensorflow.python.keras.mixed_precision.experimental import policy +from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 +from tensorflow.python.platform import test +from tensorflow.python.training.experimental import mixed_precision + + +if tf2.enabled(): + enable_mixed_precision_graph_rewrite = ( + mixed_precision.enable_mixed_precision_graph_rewrite) +else: + enable_mixed_precision_graph_rewrite = ( + mixed_precision.enable_mixed_precision_graph_rewrite_v1) + + +class MixedPrecisionTest(test.TestCase, parameterized.TestCase): + + IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE' + + def setUp(self): + super(MixedPrecisionTest, self).setUp() + # Enable the tests to be run on pre-Volta GPUs by telling the grappler pass + # to ignore performance and always transform the graph. + self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR) + os.environ[self.IGNORE_PERF_VAR] = '1' + + def tearDown(self): + # Set the IGNORE_PERF_VAR variable back to it's original value. + if self._original_ignore_perf_value is not None: + os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value + else: + del os.environ[self.IGNORE_PERF_VAR] + + mixed_precision.disable_mixed_precision_graph_rewrite() + super(MixedPrecisionTest, self).tearDown() + + @test_util.run_in_graph_and_eager_modes + def test_wrap_optimizer(self): + opt = gradient_descent_v2.SGD(1.0) + opt = enable_mixed_precision_graph_rewrite(opt, 123.) + self.assertIsInstance( + opt, loss_scale_optimizer_v2.LossScaleOptimizer) + self.assertEqual(self.evaluate(opt._loss_scale()), 123.) + + @test_util.run_in_graph_and_eager_modes + def test_optimizer_errors(self): + opt = gradient_descent_v2.SGD(1.0) + opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic') + with self.assertRaisesRegexp(ValueError, + '"opt" must not already be an instance of a ' + 'LossScaleOptimizer.'): + enable_mixed_precision_graph_rewrite(opt) + self.assertFalse(config.get_optimizer_experimental_options() + .get('auto_mixed_precision', False)) + + @testing_utils.enable_v2_dtype_behavior + def test_error_if_policy_is_set(self): + with policy.policy_scope('mixed_float16'): + with self.assertRaisesRegexp( + ValueError, 'the global Keras dtype Policy has been set'): + enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) + # Test no error is thrown when the policy is currently the default. + enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) + # Test no error is thrown when the policy is a non-mixed policy. + with policy.policy_scope('float64'): + enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/experimental/mixed_precision_test.py b/tensorflow/python/training/experimental/mixed_precision_test.py index 65d5b690b32..2ce93245413 100644 --- a/tensorflow/python/training/experimental/mixed_precision_test.py +++ b/tensorflow/python/training/experimental/mixed_precision_test.py @@ -28,10 +28,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import test_util -from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 -from tensorflow.python.keras.mixed_precision.experimental import policy -from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables @@ -80,12 +76,6 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer) self.assertEqual(self.evaluate(opt._loss_scale()), 123.) - opt = gradient_descent_v2.SGD(1.0) - opt = enable_mixed_precision_graph_rewrite(opt, 123.) - self.assertIsInstance( - opt, loss_scale_optimizer_v2.LossScaleOptimizer) - self.assertEqual(self.evaluate(opt._loss_scale()), 123.) - @test_util.run_in_graph_and_eager_modes def test_optimizer_errors(self): opt = 1 @@ -110,19 +100,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): self.assertFalse(config.get_optimizer_experimental_options() .get('auto_mixed_precision', False)) - opt = gradient_descent_v2.SGD(1.0) - opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic') - with self.assertRaisesRegexp(ValueError, - '"opt" must not already be an instance of a ' - 'LossScaleOptimizer.'): - enable_mixed_precision_graph_rewrite(opt) - self.assertFalse(config.get_optimizer_experimental_options() - .get('auto_mixed_precision', False)) - @test_util.run_gpu_only @test_util.run_in_graph_and_eager_modes def test_grappler_pass_enabled(self): - opt = gradient_descent_v2.SGD(1.0) + opt = gradient_descent_v1.GradientDescentOptimizer(1.0) enable_mixed_precision_graph_rewrite(opt, 123.) var = variables.Variable([[1.0]]) @@ -168,7 +149,8 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): mixed_precision_global_state.non_mixed_precision_session_created = False with session.Session(): - enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) + enable_mixed_precision_graph_rewrite( + gradient_descent_v1.GradientDescentOptimizer(1.0)) mock_warn.assert_any_call( 'You already have existing Sessions that do not use mixed precision. ' 'enable_mixed_precision_graph_rewrite() will not affect these ' @@ -180,7 +162,8 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): # the warning. mixed_precision_global_state.non_mixed_precision_session_created = False - enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) + enable_mixed_precision_graph_rewrite( + gradient_descent_v1.GradientDescentOptimizer(1.0)) with session.Session(): # Make sure the "You already have existing Sessions" warning was not # issued, since the Session was only created after @@ -190,18 +173,6 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase): self.assertNotIn('You already have existing Sessions that do not use ' 'mixed precision', msg) - @testing_utils.enable_v2_dtype_behavior - def test_error_if_policy_is_set(self): - with policy.policy_scope('mixed_float16'): - with self.assertRaisesRegexp( - ValueError, 'the global Keras dtype Policy has been set'): - enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) - # Test no error is thrown when the policy is currently the default. - enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) - # Test no error is thrown when the policy is a non-mixed policy. - with policy.policy_scope('float64'): - enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0)) - if __name__ == '__main__': test.main()