Merge pull request #2763 from reuben/transfer-learning-rebase

Transfer learning support
This commit is contained in:
Reuben Morais 2020-02-17 17:29:03 +01:00 committed by GitHub
commit 200a46711a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 454 additions and 165 deletions

View File

@ -6,7 +6,8 @@ import os
import sys import sys
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0 LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3' DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
import absl.app import absl.app
import json import json
@ -17,18 +18,25 @@ import tensorflow as tf
import tensorflow.compat.v1 as tfv1 import tensorflow.compat.v1 as tfv1
import time import time
tfv1.logging.set_verbosity({
'0': tfv1.logging.DEBUG,
'1': tfv1.logging.INFO,
'2': tfv1.logging.WARN,
'3': tfv1.logging.ERROR
}.get(DESIRED_LOG_LEVEL))
from datetime import datetime from datetime import datetime
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from evaluate import evaluate from evaluate import evaluate
from six.moves import zip, range from six.moves import zip, range
from tensorflow.python.tools import freeze_graph, strip_unused_lib
from tensorflow.python.framework import errors_impl
from util.config import Config, initialize_globals from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from util.flags import create_flags, FLAGS from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
check_ctcdecoder_version()
# Graph Creation # Graph Creation
# ============== # ==============
@ -222,7 +230,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
# Obtain the next batch of data # Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next() batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
if FLAGS.use_cudnn_rnn: if FLAGS.train_cudnn:
rnn_impl = rnn_impl_cudnn_rnn rnn_impl = rnn_impl_cudnn_rnn
else: else:
rnn_impl = rnn_impl_lstmblockfusedcell rnn_impl = rnn_impl_lstmblockfusedcell
@ -397,30 +405,6 @@ def log_grads_and_vars(grads_and_vars):
log_variable(variable, gradient=gradient) log_variable(variable, gradient=gradient)
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
try:
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
if not checkpoint:
return False
checkpoint_path = checkpoint.model_checkpoint_path
saver.restore(session, checkpoint_path)
if load_step:
restored_step = session.run(tfv1.train.get_global_step())
if log_success:
log_info('Restored variables from %s checkpoint at %s, step %d' %
(caption, checkpoint_path, restored_step))
elif log_success:
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
return True
except tf.errors.InvalidArgumentError as e:
log_error(str(e))
log_error('The checkpoint in {0} does not match the shapes of the model.'
' Did you change alphabet.txt or the --n_hidden parameter'
' between train runs using the same checkpoint dir? Try moving'
' or removing the contents of {0}.'.format(checkpoint_path))
sys.exit(1)
def train(): def train():
do_cache_dataset = True do_cache_dataset = True
@ -494,76 +478,29 @@ def train():
# Checkpointing # Checkpointing
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train') checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
best_dev_saver = tfv1.train.Saver(max_to_keep=1) best_dev_saver = tfv1.train.Saver(max_to_keep=1)
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev') best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
# Save flags next to checkpoints # Save flags next to checkpoints
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True) os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt')
with open(flags_file, 'w') as fout: with open(flags_file, 'w') as fout:
fout.write(FLAGS.flags_into_string()) fout.write(FLAGS.flags_into_string())
initializer = tfv1.global_variables_initializer()
with tfv1.Session(config=Config.session_config) as session: with tfv1.Session(config=Config.session_config) as session:
log_debug('Session opened.') log_debug('Session opened.')
# Loading or initializing # Prevent further graph changes
loaded = False
# Initialize training from a CuDNN RNN checkpoint
if FLAGS.cudnn_checkpoint:
if FLAGS.use_cudnn_rnn:
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
'was specified. The --cudnn_checkpoint flag is only '
'needed when converting a CuDNN RNN checkpoint to '
'a CPU-capable graph. If your system is capable of '
'using CuDNN RNN, you can just specify the CuDNN RNN '
'checkpoint normally with --checkpoint_dir.')
sys.exit(1)
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
missing_variables = []
# Load compatible variables from checkpoint
for v in tfv1.global_variables():
try:
v.load(ckpt.get_tensor(v.op.name), session=session)
except tf.errors.NotFoundError:
missing_variables.append(v)
# Check that the only missing variables are the Adam moment tensors
if any('Adam' not in v.op.name for v in missing_variables):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors.')
sys.exit(1)
# Initialize Adam moment tensors from scratch to allow use of CuDNN
# RNN checkpoints.
log_info('Initializing missing Adam moment tensors.')
init_op = tfv1.variables_initializer(missing_variables)
session.run(init_op)
loaded = True
tfv1.get_default_graph().finalize() tfv1.get_default_graph().finalize()
if not loaded and FLAGS.load in ['auto', 'last']: # Load checkpoint or initialize variables
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent') if FLAGS.load == 'auto':
if not loaded and FLAGS.load in ['auto', 'best']: method_order = ['best', 'last', 'init']
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation') else:
if not loaded: method_order = [FLAGS.load]
if FLAGS.load in ['auto', 'init']: load_or_init_graph(session, method_order)
log_info('Initializing variables...')
session.run(initializer)
else:
log_error('Unable to load %s model from specified checkpoint dir'
' - consider using load option "auto" or "init".' % FLAGS.load)
sys.exit(1)
def run_set(set_name, epoch, init_op, dataset=None): def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train' is_train = set_name == 'train'
@ -682,7 +619,7 @@ def train():
def test(): def test():
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file: if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats # Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
@ -803,49 +740,40 @@ def export():
if FLAGS.export_language: if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language') outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
# Prevent further graph changes
tfv1.get_default_graph().finalize()
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)] output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)] output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
output_names = ",".join(output_names_tensors + output_names_ops) output_names = output_names_tensors + output_names_ops
# Create a saver using variables from the above newly created graph with tf.Session() as session:
saver = tfv1.train.Saver() # Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_or_init_graph(session, method_order)
# Restore variables from training checkpoint output_filename = FLAGS.export_name + '.pb'
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if FLAGS.remove_export:
checkpoint_path = checkpoint.model_checkpoint_path if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
output_filename = FLAGS.export_name + '.pb'
if FLAGS.remove_export:
if os.path.isdir(FLAGS.export_dir):
log_info('Removing old export')
shutil.rmtree(FLAGS.export_dir)
try:
output_graph_path = os.path.join(FLAGS.export_dir, output_filename) output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
if not os.path.isdir(FLAGS.export_dir): if not os.path.isdir(FLAGS.export_dir):
os.makedirs(FLAGS.export_dir) os.makedirs(FLAGS.export_dir)
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''): frozen_graph = tfv1.graph_util.convert_variables_to_constants(
frozen = freeze_graph.freeze_graph_with_def_protos( sess=session,
input_graph_def=tfv1.get_default_graph().as_graph_def(), input_graph_def=tfv1.get_default_graph().as_graph_def(),
input_saver_def=saver.as_saver_def(), output_node_names=output_names)
input_checkpoint=checkpoint_path,
output_node_names=output_node_names,
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_file,
clear_devices=False,
variable_names_blacklist=variables_blacklist,
initializer_nodes='')
input_node_names = [] frozen_graph = tfv1.graph_util.extract_sub_graph(
return strip_unused_lib.strip_unused( graph_def=frozen_graph,
input_graph_def=frozen, dest_nodes=output_names)
input_node_names=input_node_names,
output_node_names=output_node_names.split(','),
placeholder_type_enum=tf.float32.as_datatype_enum)
frozen_graph = do_graph_freeze(output_node_names=output_names)
if not FLAGS.export_tflite: if not FLAGS.export_tflite:
with open(output_graph_path, 'wb') as fout: with open(output_graph_path, 'wb') as fout:
@ -854,7 +782,7 @@ def export():
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite')) output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values()) converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
converter.optimizations = [ tf.lite.Optimize.DEFAULT ] converter.optimizations = [tf.lite.Optimize.DEFAULT]
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite # AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
converter.allow_custom_ops = True converter.allow_custom_ops = True
tflite_model = converter.convert() tflite_model = converter.convert()
@ -862,11 +790,7 @@ def export():
with open(output_tflite_path, 'wb') as fout: with open(output_tflite_path, 'wb') as fout:
fout.write(tflite_model) fout.write(tflite_model)
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
log_info('Models exported at %s' % (FLAGS.export_dir)) log_info('Models exported at %s' % (FLAGS.export_dir))
except RuntimeError as e:
log_error(str(e))
def package_zip(): def package_zip():
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip # --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
@ -887,18 +811,12 @@ def do_single_file_inference(input_file_path):
with tfv1.Session(config=Config.session_config) as session: with tfv1.Session(config=Config.session_config) as session:
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1) inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Create a saver using variables from the above newly created graph
saver = tfv1.train.Saver()
# Restore variables from training checkpoint # Restore variables from training checkpoint
loaded = False if FLAGS.load == 'auto':
if not loaded and FLAGS.load in ['auto', 'last']: method_order = ['best', 'last']
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False) else:
if not loaded and FLAGS.load in ['auto', 'best']: method_order = [FLAGS.load]
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False) load_or_init_graph(session, method_order)
if not loaded:
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
features, features_len = audiofile_to_features(input_file_path) features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim]) previous_state_c = np.zeros([1, Config.n_cell_dim])

