Override __getattr__ on Optimizers V2 to support hyperparameter access.
PiperOrigin-RevId: 221038643
This commit is contained in:
parent
6e60c730b2
commit
f7bd506330
@ -346,8 +346,26 @@ class OptimizerV2(optimizer_v1.Optimizer):
|
|||||||
value = self._hyper[name]
|
value = self._hyper[name]
|
||||||
return self._call_if_callable(value)
|
return self._call_if_callable(value)
|
||||||
|
|
||||||
|
def __getattribute__(self, name):
|
||||||
|
"""Overridden to support hyperparameter access."""
|
||||||
|
try:
|
||||||
|
return super(OptimizerV2, self).__getattribute__(name)
|
||||||
|
except AttributeError as e:
|
||||||
|
# Needed to avoid infinite recursion with __setattr__.
|
||||||
|
if name == "_hyper":
|
||||||
|
raise e
|
||||||
|
# Backwards compatibility with Keras optimizers.
|
||||||
|
if name == "lr":
|
||||||
|
name = "learning_rate"
|
||||||
|
if name in self._hyper:
|
||||||
|
return self._hyper[name]
|
||||||
|
raise e
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
"""Override setattr to support dynamic hyperparameter setting."""
|
"""Override setattr to support dynamic hyperparameter setting."""
|
||||||
|
# Backwards compatibility with Keras optimizers.
|
||||||
|
if name == "lr":
|
||||||
|
name = "learning_rate"
|
||||||
if hasattr(self, "_hyper") and name in self._hyper:
|
if hasattr(self, "_hyper") and name in self._hyper:
|
||||||
self._set_hyper(name, value)
|
self._set_hyper(name, value)
|
||||||
else:
|
else:
|
||||||
|
@ -370,6 +370,30 @@ class OptimizerTest(test.TestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
self.evaluate([var3, var4]), self.evaluate([var5, var6]))
|
self.evaluate([var3, var4]), self.evaluate([var5, var6]))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testGettingHyperParameters(self):
|
||||||
|
opt = adam.Adam(learning_rate=1.0)
|
||||||
|
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
loss = lambda: 3 * var
|
||||||
|
opt_op = opt.minimize(loss, [var])
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(opt_op)
|
||||||
|
|
||||||
|
lr = self.evaluate(opt.lr)
|
||||||
|
self.assertEqual(1.0, lr)
|
||||||
|
|
||||||
|
opt.lr = 2.0
|
||||||
|
lr = self.evaluate(opt.lr)
|
||||||
|
self.assertEqual(2.0, lr)
|
||||||
|
|
||||||
|
self.evaluate(opt.lr.assign(3.0))
|
||||||
|
lr = self.evaluate(opt.lr)
|
||||||
|
self.assertEqual(3.0, lr)
|
||||||
|
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
opt.not_an_attr += 3
|
||||||
|
|
||||||
def testOptimizerWithFunction(self):
|
def testOptimizerWithFunction(self):
|
||||||
with context.eager_mode():
|
with context.eager_mode():
|
||||||
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
|
||||||
|
Loading…
Reference in New Issue
Block a user