diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 917e871a7f0..2d1f05855ec 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2592,6 +2592,9 @@ class AddLoss(Layer): """ def __init__(self, unconditional, **kwargs): + # Pass autocast=False, as there is no reason to cast loss to a different + # dtype. + kwargs['autocast'] = False super(AddLoss, self).__init__(**kwargs) self.unconditional = unconditional diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index aaee57c2b4b..2074ef99d72 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -230,12 +230,10 @@ class Network(base_layer.Layer): else: self._graph = ops.get_default_graph() # Used in symbolic mode only. - # Both graph and subclassed networks have a dtype policy. The policy is - # currently ignored for a graph network, as graph networks disable - # autocasting (making the policy's compute dtype meaningless) and graph - # networks have no variables (making the policy's variable_dtype - # meaningless). For subclassed networks, the dtype policy acts as it does - # for any ordinary layer. + # Both graph and subclassed networks have a dtype policy. For graph + # networks, the policy's compute and variable dtypes are ignored, but other + # fields, like the loss scale, are used by Models. For subclassed networks, + # the compute and variable dtypes are used as like any ordinary layer. self._set_dtype_policy(kwargs.get('dtype', None)) # All layers in order of horizontal graph traversal. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index f7e0710557a..5f764cad9da 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -52,6 +52,8 @@ from tensorflow.python.keras.engine import training_generator from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_v2 from tensorflow.python.keras.engine import training_v2_utils +from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.utils import data_utils from tensorflow.python.keras.utils import losses_utils @@ -248,13 +250,9 @@ class Model(network.Network): self._experimental_run_tf_function = kwargs.pop( 'experimental_run_tf_function', True) - if isinstance(optimizer, (list, tuple)): - self.optimizer = [optimizers.get(opt) for opt in optimizer] - is_any_optimizer_v1 = any( - isinstance(opt, optimizers.Optimizer) for opt in self.optimizer) - else: - self.optimizer = optimizers.get(optimizer) - is_any_optimizer_v1 = isinstance(self.optimizer, optimizers.Optimizer) + self._set_optimizer(optimizer) + is_any_optimizer_v1 = any(isinstance(opt, optimizers.Optimizer) + for opt in nest.flatten(self.optimizer)) if ((sample_weight_mode is not None) or (target_tensors is not None) @@ -1443,6 +1441,47 @@ class Model(network.Network): 'and the first argument in `call` as positional arguments, ' 'found: ' + str(extra_args) + '.') + def _set_optimizer(self, optimizer): + """Sets self.optimizer. + + Sets self.optimizer to `optimizer`, potentially wrapping it with a + LossScaleOptimizer. + + Args: + optimizer: The optimizer(s) to assign to self.optimizer. + """ + if isinstance(optimizer, (list, tuple)): + self.optimizer = [optimizers.get(opt) for opt in optimizer] + else: + self.optimizer = optimizers.get(optimizer) + + if (self._dtype_policy.loss_scale is not None and + not isinstance(self.optimizer, + loss_scale_optimizer.LossScaleOptimizer)): + if isinstance(self.optimizer, list): + raise ValueError('When a dtype policy with a loss scale is used, you ' + 'can only pass a single optimizer. Using policy %s ' + 'and got optimizers: %s' % + self._dtype_policy, self.optimizer) + if not isinstance(self.optimizer, optimizer_v2.OptimizerV2): + raise ValueError('"optimizer" must be an instance of ' + 'tf.keras.optimizers.Optimizer when a dype policy ' + 'with a loss scale used, but got: %s. Using policy: ' + '%s' % + (self.optimizer, self._dtype_policy)) + self.optimizer = loss_scale_optimizer.LossScaleOptimizer( + self.optimizer, self._dtype_policy.loss_scale) + if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and + self._dtype_policy.loss_scale and + self.optimizer.loss_scale != self._dtype_policy.loss_scale): + logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) ' + 'is not the same as the dtype policy\'s loss scale (%s). ' + 'Because the dtype policy has a loss scale, you should ' + 'pass an optimizer that is not wrapped with a ' + 'LossScaleOptimizer,' + % (self.optimizer.loss_scale, + self._dtype_policy.loss_scale)) + def _prepare_validation_data(self, validation_data, batch_size, validation_steps): """Unpack and check the validation data.""" diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index cc8fa18784a..784d7b304dd 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -35,6 +35,7 @@ from tensorflow.python.keras import backend from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import models +from tensorflow.python.keras import optimizers from tensorflow.python.keras import regularizers from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer @@ -434,7 +435,7 @@ class KerasModelTest(keras_parameterized.TestCase): }, { 'testcase_name': 'infer', 'strategy_fn': create_mirrored_strategy, - 'policy_name': 'infer_with_float32_vars' + 'policy_name': 'mixed_float16' }, { 'testcase_name': 'norun_distributed', 'strategy_fn': create_mirrored_strategy, @@ -445,13 +446,15 @@ class KerasModelTest(keras_parameterized.TestCase): strategy_fn, use_operator=False, use_regularizer=False, - policy_name='float16_with_float32_vars', + policy_name='mixed_float16', experimental_run_tf_function=True): if not self._is_strategy_supported(strategy_fn, check_model_type=True): return regularizer = IdentityRegularizer() if use_regularizer else None with strategy_fn().scope(): - with policy.policy_scope(policy_name): + # Pass loss_scale=None, as this test will fail if the DynamicLossScale + # skips applying gradients for a step + with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)): layer_list = [] if testing_utils.get_model_type() == 'subclass': # Subclassed models do not have an Input layer, so the model does not @@ -579,10 +582,13 @@ class KerasModelTest(keras_parameterized.TestCase): strategy = strategy_fn() if use_loss_scaling: loss_scale = 8. + else: + loss_scale = None learning_rate = 2**-14 with strategy.scope(): - with policy.policy_scope(policy.Policy('float16_with_float32_vars')): + with policy.policy_scope(policy.Policy('mixed_float16', + loss_scale=loss_scale)): x = layers.Input(shape=(1,), batch_size=2) layer1 = AddLayer( assert_type=dtypes.float16, @@ -618,8 +624,6 @@ class KerasModelTest(keras_parameterized.TestCase): return math_ops.reduce_mean(y_pred) opt = gradient_descent.SGD(learning_rate) - if use_loss_scaling: - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) model.compile( opt, loss=loss_fn, @@ -646,6 +650,10 @@ class KerasModelTest(keras_parameterized.TestCase): }, { 'testcase_name': 'distribute', 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'pass_loss_scale_to_policy', + 'strategy_fn': create_mirrored_strategy, + 'pass_loss_scale_to_policy': True, }, { 'testcase_name': 'norun_distributed', 'strategy_fn': create_mirrored_strategy, @@ -653,18 +661,27 @@ class KerasModelTest(keras_parameterized.TestCase): }) def test_dynamic_loss_scaling(self, strategy_fn, + pass_loss_scale_to_policy=False, experimental_run_tf_function=True): if not self._is_strategy_supported(strategy_fn): return strategy = strategy_fn() initial_loss_scale = 2. batch_size = 4 + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=2) expected_gradient = backend.variable([initial_loss_scale / batch_size], dtype=dtypes.float16) # If this variable is set to True, the model below will have NaN gradients have_nan_gradients = backend.variable(False, dtype=dtypes.bool) with strategy.scope(): - with policy.policy_scope(policy.Policy('infer_float32_vars')): + opt = gradient_descent.SGD(1.) + if pass_loss_scale_to_policy: + p = policy.Policy('infer_float32_vars', loss_scale=loss_scale) + else: + p = policy.Policy('infer_float32_vars') + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + with policy.policy_scope(p): x = layers.Input( shape=(1,), batch_size=batch_size, dtype=dtypes.float16) layer = AddLayer(assert_type=dtypes.float16) @@ -685,10 +702,6 @@ class KerasModelTest(keras_parameterized.TestCase): del y_true return math_ops.reduce_mean(y_pred) - opt = gradient_descent.SGD(1.) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) model.compile( opt, loss=loss_fn, @@ -728,6 +741,40 @@ class KerasModelTest(keras_parameterized.TestCase): model.fit(dataset) self.assertEqual(backend.eval(layer.v), -3) + @test_util.run_in_graph_and_eager_modes + @testing_utils.enable_v2_dtype_behavior + def test_loss_scale_optimizer_overrides_policy_loss_scale(self): + with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + opt = gradient_descent.SGD(1.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.) + x = layers.Input(shape=(1,)) + y = AddLayer()(x) + model = models.Model(x, y) + model.compile(opt, loss='mse') + self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.) + + @test_util.run_in_graph_and_eager_modes + @testing_utils.enable_v2_dtype_behavior + def test_pass_invalid_optimizer_with_loss_scaling(self): + with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + x = layers.Input(shape=(1,)) + y = AddLayer()(x) + model = models.Model(x, y) + with self.assertRaisesRegexp(ValueError, + 'optimizer" must be an instance of '): + model.compile(optimizers.SGD(1.), 'mse') + + @test_util.run_in_graph_and_eager_modes + @testing_utils.enable_v2_dtype_behavior + def test_functional_model_loss_dtype(self): + with policy.policy_scope('float16'): + x = layers.Input(shape=(1,)) + y = AddLayer()(x) + model = models.Model(x, y) + model.add_loss(math_ops.cast(y, 'float32')) + # The loss should not be casted to the policy's dtype. + self.assertEqual(model.losses[0].dtype, 'float32') + @parameterized.named_parameters( { 'testcase_name': 'base', diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index a4f5f9fda4c..5e8ebf4fa47 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -24,19 +24,26 @@ import six from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.platform import tf_logging +from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util.tf_export import keras_export +# Default value of certain arguments, indicating the default behavior for +# that argument should be used. +USE_DEFAULT = 'USE_DEFAULT' + + @keras_export('keras.mixed_precision.experimental.Policy') class Policy(object): """A dtype policy for a Keras layer. - A dtype policy determines the computation dtype and the variable dtype of a - Keras layer. Each layer has a policy. Policies can be passed to the 'dtype' - argument of layer constructors, or a global policy can be set with - 'tf.keras.mixed_precision.experimental.set_policy'. A layer will default to - the global policy if no policy is passed to it's constructor. + A dtype policy determines dtype-related aspects of a layer, such as its + computation and variable dtypes. Each layer has a policy. Policies can be + passed to the 'dtype' argument of layer constructors, or a global policy can + be set with 'tf.keras.mixed_precision.experimental.set_policy'. A layer will + default to the global policy if no policy is passed to it's constructor. For most models, each layer will have the same computation dtype and variable dtype, which will typically be float32. However, when mixed precision @@ -47,55 +54,49 @@ class Policy(object): not match the computation dtype, variables will be automatically casted to the computation dtype to avoid type errors. - In the near future, policies will also determine the loss scaling algorithm - for Keras models. + Policies also have a `tf.train.experimental.LossScale` instance, which is used + by `tf.keras.Model`s to performance loss scaling. Loss scaling is only done by + Models in `Model.fit` and `Model.train_on_batch`. Layers which are not Models + ignore the loss scale. Policies are constructed by passing a string to the constructor, e.g. `tf.keras.mixed_precision.experimental.Policy('float32')`. The string - determines the compute and variable dtypes. Currently, it can be one of - in one of the following forms: + determines the compute and variable dtypes. It can be one of the following: * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. - * '_with_float32_vars', where is any dtype. The compute dtype - will be , while the variable dtype is float32. This is intended for - the use of mixed precision, which uses float16 or bfloat16 for most - computations, and float32 for variables. This policy is only useful if - is float16 or bfloat16, although is allowed to be any - dtype. Note we will have a "mixed" policy in the future, which will make - it even easier to use mixed precision by enabling other features such as - loss scaling. + compute dtypes will be that dtype. No loss scaling is done by default. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. These policies are used for + mixed precision training. With 'mixed_float16', a dynamic loss scale is + used by default. 'mixed_bfloat16' does no loss scaling by default, as loss + scaling is unnecessary with bfloat16. ### How to use mixed precision in layers with Policies - To use mixed precision in a model, the 'float16_with_float32_vars' policy can + To use mixed precision in a model, the 'mixed_float16' policy can be used. `tf.keras.mixed_precision.experimental.set_policy` can be used to set - the default policy for layers if no policy is passed to them. Note loss - scaling must also be done, e.g. with a - `tf.keras.mixed_precision.experimental.LossScaleOptimizer`. For example + the default policy for layers if no policy is passed to them. For example: ```python - tf.keras.mixed_precision.experimental.set_policy( - 'float16_with_float32_vars') + tf.keras.mixed_precision.experimental.set_policy('mixed_float16') model = tf.keras.models.Sequential( tf.keras.layers.Input((100,)), - # Dense layers use global policy of 'float16_with_float32_vars' + # Dense layers use global policy of 'mixed_float16', which does + # computations in float16 while keeping variables in float32. tf.keras.layers.Dense(10), tf.keras.layers.Dense(10), # Softmax should be done in float32 for numeric stability. We pass # dtype='float32' to use float32 instead of the global policy. tf.keras.layers.Activation('Softmax', dtype='float32') ) - opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...) - ... # Train `model` with `opt`. + model.fit(...) # Train `model` ``` Alternatively, the policy can be passed to individual layers instead of setting the global policy with `set_policy`: ```python - policy = tf.keras.mixed_precision.experimental.Policy( - 'float16_with_float32_vars') + policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') model = tf.keras.models.Sequential( tf.keras.layers.Input((100,)), tf.keras.layers.Dense(10, dtype=policy), @@ -103,14 +104,21 @@ class Policy(object): # Softmax should be done in float32 for numeric stability. tf.keras.layers.Activation('Softmax', dtype='float32') ) - opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...) - ... # Train `model` with `opt`. + model.fit(...) # Train `model` ``` As the above example shows, strings can be directly passed to layer constructors in the `dtype` argument instead of policies, but only if the string is convertible to a dtype. + Note the 'mixed_float16' policy will apply loss scaling by default in + `Model.fit` and `Model.train_on_batch`. If neither method is used (e.g., a + custom training loop is used) and 'mixed_float16' is used, the loss scale must + be manually applied. See + `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For + 'mixed_bfloat16', no loss scaling is done and loss scaling never needs to be + manually applied. + ### The deprecated "infer" policy In addition to a dtype or "_with_float32_vars", a policy can also be @@ -122,7 +130,13 @@ class Policy(object): the dtype of the first input. Similarly to "infer", there is a deprecated "infer_with_float32_vars" policy - that infers the compute dtype, but not the variable dtype. + that infers the compute dtype, but not the variable dtype. Once a layer with + an "infer_with_float32_vars" policy is called for the first time, the layer's + policy will change to "_with_float32_vars", where is the dtype + of the first input. These policies force variables in float32. + + Warning: Policies ending in "_with_float32_vars" will be removed in TensorFlow + 2.1. Please use "mixed_float16" or "mixed_bfloat16" instead. In TensorFlow 1, only the "infer" and "infer_with_float32_vars" policies are available. @@ -130,7 +144,13 @@ class Policy(object): # TODO(reedwm): Replace link in above docstring with a version that is more # TensorFlow-specific, and that also mentions bfloat16. - def __init__(self, name): + # If True, warn when a policy is created whose name ends in + # "_with_float32_vars". We always want to warn when a user creates such a + # policy, but when the TensorFlow creates a policy, it suppresses the warning + # by setting this to False when creating the policy. + _warn_about_float32_vars = True + + def __init__(self, name, loss_scale=USE_DEFAULT): """Constructs the policy. The `name` argument determines the compute and variable dtype, and has no @@ -141,16 +161,17 @@ class Policy(object): name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. - * _with_float32_vars, where is any dtype. The compute - dtype will be , while the variable dtype is float32. This is - intended for the use of mixed precision, which uses float16 or - bfloat16 for most computations, and float32 for variables. This policy - is only useful if is float16 or bfloat16, although is - allowed to be any dtype. Note we will have a "mixed" policy in the - future, which will make it even easier to use mixed precision by - enabling other features such as loss scaling. - * 'infer' or 'infer_with_float32_vars' (deprecated): Infer the - computation dtype from the input dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + a dynamic loss scale is used. These policies are used for mixed + precision training. + * 'infer' (deprecated): Infer the compute and variable dtype from the + input dtype. + loss_scale: A `tf.train.experimental.LossScale`, or a value convertible to + one such as "dynamic". Defaults to using no loss scaling unless `name` + is "mixed_float16", in which case this defaults to "dynamic". Only + `tf.keras.Model`s, not layers, use the loss scale, and it is only used + during `Model.fit` or `Model.train_on_batch`. """ if isinstance(name, dtypes.DType): @@ -168,6 +189,27 @@ class Policy(object): self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) + if name.endswith('_with_float32_vars') and self._warn_about_float32_vars: + warning = ("WARNING: The '%s' policy is deprecated and will be removed " + "in TensorFlow 2.1." % name) + if name == 'infer_with_float32_vars': + warning += (" Please use the 'mixed_float16' or 'mixed_bfloat16' " + "policy instead.") + elif name == 'float16_with_float32_vars': + warning += " Please use the 'mixed_float16' policy instead." + elif name == 'bfloat16_with_float32_vars': + warning += " Please use the 'mixed_bfloat16' policy instead." + tf_logging.warn(warning) + + if loss_scale == USE_DEFAULT: + loss_scale = 'dynamic' if name == 'mixed_float16' else None + if loss_scale and self._compute_dtype not in (None, 'float16'): + tf_logging.warn('Creating a Policy with a loss scale is only useful for ' + 'float16 policies. You passed loss_scale=%r for policy ' + '%s. Consider not passing any loss_scale instead.' % + (loss_scale, name)) + self._loss_scale = loss_scale_module.get(loss_scale) + def _parse_name(self, name): """Parses a Policy name into a compute and variable dtype. @@ -177,6 +219,11 @@ class Policy(object): Returns: The (compute_dtype, variable_dtype) pair. """ + if name == 'mixed_float16': + return 'float16', 'float32' + elif name == 'mixed_bfloat16': + return 'bfloat16', 'float32' + if name.endswith('_with_float32_vars'): base_name = name[:-len('_with_float32_vars')] float32_vars = True @@ -190,10 +237,10 @@ class Policy(object): try: base_dtype = dtypes.as_dtype(base_name).name except TypeError: - error = ('Cannot convert value %s to a mixed precision Policy. ' - 'Valid policies include include those in the form "" ' - 'and "_with_float32_vars", where is the name ' - 'of a dtype.' % (name,)) + error = ("Cannot convert value %s to a mixed precision Policy. " + "Valid policies include include 'mixed_float16', " + "'mixed_bfloat16', and the name of any dtype such as " + "'float32'." % (name,)) if float32_vars: error += (' The value %s ends with _with_float32_vars, but %s cannot ' 'be converted to a DType' % (name, base_name)) @@ -264,13 +311,22 @@ class Policy(object): """ return self.variable_dtype != self.compute_dtype + @property + def loss_scale(self): + """Returns the loss scale of this Policy. + + Returns: + A `tf.train.experimental.LossScale`, or None. + """ + return self._loss_scale + @property def name(self): """Returns the name of this policy.""" return self._name def __repr__(self): - return '' % self._name + return '' % (self._name, self.loss_scale) def with_input_dtype(policy, dtype): @@ -297,7 +353,11 @@ def with_input_dtype(policy, dtype): # Policies without a compute dtype are either "infer" or # "infer_with_float32_vars", so the variable_dtype must be float32 here. assert policy.variable_dtype == 'float32' - return Policy(dtype + '_with_float32_vars') + try: + Policy._warn_about_float32_vars = False # pylint: disable=protected-access + return Policy(dtype + '_with_float32_vars') + finally: + Policy._warn_about_float32_vars = True # pylint: disable=protected-access # The current global policy in effect. If None, it means the current value of diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index 15a237d3b3a..09b25960f53 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -26,6 +26,8 @@ from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.mixed_precision.experimental import policy as mp_policy from tensorflow.python.keras.optimizer_v2 import gradient_descent from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging +from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import mixed_precision @@ -45,10 +47,20 @@ class PolicyTest(test.TestCase): for dtype in 'int32', 'bool', 'float16', 'float32': policy = mp_policy.Policy(dtype) + self.assertEqual(policy.name, dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, dtype) policy = mp_policy.Policy(dtype + '_with_float32_vars') + expected_name = ( + dtype if dtype == 'float32' else dtype + '_with_float32_vars') + self.assertEqual(policy.name, expected_name) + self.assertEqual(policy.compute_dtype, dtype) + self.assertEqual(policy.variable_dtype, 'float32') + + for dtype in 'float16', 'bfloat16': + policy = mp_policy.Policy('mixed_' + dtype) + self.assertEqual(policy.name, 'mixed_' + dtype) self.assertEqual(policy.compute_dtype, dtype) self.assertEqual(policy.variable_dtype, 'float32') @@ -57,9 +69,11 @@ class PolicyTest(test.TestCase): for policy in ('infer', 'infer_with_float32_vars', 'float32', 'float16_with_float32_vars'): self.assertEqual(repr(mp_policy.Policy(policy)), - '' % policy) + '' % policy) self.assertEqual(repr(mp_policy.Policy('float32_with_float32_vars')), - '') + '') + self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)), + '') @testing_utils.enable_v2_dtype_behavior def test_policy_errors(self): @@ -99,6 +113,30 @@ class PolicyTest(test.TestCase): self.assertEqual(policy.compute_dtype, 'float32') self.assertEqual(policy.variable_dtype, 'float32') + @testing_utils.enable_v2_dtype_behavior + def test_loss_scale(self): + policy = mp_policy.Policy('float32') + self.assertEqual(policy.loss_scale, None) + + policy = mp_policy.Policy('float32', loss_scale=None) + self.assertEqual(policy.loss_scale, None) + + ls = loss_scale_module.DynamicLossScale() + policy = mp_policy.Policy('float32', loss_scale=ls) + self.assertIs(policy.loss_scale, ls) + + policy = mp_policy.Policy('float32', loss_scale='dynamic') + self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale) + + policy = mp_policy.Policy('mixed_float16') + self.assertIsInstance(policy.loss_scale, loss_scale_module.DynamicLossScale) + + policy = mp_policy.Policy('mixed_float16', loss_scale=None) + self.assertEqual(policy.loss_scale, None) + + policy = mp_policy.Policy('mixed_bfloat16') + self.assertEqual(policy.loss_scale, None) + @testing_utils.enable_v2_dtype_behavior def test_global_policy(self): if base_layer_utils.v2_dtype_behavior_enabled(): @@ -121,6 +159,59 @@ class PolicyTest(test.TestCase): finally: mp_policy.set_policy(None) + @testing_utils.enable_v2_dtype_behavior + def test_loss_scale_warning(self): + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('float32', loss_scale=2.) + self.assertEqual( + mock_warn.call_args[0][0], + 'Creating a Policy with a loss scale is only useful for float16 ' + 'policies. You passed loss_scale=2.0 for policy float32. Consider ' + 'not passing any loss_scale instead.') + + for policy_name in 'float16', 'mixed_float16': + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy(policy_name, loss_scale=2.) + mock_warn.assert_not_called() + + @testing_utils.enable_v2_dtype_behavior + def test_float32_vars_warning(self): + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('infer_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'infer_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_float16' " + "or 'mixed_bfloat16' policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('float16_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'float16_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_float16' " + "policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('bfloat16_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'bfloat16_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1. Please use the 'mixed_bfloat16' " + "policy instead.") + + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy('float64_with_float32_vars') + self.assertEqual( + mock_warn.call_args[0][0], + "WARNING: The 'float64_with_float32_vars' policy is deprecated and " + "will be removed in TensorFlow 2.1.") + + for policy_name in 'float16', 'float32', 'mixed_float16', 'mixed_bfloat16': + with test.mock.patch.object(tf_logging, 'warn') as mock_warn: + mp_policy.Policy(policy_name) + mock_warn.assert_not_called() + @testing_utils.enable_v2_dtype_behavior def test_policy_scope(self): if base_layer_utils.v2_dtype_behavior_enabled(): diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 46da10183df..46f52f0a955 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -227,6 +227,9 @@ class FixedLossScale(LossScale): del grads return control_flow_ops.no_op(), True + def __repr__(self): + return 'FixedLossScale(%s)' % self._loss_scale_value + def get_config(self): return {'loss_scale_value': self._loss_scale_value} @@ -376,6 +379,17 @@ class DynamicLossScale(LossScale): should_apply_gradients = is_finite return update_op, should_apply_gradients + def __repr__(self): + if context.executing_eagerly(): + return ('DynamicLossScale(current_loss_scale=%s, num_good_steps=%s, ' + 'initial_loss_scale=%s, increment_period=%s, multiplier=%s)' % + (self._current_loss_scale.numpy(), self._num_good_steps.numpy(), + self.initial_loss_scale, self.increment_period, self.multiplier)) + else: + return ('DynamicLossScale(initial_loss_scale=%s, increment_period=%s, ' + 'multiplier=%s)' % + (self.initial_loss_scale, self.increment_period, self.multiplier)) + def get_config(self): return { 'initial_loss_scale': self.initial_loss_scale, diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index c3e18a18422..e4a11144041 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -92,6 +92,11 @@ class FixedLossScaleTest(test.TestCase): scalar = loss_scale_module.FixedLossScale(123) self.assertIsInstance(scalar(), ops.Tensor) + @test_util.run_in_graph_and_eager_modes + def test_repr(self): + loss_scale = loss_scale_module.FixedLossScale(123) + self.assertEqual(repr(loss_scale), 'FixedLossScale(123.0)') + def _get_example_iter(inputs): dataset = dataset_ops.Dataset.from_tensor_slices(inputs) @@ -302,5 +307,22 @@ class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): scalar = loss_scale_module.DynamicLossScale() self.assertIsInstance(scalar(), ops.Tensor) + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_repr(self, strategy_fn): + with strategy_fn().scope(): + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1, increment_period=2, multiplier=3) + if context.executing_eagerly(): + self.assertEqual(repr(loss_scale), + 'DynamicLossScale(current_loss_scale=1.0, ' + 'num_good_steps=0, initial_loss_scale=1.0, ' + 'increment_period=2, multiplier=3.0)') + else: + self.assertEqual(repr(loss_scale), + 'DynamicLossScale(initial_loss_scale=1.0, ' + 'increment_period=2, multiplier=3.0)') + + if __name__ == '__main__': test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt index 600f11bc8d9..7fc2f9c8d3f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "compute_dtype" mtype: "" } + member { + name: "loss_scale" + mtype: "" + } member { name: "name" mtype: "" @@ -20,6 +24,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt index 600f11bc8d9..7fc2f9c8d3f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -6,6 +6,10 @@ tf_class { name: "compute_dtype" mtype: "" } + member { + name: "loss_scale" + mtype: "" + } member { name: "name" mtype: "" @@ -20,6 +24,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'name\', \'loss_scale\'], varargs=None, keywords=None, defaults=[\'USE_DEFAULT\'], " } }