View File

@ -23,7 +23,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--learning_rate 0.001 --dropout_rate 0.05 \ --learning_rate 0.001 --dropout_rate 0.05 \
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log --scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
echo "Did not resume training from checkpoint" echo "Did not resume training from checkpoint"
exit 1 exit 1
else else

126
bin/run-tc-transfer.sh Executable file
View File

@ -0,0 +1,126 @@
#!/bin/sh
# This bash script is for running minimum working examples
# of transfer learning for continuous integration tests
# to be run on Taskcluster.
set -xe
ru_dir="./data/smoke_test/russian_sample_data"
ru_csv="${ru_dir}/ru.csv"
ldc93s1_dir="./data/smoke_test"
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
fi;
# Force only one visible device because we have a single-sample dataset
# and when trying to run on multiple devices (like GPUs), this will break
export CUDA_VISIBLE_DEVICES=0
# Force UTF-8 output
export PYTHONIOENCODING=utf-8
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
echo "##### while iterating over loading logic #####"
for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--scorer_path '' \
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
--n_hidden 100 \
--epochs 10
echo "##############################################################################"
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "#################################################################################"
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "#################################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
done
echo "#######################################################"
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
echo "##### while iterating over loading logic #####"
echo "#######################################################"
for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "##############################################################################"
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "####################################################################################"
echo "#### Transfer to RUSSIAN model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "####################################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
done

View File

@ -0,0 +1,34 @@
о
е
а
и
н
т
с
л
в
р
к
м
д
п
ы
у
б
я
ь
г
з
ч
й
ж
х
ш
ю
ц
э
щ
ф
ё
ъ

View File

@ -0,0 +1,2 @@
wav_filename,wav_filesize,transcript
ru.wav,0,бедняга ребят на его месте должен был быть я
1 wav_filename wav_filesize transcript
2 ru.wav 0 бедняга ребят на его месте должен был быть я

Binary file not shown.

View File

@ -16,12 +16,14 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip from six.moves import zip
from util.config import Config, initialize_globals from util.config import Config, initialize_globals
from util.checkpoints import load_or_init_graph
from util.evaluate_tools import calculate_and_print_report from util.evaluate_tools import calculate_and_print_report
from util.feeding import create_dataset from util.feeding import create_dataset
from util.flags import create_flags, FLAGS from util.flags import create_flags, FLAGS
from util.helpers import check_ctcdecoder_version
from util.logging import create_progressbar, log_error, log_progress from util.logging import create_progressbar, log_error, log_progress
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
check_ctcdecoder_version()
def sparse_tensor_value_to_texts(value, alphabet): def sparse_tensor_value_to_texts(value, alphabet):
r""" r"""
@ -41,7 +43,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
return [alphabet.decode(res) for res in results] return [alphabet.decode(res) for res in results]
def evaluate(test_csvs, create_model, try_loading): def evaluate(test_csvs, create_model):
if FLAGS.scorer_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet) FLAGS.scorer_path, Config.alphabet)
@ -79,19 +81,12 @@ def evaluate(test_csvs, create_model, try_loading):
except NotImplementedError: except NotImplementedError:
num_processes = 1 num_processes = 1
# Create a saver using variables from the above newly created graph
saver = tfv1.train.Saver()
with tfv1.Session(config=Config.session_config) as session: with tfv1.Session(config=Config.session_config) as session:
# Restore variables from training checkpoint if FLAGS.load == 'auto':
loaded = False method_order = ['best', 'last']
if not loaded and FLAGS.load in ['auto', 'best']: else:
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation') method_order = [FLAGS.load]
if not loaded and FLAGS.load in ['auto', 'last']: load_or_init_graph(session, method_order)
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
if not loaded:
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
sys.exit(1)
def run_test(init_op, dataset): def run_test(init_op, dataset):
wav_filenames = [] wav_filenames = []
@ -148,8 +143,8 @@ def main(_):
'the --test_files flag.') 'the --test_files flag.')
sys.exit(1) sys.exit(1)
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) samples = evaluate(FLAGS.test_files.split(','), create_model)
if FLAGS.test_output_file: if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats # Save decoded tuples as JSON, converting NumPy floats to Python floats

View File

