Add tfe.get_optimizer_variables for fetching a list of variables which an
optimizer has created. Useful for saving them if executing eagerly. PiperOrigin-RevId: 173726859
This commit is contained in:
parent
02f55400f8
commit
5426a3c93d
@ -23,6 +23,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.training import adam as _adam
|
||||
from tensorflow.python.training import checkpoint_utils
|
||||
from tensorflow.python.training import saver as _saver
|
||||
|
||||
@ -165,3 +166,25 @@ class Saver(object):
|
||||
"""
|
||||
with ops.device("/device:CPU:0"):
|
||||
self._saver.restore(None, file_prefix)
|
||||
|
||||
|
||||
def get_optimizer_variables(optimizer):
|
||||
"""Returns a list of variables for the given `tf.train.Optimizer`.
|
||||
|
||||
Args:
|
||||
optimizer: An instance of `tf.train.Optimizer` which has created variables
|
||||
(typically after a call to `Optimizer.minimize`).
|
||||
Returns:
|
||||
A list of variables which have been created by the `Optimizer`. Currently
|
||||
returns all variables even if they were not created in the default graph,
|
||||
but this behavior may change.
|
||||
"""
|
||||
variables = []
|
||||
# pylint: disable=protected-access
|
||||
for _, variable_dict in optimizer._slots.items():
|
||||
for _, slot_for_variable in variable_dict.items():
|
||||
variables.append(slot_for_variable)
|
||||
if isinstance(optimizer, _adam.AdamOptimizer):
|
||||
variables.append(optimizer._beta1_power)
|
||||
variables.append(optimizer._beta2_power)
|
||||
return variables
|
||||
|
@ -30,6 +30,10 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import momentum
|
||||
from tensorflow.python.training import rmsprop
|
||||
|
||||
|
||||
class SaverTest(test.TestCase):
|
||||
@ -204,5 +208,42 @@ class SaverTest(test.TestCase):
|
||||
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||
|
||||
|
||||
class GetOptimizerTests(test.TestCase):
|
||||
|
||||
def _optimizer_test_template(self, optimizer):
|
||||
"""Checks save and restore. Returns the optimizer variables."""
|
||||
v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v')
|
||||
loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2
|
||||
optimizer.minimize(loss_fn)
|
||||
optimizer_variables = _saver.get_optimizer_variables(optimizer)
|
||||
saver = _saver.Saver(optimizer_variables + [v])
|
||||
checkpoint_path = saver.save(self.get_temp_dir())
|
||||
optimizer.minimize(loss_fn)
|
||||
after_first_minimize = v.numpy()
|
||||
# After we restore, the next step should be exactly the same as the one we
|
||||
# just did.
|
||||
saver.restore(checkpoint_path)
|
||||
optimizer.minimize(loss_fn)
|
||||
self.assertAllEqual(after_first_minimize, v.numpy())
|
||||
return optimizer_variables
|
||||
|
||||
def testAdam(self):
|
||||
optimizer = adam.AdamOptimizer(0.1)
|
||||
self._optimizer_test_template(optimizer)
|
||||
|
||||
def testGradientDescent(self):
|
||||
optimizer = gradient_descent.GradientDescentOptimizer(0.02)
|
||||
self.assertEqual(0, len(self._optimizer_test_template(optimizer)))
|
||||
|
||||
def testMomentum(self):
|
||||
optimizer = momentum.MomentumOptimizer(
|
||||
learning_rate=0.03,
|
||||
momentum=0.5)
|
||||
self._optimizer_test_template(optimizer)
|
||||
|
||||
def testRMSProp(self):
|
||||
optimizer = rmsprop.RMSPropOptimizer(0.01)
|
||||
self._optimizer_test_template(optimizer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -51,6 +51,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@SummaryWriter
|
||||
@@restore_variables_on_create
|
||||
@@Variable
|
||||
@@get_optimizer_variables
|
||||
|
||||
@@in_eager_mode
|
||||
@@in_graph_mode
|
||||
@ -73,6 +74,7 @@ from __future__ import print_function
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.contrib.eager.python.datasets import Iterator
|
||||
from tensorflow.contrib.eager.python.network import Network
|
||||
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
|
||||
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
|
||||
from tensorflow.contrib.eager.python.saver import Saver
|
||||
from tensorflow.contrib.eager.python.summary_writer import SummaryWriter
|
||||
|
Loading…
Reference in New Issue
Block a user