From ce592f32ee51bf2fd0e9c0659aae3a226aa75c00 Mon Sep 17 00:00:00 2001 From: zilinzhu <zilinzhu@tencent.com> Date: Mon, 30 Mar 2020 09:25:43 +0800 Subject: [PATCH] remove the added api on loss_scale --- tensorflow/python/training/experimental/loss_scale.py | 4 ---- .../python/training/experimental/loss_scale_optimizer.py | 3 ++- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 2ac6e81c6d9..86d8cee16dd 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -197,10 +197,6 @@ class LossScale(trackable.Trackable): """Creates the LossScale from its config.""" return cls(**config) - def variables(self): - """Returns the variables defined in this LossScale.""" - return list(self._weights.values()) - def get_loss_scale_weights(loss_scale): return loss_scale._weights.values() # pylint: disable=protected-access diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py index edfad20dddd..28c9bb36b82 100644 --- a/tensorflow/python/training/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -246,4 +246,5 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): def variables(self): """Returns the variables of the Optimizer.""" - return self._optimizer.variables() + self._loss_scale.variables() + return self._optimizer.variables() + \ + list(self._loss_scale._weights.values()) # pylint: disable=protected-access