Improve tf.train.list_variables API doc.

PiperOrigin-RevId: 320285328
Change-Id: I52d1d61e2c181259699024a8197d5364d3ba747e
This commit is contained in:
Pavithra Vijay 2020-07-08 16:07:24 -07:00 committed by TensorFlower Gardener
parent 353af3ea66
commit 7f357c4447

View File

@ -87,13 +87,27 @@ def load_variable(ckpt_dir_or_file, name):
@tf_export("train.list_variables")
def list_variables(ckpt_dir_or_file):
"""Returns list of all variables in the checkpoint.
"""Lists the checkpoint keys and shapes of variables in a checkpoint.
Checkpoint keys are paths in a checkpoint graph.
Example usage:
```python
import tensorflow as tf
import os
ckpt_directory = "/tmp/training_checkpoints/ckpt"
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
train_and_checkpoint(model, manager)
tf.train.list_variables(manager.latest_checkpoint)
```
Args:
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
Returns:
List of tuples `(name, shape)`.
List of tuples `(key, shape)`.
"""
reader = load_checkpoint(ckpt_dir_or_file)
variable_map = reader.get_variable_to_shape_map()