diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 8569f55d164..c6e1d57c5e4 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -346,8 +346,26 @@ class OptimizerV2(optimizer_v1.Optimizer): value = self._hyper[name] return self._call_if_callable(value) + def __getattribute__(self, name): + """Overridden to support hyperparameter access.""" + try: + return super(OptimizerV2, self).__getattribute__(name) + except AttributeError as e: + # Needed to avoid infinite recursion with __setattr__. + if name == "_hyper": + raise e + # Backwards compatibility with Keras optimizers. + if name == "lr": + name = "learning_rate" + if name in self._hyper: + return self._hyper[name] + raise e + def __setattr__(self, name, value): """Override setattr to support dynamic hyperparameter setting.""" + # Backwards compatibility with Keras optimizers. + if name == "lr": + name = "learning_rate" if hasattr(self, "_hyper") and name in self._hyper: self._set_hyper(name, value) else: diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index e5d1a104ca4..682deda23f0 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -370,6 +370,30 @@ class OptimizerTest(test.TestCase): self.assertAllClose( self.evaluate([var3, var4]), self.evaluate([var5, var6])) + @test_util.run_in_graph_and_eager_modes + def testGettingHyperParameters(self): + opt = adam.Adam(learning_rate=1.0) + var = resource_variable_ops.ResourceVariable([1.0, 2.0], + dtype=dtypes.float32) + loss = lambda: 3 * var + opt_op = opt.minimize(loss, [var]) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + + lr = self.evaluate(opt.lr) + self.assertEqual(1.0, lr) + + opt.lr = 2.0 + lr = self.evaluate(opt.lr) + self.assertEqual(2.0, lr) + + self.evaluate(opt.lr.assign(3.0)) + lr = self.evaluate(opt.lr) + self.assertEqual(3.0, lr) + + with self.assertRaises(AttributeError): + opt.not_an_attr += 3 + def testOptimizerWithFunction(self): with context.eager_mode(): var = resource_variable_ops.ResourceVariable([1.0, 2.0],