Override __getattr__ on Optimizers V2 to support hyperparameter access.
PiperOrigin-RevId: 221038643
This commit is contained in:
parent
6e60c730b2
commit
f7bd506330
@ -346,8 +346,26 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
||||
value = self._hyper[name]
|
||||
return self._call_if_callable(value)
|
||||
|
||||
def __getattribute__(self, name):
|
||||
"""Overridden to support hyperparameter access."""
|
||||
try:
|
||||
return super(OptimizerV2, self).__getattribute__(name)
|
||||
except AttributeError as e:
|
||||
# Needed to avoid infinite recursion with __setattr__.
|
||||
if name == "_hyper":
|
||||
raise e
|
||||
# Backwards compatibility with Keras optimizers.
|
||||
if name == "lr":
|
||||
name = "learning_rate"
|
||||
if name in self._hyper:
|
||||
return self._hyper[name]
|
||||
raise e
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Override setattr to support dynamic hyperparameter setting."""
|
||||
# Backwards compatibility with Keras optimizers.
|
||||
if name == "lr":
|
||||
name = "learning_rate"
|
||||
if hasattr(self, "_hyper") and name in self._hyper:
|
||||
self._set_hyper(name, value)
|
||||
else:
|
||||
|
@ -370,6 +370,30 @@ class OptimizerTest(test.TestCase):
|
||||
self.assertAllClose(
|
||||
self.evaluate([var3, var4]), self.evaluate([var5, var6]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGettingHyperParameters(self):
|
||||
opt = adam.Adam(learning_rate=1.0)
|
||||
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||
dtype=dtypes.float32)
|
||||
loss = lambda: 3 * var
|
||||
opt_op = opt.minimize(loss, [var])
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(opt_op)
|
||||
|
||||
lr = self.evaluate(opt.lr)
|
||||
self.assertEqual(1.0, lr)
|
||||
|
||||
opt.lr = 2.0
|
||||
lr = self.evaluate(opt.lr)
|
||||
self.assertEqual(2.0, lr)
|
||||
|
||||
self.evaluate(opt.lr.assign(3.0))
|
||||
lr = self.evaluate(opt.lr)
|
||||
self.assertEqual(3.0, lr)
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
opt.not_an_attr += 3
|
||||
|
||||
def testOptimizerWithFunction(self):
|
||||
with context.eager_mode():
|
||||
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||
|
Loading…
Reference in New Issue
Block a user