diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 2ac6e81c6d9..86d8cee16dd 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -197,10 +197,6 @@ class LossScale(trackable.Trackable): """Creates the LossScale from its config.""" return cls(**config) - def variables(self): - """Returns the variables defined in this LossScale.""" - return list(self._weights.values()) - def get_loss_scale_weights(loss_scale): return loss_scale._weights.values() # pylint: disable=protected-access diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py index edfad20dddd..28c9bb36b82 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -246,4 +246,5 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): def variables(self): """Returns the variables of the Optimizer.""" - return self._optimizer.variables() + self._loss_scale.variables() + return self._optimizer.variables() + \ + list(self._loss_scale._weights.values()) # pylint: disable=protected-access