diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index a0eb11164e7..5e8ebf4fa47 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -55,28 +55,21 @@ class Policy(object): computation dtype to avoid type errors. Policies also have a `tf.train.experimental.LossScale` instance, which is used - by Models to performance loss scaling. Layers which are not Models ignore - the loss scale. + by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by + Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models + ignore the loss scale. Policies are constructed by passing a string to the constructor, e.g. `tf.keras.mixed_precision.experimental.Policy('float32')`. The string - determines the compute and variable dtypes. Currently, it can be one of - in one of the following forms: + determines the compute and variable dtypes. It can be one of the following: * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. - * '_with_float32_vars', where is any dtype. The compute dtype - will be , while the variable dtype is float32. This can be used for - mixed precision, which uses float16 or bfloat16 for most computations, and - float32 for variables, but it is recommended to use the 'mixed_float16' or - 'mixed_bfloat16' policies instead. - * 'mixed_float16' or 'mixed_bfloat16': Similar to - 'float16_with_float32_vars' or 'bfloat16_with_float32_vars' respectively. - 'mixed_float16' is identical to 'float16_with_float32_vars' except the - loss_scale is dynamic by default. 'mixed_bfloat16' is currently identical - to 'bfloat16_with_float32_vars'. More changes may be added to these mixed - policies in the future, to further differentiate them from - [b]float16_with_float32_vars. + compute dtypes will be that dtype. No loss scaling is done by default. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. These policies are used for + mixed precision training. With 'mixed_float16', a dynamic loss scale is + used by default. 'mixed_bfloat16' does no loss scaling by default, as loss + scaling is unnecessary with bfloat16. ### How to use mixed precision in layers with Policies @@ -118,6 +111,14 @@ class Policy(object): constructors in the `dtype` argument instead of policies, but only if the string is convertible to a dtype. + Note the 'mixed_float16' policy will apply loss scaling by default in + `Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a + custom training loop is used) and 'mixed_float16' is used, the loss scale must + be manually applied. See + `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For + 'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be + manually applied. + ### The deprecated "infer" policy In addition to a dtype or "_with_float32_vars", a policy can also be @@ -129,7 +130,13 @@ class Policy(object): the dtype of the first input. Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy - that infers the compute dtype, but not the variable dtype. + that infers the compute dtype, but not the variable dtype. Once a layer with + an "infer_with_float32_vars" policy is called for the first time, the layer's + policy will change to "_with_float32_vars", where is the dtype + of the first input. These policies force variables in float32. + + Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow + 2.1. Please use "mixed_float16" or "mixed_bfloat16" instead. In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are available. @@ -137,6 +144,12 @@ class Policy(object): # TODO(reedwm): Replace link in above docstring with a version that is more # TensorFlow-specific, and that also mentions bfloat16. + # If True, warn when a policy is created whose name ends in + # "_with_float32_vars". We always want to warn when a user creates such a + # policy, but when the TensorFlow creates a policy, it suppresses the warning + # by setting this to False when creating the policy. + _warn_about_float32_vars = True + def __init__(self, name, loss_scale=USE_DEFAULT): """Constructs the policy. @@ -148,21 +161,12 @@ class Policy(object): name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. - * '_with_float32_vars', where is any dtype. The compute - dtype will be , while the variable dtype is float32. This can - be used for mixed precision, which uses float16 or bfloat16 for most - computations, and float32 for variables, but it is recommended to use - the 'mixed_float16' or 'mixed_bfloat16' policies instead. - * 'mixed_float16' or 'mixed_bfloat16': Similar to - 'float16_with_float32_vars' or 'bfloat16_with_float32_vars' - respectively. 'mixed_float16' is identical to - 'float16_with_float32_vars' except the loss_scale is dynamic by - default. 'mixed_bfloat16' is currently identical to - 'bfloat16_with_float32_vars'. More changes may be added to these mixed - policies in the future, to further differentiate them from - [b]float16_with_float32_vars. - * 'infer' or 'infer_with_float32_vars' (deprecated): Infer the - computation dtype from the input dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + a dynamic loss scale is used. These policies are used for mixed + precision training. + * 'infer' (deprecated): Infer the compute and variable dtype from the + input dtype. loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to one such as "dynamic". Defaults to using no loss scaling unless `name` is "mixed_float16", in which case this defaults to "dynamic". Only @@ -185,6 +189,18 @@ class Policy(object): self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) + if name.endswith('_with_float32_vars') and self._warn_about_float32_vars: + warning = ("WARNING: The '%s' policy is deprecated and will be removed " + "in TensorFlow 2.1." % name) + if name == 'infer_with_float32_vars': + warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' " + "policy instead.") + elif name == 'float16_with_float32_vars': + warning += " Please use the 'mixed_float16' policy instead." + elif name == 'bfloat16_with_float32_vars': + warning += " Please use the 'mixed_bfloat16' policy instead." + tf_logging.warn(warning) + if loss_scale == USE_DEFAULT: loss_scale = 'dynamic' if name == 'mixed_float16' else None if loss_scale and self._compute_dtype not in (None, 'float16'): @@ -221,10 +237,10 @@ class Policy(object): try: base_dtype = dtypes.as_dtype(base_name).name except TypeError: - error = ('Cannot convert value %s to a mixed precision Policy. ' - 'Valid policies include include those in the form "" ' - 'and "_with_float32_vars", where is the name ' - 'of a dtype.' % (name,)) + error = ("Cannot convert value %s to a mixed precision Policy. " + "Valid policies include include 'mixed_float16', " + "'mixed_bfloat16', and the name of any dtype such as " + "'float32'." % (name,)) if float32_vars: error += (' The value %s ends with _with_float32_vars, but %s cannot ' 'be converted to a DType' % (name, base_name)) @@ -337,7 +353,11 @@ def with_input_dtype(policy, dtype): # Policies without a compute dtype are either "infer" or # "infer_with_float32_vars", so the variable_dtype must be float32 here. assert policy.variable_dtype == 'float32' - return Policy(dtype + '_with_float32_vars') + try: + Policy._warn_about_float32_vars = False # pylint: disable=protected-access + return Policy(dtype + '_with_float32_vars') + finally: + Policy._warn_about_float32_vars = True # pylint: disable=protected-access # The current global policy in effect. If None, it means the current value of diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index f1c2504a990..09b25960f53 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -169,11 +169,49 @@ class PolicyTest(test.TestCase): 'policies. You passed loss_scale=2.0 for policy float32. Consider ' 'not passing any loss_scale instead.') - for policy_name in 'float16', 'float16_with_float32_vars', 'mixed_float16': + for policy_name in 'float16', 'mixed_float16': with test.mock.patch.object(tf_logging, 'warn') as mock_warn: mp_policy.Policy(policy_name, loss_scale=2.) mock_warn.assert_not_called() + @testing_utils.enable_v2_dtype_behavior + def test_float32_vars_warning(self): + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('infer_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'infer_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_float16' " + "or 'mixed_bfloat16' policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('float16_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'float16_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_float16' " + "policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('bfloat16_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' " + "policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('float64_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'float64_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1.") + + for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16': + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy(policy_name) + mock_warn.assert_not_called() + @testing_utils.enable_v2_dtype_behavior def test_policy_scope(self): if base_layer_utils.v2_dtype_behavior_enabled():