Try to break the dependency of MP in training and keras.
PiperOrigin-RevId: 303034978 Change-Id: I39cdc250d81f7032ad7c018ff5469ca49910dd64
This commit is contained in:
parent
b3212dd802
commit
97099b3610
@ -30,6 +30,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import mixed_precision
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -387,6 +388,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
# optimizer being used.
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
|
||||
LossScaleOptimizer)
|
||||
|
||||
|
||||
def _multiply_gradient(gradient, scale):
|
||||
"""Multiply a (possibly sparse) gradient by the given scale factor."""
|
||||
scale = math_ops.cast(scale, gradient.dtype)
|
||||
|
@ -23,49 +23,35 @@ from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training import optimizer
|
||||
from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1
|
||||
from tensorflow.python.training.experimental import mixed_precision_global_state
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# A mapping between optimizers and the corresponding wrapper class that will be
|
||||
# used for mixed precision.
|
||||
_REGISTERED_WRAPPER_OPTIMIZER_CLS = {
|
||||
optimizer.Optimizer:
|
||||
loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,
|
||||
}
|
||||
|
||||
|
||||
def _register_wrapper_optimizer_cls(optimizer_cls, wrapper_optimizer_cls):
|
||||
_REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = wrapper_optimizer_cls
|
||||
|
||||
|
||||
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
|
||||
"""Wraps an optimizer with a LossScaleOptimizer."""
|
||||
|
||||
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
|
||||
raise ValueError('"opt" must not already be an instance of a '
|
||||
'MixedPrecisionLossScaleOptimizer. '
|
||||
'`enable_mixed_precision_graph_rewrite` will '
|
||||
'automatically wrap the optimizer with a '
|
||||
'MixedPrecisionLossScaleOptimizer.')
|
||||
# To avoid a circular dependency, we cannot depend on tf.keras. Because
|
||||
# LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
|
||||
# the class name.
|
||||
if opt.__class__.__name__ == 'LossScaleOptimizer':
|
||||
raise ValueError('"opt" must not already be an instance of a '
|
||||
'LossScaleOptimizer. '
|
||||
'`enable_mixed_precision_graph_rewrite` will '
|
||||
'automatically wrap the optimizer with a '
|
||||
'LossScaleOptimizer.')
|
||||
for wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values():
|
||||
if isinstance(opt, wrapper_optimizer):
|
||||
raise ValueError('"opt" must not already be an instance of a {cls}. '
|
||||
'`enable_mixed_precision_graph_rewrite` will '
|
||||
'automatically wrap the optimizer with a '
|
||||
'{cls}.'
|
||||
.format(cls=wrapper_optimizer.__name__))
|
||||
|
||||
if isinstance(opt, optimizer.Optimizer):
|
||||
# For convenience, we allow the V2 version of this function to wrap the V1
|
||||
# optimizer, even though we do not document this.
|
||||
return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt,
|
||||
loss_scale)
|
||||
|
||||
# Because we cannot depend on tf.keras, we see if `opt` is an instance of the
|
||||
# Keras OptimizerV2 class by checking the subclass names.
|
||||
base_classes = tf_inspect.getmro(opt.__class__)
|
||||
base_class_names = [cls.__name__ for cls in base_classes]
|
||||
is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names
|
||||
|
||||
if is_loss_scale_optimizer_v2:
|
||||
# Because we cannot depend on tf.keras, we cannot unconditionally do this
|
||||
# import. But since `opt` is a Keras OptimizerV2, we know keras is
|
||||
# importable, so it is safe to do this import. (Technically, it's possible
|
||||
# to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
|
||||
# is not done in practice).
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
|
||||
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
|
||||
for optimizer_cls, wrapper_cls in _REGISTERED_WRAPPER_OPTIMIZER_CLS.items():
|
||||
if isinstance(opt, optimizer_cls):
|
||||
return wrapper_cls(opt, loss_scale)
|
||||
|
||||
if use_v1_behavior:
|
||||
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '
|
||||
|
Loading…
Reference in New Issue
Block a user