Update the eager user guide to use object-based saving (and Model)
PiperOrigin-RevId: 188332858
This commit is contained in:
parent
18ca16d73a
commit
6a619489c6
@ -574,49 +574,45 @@ repository](https://github.com/tensorflow/models/tree/master/official/mnist/mnis
|
||||
|
||||
### Checkpointing trained variables
|
||||
|
||||
TensorFlow Variables (`tfe.Variable`) provides a way to represent shared,
|
||||
persistent state of your model. The `tfe.Saver` class (which is a thin wrapper
|
||||
over the
|
||||
[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver)
|
||||
class) provides a means to save and restore variables to and from _checkpoints_.
|
||||
TensorFlow Variables (`tfe.Variable`) provide a way to represent shared,
|
||||
persistent state of your model. The `tfe.Checkpoint` class provides a means to
|
||||
save and restore variables to and from _checkpoints_.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Create variables.
|
||||
x = tfe.Variable(10., name='x')
|
||||
y = tfe.Variable(5., name='y')
|
||||
x = tfe.Variable(10.)
|
||||
y = tfe.Variable(5.)
|
||||
|
||||
# Create a Saver.
|
||||
saver = tfe.Saver([x, y])
|
||||
# Indicate that the variables should be saved as "x" and "y".
|
||||
checkpoint = tfe.Checkpoint(x=x, y=y)
|
||||
|
||||
# Assign new values to the variables and save.
|
||||
x.assign(2.)
|
||||
saver.save('/tmp/ckpt')
|
||||
checkpoint.save('/tmp/ckpt')
|
||||
|
||||
# Change the variable after saving.
|
||||
x.assign(11.)
|
||||
assert 16. == (x + y).numpy() # 11 + 5
|
||||
|
||||
# Restore the values in the checkpoint.
|
||||
saver.restore('/tmp/ckpt')
|
||||
checkpoint.restore('/tmp/ckpt-1')
|
||||
|
||||
assert 7. == (x + y).numpy() # 2 + 5
|
||||
```
|
||||
|
||||
### `tfe.Network`
|
||||
### `tf.keras.Model`
|
||||
|
||||
You may often want to organize your models using classes, like the `MNISTModel`
|
||||
class described above. We recommend inheriting from the `tfe.Network` class as
|
||||
it provides conveniences like keeping track of all model variables and methods
|
||||
to save and restore from checkpoints.
|
||||
class described above. We recommend inheriting from the `tf.keras.Model` class
|
||||
as it provides conveniences like keeping track of all model variables.
|
||||
|
||||
Sub-classes of `tfe.Network` may register `Layer`s (like classes in
|
||||
[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers),
|
||||
or [Keras
|
||||
layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers))
|
||||
using a call to `self.track_layer()` and define the computation in an
|
||||
implementation of `call()`.
|
||||
Sub-classes of `tf.keras.Model` may register `Layer`s (like classes in
|
||||
[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), or [Keras
|
||||
layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) by
|
||||
assigning them to attributes (`self.name = layer_object`) and define the
|
||||
computation in an implementation of `call()`.
|
||||
|
||||
Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables
|
||||
lazily, when the first input is encountered.
|
||||
@ -624,12 +620,11 @@ lazily, when the first input is encountered.
|
||||
For example, consider the following two-layer neural network:
|
||||
|
||||
```python
|
||||
class TwoLayerNet(tfe.Network):
|
||||
class TwoLayerNet(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(TwoLayerNet, self).__init__()
|
||||
self.layer1 = self.track_layer(
|
||||
tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False))
|
||||
self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False))
|
||||
self.layer1 = tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)
|
||||
self.layer2 = tf.layers.Dense(3, use_bias=False)
|
||||
|
||||
def call(self, x):
|
||||
return self.layer2(self.layer1(x))
|
||||
@ -653,15 +648,16 @@ assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1.
|
||||
assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2.
|
||||
```
|
||||
|
||||
The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows
|
||||
instances of `tfe.Network` to be embedded in other networks. For example:
|
||||
The `tf.keras.Model` class is itself a sub-class of `tf.layers.Layer`. This
|
||||
allows instances of `tf.keras.Model` to be embedded in other models. For
|
||||
example:
|
||||
|
||||
```python
|
||||
class ThreeLayerNet(tfe.Network):
|
||||
class ThreeLayerNet(tf.keras.Model):
|
||||
def __init__(self):
|
||||
super(ThreeLayerNet, self).__init__()
|
||||
self.a = self.track_layer(TwoLayerNet())
|
||||
self.b = self.track_layer(tf.layers.Dense(4, use_bias=False))
|
||||
self.a = TwoLayerNet()
|
||||
self.b = tf.layers.Dense(4, use_bias=False)
|
||||
|
||||
def call(self, x):
|
||||
return self.b(self.a(x))
|
||||
@ -678,9 +674,8 @@ assert [3, 4] == net.variables[2].shape.as_list()
|
||||
See more examples in
|
||||
[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples).
|
||||
|
||||
`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a
|
||||
convenient way to save and load checkpoints without changing the program once
|
||||
the checkpoint has been created. For example, we can set an objective for the
|
||||
`tfe.Checkpoint` provides a convenient way to save and load training
|
||||
checkpoints. Let's define something simple to train. We set an objective for the
|
||||
output of our network, choose an optimizer, and a location for the checkpoint:
|
||||
|
||||
```python
|
||||
@ -691,30 +686,27 @@ checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt')
|
||||
net = ThreeLayerNet()
|
||||
```
|
||||
|
||||
Note that variables have not been created yet. We want them to be restored from
|
||||
a checkpoint, if one exists, so we create them inside a
|
||||
`tfe.restore_variables_on_create` context manager. Then our training loop is the
|
||||
same whether starting training or resuming from a previous checkpoint:
|
||||
We group them in a `tfe.Checkpoint` and request that it be restored. This
|
||||
ensures that variables created by these objects are restored before their values
|
||||
are used. Our training loop is the same whether starting training or resuming
|
||||
from a previous checkpoint:
|
||||
|
||||
```python
|
||||
with tfe.restore_variables_on_create(
|
||||
tf.train.latest_checkpoint(checkpoint_directory)):
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
for _ in range(100):
|
||||
loss_fn = lambda: tf.norm(net(inp) - objective)
|
||||
optimizer.minimize(loss_fn, global_step=global_step)
|
||||
if tf.equal(global_step % 20, 0):
|
||||
print("Step %d, output %s" % (global_step.numpy(),
|
||||
net(inp).numpy()))
|
||||
all_variables = (
|
||||
net.variables
|
||||
+ optimizer.variables()
|
||||
+ [global_step])
|
||||
# Save the checkpoint.
|
||||
tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step)
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
checkpoint = tfe.Checkpoint(
|
||||
global_step=global_step, optimizer=optimizer, network=net)
|
||||
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
|
||||
for _ in range(100):
|
||||
loss_fn = lambda: tf.norm(net(inp) - objective)
|
||||
optimizer.minimize(loss_fn, global_step=global_step)
|
||||
if tf.equal(global_step % 20, 0):
|
||||
print("Step %d, output %s" % (global_step.numpy(),
|
||||
net(inp).numpy()))
|
||||
# Save the checkpoint.
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
```
|
||||
|
||||
The first time it runs, `Network` variables are initialized randomly. Then the
|
||||
The first time it runs, `Model` variables are initialized randomly. Then the
|
||||
output is trained to match the objective we've set:
|
||||
|
||||
```
|
||||
|
Loading…
Reference in New Issue
Block a user