Try to break the dependency of MP in training and keras.

PiperOrigin-RevId: 303034978
Change-Id: I39cdc250d81f7032ad7c018ff5469ca49910dd64
This commit is contained in:
Scott Zhu 2020-03-25 21:17:05 -07:00 committed by TensorFlower Gardener
parent b3212dd802
commit 97099b3610
2 changed files with 28 additions and 36 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.util.tf_export import keras_export
@ -387,6 +388,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
# optimizer being used.
# pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
LossScaleOptimizer)
def _multiply_gradient(gradient, scale):
"""Multiply a (possibly sparse) gradient by the given scale factor."""
scale = math_ops.cast(scale, gradient.dtype)

View File

@ -23,49 +23,35 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1
from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# A mapping between optimizers and the corresponding wrapper class that will be
# used for mixed precision.
_REGISTERED_WRAPPER_OPTIMIZER_CLS = {
optimizer.Optimizer:
loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,
}
def _register_wrapper_optimizer_cls(optimizer_cls, wrapper_optimizer_cls):
_REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = wrapper_optimizer_cls
def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
"""Wraps an optimizer with a LossScaleOptimizer."""
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer):
raise ValueError('"opt" must not already be an instance of a '
'MixedPrecisionLossScaleOptimizer. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'MixedPrecisionLossScaleOptimizer.')
# To avoid a circular dependency, we cannot depend on tf.keras. Because
# LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
# the class name.
if opt.__class__.__name__ == 'LossScaleOptimizer':
raise ValueError('"opt" must not already be an instance of a '
'LossScaleOptimizer. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'LossScaleOptimizer.')
for wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values():
if isinstance(opt, wrapper_optimizer):
raise ValueError('"opt" must not already be an instance of a {cls}. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'{cls}.'
.format(cls=wrapper_optimizer.__name__))
if isinstance(opt, optimizer.Optimizer):
# For convenience, we allow the V2 version of this function to wrap the V1
# optimizer, even though we do not document this.
return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt,
loss_scale)
# Because we cannot depend on tf.keras, we see if `opt` is an instance of the
# Keras OptimizerV2 class by checking the subclass names.
base_classes = tf_inspect.getmro(opt.__class__)
base_class_names = [cls.__name__ for cls in base_classes]
is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names
if is_loss_scale_optimizer_v2:
# Because we cannot depend on tf.keras, we cannot unconditionally do this
# import. But since `opt` is a Keras OptimizerV2, we know keras is
# importable, so it is safe to do this import. (Technically, it's possible
# to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
# is not done in practice).
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
for optimizer_cls, wrapper_cls in _REGISTERED_WRAPPER_OPTIMIZER_CLS.items():
if isinstance(opt, optimizer_cls):
return wrapper_cls(opt, loss_scale)
if use_v1_behavior:
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '