Fix bug in MixedPrecisionLossScaleOptimizer

This commit is contained in:
zilinzhu 2020-03-27 14:17:53 +08:00
parent e8dbf1de1a
commit df6e5c6b59
2 changed files with 8 additions and 0 deletions

View File

@ -197,6 +197,10 @@ class LossScale(trackable.Trackable):
"""Creates the LossScale from its config.""" """Creates the LossScale from its config."""
return cls(**config) return cls(**config)
def variables(self):
"""Returns the variables defined in this LossScale."""
return list(self._weights.values())
def get_loss_scale_weights(loss_scale): def get_loss_scale_weights(loss_scale):
return loss_scale._weights.values() # pylint: disable=protected-access return loss_scale._weights.values() # pylint: disable=protected-access

View File

@ -243,3 +243,7 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
def _resource_apply_dense(self, grad, handle): def _resource_apply_dense(self, grad, handle):
"""This function should never be called.""" """This function should never be called."""
raise RuntimeError('This function should never be called') raise RuntimeError('This function should never be called')
def variables(self):
"""Returns the variables of the Optimizer."""
return self._optimizer.variables() + self._loss_scale.variables()