Update the eager user guide to use object-based saving (and Model)

PiperOrigin-RevId: 188332858
This commit is contained in:
Allen Lavoie 2018-03-08 08:02:35 -08:00 committed by TensorFlower Gardener
parent 18ca16d73a
commit 6a619489c6

View File

@ -574,49 +574,45 @@ repository](https://github.com/tensorflow/models/tree/master/official/mnist/mnis
### Checkpointing trained variables
TensorFlow Variables (`tfe.Variable`) provides a way to represent shared,
persistent state of your model. The `tfe.Saver` class (which is a thin wrapper
over the
[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
class) provides a means to save and restore variables to and from _checkpoints_.
TensorFlow Variables (`tfe.Variable`) provide a way to represent shared,
persistent state of your model. The `tfe.Checkpoint` class provides a means to
save and restore variables to and from _checkpoints_.
For example:
```python
# Create variables.
x = tfe.Variable(10., name='x')
y = tfe.Variable(5., name='y')
x = tfe.Variable(10.)
y = tfe.Variable(5.)
# Create a Saver.
saver = tfe.Saver([x, y])
# Indicate that the variables should be saved as "x" and "y".
checkpoint = tfe.Checkpoint(x=x, y=y)
# Assign new values to the variables and save.
x.assign(2.)
saver.save('/tmp/ckpt')
checkpoint.save('/tmp/ckpt')
# Change the variable after saving.
x.assign(11.)
assert 16. == (x + y).numpy() # 11 + 5
# Restore the values in the checkpoint.
saver.restore('/tmp/ckpt')
checkpoint.restore('/tmp/ckpt-1')
assert 7. == (x + y).numpy() # 2 + 5
```
### `tfe.Network`
### `tf.keras.Model`
You may often want to organize your models using classes, like the `MNISTModel`
class described above. We recommend inheriting from the `tfe.Network` class as
it provides conveniences like keeping track of all model variables and methods
to save and restore from checkpoints.
class described above. We recommend inheriting from the `tf.keras.Model` class
as it provides conveniences like keeping track of all model variables.
Sub-classes of `tfe.Network` may register `Layer`s (like classes in
[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers),
or [Keras
layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers))
using a call to `self.track_layer()` and define the computation in an
implementation of `call()`.
Sub-classes of `tf.keras.Model` may register `Layer`s (like classes in
[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), or [Keras
layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) by
assigning them to attributes (`self.name = layer_object`) and define the
computation in an implementation of `call()`.
Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables
lazily, when the first input is encountered.
@ -624,12 +620,11 @@ lazily, when the first input is encountered.
For example, consider the following two-layer neural network:
```python
class TwoLayerNet(tfe.Network):
class TwoLayerNet(tf.keras.Model):
def __init__(self):
super(TwoLayerNet, self).__init__()
self.layer1 = self.track_layer(
tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False))
self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False))
self.layer1 = tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)
self.layer2 = tf.layers.Dense(3, use_bias=False)
def call(self, x):
return self.layer2(self.layer1(x))
@ -653,15 +648,16 @@ assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1.
assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2.
```
The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows
instances of `tfe.Network` to be embedded in other networks. For example:
The `tf.keras.Model` class is itself a sub-class of `tf.layers.Layer`. This
allows instances of `tf.keras.Model` to be embedded in other models. For
example:
```python
class ThreeLayerNet(tfe.Network):
class ThreeLayerNet(tf.keras.Model):
def __init__(self):
super(ThreeLayerNet, self).__init__()
self.a = self.track_layer(TwoLayerNet())
self.b = self.track_layer(tf.layers.Dense(4, use_bias=False))
self.a = TwoLayerNet()
self.b = tf.layers.Dense(4, use_bias=False)
def call(self, x):
return self.b(self.a(x))
@ -678,9 +674,8 @@ assert [3, 4] == net.variables[2].shape.as_list()
See more examples in
[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples).
`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a
convenient way to save and load checkpoints without changing the program once
the checkpoint has been created. For example, we can set an objective for the
`tfe.Checkpoint` provides a convenient way to save and load training
checkpoints. Let's define something simple to train. We set an objective for the
output of our network, choose an optimizer, and a location for the checkpoint:
```python
@ -691,30 +686,27 @@ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
net = ThreeLayerNet()
```
Note that variables have not been created yet. We want them to be restored from
a checkpoint, if one exists, so we create them inside a
`tfe.restore_variables_on_create` context manager. Then our training loop is the
same whether starting training or resuming from a previous checkpoint:
We group them in a `tfe.Checkpoint` and request that it be restored. This
ensures that variables created by these objects are restored before their values
are used. Our training loop is the same whether starting training or resuming
from a previous checkpoint:
```python
with tfe.restore_variables_on_create(
tf.train.latest_checkpoint(checkpoint_directory)):
global_step = tf.train.get_or_create_global_step()
for _ in range(100):
loss_fn = lambda: tf.norm(net(inp) - objective)
optimizer.minimize(loss_fn, global_step=global_step)
if tf.equal(global_step % 20, 0):
print("Step %d, output %s" % (global_step.numpy(),
net(inp).numpy()))
all_variables = (
net.variables
+ optimizer.variables()
+ [global_step])
# Save the checkpoint.
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
global_step = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint(
global_step=global_step, optimizer=optimizer, network=net)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(100):
loss_fn = lambda: tf.norm(net(inp) - objective)
optimizer.minimize(loss_fn, global_step=global_step)
if tf.equal(global_step % 20, 0):
print("Step %d, output %s" % (global_step.numpy(),
net(inp).numpy()))
# Save the checkpoint.
checkpoint.save(checkpoint_prefix)
```
The first time it runs, `Network` variables are initialized randomly. Then the
The first time it runs, `Model` variables are initialized randomly. Then the
output is trained to match the objective we've set:
```