Fork the keras related MP test to keras folder.
PiperOrigin-RevId: 316183309 Change-Id: I89aca251b05817811a382e098cf0e882714dbd45
This commit is contained in:
parent
06f3120f42
commit
710cc7e0e5
@ -211,6 +211,25 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "mixed_precision_graph_rewrite_test",
|
||||
size = "small",
|
||||
srcs = ["mixed_precision_graph_rewrite_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":loss_scale_optimizer",
|
||||
":policy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:config",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:mixed_precision",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python/keras:testing_utils",
|
||||
"//tensorflow/python/keras/optimizer_v2",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "test_util",
|
||||
srcs = ["test_util.py"],
|
||||
|
@ -0,0 +1,97 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests Keras integration with enable_mixed_precision_graph_rewrite()."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.experimental import mixed_precision
|
||||
|
||||
|
||||
if tf2.enabled():
|
||||
enable_mixed_precision_graph_rewrite = (
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite)
|
||||
else:
|
||||
enable_mixed_precision_graph_rewrite = (
|
||||
mixed_precision.enable_mixed_precision_graph_rewrite_v1)
|
||||
|
||||
|
||||
class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
IGNORE_PERF_VAR = 'TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'
|
||||
|
||||
def setUp(self):
|
||||
super(MixedPrecisionTest, self).setUp()
|
||||
# Enable the tests to be run on pre-Volta GPUs by telling the grappler pass
|
||||
# to ignore performance and always transform the graph.
|
||||
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
||||
os.environ[self.IGNORE_PERF_VAR] = '1'
|
||||
|
||||
def tearDown(self):
|
||||
# Set the IGNORE_PERF_VAR variable back to it's original value.
|
||||
if self._original_ignore_perf_value is not None:
|
||||
os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
|
||||
else:
|
||||
del os.environ[self.IGNORE_PERF_VAR]
|
||||
|
||||
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||
super(MixedPrecisionTest, self).tearDown()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_wrap_optimizer(self):
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
self.assertIsInstance(
|
||||
opt, loss_scale_optimizer_v2.LossScaleOptimizer)
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_optimizer_errors(self):
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic')
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'"opt" must not already be an instance of a '
|
||||
'LossScaleOptimizer.'):
|
||||
enable_mixed_precision_graph_rewrite(opt)
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_error_if_policy_is_set(self):
|
||||
with policy.policy_scope('mixed_float16'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'the global Keras dtype Policy has been set'):
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
# Test no error is thrown when the policy is currently the default.
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
# Test no error is thrown when the policy is a non-mixed policy.
|
||||
with policy.policy_scope('float64'):
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -28,10 +28,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -80,12 +76,6 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer)
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
|
||||
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
self.assertIsInstance(
|
||||
opt, loss_scale_optimizer_v2.LossScaleOptimizer)
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), 123.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_optimizer_errors(self):
|
||||
opt = 1
|
||||
@ -110,19 +100,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic')
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'"opt" must not already be an instance of a '
|
||||
'LossScaleOptimizer.'):
|
||||
enable_mixed_precision_graph_rewrite(opt)
|
||||
self.assertFalse(config.get_optimizer_experimental_options()
|
||||
.get('auto_mixed_precision', False))
|
||||
|
||||
@test_util.run_gpu_only
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_grappler_pass_enabled(self):
|
||||
opt = gradient_descent_v2.SGD(1.0)
|
||||
opt = gradient_descent_v1.GradientDescentOptimizer(1.0)
|
||||
enable_mixed_precision_graph_rewrite(opt, 123.)
|
||||
|
||||
var = variables.Variable([[1.0]])
|
||||
@ -168,7 +149,8 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
with session.Session():
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v1.GradientDescentOptimizer(1.0))
|
||||
mock_warn.assert_any_call(
|
||||
'You already have existing Sessions that do not use mixed precision. '
|
||||
'enable_mixed_precision_graph_rewrite() will not affect these '
|
||||
@ -180,7 +162,8 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
# the warning.
|
||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
enable_mixed_precision_graph_rewrite(
|
||||
gradient_descent_v1.GradientDescentOptimizer(1.0))
|
||||
with session.Session():
|
||||
# Make sure the "You already have existing Sessions" warning was not
|
||||
# issued, since the Session was only created after
|
||||
@ -190,18 +173,6 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertNotIn('You already have existing Sessions that do not use '
|
||||
'mixed precision', msg)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_error_if_policy_is_set(self):
|
||||
with policy.policy_scope('mixed_float16'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'the global Keras dtype Policy has been set'):
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
# Test no error is thrown when the policy is currently the default.
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
# Test no error is thrown when the policy is a non-mixed policy.
|
||||
with policy.policy_scope('float64'):
|
||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user