Add disable_mixed_precision_graph_rewrite function for unit testing
PiperOrigin-RevId: 248036105
This commit is contained in:
parent
09ca187c09
commit
20f681b89a
@ -169,3 +169,54 @@ def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
|||||||
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
|
config.set_optimizer_experimental_options({'auto_mixed_precision': True})
|
||||||
mixed_precision_global_state.mixed_precision_is_enabled = True
|
mixed_precision_global_state.mixed_precision_is_enabled = True
|
||||||
return opt
|
return opt
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('train.experimental.disable_mixed_precision_graph_rewrite', v1=[])
|
||||||
|
def disable_mixed_precision_graph_rewrite():
|
||||||
|
"""Disables the mixed precision graph rewrite.
|
||||||
|
|
||||||
|
After this is called, the mixed precision graph rewrite will no longer run for
|
||||||
|
tf.functions, and so float32 operations will no longer be converted to
|
||||||
|
float16.
|
||||||
|
|
||||||
|
This does not undo the effects of loss scaling. Any optimizers wrapped with a
|
||||||
|
LossScaleOptimizer will continue to do loss scaling, although this loss
|
||||||
|
scaling will no longer be useful, as the graph rewrite no longer converts
|
||||||
|
tf.functions to use float16.
|
||||||
|
|
||||||
|
This function is useful for unit testing. A unit test can test using the mixed
|
||||||
|
precision graph rewrite, then disable it so future unit tests continue using
|
||||||
|
float32.
|
||||||
|
"""
|
||||||
|
if not mixed_precision_global_state.mixed_precision_is_enabled:
|
||||||
|
tf_logging.warn('disable_mixed_precision_graph_rewrite() called when mixed '
|
||||||
|
'precision is already disabled.')
|
||||||
|
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
||||||
|
mixed_precision_global_state.mixed_precision_is_enabled = False
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=['train.experimental.disable_mixed_precision_graph_rewrite'])
|
||||||
|
def disable_mixed_precision_graph_rewrite_v1():
|
||||||
|
"""Disables the mixed precision graph rewrite.
|
||||||
|
|
||||||
|
After this is called, the mixed precision graph rewrite will no longer run for
|
||||||
|
new Sessions, and so float32 operations will no longer be converted to float16
|
||||||
|
in such Sessions. However, any existing Sessions will continue to have the
|
||||||
|
graph rewrite enabled if they were created after
|
||||||
|
`enable_mixed_precision_graph_rewrite` was called but before
|
||||||
|
`disable_mixed_precision_graph_rewrite` was called.
|
||||||
|
|
||||||
|
This does not undo the effects of loss scaling. Any optimizers wrapped with a
|
||||||
|
LossScaleOptimizer will continue to do loss scaling, although this loss
|
||||||
|
scaling will no longer be useful if the optimizer is used in new Sessions, as
|
||||||
|
the graph rewrite no longer converts the graph to use float16.
|
||||||
|
|
||||||
|
This function is useful for unit testing. A unit tests can test using the
|
||||||
|
mixed precision graph rewrite, then disable it so future unit tests continue
|
||||||
|
using float32. If this is done, unit tests should not share a single session,
|
||||||
|
as `enable_mixed_precision_graph_rewrite` and
|
||||||
|
`disable_mixed_precision_graph_rewrite` have no effect on existing sessions.
|
||||||
|
"""
|
||||||
|
# We only have a separate V1 version of this function, because the V1
|
||||||
|
# docstring mentions sessions.
|
||||||
|
disable_mixed_precision_graph_rewrite()
|
||||||
|
@ -58,22 +58,15 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
# to ignore performance and always transform the graph.
|
# to ignore performance and always transform the graph.
|
||||||
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
self._original_ignore_perf_value = os.getenv(self.IGNORE_PERF_VAR)
|
||||||
os.environ[self.IGNORE_PERF_VAR] = '1'
|
os.environ[self.IGNORE_PERF_VAR] = '1'
|
||||||
# Set global variables to their original state, in case other tests modified
|
|
||||||
# them
|
|
||||||
mixed_precision_global_state.mixed_precision_is_enabled = False
|
|
||||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
# Set auto_mixed_precision back to it's default value.
|
|
||||||
config.set_optimizer_experimental_options({'auto_mixed_precision': False})
|
|
||||||
# Set the IGNORE_PERF_VAR variable back to it's original value.
|
# Set the IGNORE_PERF_VAR variable back to it's original value.
|
||||||
if self._original_ignore_perf_value is not None:
|
if self._original_ignore_perf_value is not None:
|
||||||
os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
|
os.environ[self.IGNORE_PERF_VAR] = self._original_ignore_perf_value
|
||||||
else:
|
else:
|
||||||
del os.environ[self.IGNORE_PERF_VAR]
|
del os.environ[self.IGNORE_PERF_VAR]
|
||||||
# Revert global variables
|
|
||||||
mixed_precision_global_state.mixed_precision_is_enabled = False
|
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||||
mixed_precision_global_state.non_mixed_precision_session_created = False
|
|
||||||
super(MixedPrecisionTest, self).tearDown()
|
super(MixedPrecisionTest, self).tearDown()
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
@ -141,6 +134,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertEqual(f().numpy(), float('Inf'))
|
self.assertEqual(f().numpy(), float('Inf'))
|
||||||
# Outside a def_function.function, the grappler pass will not be applied.
|
# Outside a def_function.function, the grappler pass will not be applied.
|
||||||
self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20)
|
self.assertAlmostEqual(overflow_in_float16().numpy(), 2 ** 20)
|
||||||
|
|
||||||
|
# Test disabling mixed precision.
|
||||||
|
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||||
|
self.assertEqual(f().numpy(), 2 ** 20)
|
||||||
else:
|
else:
|
||||||
with session.Session() as sess:
|
with session.Session() as sess:
|
||||||
out = overflow_in_float16()
|
out = overflow_in_float16()
|
||||||
@ -154,8 +151,19 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
sess.run(var.initializer)
|
sess.run(var.initializer)
|
||||||
self.assertEqual(sess.run(out), float('Inf'))
|
self.assertEqual(sess.run(out), float('Inf'))
|
||||||
|
|
||||||
|
# Test disabling mixed precision.
|
||||||
|
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||||
|
with session.Session() as sess:
|
||||||
|
out = overflow_in_float16()
|
||||||
|
sess.run(var.initializer)
|
||||||
|
self.assertAlmostEqual(sess.run(out), 2 ** 20)
|
||||||
|
|
||||||
@test.mock.patch.object(tf_logging, 'warn')
|
@test.mock.patch.object(tf_logging, 'warn')
|
||||||
def test_warn_if_session_already_exists(self, mock_warn):
|
def test_warn_if_session_already_exists(self, mock_warn):
|
||||||
|
# Set this to False, so Sessions created in previous tests do not trigger
|
||||||
|
# the warning.
|
||||||
|
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||||
|
|
||||||
with session.Session():
|
with session.Session():
|
||||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||||
mock_warn.assert_any_call(
|
mock_warn.assert_any_call(
|
||||||
@ -165,6 +173,10 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
@test.mock.patch.object(tf_logging, 'warn')
|
@test.mock.patch.object(tf_logging, 'warn')
|
||||||
def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn):
|
def test_do_not_warn_if_session_does_not_already_exist(self, mock_warn):
|
||||||
|
# Set this to False, so Sessions created in previous tests do not trigger
|
||||||
|
# the warning.
|
||||||
|
mixed_precision_global_state.non_mixed_precision_session_created = False
|
||||||
|
|
||||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||||
with session.Session():
|
with session.Session():
|
||||||
# Make sure the "You already have existing Sessions" warning was not
|
# Make sure the "You already have existing Sessions" warning was not
|
||||||
|
@ -8,6 +8,10 @@ tf_module {
|
|||||||
name: "PythonState"
|
name: "PythonState"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "disable_mixed_precision_graph_rewrite"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "enable_mixed_precision_graph_rewrite"
|
name: "enable_mixed_precision_graph_rewrite"
|
||||||
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
||||||
|
@ -4,6 +4,10 @@ tf_module {
|
|||||||
name: "PythonState"
|
name: "PythonState"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "disable_mixed_precision_graph_rewrite"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "enable_mixed_precision_graph_rewrite"
|
name: "enable_mixed_precision_graph_rewrite"
|
||||||
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
argspec: "args=[\'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'dynamic\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user