@ -0,0 +1,65 @@
#!/bin/bash
set -xe
source $(dirname "$0")/tc-tests-utils.sh
pyver_full=$1
if [ -z "${pyver_full}" ]; then
echo "No python version given, aborting."
exit 1
fi;
pyver=$(echo "${pyver_full}" | cut -d':' -f1)
# 2.7.x => 27
pyver_pkg=$(echo "${pyver}" | cut -d'.' -f1,2 | tr -d '.')
py_unicode_type=$(echo "${pyver_full}" | cut -d':' -f2)
if [ "${py_unicode_type}" = "m" ]; then
pyconf="ucs2"
elif [ "${py_unicode_type}" = "mu" ]; then
pyconf="ucs4"
fi;
unset PYTHON_BIN_PATH
unset PYTHONPATH
export PYENV_ROOT="${HOME}/ds-train/.pyenv"
export PATH="${PYENV_ROOT}/bin:${HOME}/bin:$PATH"
mkdir -p ${PYENV_ROOT} || true
mkdir -p ${TASKCLUSTER_ARTIFACTS} || true
mkdir -p /tmp/train || true
mkdir -p /tmp/train_tflite || true
install_pyenv "${PYENV_ROOT}"
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
PYENV_NAME=deepspeech-train
PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf}" pyenv install ${pyver}
pyenv virtualenv ${pyver} ${PYENV_NAME}
source ${PYENV_ROOT}/versions/${pyver}/envs/${PYENV_NAME}/bin/activate
set -o pipefail
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
set +o pipefail
pushd ${HOME}/DeepSpeech/ds/
verify_ctcdecoder_url
popd
platform=$(python -c 'import sys; import platform; plat = platform.system().lower(); arch = platform.machine().lower(); plat = "manylinux1" if plat == "linux" and arch == "x86_64" else plat; plat = "macosx_10_10" if plat == "darwin" else plat; sys.stdout.write("%s_%s" % (plat, platform.machine()));')
whl_ds_version="$(python -c 'from pkg_resources import parse_version; print(parse_version("'${DS_VERSION}'"))')"
decoder_pkg="ds_ctcdecoder-${whl_ds_version}-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
decoder_pkg_url=${DECODER_ARTIFACTS_ROOT}/${decoder_pkg}
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
pushd ${HOME}/DeepSpeech/ds/
time ./bin/run-tc-transfer.sh
popd
deactivate

View File

@ -0,0 +1,12 @@
build:
template_file: test-linux-opt-base.tyml
dependencies:
- "linux-amd64-ctc-opt"
system_setup:
>
apt-get -qq -y install ${python.packages_trusty.apt}
args:
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-transfer-tests.sh 3.6.4:m"
metadata:
name: "DeepSpeech Linux AMD64 CPU transfer learning Py3.6"
description: "Training a DeepSpeech LDC93S1 model with transfer learning for Linux/AMD64 16kHz Python 3.6, CPU only, optimized version"

115
util/checkpoints.py Normal file
View File

