Merge pull request #2763 from reuben/transfer-learning-rebase
Transfer learning support
This commit is contained in:
commit
200a46711a
176
DeepSpeech.py
176
DeepSpeech.py
|
@ -6,7 +6,8 @@ import os
|
|||
import sys
|
||||
|
||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||
|
||||
import absl.app
|
||||
import json
|
||||
|
@ -17,18 +18,25 @@ import tensorflow as tf
|
|||
import tensorflow.compat.v1 as tfv1
|
||||
import time
|
||||
|
||||
tfv1.logging.set_verbosity({
|
||||
'0': tfv1.logging.DEBUG,
|
||||
'1': tfv1.logging.INFO,
|
||||
'2': tfv1.logging.WARN,
|
||||
'3': tfv1.logging.ERROR
|
||||
}.get(DESIRED_LOG_LEVEL))
|
||||
|
||||
from datetime import datetime
|
||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
# Graph Creation
|
||||
# ==============
|
||||
|
@ -222,7 +230,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
|||
# Obtain the next batch of data
|
||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||
|
||||
if FLAGS.use_cudnn_rnn:
|
||||
if FLAGS.train_cudnn:
|
||||
rnn_impl = rnn_impl_cudnn_rnn
|
||||
else:
|
||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||
|
@ -397,30 +405,6 @@ def log_grads_and_vars(grads_and_vars):
|
|||
log_variable(variable, gradient=gradient)
|
||||
|
||||
|
||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
|
||||
try:
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
return False
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
saver.restore(session, checkpoint_path)
|
||||
if load_step:
|
||||
restored_step = session.run(tfv1.train.get_global_step())
|
||||
if log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s, step %d' %
|
||||
(caption, checkpoint_path, restored_step))
|
||||
elif log_success:
|
||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
||||
return True
|
||||
except tf.errors.InvalidArgumentError as e:
|
||||
log_error(str(e))
|
||||
log_error('The checkpoint in {0} does not match the shapes of the model.'
|
||||
' Did you change alphabet.txt or the --n_hidden parameter'
|
||||
' between train runs using the same checkpoint dir? Try moving'
|
||||
' or removing the contents of {0}.'.format(checkpoint_path))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def train():
|
||||
do_cache_dataset = True
|
||||
|
||||
|
@ -494,76 +478,29 @@ def train():
|
|||
|
||||
# Checkpointing
|
||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
||||
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||
|
||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
||||
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||
|
||||
# Save flags next to checkpoints
|
||||
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
||||
|
||||
flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt')
|
||||
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||
with open(flags_file, 'w') as fout:
|
||||
fout.write(FLAGS.flags_into_string())
|
||||
|
||||
initializer = tfv1.global_variables_initializer()
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
log_debug('Session opened.')
|
||||
|
||||
# Loading or initializing
|
||||
loaded = False
|
||||
|
||||
# Initialize training from a CuDNN RNN checkpoint
|
||||
if FLAGS.cudnn_checkpoint:
|
||||
if FLAGS.use_cudnn_rnn:
|
||||
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
|
||||
'was specified. The --cudnn_checkpoint flag is only '
|
||||
'needed when converting a CuDNN RNN checkpoint to '
|
||||
'a CPU-capable graph. If your system is capable of '
|
||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||
'checkpoint normally with --checkpoint_dir.')
|
||||
sys.exit(1)
|
||||
|
||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
||||
missing_variables = []
|
||||
|
||||
# Load compatible variables from checkpoint
|
||||
for v in tfv1.global_variables():
|
||||
try:
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
except tf.errors.NotFoundError:
|
||||
missing_variables.append(v)
|
||||
|
||||
# Check that the only missing variables are the Adam moment tensors
|
||||
if any('Adam' not in v.op.name for v in missing_variables):
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors.')
|
||||
sys.exit(1)
|
||||
|
||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
||||
# RNN checkpoints.
|
||||
log_info('Initializing missing Adam moment tensors.')
|
||||
init_op = tfv1.variables_initializer(missing_variables)
|
||||
session.run(init_op)
|
||||
loaded = True
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded:
|
||||
if FLAGS.load in ['auto', 'init']:
|
||||
log_info('Initializing variables...')
|
||||
session.run(initializer)
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
log_error('Unable to load %s model from specified checkpoint dir'
|
||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
||||
sys.exit(1)
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
|
@ -682,7 +619,7 @@ def train():
|
|||
|
||||
|
||||
def test():
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||
|
@ -803,49 +740,40 @@ def export():
|
|||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
||||
# Prevent further graph changes
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
||||
output_names = output_names_tensors + output_names_ops
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
checkpoint_path = checkpoint.model_checkpoint_path
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
output_filename = FLAGS.export_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
if os.path.isdir(FLAGS.export_dir):
|
||||
log_info('Removing old export')
|
||||
shutil.rmtree(FLAGS.export_dir)
|
||||
try:
|
||||
|
||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||
|
||||
if not os.path.isdir(FLAGS.export_dir):
|
||||
os.makedirs(FLAGS.export_dir)
|
||||
|
||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
||||
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||
sess=session,
|
||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||
input_saver_def=saver.as_saver_def(),
|
||||
input_checkpoint=checkpoint_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name=None,
|
||||
filename_tensor_name=None,
|
||||
output_graph=output_file,
|
||||
clear_devices=False,
|
||||
variable_names_blacklist=variables_blacklist,
|
||||
initializer_nodes='')
|
||||
output_node_names=output_names)
|
||||
|
||||
input_node_names = []
|
||||
return strip_unused_lib.strip_unused(
|
||||
input_graph_def=frozen,
|
||||
input_node_names=input_node_names,
|
||||
output_node_names=output_node_names.split(','),
|
||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
||||
|
||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
||||
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||
graph_def=frozen_graph,
|
||||
dest_nodes=output_names)
|
||||
|
||||
if not FLAGS.export_tflite:
|
||||
with open(output_graph_path, 'wb') as fout:
|
||||
|
@ -862,11 +790,7 @@ def export():
|
|||
with open(output_tflite_path, 'wb') as fout:
|
||||
fout.write(tflite_model)
|
||||
|
||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
||||
|
||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||
except RuntimeError as e:
|
||||
log_error(str(e))
|
||||
|
||||
def package_zip():
|
||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||
|
@ -887,18 +811,12 @@ def do_single_file_inference(input_file_path):
|
|||
with tfv1.Session(config=Config.session_config) as session:
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
|
||||
if not loaded:
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
|
|
|
@ -23,7 +23,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
|||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
|
||||
if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
exit 1
|
||||
else
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
#!/bin/sh
|
||||
# This bash script is for running minimum working examples
|
||||
# of transfer learning for continuous integration tests
|
||||
# to be run on Taskcluster.
|
||||
set -xe
|
||||
|
||||
ru_dir="./data/smoke_test/russian_sample_data"
|
||||
ru_csv="${ru_dir}/ru.csv"
|
||||
|
||||
ldc93s1_dir="./data/smoke_test"
|
||||
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||
|
||||
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||
fi;
|
||||
|
||||
# Force only one visible device because we have a single-sample dataset
|
||||
# and when trying to run on multiple devices (like GPUs), this will break
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# Force UTF-8 output
|
||||
export PYTHONIOENCODING=utf-8
|
||||
|
||||
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||
echo "##### while iterating over loading logic #####"
|
||||
|
||||
for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--scorer_path '' \
|
||||
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
echo "##############################################################################"
|
||||
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
echo "#################################################################################"
|
||||
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "#################################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
done
|
||||
|
||||
echo "#######################################################"
|
||||
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||
echo "##### while iterating over loading logic #####"
|
||||
echo "#######################################################"
|
||||
|
||||
for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
|
||||
echo "##############################################################################"
|
||||
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
echo "####################################################################################"
|
||||
echo "#### Transfer to RUSSIAN model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "####################################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
done
|
|
@ -0,0 +1,34 @@
|
|||
|
||||
о
|
||||
е
|
||||
а
|
||||
и
|
||||
н
|
||||
т
|
||||
с
|
||||
л
|
||||
в
|
||||
р
|
||||
к
|
||||
м
|
||||
д
|
||||
п
|
||||
ы
|
||||
у
|
||||
б
|
||||
я
|
||||
ь
|
||||
г
|
||||
з
|
||||
ч
|
||||
й
|
||||
ж
|
||||
х
|
||||
ш
|
||||
ю
|
||||
ц
|
||||
э
|
||||
щ
|
||||
ф
|
||||
ё
|
||||
ъ
|
|
@ -0,0 +1,2 @@
|
|||
wav_filename,wav_filesize,transcript
|
||||
ru.wav,0,бедняга ребят на его месте должен был быть я
|
|
Binary file not shown.
27
evaluate.py
27
evaluate.py
|
@ -16,12 +16,14 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
|||
from six.moves import zip
|
||||
|
||||
from util.config import Config, initialize_globals
|
||||
from util.checkpoints import load_or_init_graph
|
||||
from util.evaluate_tools import calculate_and_print_report
|
||||
from util.feeding import create_dataset
|
||||
from util.flags import create_flags, FLAGS
|
||||
from util.helpers import check_ctcdecoder_version
|
||||
from util.logging import create_progressbar, log_error, log_progress
|
||||
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
||||
|
||||
check_ctcdecoder_version()
|
||||
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
|
@ -41,7 +43,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|||
return [alphabet.decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model, try_loading):
|
||||
def evaluate(test_csvs, create_model):
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
|
@ -79,19 +81,12 @@ def evaluate(test_csvs, create_model, try_loading):
|
|||
except NotImplementedError:
|
||||
num_processes = 1
|
||||
|
||||
# Create a saver using variables from the above newly created graph
|
||||
saver = tfv1.train.Saver()
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
# Restore variables from training checkpoint
|
||||
loaded = False
|
||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
||||
if not loaded:
|
||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
||||
sys.exit(1)
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph(session, method_order)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
|
@ -148,8 +143,8 @@ def main(_):
|
|||
'the --test_files flag.')
|
||||
sys.exit(1)
|
||||
|
||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
||||
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||
|
||||
if FLAGS.test_output_file:
|
||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -xe
|
||||
|
||||
source $(dirname "$0")/tc-tests-utils.sh
|
||||
|
||||
pyver_full=$1
|
||||
|
||||
if [ -z "${pyver_full}" ]; then
|
||||
echo "No python version given, aborting."
|
||||
exit 1
|
||||
fi;
|
||||
|
||||
pyver=$(echo "${pyver_full}" | cut -d':' -f1)
|
||||
|
||||
# 2.7.x => 27
|
||||
pyver_pkg=$(echo "${pyver}" | cut -d'.' -f1,2 | tr -d '.')
|
||||
|
||||
py_unicode_type=$(echo "${pyver_full}" | cut -d':' -f2)
|
||||
if [ "${py_unicode_type}" = "m" ]; then
|
||||
pyconf="ucs2"
|
||||
elif [ "${py_unicode_type}" = "mu" ]; then
|
||||
pyconf="ucs4"
|
||||
fi;
|
||||
|
||||
unset PYTHON_BIN_PATH
|
||||
unset PYTHONPATH
|
||||
export PYENV_ROOT="${HOME}/ds-train/.pyenv"
|
||||
export PATH="${PYENV_ROOT}/bin:${HOME}/bin:$PATH"
|
||||
|
||||
mkdir -p ${PYENV_ROOT} || true
|
||||
mkdir -p ${TASKCLUSTER_ARTIFACTS} || true
|
||||
mkdir -p /tmp/train || true
|
||||
mkdir -p /tmp/train_tflite || true
|
||||
|
||||
install_pyenv "${PYENV_ROOT}"
|
||||
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
|
||||
|
||||
PYENV_NAME=deepspeech-train
|
||||
PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf}" pyenv install ${pyver}
|
||||
pyenv virtualenv ${pyver} ${PYENV_NAME}
|
||||
source ${PYENV_ROOT}/versions/${pyver}/envs/${PYENV_NAME}/bin/activate
|
||||
|
||||
set -o pipefail
|
||||
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||
set +o pipefail
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
verify_ctcdecoder_url
|
||||
popd
|
||||
|
||||
platform=$(python -c 'import sys; import platform; plat = platform.system().lower(); arch = platform.machine().lower(); plat = "manylinux1" if plat == "linux" and arch == "x86_64" else plat; plat = "macosx_10_10" if plat == "darwin" else plat; sys.stdout.write("%s_%s" % (plat, platform.machine()));')
|
||||
whl_ds_version="$(python -c 'from pkg_resources import parse_version; print(parse_version("'${DS_VERSION}'"))')"
|
||||
decoder_pkg="ds_ctcdecoder-${whl_ds_version}-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
|
||||
|
||||
decoder_pkg_url=${DECODER_ARTIFACTS_ROOT}/${decoder_pkg}
|
||||
|
||||
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
time ./bin/run-tc-transfer.sh
|
||||
popd
|
||||
|
||||
deactivate
|
|
@ -0,0 +1,12 @@
|
|||
build:
|
||||
template_file: test-linux-opt-base.tyml
|
||||
dependencies:
|
||||
- "linux-amd64-ctc-opt"
|
||||
system_setup:
|
||||
>
|
||||
apt-get -qq -y install ${python.packages_trusty.apt}
|
||||
args:
|
||||
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-transfer-tests.sh 3.6.4:m"
|
||||
metadata:
|
||||
name: "DeepSpeech Linux AMD64 CPU transfer learning Py3.6"
|
||||
description: "Training a DeepSpeech LDC93S1 model with transfer learning for Linux/AMD64 16kHz Python 3.6, CPU only, optimized version"
|
|
@ -0,0 +1,115 @@
|
|||
import sys
|
||||
import tensorflow as tf
|
||||
import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from util.flags import FLAGS
|
||||
from util.logging import log_info, log_error, log_warn
|
||||
|
||||
|
||||
def _load_checkpoint(session, checkpoint_path):
|
||||
# Load the checkpoint and put all variables into loading list
|
||||
# we will exclude variables we do not wish to load and then
|
||||
# we will initialize them instead
|
||||
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
||||
load_vars = set(tfv1.global_variables())
|
||||
init_vars = set()
|
||||
|
||||
if FLAGS.load_cudnn:
|
||||
# Initialize training from a CuDNN RNN checkpoint
|
||||
# Identify the variables which we cannot load, and set them
|
||||
# for initialization
|
||||
for v in load_vars:
|
||||
try:
|
||||
ckpt.get_tensor(v.op.name)
|
||||
except tf.errors.NotFoundError:
|
||||
log_error('CUDNN variable not found: %s' % (v.op.name))
|
||||
init_vars.add(v)
|
||||
|
||||
load_vars -= init_vars
|
||||
|
||||
# Check that the only missing variables (i.e. those to be initialised)
|
||||
# are the Adam moment tensors, if they aren't then we have an issue
|
||||
init_var_names = [v.op.name for v in init_vars]
|
||||
if any('Adam' not in v for v in init_var_names):
|
||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||
'more missing variables than just the Adam moment '
|
||||
'tensors. Missing variables: {}'.format(init_var_names))
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.drop_source_layers > 0:
|
||||
# This transfer learning approach requires supplying
|
||||
# the layers which we exclude from the source model.
|
||||
# Say we want to exclude all layers except for the first one,
|
||||
# then we are dropping five layers total, so: drop_source_layers=5
|
||||
# If we want to use all layers from the source model except
|
||||
# the last one, we use this: drop_source_layers=1
|
||||
if FLAGS.drop_source_layers >= 6:
|
||||
log_warn('The checkpoint only has 6 layers, but you are trying to drop '
|
||||
'all of them or more than all of them. Continuing and '
|
||||
'dropping only 5 layers.')
|
||||
FLAGS.drop_source_layers = 5
|
||||
|
||||
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
|
||||
# Initialize all variables needed for DS, but not loaded from ckpt
|
||||
for v in load_vars:
|
||||
if any(layer in v.op.name for layer in dropped_layers):
|
||||
init_vars.add(v)
|
||||
load_vars -= init_vars
|
||||
|
||||
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||
log_info('Loading variable from checkpoint: %s' % (v.op.name))
|
||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||
|
||||
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||
log_info('Initializing variable: %s' % (v.op.name))
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def _checkpoint_path_or_none(checkpoint_filename):
|
||||
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
|
||||
if not checkpoint:
|
||||
return None
|
||||
return checkpoint.model_checkpoint_path
|
||||
|
||||
|
||||
def _initialize_all_variables(session):
|
||||
init_vars = tfv1.global_variables()
|
||||
for v in init_vars:
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def load_or_init_graph(session, method_order):
|
||||
'''
|
||||
Load variables from checkpoint or initialize variables following the method
|
||||
order specified in the method_order parameter.
|
||||
|
||||
Valid methods are 'best', 'last' and 'init'.
|
||||
'''
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == 'best':
|
||||
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
log_info('Could not find best validating checkpoint.')
|
||||
|
||||
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||
elif method == 'last':
|
||||
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||
if ckpt_path:
|
||||
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||
return _load_checkpoint(session, ckpt_path)
|
||||
log_info('Could not find most recent checkpoint.')
|
||||
|
||||
# Initialize all variables
|
||||
elif method == 'init':
|
||||
log_info('Initializing all variables.')
|
||||
return _initialize_all_variables(session)
|
||||
|
||||
else:
|
||||
log_error('Unknown initialization method: {}'.format(method))
|
||||
sys.exit(1)
|
||||
|
||||
log_error('All initialization methods failed ({}).'.format(method_order))
|
||||
sys.exit(1)
|
|
@ -124,4 +124,21 @@ def initialize_globals():
|
|||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
sys.exit(1)
|
||||
|
||||
if FLAGS.train_cudnn and FLAGS.load_cudnn:
|
||||
log_error('Trying to use --train_cudnn, but --load_cudnn '
|
||||
'was also specified. The --load_cudnn flag is only '
|
||||
'needed when converting a CuDNN RNN checkpoint to '
|
||||
'a CPU-capable graph. If your system is capable of '
|
||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||
'checkpoint normally with --save_checkpoint_dir.')
|
||||
sys.exit(1)
|
||||
|
||||
# If separate save and load flags were not specified, default to load and save
|
||||
# from the same dir.
|
||||
if not FLAGS.save_checkpoint_dir:
|
||||
FLAGS.save_checkpoint_dir = FLAGS.checkpoint_dir
|
||||
|
||||
if not FLAGS.load_checkpoint_dir:
|
||||
FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir
|
||||
|
||||
ConfigSingleton._config = c # pylint: disable=protected-access
|
||||
|
|
|
@ -49,7 +49,6 @@ def create_flags():
|
|||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
|
||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
|
||||
|
||||
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
|
@ -84,9 +83,8 @@ def create_flags():
|
|||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
|
||||
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||
|
||||
f.DEFINE_boolean('load_cudnn', False, 'Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
||||
|
||||
# Sample limits
|
||||
|
@ -97,10 +95,16 @@ def create_flags():
|
|||
|
||||
# Checkpointing
|
||||
|
||||
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_string('load_checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
|
||||
|
||||
# Transfer Learning
|
||||
|
||||
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
|
||||
|
||||
# Exporting
|
||||
|
||||
|
@ -115,7 +119,7 @@ def create_flags():
|
|||
|
||||
# Reporting
|
||||
|
||||
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
||||
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')
|
||||
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||
|
||||
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import os
|
||||
import semver
|
||||
import sys
|
||||
|
||||
|
||||
def keep_only_digits(txt):
|
||||
return ''.join(filter(lambda c: c.isdigit(), txt))
|
||||
|
@ -8,15 +12,12 @@ def secs_to_hours(secs):
|
|||
minutes, seconds = divmod(remainder, 60)
|
||||
return '%d:%02d:%02d' % (hours, minutes, seconds)
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
def check_ctcdecoder_version():
|
||||
import sys
|
||||
import os
|
||||
import semver
|
||||
|
||||
def check_ctcdecoder_version():
|
||||
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip()
|
||||
|
||||
try:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from ds_ctcdecoder import __version__ as decoder_version
|
||||
except ImportError as e:
|
||||
if e.msg.find('__version__') > 0:
|
||||
|
|
Loading…
Reference in New Issue