Merge pull request #37965 from zhuzilin:keras-amp-variables-dev

PiperOrigin-RevId: 305096445
Change-Id: I2001cc63b77e99ea6aff819e94e8c7f366282da2
This commit is contained in:
TensorFlower Gardener 2020-04-06 12:54:33 -07:00
commit 6cbb2f3405
2 changed files with 26 additions and 0 deletions

View File

@ -31,6 +31,7 @@ from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.platform import test
from tensorflow.python.training.adam import AdamOptimizer
from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
def _get_model(input_dim, num_hidden, output_dim):
@ -232,6 +233,26 @@ class KerasOptimizersTest(keras_parameterized.TestCase):
with self.assertRaises(ValueError):
_ = keras.optimizers.Adam(clipnorm=-2.0)
def test_mixed_precision_loss_scale_optimizer(self):
if context.executing_eagerly():
self.skipTest('v1 optimizer does not run in eager mode')
optimizer = MixedPrecisionLossScaleOptimizer(AdamOptimizer(), 'dynamic')
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
2, input_shape=(3,),
kernel_constraint=keras.constraints.MaxNorm(1)))
model.compile(
loss='mean_squared_error',
optimizer=optimizer,
run_eagerly=testing_utils.should_run_eagerly())
model.fit(
np.random.random((5, 3)),
np.random.random((5, 2)),
epochs=1,
batch_size=5,
verbose=0)
if __name__ == '__main__':
test.main()

View File

@ -243,3 +243,8 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
def _resource_apply_dense(self, grad, handle):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def variables(self):
"""Returns the variables of the Optimizer."""
return (self._optimizer.variables() +
list(self._loss_scale._weights.values())) # pylint: disable=protected-access