Improve tf.train.list_variables API doc.
PiperOrigin-RevId: 320285328 Change-Id: I52d1d61e2c181259699024a8197d5364d3ba747e
This commit is contained in:
parent
353af3ea66
commit
7f357c4447
@ -87,13 +87,27 @@ def load_variable(ckpt_dir_or_file, name):
|
||||
|
||||
@tf_export("train.list_variables")
|
||||
def list_variables(ckpt_dir_or_file):
|
||||
"""Returns list of all variables in the checkpoint.
|
||||
"""Lists the checkpoint keys and shapes of variables in a checkpoint.
|
||||
|
||||
Checkpoint keys are paths in a checkpoint graph.
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
import os
|
||||
ckpt_directory = "/tmp/training_checkpoints/ckpt"
|
||||
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
|
||||
train_and_checkpoint(model, manager)
|
||||
tf.train.list_variables(manager.latest_checkpoint)
|
||||
```
|
||||
|
||||
Args:
|
||||
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
|
||||
|
||||
Returns:
|
||||
List of tuples `(name, shape)`.
|
||||
List of tuples `(key, shape)`.
|
||||
"""
|
||||
reader = load_checkpoint(ckpt_dir_or_file)
|
||||
variable_map = reader.get_variable_to_shape_map()
|
||||
|
Loading…
x
Reference in New Issue
Block a user