From 51fbc48cef3df2b716acae321a52bd678d3fce62 Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 19 Oct 2020 14:46:43 -0700 Subject: [PATCH] Deprecate LossScale and modify Keras APIs to not use it. LossScale and its subclasses are deprecated and will be removed from the TF 2 namespace in TensorFlow 2.5. It will still be accessible under the tf.compat.v1 namespace, and this change makes it non-experimental under the tf.compat.v1 namespace, exporting it as `tf.compat.v1.mixed_precision.LossScale`. LossScale cannot be removed from the tf.compat.v1 namespace since its used by the V1-only class tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer. LossScaleOptimizer previously used a LossScale, but now it directly performs loss scaling within the class itself. Additionally a new non-experimental `tf.keras.mixed_precision.LossScaleOptimizer` has been introduced. Unlike the experimental LossScaleOptimizer, the non-experimental LossScaleOptimizer does not accept a LossScale but instead has different constructor arguments to specify the type of loss scaling to be done. The old experimental LossScaleOptimizer will be removed in TensorFlow 2.5, at which point a LossScale cannot be used with any Keras LossScaleOptimizer. Internally, LossScaleOptimizer uses a fork of DynamicLossScale called _DynamicLossScaleState, but this is not exposed to the user. In the future, _DynamicLossScaleState will be merged into LossScaleOptimizer. LossScaleOptimizer now exposes some attributes that DynamicLossScale previously did. "increment_period" is renamed to "dynamic_growth_steps" for consistency with `ExponentialDecay.decay_steps`. `num_good_steps` is replaced by `dynamic_counter`. LossScaleOptimizer.loss_scale is now a tensor, not a LossScale. This means the previous way of getting the loss scale as a tensor (calling `optimizer.loss_scale()`) will raise an error instead. I don't know of any users who do this, so I do not anticipate any breakages. Policy previously had an instance of a LossScale, and optionally took a LossScale in the constructor. By default, the "mixed_float16" policy had a DynamicLossScale, while all other policies had no loss scale. Now, Policy no longer has a loss scale or takes an instance of a loss scale. To temporarily preserve backwards compatibility with the old API, the symbol `tf.keras.mixed_precision.experimental.Policy` still takes and holds a LossScale, as it did before. A new non-experimental symbol, `tf.keras.mixed_precision.Policy`, removes the use of the LossScale. The old experimental symbol will be removed in the future. When deserializing a layer or model with an old experimental policy, it will be restored as the new policy and the loss scale will be silently dropped. This is to preserve SavedModel compatibility with models saved in TensorFlow 2.3 and restored in future versions of TensorFlow once the old experimental Policy is removed. Luckily, dropping the loss scale is unlikely to break anyone, as a bug in the mixed precision API causes models to not save their dtype policies at all when being serialized. Similarly, when deserializing a model with the old experimental LossScaleOptimizer, it will be restored as the new LossScaleOptimizer but unlike the policy case, nothing is silently dropped. This change is different than what is described in the mixed precision RFC (https://github.com/tensorflow/community/pull/293) but I think this API is a lot clearer and simpler than the API in the RFC. The RFC forked the LossScale classes into Keras, but I now think its better to simply not use them and make LossScale exposed under tf.compat.v1 only. This new API was designed based on feedback from @fchollet and @omalleyt12. I will retroactively update the RFC to reflect this API. PiperOrigin-RevId: 337938270 Change-Id: Id7bb3bb89eb2143e5fadabeb2f57d1f8267379b3 --- tensorflow/python/keras/engine/base_layer.py | 17 +- .../python/keras/engine/base_layer_utils.py | 6 +- .../python/keras/engine/base_layer_v1.py | 17 +- tensorflow/python/keras/engine/training.py | 18 +- tensorflow/python/keras/engine/training_v1.py | 27 +- .../experimental/keras_test.py | 227 ++++--- .../experimental/loss_scale_optimizer.py | 600 ++++++++++++++++-- .../experimental/loss_scale_optimizer_test.py | 489 +++++++++++--- .../mixed_precision_graph_rewrite_test.py | 6 +- .../mixed_precision/experimental/policy.py | 190 ++++-- .../experimental/policy_test.py | 64 +- tensorflow/python/keras/optimizers.py | 4 + tensorflow/python/keras/type/types.py | 9 +- .../training/experimental/loss_scale.py | 82 ++- .../mixed_precision_global_state.py | 6 +- ...ixed_precision.-loss-scale-optimizer.pbtxt | 112 ++++ ...n.experimental.-loss-scale-optimizer.pbtxt | 17 + .../v1/tensorflow.keras.mixed_precision.pbtxt | 4 + ....mixed_precision.-dynamic-loss-scale.pbtxt | 35 + ...ow.mixed_precision.-fixed-loss-scale.pbtxt | 23 + ...nsorflow.mixed_precision.-loss-scale.pbtxt | 22 + .../v1/tensorflow.mixed_precision.pbtxt | 12 + ...ixed_precision.-loss-scale-optimizer.pbtxt | 112 ++++ ...orflow.keras.mixed_precision.-policy.pbtxt | 33 + ...n.experimental.-loss-scale-optimizer.pbtxt | 17 + ...mixed_precision.experimental.-policy.pbtxt | 1 + .../v2/tensorflow.keras.mixed_precision.pbtxt | 8 + tensorflow/tools/compatibility/renames_v2.py | 6 + 28 files changed, 1769 insertions(+), 395 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 3a3f6363e3c..5c431ad8e2d 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -137,9 +137,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): Attributes: name: The name of the layer (string). dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. + precision is used with a `tf.keras.mixed_precision.Policy`, this is + instead just the dtype of the layer's weights, as the computations are + done in a different dtype. trainable_weights: List of variables to be included in backprop. non_trainable_weights: List of variables that should not be included in backprop. @@ -279,8 +279,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed precision is used, layers may have different computation and variable dtypes. - See `tf.keras.mixed_precision.experimental.Policy` for details on layer - dtypes. + See `tf.keras.mixed_precision.Policy` for details on layer dtypes. """ # See tf.Module for the usage of this property. @@ -388,9 +387,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._metrics_lock = threading.Lock() # Both graph and subclassed networks have a dtype policy. For graph - # networks, the policy's compute and variable dtypes are ignored, but other - # fields, like the loss scale, are used by Models. For subclassed networks, - # the compute and variable dtypes are used as like any ordinary layer. + # networks, the policy's compute and variable dtypes are ignored. Such + # networks only use the policy if it is a PolicyV1, in which case it uses + # the PolicyV1's loss_scale (Policy does not have a loss_scale). For + # subclassed networks, the compute and variable dtypes are used as like any + # ordinary layer. self._set_dtype_policy(dtype) # Boolean indicating whether the layer automatically casts its inputs to the # layer's compute_dtype. diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index d6b32907593..0c49bdcd413 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -741,9 +741,9 @@ def enable_v2_dtype_behavior(): autocasting part of the V2 behavior for that layer, but not the defaulting to floatx part of the V2 behavior. - When a global `tf.keras.mixed_precision.experimental.Policy` is set, a Keras - layer's dtype will default to the global policy instead of floatx. Layers - will automatically cast inputs to the policy's compute_dtype. + When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype + will default to the global policy instead of floatx. Layers will automatically + cast inputs to the policy's compute_dtype. """ global V2_DTYPE_BEHAVIOR V2_DTYPE_BEHAVIOR = True diff --git a/tensorflow/python/keras/engine/base_layer_v1.py b/tensorflow/python/keras/engine/base_layer_v1.py index a190218bbcf..eaa2b0185b6 100644 --- a/tensorflow/python/keras/engine/base_layer_v1.py +++ b/tensorflow/python/keras/engine/base_layer_v1.py @@ -112,9 +112,9 @@ class Layer(base_layer.Layer): Attributes: name: The name of the layer (string). dtype: The dtype of the layer's computations and weights. If mixed - precision is used with a `tf.keras.mixed_precision.experimental.Policy`, - this is instead just the dtype of the layer's weights, as the computations - are done in a different dtype. + precision is used with a `tf.keras.mixed_precision.Policy`, this is + instead just the dtype of the layer's weights, as the computations are + done in a different dtype. updates: List of update ops of this layer. losses: List of losses added by this layer. trainable_weights: List of variables to be included in backprop. @@ -133,8 +133,7 @@ class Layer(base_layer.Layer): if no dtype is passed. `floatx()` itself defaults to "float32". Additionally, layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed precision is used, layers may have different computation and variable dtypes. - See `tf.keras.mixed_precision.experimental.Policy` for details on layer - dtypes. + See `tf.keras.mixed_precision.Policy` for details on layer dtypes. """ # See tf.Module for the usage of this property. @@ -199,9 +198,11 @@ class Layer(base_layer.Layer): self._metrics = [] # Both graph and subclassed networks have a dtype policy. For graph - # networks, the policy's compute and variable dtypes are ignored, but other - # fields, like the loss scale, are used by Models. For subclassed networks, - # the compute and variable dtypes are used as like any ordinary layer. + # networks, the policy's compute and variable dtypes are ignored. Such + # networks only use the policy if it is a PolicyV1, in which case it uses + # the PolicyV1's loss_scale (Policy does not have a loss_scale). For + # subclassed networks, the compute and variable dtypes are used as like any + # ordinary layer. self._set_dtype_policy(dtype) # Boolean indicating whether the layer automatically casts its inputs to the # layer's compute_dtype. diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 02b62f235a1..a13d0b01718 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -52,6 +52,7 @@ from tensorflow.python.keras.engine import compile_utils from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.saving import hdf5_format from tensorflow.python.keras.saving import save from tensorflow.python.keras.saving.saved_model import json_utils @@ -549,12 +550,25 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector): def _get_optimizer(self, optimizer): """Wraps `optimizer` in `LossScaleOptimizer` if necessary.""" + # The deprecated PolicyV1 has a loss_scale, which we use for backwards + # compatibility to match TF 2.3 behavior. The new Policy does not have a + # loss_scale, so we use dynamic loss scaling if the mixed_float16 policy is + # used. + if isinstance(self._dtype_policy, policy.PolicyV1): + loss_scale = self._dtype_policy.loss_scale + elif self._dtype_policy.name == 'mixed_float16': + loss_scale = 'dynamic' + else: + loss_scale = None def _get_single_optimizer(opt): opt = optimizers.get(opt) - if (self._dtype_policy.loss_scale is not None and + if (loss_scale is not None and not isinstance(opt, lso.LossScaleOptimizer)): - opt = lso.LossScaleOptimizer(opt, self._dtype_policy.loss_scale) + if loss_scale == 'dynamic': + opt = lso.LossScaleOptimizer(opt) + else: + opt = lso.LossScaleOptimizerV1(opt, loss_scale) return opt return nest.map_structure(_get_single_optimizer, optimizer) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index 54969bb5e83..e4133b58372 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -53,6 +53,7 @@ from tensorflow.python.keras.engine import training_generator_v1 from tensorflow.python.keras.engine import training_utils from tensorflow.python.keras.engine import training_utils_v1 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision.experimental import policy from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.saving.saved_model import model_serialization from tensorflow.python.keras.utils import data_utils @@ -1342,7 +1343,14 @@ class Model(training_lib.Model): else: self.optimizer = optimizers.get(optimizer) - if (self._dtype_policy.loss_scale is not None and + if isinstance(self._dtype_policy, policy.PolicyV1): + loss_scale = self._dtype_policy.loss_scale + elif self._dtype_policy.name == 'mixed_float16': + loss_scale = 'dynamic' + else: + loss_scale = None + + if (loss_scale is not None and not isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer)): if isinstance(self.optimizer, list): @@ -1356,18 +1364,11 @@ class Model(training_lib.Model): 'with a loss scale used, but got: %s. Using policy: ' '%s' % (self.optimizer, self._dtype_policy)) - self.optimizer = loss_scale_optimizer.LossScaleOptimizer( - self.optimizer, self._dtype_policy.loss_scale) - if (isinstance(self.optimizer, loss_scale_optimizer.LossScaleOptimizer) and - self._dtype_policy.loss_scale and - self.optimizer.loss_scale != self._dtype_policy.loss_scale): - logging.warning('LossScale of LossScaleOptimizer passed to compile (%s) ' - 'is not the same as the dtype policy\'s loss scale (%s). ' - 'Because the dtype policy has a loss scale, you should ' - 'pass an optimizer that is not wrapped with a ' - 'LossScaleOptimizer,' - % (self.optimizer.loss_scale, - self._dtype_policy.loss_scale)) + if loss_scale == 'dynamic': + self.optimizer = loss_scale_optimizer.LossScaleOptimizer(self.optimizer) + else: + self.optimizer = loss_scale_optimizer.LossScaleOptimizerV1( + self.optimizer, loss_scale) def _prepare_validation_data(self, validation_data, batch_size, validation_steps): diff --git a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py index dd754e87bb4..0668b7952cc 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/keras_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/keras_test.py @@ -344,32 +344,10 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer.dtype, 'float32') self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float32') - - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('mixed_float16', - loss_scale=None)) config = layer.get_config() self.assertEqual(config['dtype'], {'class_name': 'Policy', - 'config': {'name': 'mixed_float16', - 'loss_scale': None}}) - layer = mp_test_util.MultiplyLayer.from_config(config) - self.assertEqual(layer.dtype, 'float32') - self.assertEqual(layer(x).dtype, 'float16') - self.assertEqual(layer.v.dtype, 'float32') - - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('float64', - loss_scale=2.)) - config = layer.get_config() - self.assertEqual(config['dtype'], - {'class_name': 'Policy', - 'config': {'name': 'float64', - 'loss_scale': { - 'class_name': 'FixedLossScale', - 'config': {'loss_scale_value': 2.0}}}}) - layer = mp_test_util.MultiplyLayer.from_config(config) - self.assertEqual(layer.dtype, 'float64') - self.assertEqual(layer(x).dtype, 'float64') - self.assertEqual(layer.v.dtype, 'float64') + 'config': {'name': 'mixed_float16'}}) layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer')) config = layer.get_config() @@ -383,11 +361,53 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer(x).dtype, 'float32') self.assertEqual(layer.v.dtype, 'float32') - layer = mp_test_util.MultiplyLayer(dtype=policy.Policy('_infer', - loss_scale=2.)) + @parameterized.named_parameters(*TESTCASES) + def test_config_policy_v1(self, strategy_fn): + x = constant_op.constant([1.], dtype=dtypes.float16) + with strategy_fn().scope(): + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('mixed_float16', + loss_scale=None)) config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', + 'config': {'name': 'mixed_float16', + 'loss_scale': None}}) + layer = mp_test_util.MultiplyLayer.from_config(config) + self.assertEqual(layer.dtype, 'float32') + self.assertEqual(layer(x).dtype, 'float16') + self.assertEqual(layer.v.dtype, 'float32') + # Restoring a PolicyV1 silently converts it to a Policy and drops the loss + # scale. + self.assertEqual(type(layer._dtype_policy), policy.Policy) + config = layer.get_config() + # The loss_scale is silently dropped self.assertEqual(config['dtype'], {'class_name': 'Policy', + 'config': {'name': 'mixed_float16'}}) + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('float64', + loss_scale=2.)) + config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', + 'config': {'name': 'float64', + 'loss_scale': { + 'class_name': 'FixedLossScale', + 'config': {'loss_scale_value': 2.0}}}}) + layer = mp_test_util.MultiplyLayer.from_config(config) + self.assertEqual(layer.dtype, 'float64') + self.assertEqual(layer(x).dtype, 'float64') + self.assertEqual(layer.v.dtype, 'float64') + self.assertEqual(type(layer._dtype_policy), policy.Policy) + config = layer.get_config() + self.assertEqual(config['dtype'], 'float64') + + layer = mp_test_util.MultiplyLayer(dtype=policy.PolicyV1('_infer', + loss_scale=2.)) + config = layer.get_config() + self.assertEqual(config['dtype'], + {'class_name': 'PolicyV1', 'config': {'name': '_infer', 'loss_scale': { 'class_name': 'FixedLossScale', @@ -396,6 +416,9 @@ class KerasLayerTest(keras_parameterized.TestCase): self.assertEqual(layer.dtype, None) self.assertEqual(layer(x).dtype, 'float16') self.assertEqual(layer.v.dtype, 'float16') + self.assertEqual(type(layer._dtype_policy), policy.Policy) + config = layer.get_config() + self.assertEqual(config['dtype'], 'float16') def test_delete_variable(self): layer = base_layer.Layer(dtype=policy.Policy('mixed_float16')) @@ -501,6 +524,11 @@ class KerasModelTest(keras_parameterized.TestCase): 'strategy_fn': create_mirrored_strategy, 'save_format': 'h5', 'use_regularizer': True, + }, { + 'testcase_name': 'saved_model_v1_policy', + 'strategy_fn': create_mirrored_strategy, + 'use_v1_policy': True, + 'save_format': 'tf', }) def test_model(self, strategy_fn, @@ -509,15 +537,15 @@ class KerasModelTest(keras_parameterized.TestCase): policy_name='mixed_float16', get_config=False, save_format=None, - use_input_spec=False): + use_input_spec=False, + use_v1_policy=False): self._skip_if_strategy_unsupported(strategy_fn) self._skip_if_save_format_unsupported(save_format) regularizer = (mp_test_util.IdentityRegularizer() if use_regularizer else None) with strategy_fn().scope(): - # Pass loss_scale=None, as this test will fail if the DynamicLossScale - # skips applying gradients for a step - with policy.policy_scope(policy.Policy(policy_name, loss_scale=None)): + cls = policy.PolicyV1 if use_v1_policy else policy.Policy + with policy.policy_scope(cls(policy_name)): layer = mp_test_util.MultiplyLayer( assert_type=dtypes.float16, use_operator=use_operator, @@ -543,6 +571,10 @@ class KerasModelTest(keras_parameterized.TestCase): # the variable will not change. So this tests the learning rate not # applied to a float16 value, but instead the float32 variable. opt = gradient_descent.SGD(2**-14) + # Use a fixed loss scale, as this test will fail if gradients are + # skipped for a step due to dynamic loss scaling. + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=8) model.compile( opt, loss=loss_fn, @@ -597,6 +629,13 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertEqual(layer.v.dtype, 'float32') self.assertEqual(layer(np.ones((2, 1))).dtype, 'float16') + # Loading a model always loads with a v2 Policy, even if saved with a + # PolicyV1. + self.assertEqual(type(model._dtype_policy), policy.Policy) + self.assertEqual(layer.get_config()['dtype'], + {'class_name': 'Policy', 'config': { + 'name': 'mixed_float16'}}) + @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( { @@ -630,7 +669,8 @@ class KerasModelTest(keras_parameterized.TestCase): return math_ops.reduce_mean(y_pred) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) model.compile( opt, loss=loss_fn, @@ -669,13 +709,10 @@ class KerasModelTest(keras_parameterized.TestCase): strategy = strategy_fn() if use_loss_scaling: loss_scale = 8. - else: - loss_scale = None learning_rate = 2**-14 with strategy.scope(): - with policy.policy_scope(policy.Policy('mixed_float16', - loss_scale=loss_scale)): + with policy.policy_scope(policy.Policy('mixed_float16')): x = layers.Input(shape=(1,), batch_size=2) layer1 = mp_test_util.MultiplyLayer( assert_type=dtypes.float16, @@ -710,6 +747,9 @@ class KerasModelTest(keras_parameterized.TestCase): return math_ops.reduce_mean(y_pred) opt = gradient_descent.SGD(learning_rate) + if use_loss_scaling: + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=loss_scale) model.compile( opt, loss=loss_fn, @@ -743,6 +783,11 @@ class KerasModelTest(keras_parameterized.TestCase): 'testcase_name': 'get_config', 'strategy_fn': create_mirrored_strategy, 'get_config': True, + }, { + 'testcase_name': 'get_config_v1_lso', + 'strategy_fn': create_mirrored_strategy, + 'get_config': True, + 'use_v1_loss_scale_optimizer': True, }, { 'testcase_name': 'get_config_and_pass_loss_scale_to_policy', 'strategy_fn': create_mirrored_strategy, @@ -752,12 +797,11 @@ class KerasModelTest(keras_parameterized.TestCase): def test_dynamic_loss_scaling(self, strategy_fn, pass_loss_scale_to_policy=False, - get_config=False): + get_config=False, + use_v1_loss_scale_optimizer=False): strategy = strategy_fn() initial_loss_scale = 2. batch_size = 4 - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=2) expected_gradient = backend.variable([initial_loss_scale / batch_size], dtype=dtypes.float16) # If this variable is set to True, the model below will have NaN gradients @@ -765,10 +809,19 @@ class KerasModelTest(keras_parameterized.TestCase): with strategy.scope(): opt = gradient_descent.SGD(1.) if pass_loss_scale_to_policy: - p = policy.Policy('mixed_float16', loss_scale=loss_scale) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=2) + p = policy.PolicyV1('mixed_float16', loss_scale=loss_scale) + elif use_v1_loss_scale_optimizer: + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=2) + p = policy.Policy('mixed_float16') + opt = loss_scale_optimizer.LossScaleOptimizerV1( + opt, loss_scale) else: - p = policy.Policy('mixed_float16', loss_scale=None) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + p = policy.Policy('mixed_float16') + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=initial_loss_scale, dynamic_growth_steps=2) with policy.policy_scope(p): x = layers.Input( shape=(1,), batch_size=batch_size, dtype=dtypes.float16) @@ -835,19 +888,32 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertEqual(backend.eval(layer.v), -3) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) - def test_loss_scale_optimizer_overrides_policy_loss_scale(self): - with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + def test_loss_scale_optimizer_overrides_policy_v1_loss_scale(self): + with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)): opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=5.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=5.) x = layers.Input(shape=(1,)) y = mp_test_util.MultiplyLayer()(x) model = models.Model(x, y) model.compile(opt, loss='mse') - self.assertEqual(self.evaluate(model.optimizer.loss_scale()), 5.) + self.assertEqual(self.evaluate(model.optimizer.loss_scale), 5.) + + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) + def test_policy_v1_without_loss_scale(self): + with policy.policy_scope(policy.PolicyV1('mixed_float16', + loss_scale=None)): + opt = gradient_descent.SGD(1.) + x = layers.Input(shape=(1,)) + y = mp_test_util.MultiplyLayer()(x) + model = models.Model(x, y) + model.compile(opt, loss='mse') + self.assertNotIsInstance(model.optimizer, + loss_scale_optimizer.LossScaleOptimizer) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_pass_invalid_optimizer_with_loss_scaling(self): - with policy.policy_scope(policy.Policy('float32', loss_scale=10.)): + with policy.policy_scope(policy.PolicyV1('float32', loss_scale=10.)): x = layers.Input(shape=(1,)) y = mp_test_util.MultiplyLayer()(x) model = models.Model(x, y) @@ -926,7 +992,7 @@ class KerasModelTest(keras_parameterized.TestCase): def test_save_slot_variables_with_autocast_vars(self, strategy_fn, var_name='v'): - p = policy.Policy('mixed_float16', loss_scale=None) + p = policy.Policy('mixed_float16') with strategy_fn().scope(), policy.policy_scope(p): x = layers.Input(shape=(2,), batch_size=2) # Having a var_name other than 'v' tests that a fixed bug (b/134713714) @@ -938,6 +1004,8 @@ class KerasModelTest(keras_parameterized.TestCase): y = layer(x) model = models.Model(inputs=x, outputs=y) opt = gradient_descent.SGD(1., 1.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=1) model.compile( optimizer=opt, loss='mse', @@ -971,18 +1039,17 @@ class KerasModelTest(keras_parameterized.TestCase): y = mp_test_util.MultiplyLayer(assert_type=dtypes.float32)(x) model = models.Model(inputs=x, outputs=y) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., multiplier=2.) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=1., dynamic_growth_steps=2.) model.compile( optimizer=opt, loss='mse', run_eagerly=testing_utils.should_run_eagerly()) # Run for 3 steps (6 examples with a batch size of 2) model.fit(np.zeros((6, 2)), np.zeros((6, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) # Save model weights. save_prefix = os.path.join(self.get_temp_dir(), 'ckpt') @@ -990,20 +1057,20 @@ class KerasModelTest(keras_parameterized.TestCase): # Run model again for 1 step (2 examples with a batch size of 2) model.fit(np.zeros((2, 2)), np.zeros((2, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 4) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) + self.assertEqual(backend.get_value(opt.loss_scale), 4) + self.assertEqual(backend.get_value(opt.dynamic_counter), 0) # Load model weights and ensure loss scale weights are restored. model.load_weights(save_prefix) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) @keras_parameterized.run_all_keras_modes def test_restore_old_loss_scale_checkpoint(self): # Ensure a checkpoint from TF 2.2 can be loaded. The checkpoint format # of LossScaleOptimizer changed, but old checkpoints can still be loaded opt = gradient_descent.SGD(0.1, momentum=0.1) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + opt = loss_scale_optimizer.LossScaleOptimizer(opt) model = sequential.Sequential([core.Dense(2,)]) # The checkpoint and expected values were obtained from the program in @@ -1024,8 +1091,8 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertAllClose( self.evaluate(opt.get_slot(model.weights[0], 'momentum')), expected_slot) - self.assertEqual(self.evaluate(opt.loss_scale()), 32768) - self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 32768) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) # Check restoring works even after the model is compiled and the weights # have been created. @@ -1039,8 +1106,8 @@ class KerasModelTest(keras_parameterized.TestCase): self.assertAllClose( self.evaluate(opt.get_slot(model.weights[0], 'momentum')), expected_slot) - self.assertEqual(self.evaluate(opt.loss_scale()), 32768) - self.assertEqual(self.evaluate(opt.loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 32768) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) def test_restore_old_saved_model(self): saved_model_dir = os.path.join( @@ -1053,8 +1120,8 @@ class KerasModelTest(keras_parameterized.TestCase): model = save.load_model(saved_model_dir) expected_kernel = np.array([[9.229685, 10.901115], [10.370763, 9.757362]]) self.assertAllClose(backend.eval(model.weights[0]), expected_kernel) - self.assertIsInstance(model.optimizer, - loss_scale_optimizer.LossScaleOptimizer) + self.assertEqual(type(model.optimizer), + loss_scale_optimizer.LossScaleOptimizer) @keras_parameterized.run_all_keras_modes @parameterized.named_parameters( @@ -1064,6 +1131,10 @@ class KerasModelTest(keras_parameterized.TestCase): }, { 'testcase_name': 'distribute', 'strategy_fn': create_mirrored_strategy, + }, { + 'testcase_name': 'use_v1_lso', + 'strategy_fn': create_mirrored_strategy, + 'use_v1_loss_scale_optimizer': True }, { 'testcase_name': 'base_h5', 'strategy_fn': default_strategy_fn, @@ -1073,7 +1144,8 @@ class KerasModelTest(keras_parameterized.TestCase): 'strategy_fn': create_mirrored_strategy, 'h5': True, }) - def test_save_model_with_dynamic_loss_scaling(self, strategy_fn, h5=False): + def test_save_model_with_dynamic_loss_scaling( + self, strategy_fn, h5=False, use_v1_loss_scale_optimizer=False): # TODO(reedwm): Support and test saving model with a mixed_[b]float16 policy # as well. strategy = strategy_fn() @@ -1088,18 +1160,22 @@ class KerasModelTest(keras_parameterized.TestCase): y = mp_test_util.MultiplyLayer()(x) model = models.Model(inputs=x, outputs=y) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., multiplier=2.) opt = gradient_descent.SGD(1.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + if use_v1_loss_scale_optimizer: + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1., increment_period=2.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + else: + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) model.compile( optimizer=opt, loss='mse', run_eagerly=testing_utils.should_run_eagerly()) # Run for 3 steps (6 examples with a batch size of 2) model.fit(np.ones((6, 2)), np.zeros((6, 2)), batch_size=2) - self.assertEqual(backend.get_value(loss_scale()), 2) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 1) + self.assertEqual(backend.get_value(opt.loss_scale), 2) + self.assertEqual(backend.get_value(opt.dynamic_counter), 1) (weight,) = model.trainable_weights orig_weight = backend.get_value(weight) @@ -1111,13 +1187,12 @@ class KerasModelTest(keras_parameterized.TestCase): model.fit(np.ones((2, 2)), np.zeros((2, 2)), batch_size=2) new_weight = backend.get_value(weight) self.assertNotEqual(new_weight, orig_weight) - self.assertEqual(backend.get_value(loss_scale()), 4) - self.assertEqual(backend.get_value(loss_scale._num_good_steps), 0) + self.assertEqual(backend.get_value(opt.loss_scale), 4) + self.assertEqual(backend.get_value(opt.dynamic_counter), 0) # Load model weights and ensure loss scale weights are restored. model = save.load_model( save_path, custom_objects={'MultiplyLayer': mp_test_util.MultiplyLayer}) - loss_scale = model.optimizer.loss_scale (weight,) = model.trainable_weights loaded_weight = backend.get_value(weight) self.assertEqual(loaded_weight, orig_weight) @@ -1125,8 +1200,14 @@ class KerasModelTest(keras_parameterized.TestCase): # Model.save(). So we assert the loss scale either has the value when it was # saved, or the value it was initialized with. # TODO(reedwm): Always save/restore the loss scale with Model.save(). - self.assertIn(backend.get_value(loss_scale()), (1, 2)) - self.assertIn(backend.get_value(loss_scale._num_good_steps), (0, 1)) + self.assertIn(backend.get_value(model.optimizer.loss_scale), (1, 2)) + self.assertIn(backend.get_value(model.optimizer.dynamic_counter), (0, 1)) + + # Test optimizer attributes and type + self.assertEqual(model.optimizer.initial_scale, 1.) + self.assertEqual(model.optimizer.dynamic_growth_steps, 2.) + self.assertEqual(type(model.optimizer), + loss_scale_optimizer.LossScaleOptimizer) if __name__ == '__main__': diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index dd7bf6a682d..e2f384fd756 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -21,8 +21,10 @@ from tensorflow.python.distribute import collective_all_reduce_strategy from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import one_device_strategy +from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import tpu_strategy from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond @@ -32,8 +34,13 @@ from tensorflow.python.keras.mixed_precision.experimental import loss_scale as k from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging +from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import keras_export @@ -173,7 +180,220 @@ class _DelegatingTrackableMixin(object): # pylint: enable=protected-access -@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') +def _is_all_finite(grads): + """Returns a scalar boolean tensor indicating if all gradients are finite.""" + is_finite_per_grad = [ + math_ops.reduce_all(math_ops.is_finite(g)) for g in grads if g is not None + ] + return math_ops.reduce_all(is_finite_per_grad) + + +def _op_in_graph_mode(tensor): + """Returns the tensor's op in graph mode, or the tensor in eager mode. + + This is useful because sometimes an op is needed in graph mode instead of a + tensor. In eager mode, there are no ops. + + Args: + tensor: A tensor. + + Returns: + The tensor's op in graph mode. The tensor in eager mode. + """ + if context.executing_eagerly(): + return tensor + return tensor.op + + +def _assign_if_finite(var, value): + """Assigns a value to a variable if the value is finite.""" + return control_flow_ops.cond( + math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), + control_flow_ops.no_op) + + +class _DynamicLossScaleState(trackable.Trackable): + """The state of a dynamic loss scale.""" + + def __init__(self, + initial_loss_scale, + growth_steps, + multiplier): + """Creates the dynamic loss scale.""" + super(_DynamicLossScaleState, self).__init__() + self._initial_loss_scale = float(initial_loss_scale) + self._growth_steps = int(growth_steps) + self._multiplier = float(multiplier) + + self._weights = {} + self._current_loss_scale = self._add_weight( + name='current_loss_scale', + dtype=dtypes.float32, + initial_value=self._initial_loss_scale) + # The number of consecutive steps with finite gradients since the last + # nonfinite gradient or change in loss scale. The name is 'good_steps' for + # backwards compatibility with older checkpoints. + self._counter = self._add_weight( + name='good_steps', dtype=dtypes.int64, initial_value=0) + + def _add_weight(self, name, initial_value, dtype=None): + """Adds a weight to this loss scale. + + Args: + name: Variable name. + initial_value: The variable's initial value. + dtype: The type of the variable. + + Returns: + A variable. + + Raises: + RuntimeError: If a weight with `name` has already been added. + """ + variable = variable_scope.variable( + initial_value=initial_value, + name=name, + dtype=dtype, + trainable=False, + use_resource=True, + synchronization=variables.VariableSynchronization.AUTO, + # Set aggregation to NONE, as loss scaling variables should never be + # aggregated. + aggregation=variables.VariableAggregation.NONE) + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + + key = (name, graph_key) + self._weights[key] = variable + self._handle_deferred_dependencies(name=name, trackable=variable) + backend.track_variable(variable) + return variable + + @property + def _checkpoint_dependencies(self): + """From Trackable. Gather graph-specific weights to save.""" + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + weights = [] + for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): + if g == graph_key: + weights.append(trackable.TrackableReference(name=name, ref=v)) + return (super(_DynamicLossScaleState, self)._checkpoint_dependencies + + weights) + + def _lookup_dependency(self, name): + """From Trackable. Find a weight in the current graph.""" + unconditional = super(_DynamicLossScaleState, self)._lookup_dependency(name) + if unconditional is not None: + return unconditional + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + return self._weights.get((name, graph_key), None) + + @property + def initial_loss_scale(self): + return self._initial_loss_scale + + @property + def growth_steps(self): + return self._growth_steps + + @property + def multiplier(self): + return self._multiplier + + @property + def current_loss_scale(self): + """Returns the current loss scale as a float32 `tf.Variable`.""" + return self._current_loss_scale + + @property + def counter(self): + """Returns the counter as a float32 `tf.Variable`.""" + return self._counter + + def __call__(self): + """Returns the current loss scale as a scalar `float32` tensor.""" + return ops.convert_to_tensor(self._current_loss_scale) + + def update(self, grads): + """Updates the value of the loss scale. + + Args: + grads: A nested structure of unscaled gradients, each which is the + gradient of the loss with respect to a weight. + + Returns: + update_op: In eager mode, None. In graph mode, an op to update the loss + scale. + should_apply_gradients: Either a bool or a scalar boolean tensor. If + False, the caller should skip applying `grads` to the variables this + step. + """ + grads = nest.flatten(grads) + if distribution_strategy_context.has_strategy(): + distribution = distribution_strategy_context.get_strategy() + + def get_is_finite(grads): + is_finite = _is_all_finite(grads) + # We cast to float, because we cannot reduce booleans with + # DistributionStrategy. + return math_ops.cast(is_finite, dtypes.float32) + + is_finite_float = distribution.extended.call_for_each_replica( + get_is_finite, args=(grads,)) + reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, + is_finite_float, axis=None) + is_finite = math_ops.equal(reduced_is_finite_float, + distribution.num_replicas_in_sync) + else: + is_finite = _is_all_finite(grads) + + def update_if_finite_grads(): + """Update assuming the gradients are finite.""" + + def incr_loss_scale(): + new_loss_scale = self.current_loss_scale * self.multiplier + return control_flow_ops.group( + _assign_if_finite(self.current_loss_scale, new_loss_scale), + self.counter.assign(0)) + + return control_flow_ops.cond( + self.counter + 1 >= self.growth_steps, + incr_loss_scale, + lambda: _op_in_graph_mode(self.counter.assign_add(1))) + + def update_if_not_finite_grads(): + """Update assuming the gradients are nonfinite.""" + + new_loss_scale = math_ops.maximum( + self.current_loss_scale / self.multiplier, 1) + return control_flow_ops.group( + self.counter.assign(0), + self.current_loss_scale.assign(new_loss_scale)) + + update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, + update_if_not_finite_grads) + should_apply_gradients = is_finite + return update_op, should_apply_gradients + + +# See LossScaleOptimizer docstring for why this is so big +_DEFAULT_INITIAL_SCALE = 2 ** 15 +_DEFAULT_GROWTH_STEPS = 2000 + + +# pylint: disable=g-classes-have-attributes +@keras_export('keras.mixed_precision.LossScaleOptimizer') class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): """An optimizer that applies loss scaling. @@ -194,19 +414,21 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): applied. The loss scale can either be a fixed constant, chosen by the user, or be - dynamically determined. Dynamically determining the loss scale is convenient - as a loss scale does not have to be explicitly chosen. However it reduces - performance. + dynamically determined. Using a dynamic loss scale is highly recommend and is + the default behavior, as choosing a specific fixed loss scale is difficult. + Every step, the dynamic loss scale is potentially updated to a new value. + Dynamic loss scaling sometimes causes the loss scale to be too high and cause + the gradients to overflow, in which case gradients are not applied to + variables that step. - This optimizer wraps another optimizer and applies loss scaling to it via a - `LossScale`. Loss scaling is applied whenever gradients are - computed, either through `minimize()` or `get_gradients()`. The loss scale is - updated via `LossScale.update()` whenever gradients are applied, either - through `minimize()` or `apply_gradients()`. For example: + `LossScaleOptimizer` wraps another optimizer and applies loss scaling to it. + Loss scaling is applied whenever gradients are computed, either through + `minimize()` or `get_gradients()`. If dynamic, the loss scale is updated + whenever gradients are applied, either through `minimize()` or + `apply_gradients()`. For example: >>> opt = tf.keras.optimizers.SGD(0.25) - >>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, - ... "dynamic") + >>> opt = tf.keras.mixed_precision.LossScaleOptimizer(opt) >>> var = tf.Variable(1.) >>> loss_fn = lambda: var ** 2 >>> # 'minimize' applies loss scaling to the loss and updates the loss sale. @@ -230,12 +452,43 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): >>> var.numpy() 0.25 + Args: + optimizer: The Optimizer instance to wrap. + dynamic: Bool indicating whether dynamic loss scaling is used. Defaults to + True. If True, the loss scale will be dynamically updated over time using + an algorithm that keeps the loss scale at approximately its optimal value. + If False, a single fixed loss scale is used and `initial_scale` must be + specified, which is used as the loss scale. Recommended to keep as True, + as choosing a fixed loss scale can be tricky. Currently, there is a small + performance overhead to dynamic loss scaling compared to fixed loss + scaling. + initial_scale: The initial loss scale. If `dynamic` is True, this defaults + to 2 ** 15. If `dynamic` is False, this must be specified and acts as the + sole loss scale, as the loss scale does not change over time. When dynamic + loss scaling is used, is better for this to be a very high number, because + a loss scale that is too high gets lowered far more quickly than a loss + scale that is too low gets raised. + dynamic_growth_steps: With dynamic loss scaling, every + `dynamic_growth_steps` steps with finite gradients, the loss scale is + doubled. Defaults to 2000. If a nonfinite gradient is encountered, the + count is reset back to zero, gradients are skipped that step, and the loss + scale is halved. The count can be queried with + `LossScaleOptimizer.dynamic_counter`. This argument can only be specified + if `dynamic` is True. + + To use a fixed loss scale instead of dynamic loss scale, pass `dynamic=False` + and pass the loss scale to `initial_scale`. For example: + + >>> opt = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=1024) + >>> opt.loss_scale.numpy() + 1024. + Hyperparameters can be accessed and set on the LossScaleOptimizer, which will be delegated to the wrapped optimizer. >>> opt = tf.keras.optimizers.Adam(beta_1=0.8, epsilon=1e-5) - >>> lso = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, - ... "dynamic") + >>> lso = tf.keras.mixed_precision.LossScaleOptimizer(opt) >>> opt.beta_1 0.8 >>> lso.beta_1 # Equivalent to `opt.beta_1` @@ -268,49 +521,98 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): _HAS_AGGREGATE_GRAD = True - def __init__(self, optimizer, loss_scale): - """Initializes this loss scale optimizer. - - Args: - optimizer: The Optimizer instance to wrap. - loss_scale: The loss scale to scale the loss and gradients. This can - either be an int/float to use a fixed loss scale, the string "dynamic" - to use dynamic loss scaling, or an instance of a LossScale. The string - "dynamic" equivalent to passing `DynamicLossScale()`, and passing an - int/float is equivalent to passing a FixedLossScale with the given loss - scale. - """ + def __init__(self, optimizer, dynamic=True, initial_scale=None, + dynamic_growth_steps=None): if not isinstance(optimizer, optimizer_v2.OptimizerV2): - raise ValueError('"optimizer" must be an instance of OptimizerV2, but ' - 'got: %s' % optimizer) + raise TypeError('"optimizer" must be an instance of OptimizerV2, but ' + 'got: %s' % optimizer) + if not isinstance(dynamic, bool): + # Catch errors if a user incorrectly passes a string or float to the + # second argument argument, as this is commonly done for + # LossScaleOptimizerV1. + raise TypeError('"dynamic" argument to LossScaleOptimizer.__init__ must ' + 'be a bool, but got: %r' % (dynamic,)) self._raise_if_strategy_unsupported() - self._optimizer = optimizer - self._loss_scale = keras_loss_scale_module.get(loss_scale) - if self._loss_scale is None: - raise ValueError('loss_scale cannot be None.') # We don't call super().__init__, since we do not want to call OptimizerV2's # constructor. _DelegatingTrackableMixin.__init__(self, self._optimizer) - for weight in self._loss_scale._weights.values(): # pylint: disable=protected-access - # We cannot call `track_variable` in the LossScale class itself, because a - # file outside of Keras cannot depend on a Keras file. Calling it here - # instead is OK, because a variable only needs to be tracked if used with - # a Keras class, and the only way to use LossScale with a Keras class is - # through the LossScaleOptimizer. - backend.track_variable(weight) - self._track_trackable(self._loss_scale, 'loss_scale') + if dynamic: + if initial_scale is None: + initial_scale = _DEFAULT_INITIAL_SCALE + if dynamic_growth_steps is None: + dynamic_growth_steps = _DEFAULT_GROWTH_STEPS + self._loss_scale = _DynamicLossScaleState( + initial_scale, dynamic_growth_steps, multiplier=2) + self._track_trackable(self._loss_scale, 'loss_scale') + else: + if initial_scale is None: + raise ValueError('"initial_scale" must be specified if "dynamic" is ' + 'False') + self._loss_scale = float(initial_scale) + if dynamic_growth_steps is not None: + raise ValueError('"dynamic_growth_steps" must be None if "dynamic" ' + 'is False, but got: %s' % (dynamic_growth_steps,)) # To support restoring TensorFlow 2.2 checkpoints. self._track_trackable(FakeOptimizerForRestoration(self._optimizer), 'base_optimizer') + @property + def dynamic(self): + return isinstance(self._loss_scale, _DynamicLossScaleState) + @property def loss_scale(self): - """The `LossScale` instance associated with this optimizer.""" - return self._loss_scale + """The current loss scale as a float32 scalar tensor.""" + if isinstance(self._loss_scale, _DynamicLossScaleState): + return ops.convert_to_tensor(self._loss_scale.current_loss_scale) + else: + return ops.convert_to_tensor(self._loss_scale) + + @property + def dynamic_counter(self): + """The number of steps since the loss scale was last increased or decreased. + + This is None if `LossScaleOptimizer.dynamic` is False. + + The counter is incremented every step. Once it reaches + `LossScaleOptimizer.dynamic_growth_steps`, the loss scale will be doubled + and the counter will be reset back to zero. If nonfinite gradients are + encountered, the loss scale will be halved and the counter will be reset + back to zero. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.counter + else: + return None + + @property + def initial_scale(self): + """The initial loss scale. + + This is None if `LossScaleOptimizer.dynamic` is False. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.initial_loss_scale + else: + return self._loss_scale + + @property + def dynamic_growth_steps(self): + """The number of steps it takes to increase the loss scale. + + This is None if `LossScaleOptimizer.dynamic` is False. + + Every `dynamic_growth_steps` consecutive steps with finite gradients, the + loss scale is increased. + """ + if isinstance(self._loss_scale, _DynamicLossScaleState): + return self._loss_scale.growth_steps + else: + return None def get_scaled_loss(self, loss): """Scales the loss by the loss scale. @@ -322,7 +624,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): scaling is automatically applied and this method is unneeded. If this method is called, `get_unscaled_gradients` should also be called. - See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for + See the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an example. Args: @@ -330,16 +632,15 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): a tensor or a callable returning a tensor. Returns: - `loss` multiplied by `LossScaleOptimizer.loss_scale()`. + `loss` multiplied by `LossScaleOptimizer.loss_scale`. """ - loss_scale = self._loss_scale() if callable(loss): def new_loss(): loss_val = loss() - return loss_val * math_ops.cast(loss_scale, loss_val.dtype) + return loss_val * math_ops.cast(self.loss_scale, loss_val.dtype) return new_loss else: - return loss * math_ops.cast(loss_scale, loss.dtype) + return loss * math_ops.cast(self.loss_scale, loss.dtype) def get_unscaled_gradients(self, grads): """Unscales the gradients by the loss scale. @@ -351,7 +652,7 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): scaling is automatically applied and this method is unneeded. If this method is called, `get_scaled_loss` should also be called. See - the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an + the `tf.keras.mixed_precision.LossScaleOptimizer` doc for an example. Args: @@ -360,10 +661,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): Returns: A new list the same size as `grads`, where every non-None value in `grads` - is divided by `LossScaleOptimizer.loss_scale()`. + is divided by `LossScaleOptimizer.loss_scale`. """ - loss_scale = self._loss_scale() - loss_scale_reciprocal = 1. / loss_scale + loss_scale_reciprocal = 1. / self.loss_scale return [ _multiply_gradient(g, loss_scale_reciprocal) if g is not None else None for g in grads @@ -379,9 +679,9 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): grad_loss, tape=tape) grads = [g for g, _ in grads_and_vars] - variables = [v for _, v in grads_and_vars] + weights = [v for _, v in grads_and_vars] unscaled_grads = self.get_unscaled_gradients(grads) - return list(zip(unscaled_grads, variables)) + return list(zip(unscaled_grads, weights)) def get_gradients(self, loss, params): loss = self.get_scaled_loss(loss) @@ -409,7 +709,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name, experimental_aggregate_gradients): grads = [g for g, _ in grads_and_vars] - loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) + if isinstance(self._loss_scale, _DynamicLossScaleState): + loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads) + else: + loss_scale_update_op = control_flow_ops.no_op() + should_apply_grads = True def apply_fn(): # We do not want DistributionStrategy to unwrap any MirroredVariables in @@ -447,10 +751,11 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): def get_config(self): serialized_optimizer = optimizers.serialize(self._optimizer) - serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale) return { 'optimizer': serialized_optimizer, - 'loss_scale': serialized_loss_scale, + 'dynamic': self.dynamic, + 'initial_scale': self.initial_scale, + 'dynamic_growth_steps': self.dynamic_growth_steps, } @classmethod @@ -458,8 +763,28 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): config = config.copy() # Make a copy, since we mutate config config['optimizer'] = optimizers.deserialize( config['optimizer'], custom_objects=custom_objects) - config['loss_scale'] = keras_loss_scale_module.deserialize( - config['loss_scale'], custom_objects=custom_objects) + if 'loss_scale' in config: + # If loss_scale is in config, we assume we are deserializing a + # LossScaleOptimizer from TF 2.3 or below. We convert the config so it + # can be deserialized in the current LossScaleOptimizer. + loss_scale = keras_loss_scale_module.deserialize( + config.pop('loss_scale')) + if isinstance(loss_scale, loss_scale_module.FixedLossScale): + config['dynamic'] = False + config['initial_scale'] = loss_scale._loss_scale_value # pylint: disable=protected-access + elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): + config['dynamic'] = True + config['initial_scale'] = loss_scale.initial_loss_scale + config['dynamic_growth_steps'] = loss_scale.increment_period + if loss_scale.multiplier != 2: + raise ValueError('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: %s' % (loss_scale,)) + else: + raise ValueError( + 'Serialized LossScaleOptimizers with a LossScale that is neither a ' + 'FixedLossScale nor a DynamicLossScale can no longer be ' + 'deserialized') return cls(**config) def _raise_if_strategy_unsupported(self): @@ -601,15 +926,162 @@ class LossScaleOptimizer(_DelegatingTrackableMixin, optimizer_v2.OptimizerV2): # both self._compute_gradients() and self.apply_gradients(), and both need # to have the LossScaleOptimizer version called. - # TODO(reedwm): Maybe merge this class's functionality into OptimizerV2. - # TODO(reedwm): Maybe throw an error if mixed precision is used without this # optimizer being used. - # Trackable delegations: Delegate all Trackable methods to the wrapped - # optimizer. This is so the checkpoint format for a LossScaleOptimizer is - # identical to the checkpoint format for a normal optimizer, except the loss - # scale is stored in the checkpoint. + +@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer') +class LossScaleOptimizerV1(LossScaleOptimizer): + """An deprecated optimizer that applies loss scaling. + + Warning: This class is deprecated and will be removed in TensorFlow 2.5. + Please use the non-experimental class + `tf.keras.mixed_precision.LossScaleOptimizer` instead. + + This class is identical to the non-experimental + `keras.mixed_precision.LossScaleOptimizer` except its constructor takes + different arguments. For this class (the experimental version), the + constructor takes a `loss_scale` argument. For the non-experimental class, + the constructor encodes the loss scaling information in multiple arguments. + Note that unlike this class, the non-experimental class does not accept a + `tf.compat.v1.mixed_precision.LossScale`, which is deprecated. + + If you currently use this class, you should switch to the non-experimental + `tf.keras.mixed_precision.LossScaleOptimizer` instead. We show several + examples of converting the use of the experimental class to the equivalent + non-experimental class. + + >>> # In all of the the examples below, `opt1` and `opt2` are identical + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale='dynamic') + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD()) + >>> assert opt1.get_config() == opt2.get_config() + + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale=123) + >>> # dynamic=False indicates to use fixed loss scaling. initial_scale=123 + >>> # refers to the initial loss scale, which is the single fixed loss scale + >>> # when dynamic=False. + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), dynamic=False, initial_scale=123) + >>> assert opt1.get_config() == opt2.get_config() + + >>> loss_scale = tf.compat.v1.mixed_precision.experimental.DynamicLossScale( + ... initial_loss_scale=2048, increment_period=500) + >>> opt1 = tf.keras.mixed_precision.experimental.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), loss_scale=loss_scale) + >>> opt2 = tf.keras.mixed_precision.LossScaleOptimizer( + ... tf.keras.optimizers.SGD(), initial_scale=2048, + ... dynamic_growth_steps=500) + >>> assert opt1.get_config() == opt2.get_config() + + Args: + optimizer: The Optimizer instance to wrap. + loss_scale: The loss scale to scale the loss and gradients. This can + either be an int/float to use a fixed loss scale, the string "dynamic" + to use dynamic loss scaling, or an instance of a LossScale. The string + "dynamic" equivalent to passing `DynamicLossScale()`, and passing an + int/float is equivalent to passing a FixedLossScale with the given loss + scale. If a DynamicLossScale is passed, DynamicLossScale.multiplier must + be 2 (the default). + """ + + def __init__(self, optimizer, loss_scale): + warn_msg_prefix = ( + 'tf.keras.mixed_precision.experimental.LossScaleOptimizer is ' + 'deprecated. Please use tf.keras.mixed_precision.LossScaleOptimizer ' + 'instead. ') + + if isinstance(loss_scale, dict): + loss_scale = keras_loss_scale_module.deserialize(loss_scale) + + if isinstance(loss_scale, (int, float)): + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt, dynamic=False, initial_scale={})'.format(loss_scale)) + super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, + initial_scale=loss_scale) + elif isinstance(loss_scale, loss_scale_module.FixedLossScale): + ls_val = loss_scale._loss_scale_value # pylint: disable=protected-access + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt, dynamic=False, initial_scale={})'.format(ls_val)) + super(LossScaleOptimizerV1, self).__init__(optimizer, dynamic=False, + initial_scale=ls_val) + elif loss_scale == 'dynamic': + tf_logging.warn( + warn_msg_prefix + 'For example\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt)') + super(LossScaleOptimizerV1, self).__init__(optimizer) + elif isinstance(loss_scale, loss_scale_module.DynamicLossScale): + kwargs = {} + extra_arguments = '' + if loss_scale.initial_loss_scale != _DEFAULT_INITIAL_SCALE: + kwargs['initial_scale'] = loss_scale.initial_loss_scale + extra_arguments += (', initial_scale=%s' % + loss_scale.initial_loss_scale) + if loss_scale.increment_period != _DEFAULT_GROWTH_STEPS: + kwargs['dynamic_growth_steps'] = loss_scale.increment_period + extra_arguments += (', dynamic_growth_steps=%s' % + loss_scale.increment_period) + if loss_scale.multiplier != 2: + raise ValueError('When passing a DynamicLossScale to "loss_scale", ' + 'DynamicLossScale.multiplier must be 2. Got: %s' + % (loss_scale,)) + tf_logging.warn( + warn_msg_prefix + + 'Note that the non-experimental LossScaleOptimizer does not take a ' + 'DynamicLossScale but instead takes the dynamic configuration ' + 'directly in the constructor. For example:\n' + ' opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(' + 'opt{})\n'.format(extra_arguments)) + super(LossScaleOptimizerV1, self).__init__(optimizer, **kwargs) + elif isinstance(loss_scale, loss_scale_module.LossScale): + raise TypeError('Passing a LossScale that is not a FixedLossScale or a ' + 'DynamicLossScale is no longer supported. Got: {}' + .format(loss_scale)) + else: + raise ValueError('Invalid value passed to loss_scale. loss_scale ' + 'must be the string "dynamic" (recommended), an int, ' + 'a float, a FixedLossScale, or a DynamicLossScale. Got ' + 'value: {}'.format(loss_scale)) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() # Make a copy, since we mutate config + config['optimizer'] = optimizers.deserialize( + config['optimizer'], custom_objects=custom_objects) + + # If loss_scale is in config, we assume we are deserializing a + # LossScaleOptimizer from TF 2.3 or below. Otherwise, we assume we are + # deserializing a LossScaleOptimizer from TF 2.4 or above. + if 'loss_scale' in config: + config['loss_scale'] = keras_loss_scale_module.deserialize( + config['loss_scale']) + if (isinstance(config['loss_scale'], loss_scale_module.DynamicLossScale) + and config['loss_scale'].multiplier != 2): + raise ValueError('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: %s' % (config['loss_scale'],)) + return cls(**config) + + # We convert the config, as generated by LossScaleOptimizer.get_config, to a + # version that can be passed to LossScaleOptimizerV1.__init__ + if config['dynamic']: + config['loss_scale'] = loss_scale_module.DynamicLossScale( + config['initial_scale'], config['dynamic_growth_steps'], multiplier=2) + else: + config['loss_scale'] = loss_scale_module.FixedLossScale( + config['initial_scale']) + + del config['dynamic'] + del config['initial_scale'] + del config['dynamic_growth_steps'] + return cls(**config) class FakeOptimizerForRestoration(trackable.Trackable): @@ -654,7 +1126,7 @@ class FakeOptimizerForRestoration(trackable.Trackable): # pylint: disable=protected-access mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, - LossScaleOptimizer) + LossScaleOptimizerV1) def _multiply_gradient(gradient, scale): diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py index fe3a237ef83..8da89e7e72a 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -40,7 +40,7 @@ from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.experimental import loss_scale as tf_loss_scale_module from tensorflow.python.training.tracking import util as trackable_utils # Disable not-callable lint error, as the linter is unable to detect that @@ -93,7 +93,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([5.0]) opt = gradient_descent.SGD(2.0) loss_scale = 10. - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) + self.assertEqual(self.evaluate(opt.loss_scale), loss_scale) + self.assertIsInstance(opt.loss_scale, ops.Tensor) # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale # / strategy.num_replicas_in_sync will not be exact, which could lead to # assertion failures due to rounding issues. @@ -112,7 +115,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = gradient_descent.SGD(1.0) loss_scale = 10. - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=loss_scale) grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( loss_scale) loss = grad_check_fn(var) @@ -122,9 +126,18 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # mp_test_util.create_identity_with_grad_check_fn added an assertion op. self.evaluate(run_op) + def testDynamicAttrsWithFixedLossScale(self): + opt = gradient_descent.SGD() + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2.) + self.assertFalse(opt.dynamic) + self.assertIsNone(opt.dynamic_counter) + self.assertIsNone(opt.dynamic_growth_steps) + def testGetScaledLoss(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2.) loss = ops.convert_to_tensor_v2_with_dispatch(5.) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(loss))) self.assertEqual(10., self.evaluate(opt.get_scaled_loss(lambda: loss)())) @@ -134,7 +147,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): def testGetUnscaledGradients(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2) scaled_grads = [ ops.convert_to_tensor_v2_with_dispatch(3.), None, ops.convert_to_tensor_v2_with_dispatch(-4., dtype='float16') @@ -145,7 +159,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): def testGetUnscaledSparseGradients(self): opt = gradient_descent.SGD(2.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=2) sparse_scaled_grad = ops.IndexedSlices( ops.convert_to_tensor_v2_with_dispatch([[4., 2.], [8., 5.]]), ops.convert_to_tensor_v2_with_dispatch([1, 3], dtype='int32'), @@ -165,12 +180,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - self.assertEqual( - loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) + self.assertEqual(opt.initial_scale, 2.) + self.assertIsInstance(opt.initial_scale, float) + self.assertEqual(opt.dynamic_growth_steps, 1) + self.assertIsInstance(opt.dynamic_growth_steps, int) + self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0) run_fn = self._run_fn_with_grad_check(strategy, var, opt, expected_gradient) run_op = strategy.experimental_run(run_fn) @@ -189,6 +206,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # 1. self.assertAllClose([1.], self.evaluate(var)) + def testDynamicLossScaleDefaultValues(self): + opt = gradient_descent.SGD() + opt = loss_scale_optimizer.LossScaleOptimizer(opt) + self.assertEqual(opt.initial_scale, 2 ** 15) + self.assertEqual(opt.dynamic_growth_steps, 2000) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15) + # pylint: disable=cell-var-from-loop @parameterized.named_parameters(*TESTCASES) def testClipping(self, strategy_fn): @@ -198,12 +223,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(), self.subTest(clip_type=clip_type): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate, **{clip_type: 2.0}) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) self.assertEqual(getattr(opt, clip_type), 2.0) - self.assertEqual( - loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + self.assertEqual(opt.initial_scale % strategy.num_replicas_in_sync, 0) loss = lambda: var * 4 / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, var_list=[var]) @@ -215,7 +238,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # The gradient is 4 but is clipped to 2, so the variable will be # init_val - clipped_grad * lr == 5 - 2 * 2 == 1 self.assertAllClose([1.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), 4) + self.assertEqual(self.evaluate(opt.loss_scale), 4) # Test changing the clip amount and running again setattr(opt, clip_type, 3.0) @@ -224,7 +247,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # The gradient is 4 but is clipped to 3, so the variable will be # prev_var - clipped_grad * lr == 1 - 3 * 2 == -5 self.assertAllClose([-5.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), 8) + self.assertEqual(self.evaluate(opt.loss_scale), 8) # Test Inf gradients are still skipped instead of being clipped loss = lambda: var * float('Inf') @@ -232,7 +255,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): run_op = strategy.experimental_run(run_fn) self._run_if_in_graph_mode(run_op) self.assertAllClose([-5.], self.evaluate(var)) # Var does not change - self.assertEqual(self.evaluate(opt.loss_scale()), 4) + self.assertEqual(self.evaluate(opt.loss_scale), 4) # pylint: enable=cell-var-from-loop @parameterized.named_parameters(*TESTCASES) @@ -240,9 +263,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy_fn().scope() as strategy: var = variables.Variable([1.0, 2.0]) opt = gradient_descent.SGD(1.0) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) # Test optimizer with finite gradients loss = lambda: var * 2.0 / strategy.num_replicas_in_sync @@ -253,7 +275,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Gradient is 2, so variable will have 2 subtracted from it self.assertAllClose([-1.0, 0.0], self.evaluate(var)) # Loss scale has doubled from 2 to 4 - self.assertEqual(4., self.evaluate(opt.loss_scale())) + self.assertEqual(4., self.evaluate(opt.loss_scale)) # Test optimizer with NaN gradients loss = lambda: var * float('NaN') @@ -263,7 +285,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Variable should not change from before, due to NaN gradients. self.assertAllClose(self.evaluate(var), [-1.0, 0.0]) # Loss scale should half due to NaN gradients. - self.assertEqual(2., self.evaluate(opt.loss_scale())) + self.assertEqual(2., self.evaluate(opt.loss_scale)) @parameterized.named_parameters(*TESTCASES) def testDynamicLossScaleWithFloat16Loss(self, strategy_fn): @@ -272,9 +294,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy.scope(): var = variables.Variable([5.0]) opt = gradient_descent.SGD(learning_rate) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2, increment_period=1, multiplier=2) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=1) def loss(): return math_ops.cast(var / strategy.num_replicas_in_sync, 'float16') @@ -297,11 +318,9 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([1.0, 2.0]) # An SGD optimizer with momentum has slot variables. opt = gradient_descent.SGD(1.0, momentum=1.) - initial_loss_scale = 2. - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=1, - multiplier=4) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + initial_scale = 2. + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, initial_scale=initial_scale, dynamic_growth_steps=1) loss = lambda: var / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, var_list=[var]) run_op = strategy.experimental_run(run_fn) @@ -312,7 +331,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # variable is subtracted by the accumulator, so the variable is subtracted # by 1. self.assertAllClose([0.0, 1.0], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4) + self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 2) run_op = strategy.experimental_run(run_fn) self._run_if_in_graph_mode(run_op) @@ -321,14 +340,14 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # variable is subtracted by the accumulator, so the variable is subtracted # by 2. self.assertAllClose([-2., -1.], self.evaluate(var)) - self.assertEqual(self.evaluate(opt.loss_scale()), - initial_loss_scale * 16) + self.assertEqual(self.evaluate(opt.loss_scale), initial_scale * 4) self.assertEqual(opt.get_slot_names(), ['momentum']) def testIterations(self): opt = gradient_descent.SGD(2.0) - lso = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=10.) + lso = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=10.) lso.iterations = 7 self.assertEqual(lso.iterations, 7) self.assertEqual(opt.iterations, 7) @@ -338,7 +357,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with strategy_fn().scope() as strategy: # Test iterations is incremented in opt.minimize. opt = gradient_descent.SGD(1.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale='dynamic') + opt = loss_scale_optimizer.LossScaleOptimizer(opt) var = variables.Variable([5.0]) loss = lambda: var * 2.0 / strategy.num_replicas_in_sync run_fn = lambda: opt.minimize(loss, [var]) @@ -361,11 +380,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with self.test_session(): var = variables.Variable([1.0]) opt = gradient_descent.SGD(1.0) - initial_loss_scale = 2. - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=initial_loss_scale, increment_period=1, - multiplier=4) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=1) run_op = opt.minimize(lambda: var * 2, [var]) self.evaluate(variables.global_variables_initializer()) self._run_if_in_graph_mode(run_op) @@ -377,15 +393,10 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): opt.set_weights([np.array(2.)]) self.assertEqual(self.evaluate(opt.variables()[0]), 2) - def testPassingNoneToLossScale(self): - opt = gradient_descent.SGD() - with self.assertRaisesRegex(ValueError, r'loss_scale cannot be None'): - loss_scale_optimizer.LossScaleOptimizer(opt, None) - def testHyperParametersExposed(self): with self.cached_session(): opt = adam.Adam(learning_rate=1.0, beta_1=0.5, beta_2=0.9) - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) # Force hyperparameters to be created opt.lr # pylint: disable=pointless-statement self.evaluate(variables.global_variables_initializer()) @@ -420,13 +431,13 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self._set_hyper('loss_scale', 123.) opt = MyOpt() - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) with self.assertRaises(AttributeError): - lso.loss_scale = loss_scale_module.FixedLossScale(2.) + lso.loss_scale = 2. def testArbitraryAttributesNotExposed(self): opt = gradient_descent.SGD() - lso = loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(opt) self.assertFalse(opt.nesterov) with self.assertRaisesRegex( AttributeError, @@ -438,8 +449,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertFalse(opt.nesterov) def testDir(self): - lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), - 'dynamic') + lso = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) dir_result = dir(lso) self.assertIn('learning_rate', dir_result) # Hyperparameter self.assertIn('lr', dir_result) # Hyperparameter @@ -471,11 +481,125 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with create_mirrored_strategy().scope() as strategy: var = variables.Variable([5.0]) opt = MyOptimizer(learning_rate=1.0) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=1) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False, + initial_scale=1) loss = lambda: var * 2.0 run_fn = lambda: opt.minimize(loss, [var]) strategy.experimental_run(run_fn) + @parameterized.named_parameters(*TESTCASES) + def testV1Optimizer(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + with strategy.scope(): + # Test FixedLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale=2) + self.assertIsInstance(opt.loss_scale, ops.Tensor) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2) + self.assertEqual(opt.initial_scale, 2) + self.assertIsNone(opt.dynamic_growth_steps) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 2 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Test DynamicLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 'dynamic') + self.assertEqual(opt.initial_scale, 2 ** 15) + self.assertEqual(opt.dynamic_growth_steps, 2000) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 15) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + loss = lambda: var * float('NaN') + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertAllClose([5.], self.evaluate(var)) + self.assertEqual(self.evaluate(opt.loss_scale), 2 ** 14) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + @parameterized.named_parameters(*TESTCASES) + def testPassingV1LossScale(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + with strategy.scope(): + # Test FixedLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + loss_scale = tf_loss_scale_module.FixedLossScale(2.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + self.assertIsInstance(opt.loss_scale, ops.Tensor) + self.evaluate(variables.global_variables_initializer()) + self.assertEqual(self.evaluate(opt.loss_scale), 2) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 2 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Test DynamicLossScale + var = variables.Variable([5.0]) + opt = gradient_descent.SGD(learning_rate) + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=4, increment_period=1, multiplier=2) + loss_scale._current_loss_scale.assign(2) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + self.assertEqual(opt.initial_scale, 4) + self.assertEqual(opt.dynamic_growth_steps, 1) + self.evaluate(variables.global_variables_initializer()) + # Current loss scale is not copied so loss scale is reinitialized to 4 + self.assertEqual(self.evaluate(opt.loss_scale), 4) + for s in strategy.experimental_local_results(opt.dynamic_counter): + self.assertEqual(self.evaluate(s), 0) + + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, 4 / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertAllClose([3.], self.evaluate(var)) + + def testPassingV1LossScaleErrors(self): + opt = gradient_descent.SGD() + loss_scale = tf_loss_scale_module.DynamicLossScale(multiplier=4) + with self.assertRaisesRegex( + ValueError, 'When passing a DynamicLossScale to "loss_scale", ' + 'DynamicLossScale.multiplier must be 2. Got: ' + 'DynamicLossScale'): + loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + + class MyLossScale(tf_loss_scale_module.LossScale): + + def __call__(self): + return 1. + + def update(self, grads): + return None, True + + def get_config(self): + return {} + + with self.assertRaisesRegex( + TypeError, 'Passing a LossScale that is not a FixedLossScale or a ' + 'DynamicLossScale is no longer supported. Got:'): + loss_scale_optimizer.LossScaleOptimizerV1(opt, MyLossScale()) + @parameterized.named_parameters({ 'testcase_name': 'SaveAndRestoreBase', 'strategy_fn': default_strategy_fn, @@ -529,10 +653,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if save_with_ls: - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., - multiplier=2.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) run_fn = lambda: opt.minimize(lambda: var / replicas + 1., var_list=[var]) opt_op = strategy.experimental_run(run_fn) self.evaluate(variables.global_variables_initializer()) @@ -541,8 +663,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): # Assert values. self.assertEqual(self.evaluate(var), 1.) if save_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) slot_var = opt.get_slot(var, 'momentum') self.assertEqual(self.evaluate(slot_var).item(), -1) self.assertEqual(self.evaluate(opt.iterations), 1) @@ -560,10 +682,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): var = variables.Variable([2.0]) opt = inner_opt = MySGD(1., momentum=1.) if restore_with_ls: - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=1., increment_period=2., - multiplier=2.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=1., + dynamic_growth_steps=2.) # Restore new model. checkpoint = trackable_utils.Checkpoint(optimizer=opt, var=var) @@ -578,11 +698,11 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): if context.executing_eagerly(): self.assertEqual(self.evaluate(var), 1.) if save_with_ls and restore_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) elif restore_with_ls: - self.assertEqual(self.evaluate(loss_scale()), 1.) - self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) + self.assertEqual(self.evaluate(opt.loss_scale), 1.) + self.assertEqual(self.evaluate(opt.dynamic_counter), 0) self.assertEqual(self.evaluate(opt.iterations), 1) # Run the model again. @@ -611,30 +731,180 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(var), 1) self.assertEqual(self.evaluate(slot_var).item(), -1) - def testGetConfig(self): + @combinations.generate(combinations.combine( + get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2'])) + def testGetConfigFixed(self, get_config, from_config): + # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the + # LossScaleOptimizer from TF 2.3. Then restore the config into a + # LossScaleOptimizerV1 or LossScaleOptimizer opt = gradient_descent.SGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) - config = opt.get_config() - opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + if get_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, 2) + config = opt.get_config() + elif get_config == 'v2': + opt = loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=2) + config = opt.get_config() + else: + self.assertEqual(get_config, 'tf2_3') + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'FixedLossScale', + 'config': {'loss_scale_value': 2.0} + }, + } + + if from_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + else: + self.assertEqual(from_config, 'v2') + opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + # Force hyperparameters to be created opt.lr # pylint: disable=pointless-statement self.evaluate(variables.global_variables_initializer()) + # Test attributes on the optimizer self.assertEqual(self.evaluate(opt.lr), 2.) - self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) + self.assertEqual(self.evaluate(opt._optimizer.lr), 2.) + self.assertEqual(self.evaluate(opt.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.initial_scale, 2.) + self.assertIsNone(opt.dynamic_growth_steps) + self.assertIsNone(opt.dynamic_counter) + self.assertFalse(opt.dynamic) - def testSerializationWithBuiltInOptimizer(self): + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + + @combinations.generate(combinations.combine( + get_config=['v1', 'v2', 'tf2_3'], from_config=['v1', 'v2'])) + def testGetConfigDynamic(self, get_config, from_config): + # Get a config from LossScaleOptimizerV1, LossScaleOptimizer, or the + # LossScaleOptimizer from TF 2.3. Then restore the config into a + # LossScaleOptimizerV1 or LossScaleOptimizer opt = gradient_descent.SGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + if get_config == 'v1': + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=3) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + config = opt.get_config() + elif get_config == 'v2': + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2, + dynamic_growth_steps=3) + config = opt.get_config() + else: + self.assertEqual(get_config, 'tf2_3') + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'DynamicLossScale', + 'config': { + 'initial_loss_scale': 2.0, + 'increment_period': 3, + 'multiplier': 2.0, + } + }, + } + + if from_config == 'v1': + opt = loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + else: + self.assertEqual(from_config, 'v2') + opt = loss_scale_optimizer.LossScaleOptimizer.from_config(config) + + # Force hyperparameters to be created + opt.lr # pylint: disable=pointless-statement + self.evaluate(variables.global_variables_initializer()) + + # Test attributes on the optimizer + self.assertEqual(self.evaluate(opt.lr), 2.) + self.assertEqual(self.evaluate(opt._optimizer.lr), 2.) + self.assertEqual(self.evaluate(opt.momentum), 0.5) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.initial_scale, 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) + self.assertTrue(opt.dynamic) + + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) + + def test_from_config_with_invalid_multiplier(self): + config = { + 'optimizer': { + 'class_name': 'SGD', + 'config': { + 'learning_rate': 2.0, + 'momentum': 0.5, + 'decay': 0.0, + 'nesterov': False, + 'name': 'SGD', + } + }, + 'loss_scale': { + 'class_name': 'DynamicLossScale', + 'config': { + 'initial_loss_scale': 2.0, + 'increment_period': 3, + 'multiplier': 4.0, + } + }, + } + + expected_error = ('Cannot deserialize LossScaleOptimizer with a ' + 'DynamicLossScale whose multiplier is not 2. Got ' + 'DynamicLossScale: DynamicLossScale\\(') + with self.assertRaisesRegex(ValueError, expected_error): + loss_scale_optimizer.LossScaleOptimizer.from_config(config) + with self.assertRaisesRegex(ValueError, expected_error): + loss_scale_optimizer.LossScaleOptimizerV1.from_config(config) + + @parameterized.named_parameters({ + 'testcase_name': 'V2', + 'use_v1': False, + }, { + 'testcase_name': 'V1', + 'use_v1': True, + },) + def testSerializationWithBuiltInOptimizer(self, use_v1): + opt = gradient_descent.SGD(2., momentum=0.5) + if use_v1: + loss_scale = tf_loss_scale_module.DynamicLossScale( + initial_loss_scale=2., increment_period=3.) + opt = loss_scale_optimizer.LossScaleOptimizerV1(opt, loss_scale) + else: + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=3.) config = optimizers.serialize(opt) opt = optimizers.deserialize(config) # Force hyperparameters to be created @@ -643,9 +913,21 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(opt.lr), 2.) self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) + self.assertTrue(opt.dynamic, 4.) + # Deserializing a LossScaleOptimizer always always results in a V2 + # LossScaleOptimizer, even if serialized with a LossScaleOptimizerV1. + self.assertAllEqual(type(opt), loss_scale_optimizer.LossScaleOptimizer) + + # Ensure the optimizer can be used + var = variables.Variable([5.0]) + run_op = self._run_fn_with_grad_check( + distribution_strategy_context.get_strategy(), var, opt, 2)() + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + self.assertEqual(self.evaluate(var), [3.]) + self.assertEqual(self.evaluate(opt.dynamic_counter), 1) def testSerializationWithCustomOptimizer(self): class MySGD(gradient_descent.SGD): @@ -655,10 +937,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.my_attribute = 123 opt = MySGD(2., momentum=0.5) - loss_scale = loss_scale_module.DynamicLossScale( - initial_loss_scale=2., increment_period=3., - multiplier=4.) - opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, initial_scale=2., + dynamic_growth_steps=3.) config = optimizers.serialize(opt) custom_objects = {'MySGD': MySGD} opt = optimizers.deserialize(config, custom_objects=custom_objects) @@ -668,9 +948,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): self.assertEqual(self.evaluate(opt.lr), 2.) self.assertEqual(self.evaluate(opt._optimizer.momentum), 0.5) - self.assertEqual(self.evaluate(opt.loss_scale()), 2.) - self.assertEqual(opt.loss_scale.increment_period, 3.) - self.assertEqual(opt.loss_scale.multiplier, 4.) + self.assertEqual(self.evaluate(opt.loss_scale), 2.) + self.assertEqual(opt.dynamic_growth_steps, 3.) self.assertEqual(opt._optimizer.my_attribute, 123) def testUnsupportedStrategy(self): @@ -680,8 +959,8 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): 'CentralStorageStrategy. Try using a different Strategy, e.g. a ' 'MirroredStrategy') with strategy.scope(), self.assertRaisesRegex(ValueError, expected_error): - loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) - opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD(), 1.) + loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) + opt = loss_scale_optimizer.LossScaleOptimizer(gradient_descent.SGD()) with strategy.scope(): var = variables.Variable(1.0) loss = lambda: var * 2.0 @@ -689,6 +968,24 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): with self.assertRaisesRegex(ValueError, expected_error): strategy.experimental_run(run_fn) + def testInvalidArgsWithFixedLossScale(self): + opt = gradient_descent.SGD() + with self.assertRaisesRegex( + ValueError, '"initial_scale" must be specified if "dynamic" is False'): + loss_scale_optimizer.LossScaleOptimizer(opt, dynamic=False) + with self.assertRaisesRegex( + ValueError, '"dynamic_growth_steps" must be None if "dynamic" is ' + 'False, but got: 2'): + loss_scale_optimizer.LossScaleOptimizer( + opt, dynamic=False, initial_scale=1, dynamic_growth_steps=2) + + def testDynamicMustBeBool(self): + opt = gradient_descent.SGD() + with self.assertRaisesRegex( + TypeError, '"dynamic" argument to LossScaleOptimizer.__init__ must be ' + "a bool, but got: 'dynamic'"): + loss_scale_optimizer.LossScaleOptimizer(opt, 'dynamic') + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py b/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py index d0fea573bd0..a70620e4d6d 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/mixed_precision_graph_rewrite_test.py @@ -65,13 +65,13 @@ class MixedPrecisionTest(keras_parameterized.TestCase): opt = gradient_descent_v2.SGD(1.0) opt = enable_mixed_precision_graph_rewrite(opt, 123.) self.assertIsInstance( - opt, loss_scale_optimizer_v2.LossScaleOptimizer) - self.assertEqual(self.evaluate(opt._loss_scale()), 123.) + opt, loss_scale_optimizer_v2.LossScaleOptimizerV1) + self.assertEqual(self.evaluate(opt.loss_scale), 123.) @combinations.generate(combinations.combine(mode=['graph', 'eager'])) def test_optimizer_errors(self): opt = gradient_descent_v2.SGD(1.0) - opt = loss_scale_optimizer_v2.LossScaleOptimizer(opt, 'dynamic') + opt = loss_scale_optimizer_v2.LossScaleOptimizerV1(opt, 'dynamic') with self.assertRaisesRegex( ValueError, '"opt" must not already be an instance of a ' 'LossScaleOptimizer.'): diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy.py b/tensorflow/python/keras/mixed_precision/experimental/policy.py index 1e1ce24963f..e42d92b8419 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy.py @@ -32,15 +32,15 @@ from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.util.tf_export import keras_export -@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) +@keras_export('keras.mixed_precision.Policy', v1=[]) class Policy(object): """A dtype policy for a Keras layer. - A dtype policy determines dtype-related aspects of a layer, such as its - computation and variable dtypes. Each layer has a policy. Policies can be - passed to the `dtype` argument of layer constructors, or a global policy can - be set with `tf.keras.mixed_precision.experimental.set_policy`. A layer will - default to the global policy if no policy is passed to it's constructor. + A dtype policy determines a layer's computation and variable dtypes. Each + layer has a policy. Policies can be passed to the `dtype` argument of layer + constructors, or a global policy can be set with + `tf.keras.mixed_precision.experimental.set_policy`. A layer will default to + the global policy if no policy is passed to it's constructor. For many models, each layer's policy will have the same compute dtype and variable dtype, which will typically be float32. In this case, we refer to the @@ -56,24 +56,17 @@ class Policy(object): https://www.tensorflow.org/guide/keras/mixed_precision) for more information on how to use mixed precision. - Certain policies also have a `tf.mixed_precision.experimental.LossScale` - instance, which is used by `tf.keras.Model`s to performance loss scaling. Loss - scaling is a technique used with mixed precision to avoid numerical underflow - in float16 gradients. Loss scaling is only done by Models in `Model.fit`, - `Model.train_on_batch`, and similar methods. Layers which are not Models - ignore the loss scale. - Policies are constructed by passing a string to the constructor, e.g. - `tf.keras.mixed_precision.experimental.Policy('float32')`. The string - determines the compute and variable dtypes. It can be one of the following: + `tf.keras.mixed_precision.Policy('float32')`. The string determines the + compute and variable dtypes. It can be one of the following: - * Any dtype name, such as 'float32' or 'float64'. Both the variable and - compute dtypes will be that dtype. No loss scaling is done by default. - * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or - bfloat16, while the variable dtype is float32. These policies are used for - mixed precision training. With 'mixed_float16', a dynamic loss scale is - used by default. 'mixed_bfloat16' does no loss scaling by default, as loss - scaling is unnecessary with bfloat16. + * Any dtype name, such as 'float32' or 'float64'. Both the variable and + compute dtypes will be that dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + `tf.keras.Model.compile` will wrap the optimizer with a + `tf.keras.mixed_precision.LossScaleOptimizer`. These policies are used for + mixed precision training. ### How to use mixed precision in a Keras model @@ -97,7 +90,7 @@ class Policy(object): Alternatively, the policy can be passed to individual layers instead of setting the global policy with `set_policy`: - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> model = tf.keras.models.Sequential([ ... tf.keras.layers.Input((100,)), ... tf.keras.layers.Dense(10, dtype=policy), @@ -110,7 +103,7 @@ class Policy(object): `Model.fit`, `Model.train_on_batch`, and other training methods. If no such method is used (e.g., a custom training loop is used) and `'mixed_float16'` is used, the loss scale must be manually applied. See - `tf.keras.mixed_precision.experimental.LossScaleOptimizer` for details. For + `tf.keras.mixed_precision.LossScaleOptimizer` for details. For `'mixed_bfloat16'`, no loss scaling is done and loss scaling never needs to be manually applied. @@ -227,7 +220,7 @@ class Policy(object): ... def build(self, input_shape): ... self.x = self.add_weight('x') ... self.y = self.add_weight('y', experimental_autocast=False) - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> layer = MyLayer(dtype=policy) >>> layer.build((2, 2)) >>> layer.x @@ -258,7 +251,7 @@ class Policy(object): ... def call(self, inputs): ... return tf.matmul(inputs, self.kernel) - >>> policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16') + >>> policy = tf.keras.mixed_precision.Policy('mixed_float16') >>> layer = MyDense(dtype=policy) >>> x = np.random.rand(10, 10) >>> y = layer(x) @@ -289,29 +282,27 @@ class Policy(object): layer would only work if the inputs were float32. """ - def __init__(self, name, loss_scale='auto'): + def __init__(self, name): """Constructs the policy. - The `name` argument determines the compute and variable dtype, the default - loss scale, and has no additional effect on the Policy. The compute and - variable dtypes can only be specified through `name`, and cannot be + The `name` argument determines the compute and variable dtype. The compute + and variable dtypes can only be specified through `name`, and cannot be specified directly. + `name` is also used by `tf.keras.Model.compile`. If `name` is + `"mixed_float16"`, `tf.keras.Model.compile` will automatically wrap the + optimizer with a LossScaleOptimizer if it is not already a + LossScaleOptimizer. + Args: name: A string. Can be one of the following values: * Any dtype name, such as 'float32' or 'float64'. Both the variable and compute dtypes will be that dtype. * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or bfloat16, while the variable dtype is float32. With 'mixed_float16', - a dynamic loss scale is used. These policies are used for mixed - precision training. - loss_scale: A `tf.mixed_precision.experimental.LossScale`, an int (which - uses a `FixedLossScale`), the string "dynamic" (which uses a - `DynamicLossScale`), or None (which uses no loss scale). Defaults to - `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then - use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only - `tf.keras.Model`s, not layers, use the loss scale, and it is only used - during `Model.fit`, `Model.train_on_batch`, and other similar methods. + `tf.keras.Model.compile` will wrap the optimizer with a + `tf.keras.mixed_precision.LossScaleOptimizer. These policies are used + for mixed precision training. """ if isinstance(name, dtypes.DType): raise TypeError("'name' must be a string, not a DType. " @@ -320,19 +311,6 @@ class Policy(object): raise TypeError("'name' must be a string, but got: %s" % (name,)) self._name = name self._compute_dtype, self._variable_dtype = self._parse_name(name) - - if loss_scale == 'auto': - loss_scale = 'dynamic' if name == 'mixed_float16' else None - self._using_default_loss_scale = True - else: - self._using_default_loss_scale = False - if loss_scale and self._compute_dtype not in (None, 'float16'): - tf_logging.warn('Creating a Policy with a loss scale is only useful for ' - 'float16 policies. You passed loss_scale=%r for policy ' - '%s. Consider not passing any loss_scale instead.' % - (loss_scale, name)) - self._loss_scale = keras_loss_scale_module.get(loss_scale) - if name in ('mixed_float16', 'mixed_bloat16'): device_compatibility_check.log_device_compatibility_check(name) @@ -438,22 +416,99 @@ class Policy(object): """ return self.variable_dtype != self.compute_dtype - @property - def loss_scale(self): - """Returns the loss scale of this Policy. - - Returns: - A `tf.mixed_precision.experimental.LossScale`, or None. - """ - return self._loss_scale - @property def name(self): """Returns the name of this policy.""" return self._name def __repr__(self): - return '' % (self._name, self.loss_scale) + return '' % self._name + + def get_config(self): + return {'name': self.name} + + @classmethod + def from_config(cls, config, custom_objects=None): + del custom_objects + if 'loss_scale' in config: + config = config.copy() + # Policy.get_config in TensorFlow 2.3 and below had a loss_scale. We + # silently drop it. + del config['loss_scale'] + return cls(**config) + + +@keras_export('keras.mixed_precision.experimental.Policy', v1=[]) +class PolicyV1(Policy): + """A deprecated dtype policy for a Keras layer. + + Warning: This class is now deprecated and will be removed soon. Please use the + non-experimental class `tf.keras.mixed_precision.Policy` instead. + + The difference between this class and the non-experimental class is that this + class has a `loss_scale` field and the non-experimental class does not. The + loss scale is only used by `tf.keras.Model.compile`, which automatically wraps + the optimizer with a `LossScaleOptimizer` if the optimzier is not already a + `LossScaleOptimizer`. For the non-experimental Policy class, `Model.compile` + instead wraps the optimizer with a `LossScaleOptimizer` if `Policy.name` is + "mixed_float16". + + When deserializing objects with an experimental policy using functions like + `tf.keras.utils.deserialize_keras_object`, the policy will be deserialized as + the non-experimental `tf.keras.mixed_precision.Policy`, and the loss scale + will silently be dropped. This is so that SavedModels that are generated + with an expeirmental policy can be restored after the experimental policy is + removed. + """ + + def __init__(self, name, loss_scale='auto'): + """Constructs the policy. + + The `name` argument determines the compute and variable dtype, the default + loss scale, and has no additional effect on the Policy. The compute and + variable dtypes can only be specified through `name`, and cannot be + specified directly. + + Args: + name: A string. Can be one of the following values: + * Any dtype name, such as 'float32' or 'float64'. Both the variable and + compute dtypes will be that dtype. + * 'mixed_float16' or 'mixed_bfloat16': The compute dtype is float16 or + bfloat16, while the variable dtype is float32. With 'mixed_float16', + a dynamic loss scale is used. These policies are used for mixed + precision training. + loss_scale: A `tf.compat.v1.mixed_precision.LossScale`, an int (which + uses a `FixedLossScale`), the string "dynamic" (which uses a + `DynamicLossScale`), or None (which uses no loss scale). Defaults to + `"auto"`. In the `"auto"` case: 1) if `name` is `"mixed_float16"`, then + use `loss_scale="dynamic"`. 2) otherwise, do not use a loss scale. Only + `tf.keras.Model`s, not layers, use the loss scale, and it is only used + during `Model.fit`, `Model.train_on_batch`, and other similar methods. + """ + super(PolicyV1, self).__init__(name) + if loss_scale == 'auto': + loss_scale = 'dynamic' if name == 'mixed_float16' else None + self._using_default_loss_scale = True + else: + self._using_default_loss_scale = False + if loss_scale and self._compute_dtype not in (None, 'float16'): + tf_logging.warn('Creating a Policy with a loss scale is only useful for ' + 'float16 policies. You passed loss_scale=%r for policy ' + '%s. Consider not passing any loss_scale instead.' % + (loss_scale, name)) + self._loss_scale = keras_loss_scale_module.get(loss_scale) + + @property + def loss_scale(self): + """Returns the loss scale of this Policy. + + Returns: + A `tf.compat.v1.mixed_precision.experimental.LossScale`, or None. + """ + return self._loss_scale + + def __repr__(self): + return '' % (self._name, self.loss_scale) def get_config(self): config = { @@ -497,8 +552,7 @@ def global_policy(): first time the layer is called. This behavior matches the behavior that existed in TensorFlow 1. - See `tf.keras.mixed_precision.experimental.Policy` for more information on - policies. + See `tf.keras.mixed_precision.Policy` for more information on policies. Returns: The global Policy. @@ -540,7 +594,7 @@ def set_policy(policy): `'int32'` and `'complex64'` cannot be set as the global policy because most layers do not support such policies. - See `tf.keras.mixed_precision.experimental.Policy` for more information. + See `tf.keras.mixed_precision.Policy` for more information. Args: policy: A Policy, or a string that will be converted to a Policy.. @@ -597,8 +651,8 @@ def _policy_equivalent_to_dtype(policy): """Returns True if the Policy is equivalent to a single dtype. A policy is equivalent to a single dtype if the policy's compute and variable - dtypes are the same and the policy does not cause the layer/model to have - additional behavior, such as loss scaling. + dtypes are the same and the policy's type is Policy and not a subclass of + Policy (such as PolicyV1). The "_infer" policy is considered equivalent to a single dtype. @@ -629,7 +683,7 @@ def deserialize(config, custom_objects=None): return Policy(config) if config is None: return Policy('_infer') - module_objects = {'Policy': Policy} + module_objects = {'Policy': Policy, 'PolicyV1': Policy} return generic_utils.deserialize_keras_object( config, module_objects=module_objects, diff --git a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py index 9ebcc3558e6..07096d64e6e 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/policy_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/policy_test.py @@ -60,14 +60,21 @@ class PolicyTest(test.TestCase, parameterized.TestCase): @testing_utils.enable_v2_dtype_behavior def test_repr(self): - for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): + # Test Policy repr + for policy in ('float32', 'int8', 'mixed_float16', 'mixed_bfloat16', + '_infer'): self.assertEqual(repr(mp_policy.Policy(policy)), - '' % policy) - self.assertEqual(repr(mp_policy.Policy('float16', loss_scale=2)), - '') + '' % policy) + + # Test PolicyV1 repr + for policy in ('float32', 'int8', 'mixed_bfloat16', '_infer'): + self.assertEqual(repr(mp_policy.PolicyV1(policy)), + '' % policy) + self.assertEqual(repr(mp_policy.PolicyV1('float16', loss_scale=2)), + '') self.assertStartsWith( - repr(mp_policy.Policy('mixed_float16')), - '" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "clipnorm" + mtype: "" + } + member { + name: "clipvalue" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } + member { + name: "global_clipnorm" + mtype: "" + } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "iterations" + mtype: "" + } + member { + name: "loss_scale" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_scaled_loss" + argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_unscaled_gradients" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index 3c016d331de..6c3a15cf128 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -1,5 +1,6 @@ path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" tf_class { + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -13,10 +14,26 @@ tf_class { name: "clipvalue" mtype: "" } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } member { name: "global_clipnorm" mtype: "" } + member { + name: "initial_scale" + mtype: "" + } member { name: "iterations" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt index e8648afb5f7..502aef38b75 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.mixed_precision.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.mixed_precision" tf_module { + member { + name: "LossScaleOptimizer" + mtype: "" + } member { name: "experimental" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt new file mode 100644 index 00000000000..c744ae30e11 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-dynamic-loss-scale.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.mixed_precision.DynamicLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "increment_period" + mtype: "" + } + member { + name: "initial_loss_scale" + mtype: "" + } + member { + name: "multiplier" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_loss_scale\', \'increment_period\', \'multiplier\'], varargs=None, keywords=None, defaults=[\'32768\', \'2000\', \'2.0\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt new file mode 100644 index 00000000000..7393181eb85 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-fixed-loss-scale.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.mixed_precision.FixedLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loss_scale_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt new file mode 100644 index 00000000000..044b49a9999 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.-loss-scale.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.mixed_precision.LossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt index 475c4a2ccde..e4db638806a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.mixed_precision.pbtxt @@ -1,5 +1,17 @@ path: "tensorflow.mixed_precision" tf_module { + member { + name: "DynamicLossScale" + mtype: "" + } + member { + name: "FixedLossScale" + mtype: "" + } + member { + name: "LossScale" + mtype: "" + } member { name: "experimental" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..67929db2aab --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-loss-scale-optimizer.pbtxt @@ -0,0 +1,112 @@ +path: "tensorflow.keras.mixed_precision.LossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "clipnorm" + mtype: "" + } + member { + name: "clipvalue" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } + member { + name: "global_clipnorm" + mtype: "" + } + member { + name: "initial_scale" + mtype: "" + } + member { + name: "iterations" + mtype: "" + } + member { + name: "loss_scale" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'optimizer\', \'dynamic\', \'initial_scale\', \'dynamic_growth_steps\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], " + } + member_method { + name: "add_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\', \'initializer\'], varargs=None, keywords=None, defaults=[\'zeros\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'trainable\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'zeros\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'name\', \'experimental_aggregate_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], " + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_gradients" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_scaled_loss" + argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_unscaled_gradients" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates" + argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'grad_loss\', \'name\', \'tape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt new file mode 100644 index 00000000000..927b78518d8 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.-policy.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.keras.mixed_precision.Policy" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "compute_dtype" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "should_cast_variables" + mtype: "" + } + member { + name: "variable_dtype" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt index 3c016d331de..6c3a15cf128 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -1,5 +1,6 @@ path: "tensorflow.keras.mixed_precision.experimental.LossScaleOptimizer" tf_class { + is_instance: "" is_instance: "" is_instance: "" is_instance: "" @@ -13,10 +14,26 @@ tf_class { name: "clipvalue" mtype: "" } + member { + name: "dynamic" + mtype: "" + } + member { + name: "dynamic_counter" + mtype: "" + } + member { + name: "dynamic_growth_steps" + mtype: "" + } member { name: "global_clipnorm" mtype: "" } + member { + name: "initial_scale" + mtype: "" + } member { name: "iterations" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt index e3435a32bef..5b6529c98f8 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.experimental.-policy.pbtxt @@ -1,5 +1,6 @@ path: "tensorflow.keras.mixed_precision.experimental.Policy" tf_class { + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt index e8648afb5f7..651cca8e210 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.mixed_precision.pbtxt @@ -1,5 +1,13 @@ path: "tensorflow.keras.mixed_precision" tf_module { + member { + name: "LossScaleOptimizer" + mtype: "" + } + member { + name: "Policy" + mtype: "" + } member { name: "experimental" mtype: "" diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index e2aa438cb1a..320823097ec 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -770,6 +770,12 @@ renames = { 'tf.linalg.matrix_transpose', 'tf.matrix_triangular_solve': 'tf.linalg.triangular_solve', + 'tf.mixed_precision.DynamicLossScale': + 'tf.compat.v1.mixed_precision.DynamicLossScale', + 'tf.mixed_precision.FixedLossScale': + 'tf.compat.v1.mixed_precision.FixedLossScale', + 'tf.mixed_precision.LossScale': + 'tf.compat.v1.mixed_precision.LossScale', 'tf.metrics.accuracy': 'tf.compat.v1.metrics.accuracy', 'tf.metrics.auc':