Add tfe.get_optimizer_variables for fetching a list of variables which an
optimizer has created. Useful for saving them if executing eagerly. PiperOrigin-RevId: 173726859
This commit is contained in:
parent
02f55400f8
commit
5426a3c93d
@ -23,6 +23,7 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.training import adam as _adam
|
||||||
from tensorflow.python.training import checkpoint_utils
|
from tensorflow.python.training import checkpoint_utils
|
||||||
from tensorflow.python.training import saver as _saver
|
from tensorflow.python.training import saver as _saver
|
||||||
|
|
||||||
@ -165,3 +166,25 @@ class Saver(object):
|
|||||||
"""
|
"""
|
||||||
with ops.device("/device:CPU:0"):
|
with ops.device("/device:CPU:0"):
|
||||||
self._saver.restore(None, file_prefix)
|
self._saver.restore(None, file_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimizer_variables(optimizer):
|
||||||
|
"""Returns a list of variables for the given `tf.train.Optimizer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: An instance of `tf.train.Optimizer` which has created variables
|
||||||
|
(typically after a call to `Optimizer.minimize`).
|
||||||
|
Returns:
|
||||||
|
A list of variables which have been created by the `Optimizer`. Currently
|
||||||
|
returns all variables even if they were not created in the default graph,
|
||||||
|
but this behavior may change.
|
||||||
|
"""
|
||||||
|
variables = []
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
for _, variable_dict in optimizer._slots.items():
|
||||||
|
for _, slot_for_variable in variable_dict.items():
|
||||||
|
variables.append(slot_for_variable)
|
||||||
|
if isinstance(optimizer, _adam.AdamOptimizer):
|
||||||
|
variables.append(optimizer._beta1_power)
|
||||||
|
variables.append(optimizer._beta2_power)
|
||||||
|
return variables
|
||||||
|
@ -30,6 +30,10 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.training import adam
|
||||||
|
from tensorflow.python.training import gradient_descent
|
||||||
|
from tensorflow.python.training import momentum
|
||||||
|
from tensorflow.python.training import rmsprop
|
||||||
|
|
||||||
|
|
||||||
class SaverTest(test.TestCase):
|
class SaverTest(test.TestCase):
|
||||||
@ -204,5 +208,42 @@ class SaverTest(test.TestCase):
|
|||||||
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
|
||||||
|
|
||||||
|
|
||||||
|
class GetOptimizerTests(test.TestCase):
|
||||||
|
|
||||||
|
def _optimizer_test_template(self, optimizer):
|
||||||
|
"""Checks save and restore. Returns the optimizer variables."""
|
||||||
|
v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v')
|
||||||
|
loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2
|
||||||
|
optimizer.minimize(loss_fn)
|
||||||
|
optimizer_variables = _saver.get_optimizer_variables(optimizer)
|
||||||
|
saver = _saver.Saver(optimizer_variables + [v])
|
||||||
|
checkpoint_path = saver.save(self.get_temp_dir())
|
||||||
|
optimizer.minimize(loss_fn)
|
||||||
|
after_first_minimize = v.numpy()
|
||||||
|
# After we restore, the next step should be exactly the same as the one we
|
||||||
|
# just did.
|
||||||
|
saver.restore(checkpoint_path)
|
||||||
|
optimizer.minimize(loss_fn)
|
||||||
|
self.assertAllEqual(after_first_minimize, v.numpy())
|
||||||
|
return optimizer_variables
|
||||||
|
|
||||||
|
def testAdam(self):
|
||||||
|
optimizer = adam.AdamOptimizer(0.1)
|
||||||
|
self._optimizer_test_template(optimizer)
|
||||||
|
|
||||||
|
def testGradientDescent(self):
|
||||||
|
optimizer = gradient_descent.GradientDescentOptimizer(0.02)
|
||||||
|
self.assertEqual(0, len(self._optimizer_test_template(optimizer)))
|
||||||
|
|
||||||
|
def testMomentum(self):
|
||||||
|
optimizer = momentum.MomentumOptimizer(
|
||||||
|
learning_rate=0.03,
|
||||||
|
momentum=0.5)
|
||||||
|
self._optimizer_test_template(optimizer)
|
||||||
|
|
||||||
|
def testRMSProp(self):
|
||||||
|
optimizer = rmsprop.RMSPropOptimizer(0.01)
|
||||||
|
self._optimizer_test_template(optimizer)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -51,6 +51,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
|||||||
@@SummaryWriter
|
@@SummaryWriter
|
||||||
@@restore_variables_on_create
|
@@restore_variables_on_create
|
||||||
@@Variable
|
@@Variable
|
||||||
|
@@get_optimizer_variables
|
||||||
|
|
||||||
@@in_eager_mode
|
@@in_eager_mode
|
||||||
@@in_graph_mode
|
@@in_graph_mode
|
||||||
@ -73,6 +74,7 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib.eager.python import metrics
|
from tensorflow.contrib.eager.python import metrics
|
||||||
from tensorflow.contrib.eager.python.datasets import Iterator
|
from tensorflow.contrib.eager.python.datasets import Iterator
|
||||||
from tensorflow.contrib.eager.python.network import Network
|
from tensorflow.contrib.eager.python.network import Network
|
||||||
|
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
|
||||||
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
|
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
|
||||||
from tensorflow.contrib.eager.python.saver import Saver
|
from tensorflow.contrib.eager.python.saver import Saver
|
||||||
from tensorflow.contrib.eager.python.summary_writer import SummaryWriter
|
from tensorflow.contrib.eager.python.summary_writer import SummaryWriter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user