Merge pull request #2763 from reuben/transfer-learning-rebase
Transfer learning support
This commit is contained in:
commit
200a46711a
178
DeepSpeech.py
178
DeepSpeech.py
|
@ -6,7 +6,8 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
LOG_LEVEL_INDEX = sys.argv.index('--log_level') + 1 if '--log_level' in sys.argv else 0
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
DESIRED_LOG_LEVEL = sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else '3'
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = DESIRED_LOG_LEVEL
|
||||||
|
|
||||||
import absl.app
|
import absl.app
|
||||||
import json
|
import json
|
||||||
|
@ -17,18 +18,25 @@ import tensorflow as tf
|
||||||
import tensorflow.compat.v1 as tfv1
|
import tensorflow.compat.v1 as tfv1
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
tfv1.logging.set_verbosity({
|
||||||
|
'0': tfv1.logging.DEBUG,
|
||||||
|
'1': tfv1.logging.INFO,
|
||||||
|
'2': tfv1.logging.WARN,
|
||||||
|
'3': tfv1.logging.ERROR
|
||||||
|
}.get(DESIRED_LOG_LEVEL))
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from tensorflow.python.tools import freeze_graph, strip_unused_lib
|
|
||||||
from tensorflow.python.framework import errors_impl
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
|
from util.checkpoints import load_or_init_graph
|
||||||
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
from util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
|
from util.helpers import check_ctcdecoder_version
|
||||||
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
from util.logging import log_info, log_error, log_debug, log_progress, create_progressbar
|
||||||
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
|
||||||
|
|
||||||
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
# Graph Creation
|
# Graph Creation
|
||||||
# ==============
|
# ==============
|
||||||
|
@ -222,7 +230,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
|
||||||
# Obtain the next batch of data
|
# Obtain the next batch of data
|
||||||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
|
||||||
|
|
||||||
if FLAGS.use_cudnn_rnn:
|
if FLAGS.train_cudnn:
|
||||||
rnn_impl = rnn_impl_cudnn_rnn
|
rnn_impl = rnn_impl_cudnn_rnn
|
||||||
else:
|
else:
|
||||||
rnn_impl = rnn_impl_lstmblockfusedcell
|
rnn_impl = rnn_impl_lstmblockfusedcell
|
||||||
|
@ -397,30 +405,6 @@ def log_grads_and_vars(grads_and_vars):
|
||||||
log_variable(variable, gradient=gradient)
|
log_variable(variable, gradient=gradient)
|
||||||
|
|
||||||
|
|
||||||
def try_loading(session, saver, checkpoint_filename, caption, load_step=True, log_success=True):
|
|
||||||
try:
|
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir, checkpoint_filename)
|
|
||||||
if not checkpoint:
|
|
||||||
return False
|
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
|
||||||
saver.restore(session, checkpoint_path)
|
|
||||||
if load_step:
|
|
||||||
restored_step = session.run(tfv1.train.get_global_step())
|
|
||||||
if log_success:
|
|
||||||
log_info('Restored variables from %s checkpoint at %s, step %d' %
|
|
||||||
(caption, checkpoint_path, restored_step))
|
|
||||||
elif log_success:
|
|
||||||
log_info('Restored variables from %s checkpoint at %s' % (caption, checkpoint_path))
|
|
||||||
return True
|
|
||||||
except tf.errors.InvalidArgumentError as e:
|
|
||||||
log_error(str(e))
|
|
||||||
log_error('The checkpoint in {0} does not match the shapes of the model.'
|
|
||||||
' Did you change alphabet.txt or the --n_hidden parameter'
|
|
||||||
' between train runs using the same checkpoint dir? Try moving'
|
|
||||||
' or removing the contents of {0}.'.format(checkpoint_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
do_cache_dataset = True
|
do_cache_dataset = True
|
||||||
|
|
||||||
|
@ -494,76 +478,29 @@ def train():
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep)
|
||||||
checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'train')
|
checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train')
|
||||||
|
|
||||||
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
best_dev_saver = tfv1.train.Saver(max_to_keep=1)
|
||||||
best_dev_path = os.path.join(FLAGS.checkpoint_dir, 'best_dev')
|
best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev')
|
||||||
|
|
||||||
# Save flags next to checkpoints
|
# Save flags next to checkpoints
|
||||||
os.makedirs(FLAGS.checkpoint_dir, exist_ok=True)
|
os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True)
|
||||||
|
flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt')
|
||||||
flags_file = os.path.join(FLAGS.checkpoint_dir, 'flags.txt')
|
|
||||||
with open(flags_file, 'w') as fout:
|
with open(flags_file, 'w') as fout:
|
||||||
fout.write(FLAGS.flags_into_string())
|
fout.write(FLAGS.flags_into_string())
|
||||||
|
|
||||||
initializer = tfv1.global_variables_initializer()
|
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
log_debug('Session opened.')
|
log_debug('Session opened.')
|
||||||
|
|
||||||
# Loading or initializing
|
# Prevent further graph changes
|
||||||
loaded = False
|
|
||||||
|
|
||||||
# Initialize training from a CuDNN RNN checkpoint
|
|
||||||
if FLAGS.cudnn_checkpoint:
|
|
||||||
if FLAGS.use_cudnn_rnn:
|
|
||||||
log_error('Trying to use --cudnn_checkpoint but --use_cudnn_rnn '
|
|
||||||
'was specified. The --cudnn_checkpoint flag is only '
|
|
||||||
'needed when converting a CuDNN RNN checkpoint to '
|
|
||||||
'a CPU-capable graph. If your system is capable of '
|
|
||||||
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
|
||||||
'checkpoint normally with --checkpoint_dir.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
log_info('Converting CuDNN RNN checkpoint from {}'.format(FLAGS.cudnn_checkpoint))
|
|
||||||
ckpt = tfv1.train.load_checkpoint(FLAGS.cudnn_checkpoint)
|
|
||||||
missing_variables = []
|
|
||||||
|
|
||||||
# Load compatible variables from checkpoint
|
|
||||||
for v in tfv1.global_variables():
|
|
||||||
try:
|
|
||||||
v.load(ckpt.get_tensor(v.op.name), session=session)
|
|
||||||
except tf.errors.NotFoundError:
|
|
||||||
missing_variables.append(v)
|
|
||||||
|
|
||||||
# Check that the only missing variables are the Adam moment tensors
|
|
||||||
if any('Adam' not in v.op.name for v in missing_variables):
|
|
||||||
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
|
||||||
'more missing variables than just the Adam moment '
|
|
||||||
'tensors.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Initialize Adam moment tensors from scratch to allow use of CuDNN
|
|
||||||
# RNN checkpoints.
|
|
||||||
log_info('Initializing missing Adam moment tensors.')
|
|
||||||
init_op = tfv1.variables_initializer(missing_variables)
|
|
||||||
session.run(init_op)
|
|
||||||
loaded = True
|
|
||||||
|
|
||||||
tfv1.get_default_graph().finalize()
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
# Load checkpoint or initialize variables
|
||||||
loaded = try_loading(session, checkpoint_saver, 'checkpoint', 'most recent')
|
if FLAGS.load == 'auto':
|
||||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
method_order = ['best', 'last', 'init']
|
||||||
loaded = try_loading(session, best_dev_saver, 'best_dev_checkpoint', 'best validation')
|
|
||||||
if not loaded:
|
|
||||||
if FLAGS.load in ['auto', 'init']:
|
|
||||||
log_info('Initializing variables...')
|
|
||||||
session.run(initializer)
|
|
||||||
else:
|
else:
|
||||||
log_error('Unable to load %s model from specified checkpoint dir'
|
method_order = [FLAGS.load]
|
||||||
' - consider using load option "auto" or "init".' % FLAGS.load)
|
load_or_init_graph(session, method_order)
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def run_set(set_name, epoch, init_op, dataset=None):
|
def run_set(set_name, epoch, init_op, dataset=None):
|
||||||
is_train = set_name == 'train'
|
is_train = set_name == 'train'
|
||||||
|
@ -682,7 +619,7 @@ def train():
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
if FLAGS.test_output_file:
|
if FLAGS.test_output_file:
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
|
||||||
|
@ -803,49 +740,40 @@ def export():
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
|
||||||
|
# Prevent further graph changes
|
||||||
|
tfv1.get_default_graph().finalize()
|
||||||
|
|
||||||
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
output_names_tensors = [tensor.op.name for tensor in outputs.values() if isinstance(tensor, Tensor)]
|
||||||
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
output_names_ops = [op.name for op in outputs.values() if isinstance(op, Operation)]
|
||||||
output_names = ",".join(output_names_tensors + output_names_ops)
|
output_names = output_names_tensors + output_names_ops
|
||||||
|
|
||||||
# Create a saver using variables from the above newly created graph
|
with tf.Session() as session:
|
||||||
saver = tfv1.train.Saver()
|
# Restore variables from checkpoint
|
||||||
|
if FLAGS.load == 'auto':
|
||||||
# Restore variables from training checkpoint
|
method_order = ['best', 'last']
|
||||||
checkpoint = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
else:
|
||||||
checkpoint_path = checkpoint.model_checkpoint_path
|
method_order = [FLAGS.load]
|
||||||
|
load_or_init_graph(session, method_order)
|
||||||
|
|
||||||
output_filename = FLAGS.export_name + '.pb'
|
output_filename = FLAGS.export_name + '.pb'
|
||||||
if FLAGS.remove_export:
|
if FLAGS.remove_export:
|
||||||
if os.path.isdir(FLAGS.export_dir):
|
if os.path.isdir(FLAGS.export_dir):
|
||||||
log_info('Removing old export')
|
log_info('Removing old export')
|
||||||
shutil.rmtree(FLAGS.export_dir)
|
shutil.rmtree(FLAGS.export_dir)
|
||||||
try:
|
|
||||||
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
output_graph_path = os.path.join(FLAGS.export_dir, output_filename)
|
||||||
|
|
||||||
if not os.path.isdir(FLAGS.export_dir):
|
if not os.path.isdir(FLAGS.export_dir):
|
||||||
os.makedirs(FLAGS.export_dir)
|
os.makedirs(FLAGS.export_dir)
|
||||||
|
|
||||||
def do_graph_freeze(output_file=None, output_node_names=None, variables_blacklist=''):
|
frozen_graph = tfv1.graph_util.convert_variables_to_constants(
|
||||||
frozen = freeze_graph.freeze_graph_with_def_protos(
|
sess=session,
|
||||||
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
input_graph_def=tfv1.get_default_graph().as_graph_def(),
|
||||||
input_saver_def=saver.as_saver_def(),
|
output_node_names=output_names)
|
||||||
input_checkpoint=checkpoint_path,
|
|
||||||
output_node_names=output_node_names,
|
|
||||||
restore_op_name=None,
|
|
||||||
filename_tensor_name=None,
|
|
||||||
output_graph=output_file,
|
|
||||||
clear_devices=False,
|
|
||||||
variable_names_blacklist=variables_blacklist,
|
|
||||||
initializer_nodes='')
|
|
||||||
|
|
||||||
input_node_names = []
|
frozen_graph = tfv1.graph_util.extract_sub_graph(
|
||||||
return strip_unused_lib.strip_unused(
|
graph_def=frozen_graph,
|
||||||
input_graph_def=frozen,
|
dest_nodes=output_names)
|
||||||
input_node_names=input_node_names,
|
|
||||||
output_node_names=output_node_names.split(','),
|
|
||||||
placeholder_type_enum=tf.float32.as_datatype_enum)
|
|
||||||
|
|
||||||
frozen_graph = do_graph_freeze(output_node_names=output_names)
|
|
||||||
|
|
||||||
if not FLAGS.export_tflite:
|
if not FLAGS.export_tflite:
|
||||||
with open(output_graph_path, 'wb') as fout:
|
with open(output_graph_path, 'wb') as fout:
|
||||||
|
@ -854,7 +782,7 @@ def export():
|
||||||
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
output_tflite_path = os.path.join(FLAGS.export_dir, output_filename.replace('.pb', '.tflite'))
|
||||||
|
|
||||||
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
converter = tf.lite.TFLiteConverter(frozen_graph, input_tensors=inputs.values(), output_tensors=outputs.values())
|
||||||
converter.optimizations = [ tf.lite.Optimize.DEFAULT ]
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
# AudioSpectrogram and Mfcc ops are custom but have built-in kernels in TFLite
|
||||||
converter.allow_custom_ops = True
|
converter.allow_custom_ops = True
|
||||||
tflite_model = converter.convert()
|
tflite_model = converter.convert()
|
||||||
|
@ -862,11 +790,7 @@ def export():
|
||||||
with open(output_tflite_path, 'wb') as fout:
|
with open(output_tflite_path, 'wb') as fout:
|
||||||
fout.write(tflite_model)
|
fout.write(tflite_model)
|
||||||
|
|
||||||
log_info('Exported model for TF Lite engine as {}'.format(os.path.basename(output_tflite_path)))
|
|
||||||
|
|
||||||
log_info('Models exported at %s' % (FLAGS.export_dir))
|
log_info('Models exported at %s' % (FLAGS.export_dir))
|
||||||
except RuntimeError as e:
|
|
||||||
log_error(str(e))
|
|
||||||
|
|
||||||
def package_zip():
|
def package_zip():
|
||||||
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
# --export_dir path/to/export/LANG_CODE/ => path/to/export/LANG_CODE.zip
|
||||||
|
@ -887,18 +811,12 @@ def do_single_file_inference(input_file_path):
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||||
|
|
||||||
# Create a saver using variables from the above newly created graph
|
|
||||||
saver = tfv1.train.Saver()
|
|
||||||
|
|
||||||
# Restore variables from training checkpoint
|
# Restore variables from training checkpoint
|
||||||
loaded = False
|
if FLAGS.load == 'auto':
|
||||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
method_order = ['best', 'last']
|
||||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent', load_step=False)
|
else:
|
||||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
method_order = [FLAGS.load]
|
||||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation', load_step=False)
|
load_or_init_graph(session, method_order)
|
||||||
if not loaded:
|
|
||||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
features, features_len = audiofile_to_features(input_file_path)
|
features, features_len = audiofile_to_features(input_file_path)
|
||||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||||
|
|
|
@ -23,7 +23,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
|
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
|
||||||
|
|
||||||
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
|
if ! grep "Loading best validating checkpoint from" /tmp/resume.log; then
|
||||||
echo "Did not resume training from checkpoint"
|
echo "Did not resume training from checkpoint"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
#!/bin/sh
|
||||||
|
# This bash script is for running minimum working examples
|
||||||
|
# of transfer learning for continuous integration tests
|
||||||
|
# to be run on Taskcluster.
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
ru_dir="./data/smoke_test/russian_sample_data"
|
||||||
|
ru_csv="${ru_dir}/ru.csv"
|
||||||
|
|
||||||
|
ldc93s1_dir="./data/smoke_test"
|
||||||
|
ldc93s1_csv="${ldc93s1_dir}/ldc93s1.csv"
|
||||||
|
|
||||||
|
if [ ! -f "${ldc93s1_dir}/ldc93s1.csv" ]; then
|
||||||
|
echo "Downloading and preprocessing LDC93S1 example data, saving in ${ldc93s1_dir}."
|
||||||
|
python -u bin/import_ldc93s1.py ${ldc93s1_dir}
|
||||||
|
fi;
|
||||||
|
|
||||||
|
# Force only one visible device because we have a single-sample dataset
|
||||||
|
# and when trying to run on multiple devices (like GPUs), this will break
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
# Force UTF-8 output
|
||||||
|
export PYTHONIOENCODING=utf-8
|
||||||
|
|
||||||
|
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||||
|
echo "##### while iterating over loading logic #####"
|
||||||
|
|
||||||
|
for LOAD in 'init' 'last' 'auto'; do
|
||||||
|
echo "########################################################"
|
||||||
|
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||||
|
echo "########################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--alphabet_config_path "./data/alphabet.txt" \
|
||||||
|
--load "$LOAD" \
|
||||||
|
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||||
|
--scorer_path '' \
|
||||||
|
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
|
||||||
|
echo "##############################################################################"
|
||||||
|
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||||
|
echo "##############################################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--alphabet_config_path "./data/alphabet.txt" \
|
||||||
|
--load "$LOAD" \
|
||||||
|
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||||
|
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--scorer_path '' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
|
||||||
|
echo "#################################################################################"
|
||||||
|
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||||
|
echo "#################################################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--drop_source_layers 1 \
|
||||||
|
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||||
|
--load 'last' \
|
||||||
|
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||||
|
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||||
|
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--scorer_path '' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "#######################################################"
|
||||||
|
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||||
|
echo "##### while iterating over loading logic #####"
|
||||||
|
echo "#######################################################"
|
||||||
|
|
||||||
|
for LOAD in 'init' 'last' 'auto'; do
|
||||||
|
echo "########################################################"
|
||||||
|
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||||
|
echo "########################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--alphabet_config_path "./data/alphabet.txt" \
|
||||||
|
--load "$LOAD" \
|
||||||
|
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||||
|
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--scorer_path '' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
|
||||||
|
|
||||||
|
echo "##############################################################################"
|
||||||
|
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||||
|
echo "##############################################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--alphabet_config_path "./data/alphabet.txt" \
|
||||||
|
--load "$LOAD" \
|
||||||
|
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||||
|
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--scorer_path '' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
|
||||||
|
echo "####################################################################################"
|
||||||
|
echo "#### Transfer to RUSSIAN model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||||
|
echo "####################################################################################"
|
||||||
|
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||||
|
--drop_source_layers 1 \
|
||||||
|
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||||
|
--load 'last' \
|
||||||
|
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||||
|
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||||
|
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||||
|
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||||
|
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||||
|
--scorer_path '' \
|
||||||
|
--n_hidden 100 \
|
||||||
|
--epochs 10
|
||||||
|
done
|
|
@ -0,0 +1,34 @@
|
||||||
|
|
||||||
|
о
|
||||||
|
е
|
||||||
|
а
|
||||||
|
и
|
||||||
|
н
|
||||||
|
т
|
||||||
|
с
|
||||||
|
л
|
||||||
|
в
|
||||||
|
р
|
||||||
|
к
|
||||||
|
м
|
||||||
|
д
|
||||||
|
п
|
||||||
|
ы
|
||||||
|
у
|
||||||
|
б
|
||||||
|
я
|
||||||
|
ь
|
||||||
|
г
|
||||||
|
з
|
||||||
|
ч
|
||||||
|
й
|
||||||
|
ж
|
||||||
|
х
|
||||||
|
ш
|
||||||
|
ю
|
||||||
|
ц
|
||||||
|
э
|
||||||
|
щ
|
||||||
|
ф
|
||||||
|
ё
|
||||||
|
ъ
|
|
@ -0,0 +1,2 @@
|
||||||
|
wav_filename,wav_filesize,transcript
|
||||||
|
ru.wav,0,бедняга ребят на его месте должен был быть я
|
|
Binary file not shown.
27
evaluate.py
27
evaluate.py
|
@ -16,12 +16,14 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||||
from six.moves import zip
|
from six.moves import zip
|
||||||
|
|
||||||
from util.config import Config, initialize_globals
|
from util.config import Config, initialize_globals
|
||||||
|
from util.checkpoints import load_or_init_graph
|
||||||
from util.evaluate_tools import calculate_and_print_report
|
from util.evaluate_tools import calculate_and_print_report
|
||||||
from util.feeding import create_dataset
|
from util.feeding import create_dataset
|
||||||
from util.flags import create_flags, FLAGS
|
from util.flags import create_flags, FLAGS
|
||||||
|
from util.helpers import check_ctcdecoder_version
|
||||||
from util.logging import create_progressbar, log_error, log_progress
|
from util.logging import create_progressbar, log_error, log_progress
|
||||||
from util.helpers import check_ctcdecoder_version; check_ctcdecoder_version()
|
|
||||||
|
|
||||||
|
check_ctcdecoder_version()
|
||||||
|
|
||||||
def sparse_tensor_value_to_texts(value, alphabet):
|
def sparse_tensor_value_to_texts(value, alphabet):
|
||||||
r"""
|
r"""
|
||||||
|
@ -41,7 +43,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||||
return [alphabet.decode(res) for res in results]
|
return [alphabet.decode(res) for res in results]
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model, try_loading):
|
def evaluate(test_csvs, create_model):
|
||||||
if FLAGS.scorer_path:
|
if FLAGS.scorer_path:
|
||||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||||
FLAGS.scorer_path, Config.alphabet)
|
FLAGS.scorer_path, Config.alphabet)
|
||||||
|
@ -79,19 +81,12 @@ def evaluate(test_csvs, create_model, try_loading):
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
num_processes = 1
|
num_processes = 1
|
||||||
|
|
||||||
# Create a saver using variables from the above newly created graph
|
|
||||||
saver = tfv1.train.Saver()
|
|
||||||
|
|
||||||
with tfv1.Session(config=Config.session_config) as session:
|
with tfv1.Session(config=Config.session_config) as session:
|
||||||
# Restore variables from training checkpoint
|
if FLAGS.load == 'auto':
|
||||||
loaded = False
|
method_order = ['best', 'last']
|
||||||
if not loaded and FLAGS.load in ['auto', 'best']:
|
else:
|
||||||
loaded = try_loading(session, saver, 'best_dev_checkpoint', 'best validation')
|
method_order = [FLAGS.load]
|
||||||
if not loaded and FLAGS.load in ['auto', 'last']:
|
load_or_init_graph(session, method_order)
|
||||||
loaded = try_loading(session, saver, 'checkpoint', 'most recent')
|
|
||||||
if not loaded:
|
|
||||||
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
def run_test(init_op, dataset):
|
def run_test(init_op, dataset):
|
||||||
wav_filenames = []
|
wav_filenames = []
|
||||||
|
@ -148,8 +143,8 @@ def main(_):
|
||||||
'the --test_files flag.')
|
'the --test_files flag.')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
from DeepSpeech import create_model, try_loading # pylint: disable=cyclic-import
|
from DeepSpeech import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading)
|
samples = evaluate(FLAGS.test_files.split(','), create_model)
|
||||||
|
|
||||||
if FLAGS.test_output_file:
|
if FLAGS.test_output_file:
|
||||||
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
# Save decoded tuples as JSON, converting NumPy floats to Python floats
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -xe
|
||||||
|
|
||||||
|
source $(dirname "$0")/tc-tests-utils.sh
|
||||||
|
|
||||||
|
pyver_full=$1
|
||||||
|
|
||||||
|
if [ -z "${pyver_full}" ]; then
|
||||||
|
echo "No python version given, aborting."
|
||||||
|
exit 1
|
||||||
|
fi;
|
||||||
|
|
||||||
|
pyver=$(echo "${pyver_full}" | cut -d':' -f1)
|
||||||
|
|
||||||
|
# 2.7.x => 27
|
||||||
|
pyver_pkg=$(echo "${pyver}" | cut -d'.' -f1,2 | tr -d '.')
|
||||||
|
|
||||||
|
py_unicode_type=$(echo "${pyver_full}" | cut -d':' -f2)
|
||||||
|
if [ "${py_unicode_type}" = "m" ]; then
|
||||||
|
pyconf="ucs2"
|
||||||
|
elif [ "${py_unicode_type}" = "mu" ]; then
|
||||||
|
pyconf="ucs4"
|
||||||
|
fi;
|
||||||
|
|
||||||
|
unset PYTHON_BIN_PATH
|
||||||
|
unset PYTHONPATH
|
||||||
|
export PYENV_ROOT="${HOME}/ds-train/.pyenv"
|
||||||
|
export PATH="${PYENV_ROOT}/bin:${HOME}/bin:$PATH"
|
||||||
|
|
||||||
|
mkdir -p ${PYENV_ROOT} || true
|
||||||
|
mkdir -p ${TASKCLUSTER_ARTIFACTS} || true
|
||||||
|
mkdir -p /tmp/train || true
|
||||||
|
mkdir -p /tmp/train_tflite || true
|
||||||
|
|
||||||
|
install_pyenv "${PYENV_ROOT}"
|
||||||
|
install_pyenv_virtualenv "$(pyenv root)/plugins/pyenv-virtualenv"
|
||||||
|
|
||||||
|
PYENV_NAME=deepspeech-train
|
||||||
|
PYTHON_CONFIGURE_OPTS="--enable-unicode=${pyconf}" pyenv install ${pyver}
|
||||||
|
pyenv virtualenv ${pyver} ${PYENV_NAME}
|
||||||
|
source ${PYENV_ROOT}/versions/${pyver}/envs/${PYENV_NAME}/bin/activate
|
||||||
|
|
||||||
|
set -o pipefail
|
||||||
|
pip install --upgrade pip==19.3.1 setuptools==45.0.0 wheel==0.33.6 | cat
|
||||||
|
pip install --upgrade -r ${HOME}/DeepSpeech/ds/requirements.txt | cat
|
||||||
|
set +o pipefail
|
||||||
|
|
||||||
|
pushd ${HOME}/DeepSpeech/ds/
|
||||||
|
verify_ctcdecoder_url
|
||||||
|
popd
|
||||||
|
|
||||||
|
platform=$(python -c 'import sys; import platform; plat = platform.system().lower(); arch = platform.machine().lower(); plat = "manylinux1" if plat == "linux" and arch == "x86_64" else plat; plat = "macosx_10_10" if plat == "darwin" else plat; sys.stdout.write("%s_%s" % (plat, platform.machine()));')
|
||||||
|
whl_ds_version="$(python -c 'from pkg_resources import parse_version; print(parse_version("'${DS_VERSION}'"))')"
|
||||||
|
decoder_pkg="ds_ctcdecoder-${whl_ds_version}-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
|
||||||
|
|
||||||
|
decoder_pkg_url=${DECODER_ARTIFACTS_ROOT}/${decoder_pkg}
|
||||||
|
|
||||||
|
LD_LIBRARY_PATH=${PY37_LDPATH}:$LD_LIBRARY_PATH pip install --verbose --only-binary :all: ${PY37_SOURCE_PACKAGE} ${decoder_pkg_url} | cat
|
||||||
|
|
||||||
|
pushd ${HOME}/DeepSpeech/ds/
|
||||||
|
time ./bin/run-tc-transfer.sh
|
||||||
|
popd
|
||||||
|
|
||||||
|
deactivate
|
|
@ -0,0 +1,12 @@
|
||||||
|
build:
|
||||||
|
template_file: test-linux-opt-base.tyml
|
||||||
|
dependencies:
|
||||||
|
- "linux-amd64-ctc-opt"
|
||||||
|
system_setup:
|
||||||
|
>
|
||||||
|
apt-get -qq -y install ${python.packages_trusty.apt}
|
||||||
|
args:
|
||||||
|
tests_cmdline: "${system.homedir.linux}/DeepSpeech/ds/taskcluster/tc-transfer-tests.sh 3.6.4:m"
|
||||||
|
metadata:
|
||||||
|
name: "DeepSpeech Linux AMD64 CPU transfer learning Py3.6"
|
||||||
|
description: "Training a DeepSpeech LDC93S1 model with transfer learning for Linux/AMD64 16kHz Python 3.6, CPU only, optimized version"
|
|
@ -0,0 +1,115 @@
|
||||||
|
import sys
|
||||||
|
import tensorflow as tf
|
||||||
|
import tensorflow.compat.v1 as tfv1
|
||||||
|
|
||||||
|
from util.flags import FLAGS
|
||||||
|
from util.logging import log_info, log_error, log_warn
|
||||||
|
|
||||||
|
|
||||||
|
def _load_checkpoint(session, checkpoint_path):
|
||||||
|
# Load the checkpoint and put all variables into loading list
|
||||||
|
# we will exclude variables we do not wish to load and then
|
||||||
|
# we will initialize them instead
|
||||||
|
ckpt = tfv1.train.load_checkpoint(checkpoint_path)
|
||||||
|
load_vars = set(tfv1.global_variables())
|
||||||
|
init_vars = set()
|
||||||
|
|
||||||
|
if FLAGS.load_cudnn:
|
||||||
|
# Initialize training from a CuDNN RNN checkpoint
|
||||||
|
# Identify the variables which we cannot load, and set them
|
||||||
|
# for initialization
|
||||||
|
for v in load_vars:
|
||||||
|
try:
|
||||||
|
ckpt.get_tensor(v.op.name)
|
||||||
|
except tf.errors.NotFoundError:
|
||||||
|
log_error('CUDNN variable not found: %s' % (v.op.name))
|
||||||
|
init_vars.add(v)
|
||||||
|
|
||||||
|
load_vars -= init_vars
|
||||||
|
|
||||||
|
# Check that the only missing variables (i.e. those to be initialised)
|
||||||
|
# are the Adam moment tensors, if they aren't then we have an issue
|
||||||
|
init_var_names = [v.op.name for v in init_vars]
|
||||||
|
if any('Adam' not in v for v in init_var_names):
|
||||||
|
log_error('Tried to load a CuDNN RNN checkpoint but there were '
|
||||||
|
'more missing variables than just the Adam moment '
|
||||||
|
'tensors. Missing variables: {}'.format(init_var_names))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if FLAGS.drop_source_layers > 0:
|
||||||
|
# This transfer learning approach requires supplying
|
||||||
|
# the layers which we exclude from the source model.
|
||||||
|
# Say we want to exclude all layers except for the first one,
|
||||||
|
# then we are dropping five layers total, so: drop_source_layers=5
|
||||||
|
# If we want to use all layers from the source model except
|
||||||
|
# the last one, we use this: drop_source_layers=1
|
||||||
|
if FLAGS.drop_source_layers >= 6:
|
||||||
|
log_warn('The checkpoint only has 6 layers, but you are trying to drop '
|
||||||
|
'all of them or more than all of them. Continuing and '
|
||||||
|
'dropping only 5 layers.')
|
||||||
|
FLAGS.drop_source_layers = 5
|
||||||
|
|
||||||
|
dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):]
|
||||||
|
# Initialize all variables needed for DS, but not loaded from ckpt
|
||||||
|
for v in load_vars:
|
||||||
|
if any(layer in v.op.name for layer in dropped_layers):
|
||||||
|
init_vars.add(v)
|
||||||
|
load_vars -= init_vars
|
||||||
|
|
||||||
|
for v in sorted(load_vars, key=lambda v: v.op.name):
|
||||||
|
log_info('Loading variable from checkpoint: %s' % (v.op.name))
|
||||||
|
v.load(ckpt.get_tensor(v.op.name), session=session)
|
||||||
|
|
||||||
|
for v in sorted(init_vars, key=lambda v: v.op.name):
|
||||||
|
log_info('Initializing variable: %s' % (v.op.name))
|
||||||
|
session.run(v.initializer)
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_path_or_none(checkpoint_filename):
|
||||||
|
checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename)
|
||||||
|
if not checkpoint:
|
||||||
|
return None
|
||||||
|
return checkpoint.model_checkpoint_path
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_all_variables(session):
|
||||||
|
init_vars = tfv1.global_variables()
|
||||||
|
for v in init_vars:
|
||||||
|
session.run(v.initializer)
|
||||||
|
|
||||||
|
|
||||||
|
def load_or_init_graph(session, method_order):
|
||||||
|
'''
|
||||||
|
Load variables from checkpoint or initialize variables following the method
|
||||||
|
order specified in the method_order parameter.
|
||||||
|
|
||||||
|
Valid methods are 'best', 'last' and 'init'.
|
||||||
|
'''
|
||||||
|
for method in method_order:
|
||||||
|
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||||
|
if method == 'best':
|
||||||
|
ckpt_path = _checkpoint_path_or_none('best_dev_checkpoint')
|
||||||
|
if ckpt_path:
|
||||||
|
log_info('Loading best validating checkpoint from {}'.format(ckpt_path))
|
||||||
|
return _load_checkpoint(session, ckpt_path)
|
||||||
|
log_info('Could not find best validating checkpoint.')
|
||||||
|
|
||||||
|
# Load most recent checkpoint, saved in checkpoint file 'checkpoint'
|
||||||
|
elif method == 'last':
|
||||||
|
ckpt_path = _checkpoint_path_or_none('checkpoint')
|
||||||
|
if ckpt_path:
|
||||||
|
log_info('Loading most recent checkpoint from {}'.format(ckpt_path))
|
||||||
|
return _load_checkpoint(session, ckpt_path)
|
||||||
|
log_info('Could not find most recent checkpoint.')
|
||||||
|
|
||||||
|
# Initialize all variables
|
||||||
|
elif method == 'init':
|
||||||
|
log_info('Initializing all variables.')
|
||||||
|
return _initialize_all_variables(session)
|
||||||
|
|
||||||
|
else:
|
||||||
|
log_error('Unknown initialization method: {}'.format(method))
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
log_error('All initialization methods failed ({}).'.format(method_order))
|
||||||
|
sys.exit(1)
|
|
@ -124,4 +124,21 @@ def initialize_globals():
|
||||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
if FLAGS.train_cudnn and FLAGS.load_cudnn:
|
||||||
|
log_error('Trying to use --train_cudnn, but --load_cudnn '
|
||||||
|
'was also specified. The --load_cudnn flag is only '
|
||||||
|
'needed when converting a CuDNN RNN checkpoint to '
|
||||||
|
'a CPU-capable graph. If your system is capable of '
|
||||||
|
'using CuDNN RNN, you can just specify the CuDNN RNN '
|
||||||
|
'checkpoint normally with --save_checkpoint_dir.')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# If separate save and load flags were not specified, default to load and save
|
||||||
|
# from the same dir.
|
||||||
|
if not FLAGS.save_checkpoint_dir:
|
||||||
|
FLAGS.save_checkpoint_dir = FLAGS.checkpoint_dir
|
||||||
|
|
||||||
|
if not FLAGS.load_checkpoint_dir:
|
||||||
|
FLAGS.load_checkpoint_dir = FLAGS.checkpoint_dir
|
||||||
|
|
||||||
ConfigSingleton._config = c # pylint: disable=protected-access
|
ConfigSingleton._config = c # pylint: disable=protected-access
|
||||||
|
|
|
@ -49,7 +49,6 @@ def create_flags():
|
||||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
|
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
|
||||||
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
|
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')
|
||||||
|
|
||||||
|
|
||||||
# Global Constants
|
# Global Constants
|
||||||
# ================
|
# ================
|
||||||
|
|
||||||
|
@ -84,9 +83,8 @@ def create_flags():
|
||||||
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('inter_op_parallelism_threads', 0, 'number of inter-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
f.DEFINE_integer('intra_op_parallelism_threads', 0, 'number of intra-op parallelism threads - see tf.ConfigProto for more details. USE OF THIS FLAG IS UNSUPPORTED')
|
||||||
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
|
f.DEFINE_boolean('use_allow_growth', False, 'use Allow Growth flag which will allocate only required amount of GPU memory and prevent full allocation of available GPU memory')
|
||||||
f.DEFINE_boolean('use_cudnn_rnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
f.DEFINE_boolean('load_cudnn', False, 'Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
||||||
f.DEFINE_string('cudnn_checkpoint', '', 'path to a checkpoint created using --use_cudnn_rnn. Specifying this flag allows one to convert a CuDNN RNN checkpoint to a checkpoint capable of running on a CPU graph.')
|
f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work')
|
||||||
|
|
||||||
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.')
|
||||||
|
|
||||||
# Sample limits
|
# Sample limits
|
||||||
|
@ -97,10 +95,16 @@ def create_flags():
|
||||||
|
|
||||||
# Checkpointing
|
# Checkpointing
|
||||||
|
|
||||||
f.DEFINE_string('checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
f.DEFINE_string('checkpoint_dir', '', 'directory from which checkpoints are loaded and to which they are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||||
|
f.DEFINE_string('load_checkpoint_dir', '', 'directory in which checkpoints are stored - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||||
|
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||||
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||||
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validated checkpoint, "init" for initializing a fresh model, "auto" for trying the other options in order last > best > init')
|
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
|
||||||
|
|
||||||
|
# Transfer Learning
|
||||||
|
|
||||||
|
f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)')
|
||||||
|
|
||||||
# Exporting
|
# Exporting
|
||||||
|
|
||||||
|
@ -115,7 +119,7 @@ def create_flags():
|
||||||
|
|
||||||
# Reporting
|
# Reporting
|
||||||
|
|
||||||
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: INFO, 1: WARN, 2: ERROR, 3: FATAL')
|
f.DEFINE_integer('log_level', 1, 'log level for console logs - 0: DEBUG, 1: INFO, 2: WARN, 3: ERROR')
|
||||||
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
f.DEFINE_boolean('show_progressbar', True, 'Show progress for training, validation and testing processes. Log level should be > 0.')
|
||||||
|
|
||||||
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
f.DEFINE_boolean('log_placement', False, 'whether to log device placement of the operators to the console')
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
import os
|
||||||
|
import semver
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def keep_only_digits(txt):
|
def keep_only_digits(txt):
|
||||||
return ''.join(filter(lambda c: c.isdigit(), txt))
|
return ''.join(filter(lambda c: c.isdigit(), txt))
|
||||||
|
@ -8,15 +12,12 @@ def secs_to_hours(secs):
|
||||||
minutes, seconds = divmod(remainder, 60)
|
minutes, seconds = divmod(remainder, 60)
|
||||||
return '%d:%02d:%02d' % (hours, minutes, seconds)
|
return '%d:%02d:%02d' % (hours, minutes, seconds)
|
||||||
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
def check_ctcdecoder_version():
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import semver
|
|
||||||
|
|
||||||
|
def check_ctcdecoder_version():
|
||||||
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip()
|
ds_version_s = open(os.path.join(os.path.dirname(__file__), '../VERSION')).read().strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
from ds_ctcdecoder import __version__ as decoder_version
|
from ds_ctcdecoder import __version__ as decoder_version
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
if e.msg.find('__version__') > 0:
|
if e.msg.find('__version__') > 0:
|
||||||
|
|
Loading…
Reference in New Issue