diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index 5ce9d90e8a1..bb0dc226e87 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -30,6 +30,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.experimental import mixed_precision from tensorflow.python.util.tf_export import keras_export @@ -387,6 +388,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2): # optimizer being used. +# pylint: disable=protected-access +mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2, + LossScaleOptimizer) + + def _multiply_gradient(gradient, scale): """Multiply a (possibly sparse) gradient by the given scale factor.""" scale = math_ops.cast(scale, gradient.dtype) diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py index 8e1bf42ddb8..38377dd0600 100644 --- a/tensorflow/python/training/experimental/mixed_precision.py +++ b/tensorflow/python/training/experimental/mixed_precision.py @@ -23,49 +23,35 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import optimizer from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 from tensorflow.python.training.experimental import mixed_precision_global_state -from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export +# A mapping between optimizers and the corresponding wrapper class that will be +# used for mixed precision. +_REGISTERED_WRAPPER_OPTIMIZER_CLS = { + optimizer.Optimizer: + loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer, +} + + +def _register_wrapper_optimizer_cls(optimizer_cls, wrapper_optimizer_cls): + _REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = wrapper_optimizer_cls + + def _wrap_optimizer(opt, loss_scale, use_v1_behavior): """Wraps an optimizer with a LossScaleOptimizer.""" - if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer): - raise ValueError('"opt" must not already be an instance of a ' - 'MixedPrecisionLossScaleOptimizer. ' - '`enable_mixed_precision_graph_rewrite` will ' - 'automatically wrap the optimizer with a ' - 'MixedPrecisionLossScaleOptimizer.') - # To avoid a circular dependency, we cannot depend on tf.keras. Because - # LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check - # the class name. - if opt.__class__.__name__ == 'LossScaleOptimizer': - raise ValueError('"opt" must not already be an instance of a ' - 'LossScaleOptimizer. ' - '`enable_mixed_precision_graph_rewrite` will ' - 'automatically wrap the optimizer with a ' - 'LossScaleOptimizer.') + for wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values(): + if isinstance(opt, wrapper_optimizer): + raise ValueError('"opt" must not already be an instance of a {cls}. ' + '`enable_mixed_precision_graph_rewrite` will ' + 'automatically wrap the optimizer with a ' + '{cls}.' + .format(cls=wrapper_optimizer.__name__)) - if isinstance(opt, optimizer.Optimizer): - # For convenience, we allow the V2 version of this function to wrap the V1 - # optimizer, even though we do not document this. - return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt, - loss_scale) - - # Because we cannot depend on tf.keras, we see if `opt` is an instance of the - # Keras OptimizerV2 class by checking the subclass names. - base_classes = tf_inspect.getmro(opt.__class__) - base_class_names = [cls.__name__ for cls in base_classes] - is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names - - if is_loss_scale_optimizer_v2: - # Because we cannot depend on tf.keras, we cannot unconditionally do this - # import. But since `opt` is a Keras OptimizerV2, we know keras is - # importable, so it is safe to do this import. (Technically, it's possible - # to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this - # is not done in practice). - from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top - return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale) + for optimizer_cls, wrapper_cls in _REGISTERED_WRAPPER_OPTIMIZER_CLS.items(): + if isinstance(opt, optimizer_cls): + return wrapper_cls(opt, loss_scale) if use_v1_behavior: raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '