Raise mixed precision error in fewer cases.
Now, if the mixed precision graph rewrite and the dtype policy are both enabled, an error is only thrown if the dtype policy is "mixed_float16" or "mixed_bfloat16". PiperOrigin-RevId: 300826448 Change-Id: If5f32f2aa438c44a67b38c2a756c93840da8cb50
This commit is contained in:
parent
733e721881
commit
5d8abb7a4e
@ -515,21 +515,20 @@ def policy_defaults_to_floatx():
|
|||||||
return _global_policy is None and base_layer_utils.v2_dtype_behavior_enabled()
|
return _global_policy is None and base_layer_utils.v2_dtype_behavior_enabled()
|
||||||
|
|
||||||
|
|
||||||
def _check_if_mixed_precision_graph_rewrite_is_enabled():
|
def _check_if_mixed_precision_graph_rewrite_is_enabled(policy):
|
||||||
# TODO(reedwm): Update this comment once the Keras API is complete.
|
|
||||||
if mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
|
if mixed_precision_global_state.mixed_precision_graph_rewrite_is_enabled:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The mixed precision policy cannot be set, because the mixed '
|
'The global dtype policy cannot be set to "{policy.name}", because the '
|
||||||
'precision graph rewrite has already been enabled.\n'
|
'mixed precision graph rewrite has already been enabled.\n'
|
||||||
'At most, one of the following functions can be called:\n\n'
|
'At most, one of the following can be called:\n\n'
|
||||||
' 1. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
|
' 1. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
|
||||||
'(You called this first)\n'
|
'(You called this first)\n'
|
||||||
' 2. tf.keras.mixed_precision.experimental.set_policy() (You called '
|
' 2. tf.keras.mixed_precision.experimental.set_policy() with a mixed '
|
||||||
'this second)\n\n'
|
'precision policy (You called this second)\n\n'
|
||||||
'You called both functions, which is an error, because both functions '
|
'You called both functions, which is an error, because both functions '
|
||||||
'enable you to use mixed precision. If in doubt which function to use, '
|
'enable you to use mixed precision. If in doubt which function to use, '
|
||||||
'use the second, as it supports Eager execution and is more '
|
'use the second, as it supports Eager execution and is more '
|
||||||
'customizable.')
|
'customizable.'.format(policy=policy))
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.mixed_precision.experimental.set_policy', v1=[])
|
@keras_export('keras.mixed_precision.experimental.set_policy', v1=[])
|
||||||
@ -546,14 +545,15 @@ def set_policy(policy):
|
|||||||
policy: A Policy, or a string that will be converted to a Policy..
|
policy: A Policy, or a string that will be converted to a Policy..
|
||||||
"""
|
"""
|
||||||
global _global_policy
|
global _global_policy
|
||||||
_check_if_mixed_precision_graph_rewrite_is_enabled()
|
|
||||||
if not base_layer_utils.v2_dtype_behavior_enabled():
|
if not base_layer_utils.v2_dtype_behavior_enabled():
|
||||||
raise ValueError('The global policy can only be set in TensorFlow 2')
|
raise ValueError('The global policy can only be set in TensorFlow 2')
|
||||||
if policy is not None and not isinstance(policy, Policy):
|
if policy is not None and not isinstance(policy, Policy):
|
||||||
policy = Policy(policy)
|
policy = Policy(policy)
|
||||||
|
is_mixed_policy = policy is not None and policy.should_cast_variables
|
||||||
|
if is_mixed_policy:
|
||||||
|
_check_if_mixed_precision_graph_rewrite_is_enabled(policy)
|
||||||
_global_policy = policy
|
_global_policy = policy
|
||||||
mixed_precision_global_state.using_default_mixed_precision_policy = (
|
mixed_precision_global_state.using_mixed_precision_policy = is_mixed_policy
|
||||||
_global_policy is None)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(reedwm): Make this thread local
|
# TODO(reedwm): Make this thread local
|
||||||
|
@ -293,9 +293,11 @@ class PolicyTest(test.TestCase):
|
|||||||
mixed_precision.enable_mixed_precision_graph_rewrite(
|
mixed_precision.enable_mixed_precision_graph_rewrite(
|
||||||
gradient_descent.SGD(1.))
|
gradient_descent.SGD(1.))
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, 'the mixed precision graph rewrite has already been '
|
ValueError, 'cannot be set to "mixed_float16", .* the mixed '
|
||||||
'enabled'):
|
'precision graph rewrite has already been enabled'):
|
||||||
mp_policy.set_policy('mixed_float16')
|
mp_policy.set_policy('mixed_float16')
|
||||||
|
with mp_policy.policy_scope('float64'):
|
||||||
|
pass # Non-mixed policies are allowed
|
||||||
finally:
|
finally:
|
||||||
mixed_precision.disable_mixed_precision_graph_rewrite()
|
mixed_precision.disable_mixed_precision_graph_rewrite()
|
||||||
|
|
||||||
|
@ -336,15 +336,15 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'):
|
|||||||
def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
def _enable_mixed_precision_graph_rewrite_base(opt, loss_scale,
|
||||||
use_v1_behavior):
|
use_v1_behavior):
|
||||||
"""Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
|
"""Enables mixed precision. See `enable_mixed_precision_graph_rewrite`."""
|
||||||
if not mixed_precision_global_state.using_default_mixed_precision_policy:
|
if mixed_precision_global_state.using_mixed_precision_policy:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The mixed precision graph rewrite cannot be enabled, because a keras '
|
'The mixed precision graph rewrite cannot be enabled, because the '
|
||||||
'mixed precision Policy has been set. At most, one of the following '
|
'global Keras dtype Policy has been set to a mixed precision policy. '
|
||||||
'functions can be called:\n\n'
|
'At most, one of the following can be called:\n\n'
|
||||||
' 1. tf.keras.mixed_precision.experimental.set_policy() (You called '
|
' 1. tf.keras.mixed_precision.experimental.set_policy() with a mixed '
|
||||||
'this first)\n'
|
'precision policy (You called this first)\n\n'
|
||||||
' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
|
' 2. tf.train.experimental.enable_mixed_precision_graph_rewrite() '
|
||||||
'(You called this second)\n\n'
|
'(You called this second)\n'
|
||||||
'You called both functions, which is an error, because both functions '
|
'You called both functions, which is an error, because both functions '
|
||||||
'enable you to use mixed precision. If in doubt which function to use, '
|
'enable you to use mixed precision. If in doubt which function to use, '
|
||||||
'use the first, as it supports Eager execution and is more '
|
'use the first, as it supports Eager execution and is more '
|
||||||
|
@ -33,7 +33,7 @@ mixed_precision_graph_rewrite_is_enabled = False
|
|||||||
# Session has already been created.
|
# Session has already been created.
|
||||||
non_mixed_precision_session_created = False
|
non_mixed_precision_session_created = False
|
||||||
|
|
||||||
# Whether the default tf.keras.mixed_precision.experimental.Policy is in effect.
|
# Whether the global tf.keras.mixed_precision.experimental.Policy uses mixed
|
||||||
# Used to raise an error message if both a non-default Policy and the graph
|
# precision. Used to raise an error message if both a mixed Policy and the graph
|
||||||
# rewrite are used at the same time.
|
# rewrite are used at the same time.
|
||||||
using_default_mixed_precision_policy = True
|
using_mixed_precision_policy = False
|
||||||
|
@ -194,10 +194,13 @@ class MixedPrecisionTest(test.TestCase, parameterized.TestCase):
|
|||||||
def test_error_if_policy_is_set(self):
|
def test_error_if_policy_is_set(self):
|
||||||
with policy.policy_scope('mixed_float16'):
|
with policy.policy_scope('mixed_float16'):
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, 'a keras mixed precision Policy has been set'):
|
ValueError, 'the global Keras dtype Policy has been set'):
|
||||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||||
# Test no error is thrown when the policy is current the default.
|
# Test no error is thrown when the policy is currently the default.
|
||||||
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||||
|
# Test no error is thrown when the policy is a non-mixed policy.
|
||||||
|
with policy.policy_scope('float64'):
|
||||||
|
enable_mixed_precision_graph_rewrite(gradient_descent_v2.SGD(1.0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
Reference in New Issue
Block a user