Add tfe.get_optimizer_variables for fetching a list of variables which an

optimizer has created. Useful for saving them if executing eagerly.

PiperOrigin-RevId: 173726859
This commit is contained in:
Allen Lavoie 2017-10-27 15:35:06 -07:00 committed by TensorFlower Gardener
parent 02f55400f8
commit 5426a3c93d
3 changed files with 66 additions and 0 deletions

View File

@ -23,6 +23,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.training import adam as _adam
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as _saver
@ -165,3 +166,25 @@ class Saver(object):
"""
with ops.device("/device:CPU:0"):
self._saver.restore(None, file_prefix)
def get_optimizer_variables(optimizer):
"""Returns a list of variables for the given `tf.train.Optimizer`.
Args:
optimizer: An instance of `tf.train.Optimizer` which has created variables
(typically after a call to `Optimizer.minimize`).
Returns:
A list of variables which have been created by the `Optimizer`. Currently
returns all variables even if they were not created in the default graph,
but this behavior may change.
"""
variables = []
# pylint: disable=protected-access
for _, variable_dict in optimizer._slots.items():
for _, slot_for_variable in variable_dict.items():
variables.append(slot_for_variable)
if isinstance(optimizer, _adam.AdamOptimizer):
variables.append(optimizer._beta1_power)
variables.append(optimizer._beta2_power)
return variables

View File

@ -30,6 +30,10 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import adam
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
from tensorflow.python.training import rmsprop
class SaverTest(test.TestCase):
@ -204,5 +208,42 @@ class SaverTest(test.TestCase):
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
class GetOptimizerTests(test.TestCase):
def _optimizer_test_template(self, optimizer):
"""Checks save and restore. Returns the optimizer variables."""
v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v')
loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2
optimizer.minimize(loss_fn)
optimizer_variables = _saver.get_optimizer_variables(optimizer)
saver = _saver.Saver(optimizer_variables + [v])
checkpoint_path = saver.save(self.get_temp_dir())
optimizer.minimize(loss_fn)
after_first_minimize = v.numpy()
# After we restore, the next step should be exactly the same as the one we
# just did.
saver.restore(checkpoint_path)
optimizer.minimize(loss_fn)
self.assertAllEqual(after_first_minimize, v.numpy())
return optimizer_variables
def testAdam(self):
optimizer = adam.AdamOptimizer(0.1)
self._optimizer_test_template(optimizer)
def testGradientDescent(self):
optimizer = gradient_descent.GradientDescentOptimizer(0.02)
self.assertEqual(0, len(self._optimizer_test_template(optimizer)))
def testMomentum(self):
optimizer = momentum.MomentumOptimizer(
learning_rate=0.03,
momentum=0.5)
self._optimizer_test_template(optimizer)
def testRMSProp(self):
optimizer = rmsprop.RMSPropOptimizer(0.01)
self._optimizer_test_template(optimizer)
if __name__ == '__main__':
test.main()

View File

@ -51,6 +51,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@SummaryWriter
@@restore_variables_on_create
@@Variable
@@get_optimizer_variables
@@in_eager_mode
@@in_graph_mode
@ -73,6 +74,7 @@ from __future__ import print_function
from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.eager.python.datasets import Iterator
from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver
from tensorflow.contrib.eager.python.summary_writer import SummaryWriter