Merge pull request #37965 from zhuzilin:keras-amp-variables-dev
PiperOrigin-RevId: 305096445 Change-Id: I2001cc63b77e99ea6aff819e94e8c7f366282da2
This commit is contained in:
commit
6cbb2f3405
@ -31,6 +31,7 @@ from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.utils import np_utils
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.adam import AdamOptimizer
|
||||
from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
|
||||
|
||||
|
||||
def _get_model(input_dim, num_hidden, output_dim):
|
||||
@ -232,6 +233,26 @@ class KerasOptimizersTest(keras_parameterized.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = keras.optimizers.Adam(clipnorm=-2.0)
|
||||
|
||||
def test_mixed_precision_loss_scale_optimizer(self):
|
||||
if context.executing_eagerly():
|
||||
self.skipTest('v1 optimizer does not run in eager mode')
|
||||
optimizer = MixedPrecisionLossScaleOptimizer(AdamOptimizer(), 'dynamic')
|
||||
model = keras.models.Sequential()
|
||||
model.add(
|
||||
keras.layers.Dense(
|
||||
2, input_shape=(3,),
|
||||
kernel_constraint=keras.constraints.MaxNorm(1)))
|
||||
model.compile(
|
||||
loss='mean_squared_error',
|
||||
optimizer=optimizer,
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
model.fit(
|
||||
np.random.random((5, 3)),
|
||||
np.random.random((5, 2)),
|
||||
epochs=1,
|
||||
batch_size=5,
|
||||
verbose=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -243,3 +243,8 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
||||
def _resource_apply_dense(self, grad, handle):
|
||||
"""This function should never be called."""
|
||||
raise RuntimeError('This function should never be called')
|
||||
|
||||
def variables(self):
|
||||
"""Returns the variables of the Optimizer."""
|
||||
return (self._optimizer.variables() +
|
||||
list(self._loss_scale._weights.values())) # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user