We do not support complex with certain optimizers such as Ftrl, FtrlV2, AdamWithAmsgrad, AdaMax, AddSign & PowerSign since they may use missing operations on complex values such as sqrt. Fixes #32774 PiperOrigin-RevId: 277953548 Change-Id: Ia075aa5c3f944de932d71b9741d626f7ebe5416f
133 lines
5.2 KiB
Python
133 lines
5.2 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for RMSProp optimizer."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.compiler.tests import xla_test
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import variables
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.training import rmsprop
|
|
|
|
|
|
class RmspropTest(xla_test.XLATestCase):
|
|
|
|
def _rmsprop_update_numpy(self,
|
|
var,
|
|
g,
|
|
mg,
|
|
rms,
|
|
mom,
|
|
lr,
|
|
decay=0.9,
|
|
momentum=0.0,
|
|
epsilon=1e-10,
|
|
centered=False):
|
|
rms_t = rms * decay + (1 - decay) * g * g
|
|
denom_t = rms_t + epsilon
|
|
if centered:
|
|
mg_t = mg * decay + (1 - decay) * g
|
|
denom_t -= mg_t * mg_t
|
|
else:
|
|
mg_t = mg
|
|
mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
|
|
var_t = var - mom_t
|
|
return var_t, mg_t, rms_t, mom_t
|
|
|
|
def testBasic(self):
|
|
for dtype in self.float_types | self.complex_types:
|
|
for centered in [False, True]:
|
|
with self.session(), self.test_scope():
|
|
# Initialize variables for numpy implementation.
|
|
var0_np = np.array([1.0, 2.0], dtype=dtype)
|
|
grads0_np = np.array([0.1, 0.1], dtype=dtype)
|
|
var1_np = np.array([3.0, 4.0], dtype=dtype)
|
|
grads1_np = np.array([0.01, 0.01], dtype=dtype)
|
|
mg0_np = np.array([0.0, 0.0], dtype=dtype)
|
|
mg1_np = np.array([0.0, 0.0], dtype=dtype)
|
|
rms0_np = np.array([1.0, 1.0], dtype=dtype)
|
|
rms1_np = np.array([1.0, 1.0], dtype=dtype)
|
|
mom0_np = np.array([0.0, 0.0], dtype=dtype)
|
|
mom1_np = np.array([0.0, 0.0], dtype=dtype)
|
|
|
|
var0 = resource_variable_ops.ResourceVariable(var0_np)
|
|
var1 = resource_variable_ops.ResourceVariable(var1_np)
|
|
grads0 = constant_op.constant(grads0_np)
|
|
grads1 = constant_op.constant(grads1_np)
|
|
learning_rate = 3.0
|
|
rms_opt = rmsprop.RMSPropOptimizer(learning_rate, centered=centered)
|
|
rms_update = rms_opt.apply_gradients(
|
|
zip([grads0, grads1], [var0, var1]))
|
|
self.evaluate(variables.global_variables_initializer())
|
|
|
|
mg0 = rms_opt.get_slot(var0, "mg")
|
|
self.assertEqual(mg0 is not None, centered)
|
|
mg1 = rms_opt.get_slot(var1, "mg")
|
|
self.assertEqual(mg1 is not None, centered)
|
|
rms0 = rms_opt.get_slot(var0, "rms")
|
|
self.assertIsNotNone(rms0)
|
|
rms1 = rms_opt.get_slot(var1, "rms")
|
|
self.assertIsNotNone(rms1)
|
|
mom0 = rms_opt.get_slot(var0, "momentum")
|
|
self.assertIsNotNone(mom0)
|
|
mom1 = rms_opt.get_slot(var1, "momentum")
|
|
self.assertIsNotNone(mom1)
|
|
|
|
# Fetch params to validate initial values
|
|
self.assertAllClose([1.0, 2.0], self.evaluate(var0))
|
|
self.assertAllClose([3.0, 4.0], self.evaluate(var1))
|
|
|
|
# Run 3 steps of RMSProp
|
|
for _ in range(3):
|
|
self.evaluate(rms_update)
|
|
|
|
var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
|
|
var0_np,
|
|
grads0_np,
|
|
mg0_np,
|
|
rms0_np,
|
|
mom0_np,
|
|
learning_rate,
|
|
centered=centered)
|
|
var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
|
|
var1_np,
|
|
grads1_np,
|
|
mg1_np,
|
|
rms1_np,
|
|
mom1_np,
|
|
learning_rate,
|
|
centered=centered)
|
|
|
|
# Validate updated params
|
|
if centered:
|
|
self.assertAllCloseAccordingToType(mg0_np, self.evaluate(mg0))
|
|
self.assertAllCloseAccordingToType(mg1_np, self.evaluate(mg1))
|
|
self.assertAllCloseAccordingToType(rms0_np, self.evaluate(rms0))
|
|
self.assertAllCloseAccordingToType(rms1_np, self.evaluate(rms1))
|
|
self.assertAllCloseAccordingToType(mom0_np, self.evaluate(mom0))
|
|
self.assertAllCloseAccordingToType(mom1_np, self.evaluate(mom1))
|
|
self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
|
|
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|