Merge pull request #1661 from mozilla/remove-init-frozen-model
Remove initialize_from_frozen_model flag and support code (Fixes #1659)
This commit is contained in:
commit
f11ccbe39b
@ -168,10 +168,6 @@ def create_flags():
|
||||
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.')
|
||||
|
||||
# Initialize from frozen model
|
||||
|
||||
tf.app.flags.DEFINE_string ('initialize_from_frozen_model', '', 'path to frozen model to initialize from. This behaves like a checkpoint, loading the weights from the frozen model and starting training with those weights. The optimizer parameters aren\'t restored, so remember to adjust the learning rate.')
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
def initialize_globals():
|
||||
@ -1579,26 +1575,6 @@ def train(server=None):
|
||||
saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir=FLAGS.checkpoint_dir, save_secs=FLAGS.checkpoint_secs, saver=saver))
|
||||
|
||||
if len(FLAGS.initialize_from_frozen_model) > 0:
|
||||
with tf.gfile.FastGFile(FLAGS.initialize_from_frozen_model, 'rb') as fin:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(fin.read())
|
||||
|
||||
var_names = [v.name for v in tf.trainable_variables()]
|
||||
var_tensors = tf.import_graph_def(graph_def, return_elements=var_names)
|
||||
|
||||
# build a { var_name: var_tensor } dict
|
||||
var_tensors = dict(zip(var_names, var_tensors))
|
||||
|
||||
training_graph = tf.get_default_graph()
|
||||
|
||||
assign_ops = []
|
||||
for name, restored_tensor in var_tensors.items():
|
||||
training_tensor = training_graph.get_tensor_by_name(name)
|
||||
assign_ops.append(tf.assign(training_tensor, restored_tensor))
|
||||
|
||||
init_from_frozen_model_op = tf.group(*assign_ops)
|
||||
|
||||
no_dropout_feed_dict = {
|
||||
dropout_rates[0]: 0.,
|
||||
dropout_rates[1]: 0.,
|
||||
@ -1661,11 +1637,6 @@ def train(server=None):
|
||||
config=session_config) as session:
|
||||
tf.get_default_graph().finalize()
|
||||
|
||||
if len(FLAGS.initialize_from_frozen_model) > 0:
|
||||
log_info('Initializing from frozen model: {}'.format(FLAGS.initialize_from_frozen_model))
|
||||
model_feeder.set_data_set(no_dropout_feed_dict, model_feeder.train)
|
||||
session.run(init_from_frozen_model_op, feed_dict=no_dropout_feed_dict)
|
||||
|
||||
try:
|
||||
if is_chief:
|
||||
# Retrieving global_step from the (potentially restored) model
|
||||
|
10
README.md
10
README.md
@ -46,7 +46,7 @@ See the output of `deepspeech -h` for more information on the use of `deepspeech
|
||||
- [Checkpointing](#checkpointing)
|
||||
- [Exporting a model for inference](#exporting-a-model-for-inference)
|
||||
- [Distributed computing across more than one machine](#distributed-training-across-more-than-one-machine)
|
||||
- [Continuing training from a frozen graph](#continuing-training-from-a-frozen-graph)
|
||||
- [Continuing training from a release model](#continuing-training-from-a-release-model)
|
||||
- [Code documentation](#code-documentation)
|
||||
- [Contact/Getting Help](#contactgetting-help)
|
||||
|
||||
@ -353,18 +353,18 @@ $ run-cluster.sh 1:2:1 --epoch 10
|
||||
Be aware that for the help example to be able to run, you need at least two `CUDA` capable GPUs (2 workers times 1 GPU). The script utilizes environment variable `CUDA_VISIBLE_DEVICES` for `DeepSpeech.py` to see only the provided number of GPUs per worker.
|
||||
The script is meant to be a template for your own distributed computing instrumentation. Just modify the startup code for the different servers (workers and parameter servers) accordingly. You could use SSH or something similar for running them on your remote hosts.
|
||||
|
||||
### Continuing training from a frozen graph
|
||||
### Continuing training from a release model
|
||||
|
||||
If you'd like to use one of the pre-trained models released by Mozilla to bootstrap your training process (transfer learning, fine tuning), you can do so by using the `--initialize_from_frozen_model` flag in `DeepSpeech.py`. For best results, make sure you're passing an empty `--checkpoint_dir` when resuming from a frozen model.
|
||||
If you'd like to use one of the pre-trained models released by Mozilla to bootstrap your training process (transfer learning, fine tuning), you can do so by using the `--checkpoint_dir` flag in `DeepSpeech.py`. Specify the path where you downloaded the checkpoint from the release, and training will resume from the pre-trained model.
|
||||
|
||||
For example, if you want to fine tune the entire graph using your own data in `my-train.csv`, `my-dev.csv` and `my-test.csv`, for three epochs, you can something like the following, tuning the hyperparameters as needed:
|
||||
|
||||
```bash
|
||||
mkdir fine_tuning_checkpoints
|
||||
python3 DeepSpeech.py --n_hidden 2048 --initialize_from_frozen_model path/to/model/output_graph.pb --checkpoint_dir fine_tuning_checkpoints --epoch 3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
|
||||
python3 DeepSpeech.py --n_hidden 2048 --checkpoint_dir path/to/checkpoint/folder --epoch -3 --train_files my-train.csv --dev_files my-dev.csv --test_files my_dev.csv --learning_rate 0.0001
|
||||
```
|
||||
|
||||
Note: the released models were trained with `--n_hidden 2048`, so you need to use that same value when initializing from the release models.
|
||||
Note: the released models were trained with `--n_hidden 2048`, so you need to use that same value when initializing from the release models. Note as well the use of a negative epoch count -3 (meaning 3 more epochs) since the checkpoint you're loading from was already trained for several epochs.
|
||||
|
||||
## Code documentation
|
||||
|
||||
|
31
bin/run-tc-ldc93s1_checkpoint.sh
Executable file
31
bin/run-tc-ldc93s1_checkpoint.sh
Executable file
@ -0,0 +1,31 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/ldc93s1-tc"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
epoch_count=$1
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 494 --epoch -1 --random_seed 4567 --default_stddev 0.046875 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Training of Epoch $epoch_count" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1,23 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -xe
|
||||
|
||||
ldc93s1_dir="./data/ldc93s1-tc"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 494 --epoch 1 --random_seed 4567 --default_stddev 0.046875 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \
|
||||
--nouse_seq_length --decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--initialize_from_frozen_model '/tmp/frozen_model.pb'
|
@ -5,6 +5,8 @@ set -xe
|
||||
ldc93s1_dir="./data/ldc93s1-tc"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
epoch_count=$1
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
@ -14,8 +16,9 @@ python -u DeepSpeech.py \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 494 --epoch 105 --random_seed 4567 --default_stddev 0.046875 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
|
||||
--n_hidden 494 --epoch $epoch_count --random_seed 4567 \
|
||||
--default_stddev 0.046875 --max_to_keep 1 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \
|
||||
--decoder_library_path '/tmp/ds/libctc_decoder_with_kenlm.so' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
|
@ -1,13 +0,0 @@
|
||||
build:
|
||||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-ctc-opt"
|
||||
- "test-training_upstream-linux-amd64-py27mu-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq -y install ${python.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/tc-train-tests.sh 3.6.4:m frozen"
|
||||
metadata:
|
||||
name: "DeepSpeech Linux AMD64 CPU upstream training frozen Py3.6"
|
||||
description: "Training a DeepSpeech LDC93S1 model from frozen model file for Linux/AMD64 using upstream TensorFlow Python 3.6, CPU only, optimized version"
|
@ -265,11 +265,6 @@ download_data()
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.trie ${TASKCLUSTER_TMP_DIR}/trie
|
||||
}
|
||||
|
||||
download_for_frozen()
|
||||
{
|
||||
wget -O "${TASKCLUSTER_TMP_DIR}/frozen_model.pb" "${DEEPSPEECH_TEST_MODEL}"
|
||||
}
|
||||
|
||||
download_material()
|
||||
{
|
||||
target_dir=$1
|
||||
|
@ -6,7 +6,6 @@ source $(dirname "$0")/tc-tests-utils.sh
|
||||
|
||||
pyver_full=$1
|
||||
ds=$2
|
||||
frozen=$2
|
||||
|
||||
if [ -z "${pyver_full}" ]; then
|
||||
echo "No python version given, aborting."
|
||||
@ -62,12 +61,8 @@ else
|
||||
fi;
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
if [ "${frozen}" = "frozen" ]; then
|
||||
download_for_frozen
|
||||
time ./bin/run-tc-ldc93s1_frozen.sh
|
||||
else
|
||||
time ./bin/run-tc-ldc93s1_new.sh
|
||||
fi;
|
||||
time ./bin/run-tc-ldc93s1_new.sh 105
|
||||
time ./bin/run-tc-ldc93s1_checkpoint.sh 105
|
||||
popd
|
||||
|
||||
deactivate
|
||||
|
Loading…
Reference in New Issue
Block a user