Override __getattr__ on Optimizers V2 to support hyperparameter access.

PiperOrigin-RevId: 221038643
This commit is contained in:
A. Unique TensorFlower 2018-11-11 22:23:03 -08:00 committed by TensorFlower Gardener
parent 6e60c730b2
commit f7bd506330
2 changed files with 42 additions and 0 deletions

View File

@ -346,8 +346,26 @@ class OptimizerV2(optimizer_v1.Optimizer):
value = self._hyper[name]
return self._call_if_callable(value)
def __getattribute__(self, name):
"""Overridden to support hyperparameter access."""
try:
return super(OptimizerV2, self).__getattribute__(name)
except AttributeError as e:
# Needed to avoid infinite recursion with __setattr__.
if name == "_hyper":
raise e
# Backwards compatibility with Keras optimizers.
if name == "lr":
name = "learning_rate"
if name in self._hyper:
return self._hyper[name]
raise e
def __setattr__(self, name, value):
"""Override setattr to support dynamic hyperparameter setting."""
# Backwards compatibility with Keras optimizers.
if name == "lr":
name = "learning_rate"
if hasattr(self, "_hyper") and name in self._hyper:
self._set_hyper(name, value)
else:

View File

@ -370,6 +370,30 @@ class OptimizerTest(test.TestCase):
self.assertAllClose(
self.evaluate([var3, var4]), self.evaluate([var5, var6]))
@test_util.run_in_graph_and_eager_modes
def testGettingHyperParameters(self):
opt = adam.Adam(learning_rate=1.0)
var = resource_variable_ops.ResourceVariable([1.0, 2.0],
dtype=dtypes.float32)
loss = lambda: 3 * var
opt_op = opt.minimize(loss, [var])
self.evaluate(variables.global_variables_initializer())
self.evaluate(opt_op)
lr = self.evaluate(opt.lr)
self.assertEqual(1.0, lr)
opt.lr = 2.0
lr = self.evaluate(opt.lr)
self.assertEqual(2.0, lr)
self.evaluate(opt.lr.assign(3.0))
lr = self.evaluate(opt.lr)
self.assertEqual(3.0, lr)
with self.assertRaises(AttributeError):
opt.not_an_attr += 3
def testOptimizerWithFunction(self):
with context.eager_mode():
var = resource_variable_ops.ResourceVariable([1.0, 2.0],