diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 86d8cee16dd..2ac6e81c6d9 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -197,6 +197,10 @@ class LossScale(trackable.Trackable): """Creates the LossScale from its config.""" return cls(**config) + def variables(self): + """Returns the variables defined in this LossScale.""" + return list(self._weights.values()) + def get_loss_scale_weights(loss_scale): return loss_scale._weights.values() # pylint: disable=protected-access diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py index 15a1dd565ed..edfad20dddd 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -243,3 +243,7 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): def _resource_apply_dense(self, grad, handle): """This function should never be called.""" raise RuntimeError('This function should never be called') + + def variables(self): + """Returns the variables of the Optimizer.""" + return self._optimizer.variables() + self._loss_scale.variables()