Improve tf.train.list_variables API doc.
PiperOrigin-RevId: 320285328 Change-Id: I52d1d61e2c181259699024a8197d5364d3ba747e
This commit is contained in:
parent
353af3ea66
commit
7f357c4447
@ -87,13 +87,27 @@ def load_variable(ckpt_dir_or_file, name):
|
|||||||
|
|
||||||
@tf_export("train.list_variables")
|
@tf_export("train.list_variables")
|
||||||
def list_variables(ckpt_dir_or_file):
|
def list_variables(ckpt_dir_or_file):
|
||||||
"""Returns list of all variables in the checkpoint.
|
"""Lists the checkpoint keys and shapes of variables in a checkpoint.
|
||||||
|
|
||||||
|
Checkpoint keys are paths in a checkpoint graph.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import tensorflow as tf
|
||||||
|
import os
|
||||||
|
ckpt_directory = "/tmp/training_checkpoints/ckpt"
|
||||||
|
ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model)
|
||||||
|
manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3)
|
||||||
|
train_and_checkpoint(model, manager)
|
||||||
|
tf.train.list_variables(manager.latest_checkpoint)
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
|
ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of tuples `(name, shape)`.
|
List of tuples `(key, shape)`.
|
||||||
"""
|
"""
|
||||||
reader = load_checkpoint(ckpt_dir_or_file)
|
reader = load_checkpoint(ckpt_dir_or_file)
|
||||||
variable_map = reader.get_variable_to_shape_map()
|
variable_map = reader.get_variable_to_shape_map()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user