@ -0,0 +1,115 @@
import sys
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from util.flags import FLAGS
from util.logging import log_info, log_error, log_warn
def _load_checkpoint(session, checkpoint_path):
# Load the checkpoint and put all variables into loading list
# we will exclude variables we do not wish to load and then
# we will initialize them instead
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
load_vars = set(tfv1.global_variables())
init_vars = set()
if FLAGS.load_cudnn:
# Initialize training from a CuDNN RNN checkpoint
# Identify the variables which we cannot load, and set them
# for initialization
for v in load_vars:
try:
ckpt.get_tensor(v.op.name)
except tf.errors.NotFoundError:
log_error('CUDNN variable not found: %s' % (v.op.name))
init_vars.add(v)
load_vars -= init_vars
# Check that the only missing variables (i.e. those to be initialised)
# are the Adam moment tensors, if they aren't then we have an issue
init_var_names = [v.op.name for v in init_vars]
if any('Adam' not in v for v in init_var_names):
log_error('Tried to load a CuDNN RNN checkpoint but there were '
'more missing variables than just the Adam moment '
'tensors. Missing variables: {}'.format(init_var_names))
sys.exit(1)
if FLAGS.drop_source_layers > 0:
# This transfer learning approach requires supplying
# the layers which we exclude from the source model.
# Say we want to exclude all layers except for the first one,
# then we are dropping five layers total, so: drop_source_layers=5
# If we want to use all layers from the source model except
# the last one, we use this: drop_source_layers=1
if FLAGS.drop_source_layers >= 6:
log_warn('The checkpoint only has 6 layers, but you are trying to drop '
'all of them or more than all of them. Continuing and '
'dropping only 5 layers.')
FLAGS.drop_source_layers = 5
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
# Initialize all variables needed for DS, but not loaded from ckpt
for v in load_vars:
if any(layer in v.op.name for layer in dropped_layers):
init_vars.add(v)
load_vars -= init_vars
for v in sorted(load_vars, key=lambda v: v.op.name):
log_info('Loading variable from checkpoint: %s' % (v.op.name))
v.load(ckpt.get_tensor(v.op.name), session=session)
for v in sorted(init_vars, key=lambda v: v.op.name):
log_info('Initializing variable: %s' % (v.op.name))
session.run(v.initializer)
def _checkpoint_path_or_none(checkpoint_filename):
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
if not checkpoint:
return None
return checkpoint.model_checkpoint_path
def _initialize_all_variables(session):
init_vars = tfv1.global_variables()
for v in init_vars:
session.run(v.initializer)
def load_or_init_graph(session, method_order):
'''
Load variables from checkpoint or initialize variables following the method
order specified in the method_order parameter.
Valid methods are 'best', 'last' and 'init'.
'''
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best':
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
if ckpt_path:
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
log_info('Could not find best validating checkpoint.')
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
elif method == 'last':
ckpt_path = _checkpoint_path_or_none('checkpoint')
if ckpt_path:
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
return _load_checkpoint(session, ckpt_path)
log_info('Could not find most recent checkpoint.')
# Initialize all variables
elif method == 'init':
log_info('Initializing all variables.')
return _initialize_all_variables(session)
else:
log_error('Unknown initialization method: {}'.format(method))
sys.exit(1)
log_error('All initialization methods failed ({}).'.format(method_order))
sys.exit(1)

View File

@ -124,4 +124,21 @@ def initialize_globals():
log_error('Path specified in --one_shot_infer is not a valid file.') log_error('Path specified in --one_shot_infer is not a valid file.')
sys.exit(1) sys.exit(1)
if FLAGS.train_cudnn and FLAGS.load_cudnn:
log_error('Trying to use --train_cudnn, but --load_cudnn '
'was also specified. The --load_cudnn flag is only '
'needed when converting a CuDNN RNN checkpoint to '
'a CPU-capable graph. If your system is capable of '
'using CuDNN RNN, you can just specify the CuDNN RNN '
'checkpoint normally with --save_checkpoint_dir.')
sys.exit(1)
# If separate save and load flags were not specified, default to load and save
# from the same dir.
if not FLAGS.save_checkpoint_dir:
FLAGS.save_checkpoint_dir = FLAGS.checkpoint_dir
if not FLAGS.load_checkpoint_dir:
FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir
ConfigSingleton._config = c # pylint: disable=protected-access ConfigSingleton._config = c # pylint: disable=protected-access

View File

@ -49,7 +49,6 @@ def create_flags():
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling') f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
# Global Constants # Global Constants
# ================ # ================
@ -84,9 +83,8 @@ def create_flags():
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED') f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory') f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') f.DEFINE_boolean('load_cudnn', False, 'Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.') f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
# Sample limits # Sample limits
@ -97,10 +95,16 @@ def create_flags():
# Checkpointing # Checkpointing
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification') f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_string('load_checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds') f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5') f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init') f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
# Transfer Learning
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
# Exporting # Exporting
@ -115,7 +119,7 @@ def create_flags():
# Reporting # Reporting
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL') f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.') f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console') f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')

View File

@ -1,3 +1,7 @@
import os
import semver
import sys
def keep_only_digits(txt): def keep_only_digits(txt):
return ''.join(filter(lambda c: c.isdigit(), txt)) return ''.join(filter(lambda c: c.isdigit(), txt))
@ -8,15 +12,12 @@ def secs_to_hours(secs):
minutes, seconds = divmod(remainder, 60) minutes, seconds = divmod(remainder, 60)
return '%d:%02d:%02d' % (hours, minutes, seconds) return '%d:%02d:%02d' % (hours, minutes, seconds)
# pylint: disable=import-outside-toplevel
def check_ctcdecoder_version():
import sys
import os
import semver
def check_ctcdecoder_version():
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip() ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip()
try: try:
# pylint: disable=import-outside-toplevel
from ds_ctcdecoder import __version__ as decoder_version from ds_ctcdecoder import __version__ as decoder_version
except ImportError as e: except ImportError as e:
if e.msg.find('__version__') > 0: if e.msg.find('__version__') > 0: