Try to break the dependency of MP in training and keras.

PiperOrigin-RevId: 303034978
Change-Id: I39cdc250d81f7032ad7c018ff5469ca49910dd64
This commit is contained in:
Scott Zhu 2020-03-25 21:17:05 -07:00 committed by TensorFlower Gardener
parent b3212dd802
commit 97099b3610
2 changed files with 28 additions and 36 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import mixed_precision
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -387,6 +388,11 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
# optimizer being used. # optimizer being used.
# pylint: disable=protected-access
mixed_precision._register_wrapper_optimizer_cls(optimizer_v2.OptimizerV2,
LossScaleOptimizer)
def _multiply_gradient(gradient, scale): def _multiply_gradient(gradient, scale):
"""Multiply a (possibly sparse) gradient by the given scale factor.""" """Multiply a (possibly sparse) gradient by the given scale factor."""
scale = math_ops.cast(scale, gradient.dtype) scale = math_ops.cast(scale, gradient.dtype)

View File

@ -23,49 +23,35 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import optimizer from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1 from tensorflow.python.training.experimental import loss_scale_optimizer as loss_scale_optimizer_v1
from tensorflow.python.training.experimental import mixed_precision_global_state from tensorflow.python.training.experimental import mixed_precision_global_state
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# A mapping between optimizers and the corresponding wrapper class that will be
# used for mixed precision.
_REGISTERED_WRAPPER_OPTIMIZER_CLS = {
optimizer.Optimizer:
loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer,
}
def _register_wrapper_optimizer_cls(optimizer_cls, wrapper_optimizer_cls):
_REGISTERED_WRAPPER_OPTIMIZER_CLS[optimizer_cls] = wrapper_optimizer_cls
def _wrap_optimizer(opt, loss_scale, use_v1_behavior): def _wrap_optimizer(opt, loss_scale, use_v1_behavior):
"""Wraps an optimizer with a LossScaleOptimizer.""" """Wraps an optimizer with a LossScaleOptimizer."""
if isinstance(opt, loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer): for wrapper_optimizer in _REGISTERED_WRAPPER_OPTIMIZER_CLS.values():
raise ValueError('"opt" must not already be an instance of a ' if isinstance(opt, wrapper_optimizer):
'MixedPrecisionLossScaleOptimizer. ' raise ValueError('"opt" must not already be an instance of a {cls}. '
'`enable_mixed_precision_graph_rewrite` will ' '`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a ' 'automatically wrap the optimizer with a '
'MixedPrecisionLossScaleOptimizer.') '{cls}.'
# To avoid a circular dependency, we cannot depend on tf.keras. Because .format(cls=wrapper_optimizer.__name__))
# LossScaleOptimizer is in Keras, we cannot use isinstance, so instead check
# the class name.
if opt.__class__.__name__ == 'LossScaleOptimizer':
raise ValueError('"opt" must not already be an instance of a '
'LossScaleOptimizer. '
'`enable_mixed_precision_graph_rewrite` will '
'automatically wrap the optimizer with a '
'LossScaleOptimizer.')
if isinstance(opt, optimizer.Optimizer): for optimizer_cls, wrapper_cls in _REGISTERED_WRAPPER_OPTIMIZER_CLS.items():
# For convenience, we allow the V2 version of this function to wrap the V1 if isinstance(opt, optimizer_cls):
# optimizer, even though we do not document this. return wrapper_cls(opt, loss_scale)
return loss_scale_optimizer_v1.MixedPrecisionLossScaleOptimizer(opt,
loss_scale)
# Because we cannot depend on tf.keras, we see if `opt` is an instance of the
# Keras OptimizerV2 class by checking the subclass names.
base_classes = tf_inspect.getmro(opt.__class__)
base_class_names = [cls.__name__ for cls in base_classes]
is_loss_scale_optimizer_v2 = 'OptimizerV2' in base_class_names
if is_loss_scale_optimizer_v2:
# Because we cannot depend on tf.keras, we cannot unconditionally do this
# import. But since `opt` is a Keras OptimizerV2, we know keras is
# importable, so it is safe to do this import. (Technically, it's possible
# to have a dependency on OptimizerV2 and not LossScaleOptimizer, but this
# is not done in practice).
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as loss_scale_optimizer_v2 # pylint: disable=g-import-not-at-top
return loss_scale_optimizer_v2.LossScaleOptimizer(opt, loss_scale)
if use_v1_behavior: if use_v1_behavior:
raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a ' raise ValueError('"opt" must be an instance of a tf.train.Optimizer or a '