Split --load into two to avoid unexpected behavior at evaluation time

This commit is contained in:
Reuben Morais 2020-04-06 13:56:15 +02:00
parent cc7a0ada46
commit 0c6e90868e
7 changed files with 62 additions and 104 deletions

View File

@ -31,7 +31,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
@ -45,60 +45,7 @@ for LOAD in 'init' 'last' 'auto'; do
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "#################################################################################"
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "#################################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
done
echo "#######################################################"
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
echo "##### while iterating over loading logic #####"
echo "#######################################################"
for LOAD in 'init' 'last' 'auto'; do
echo "########################################################"
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
echo "########################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
echo "##############################################################################"
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
echo "##############################################################################"
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--alphabet_config_path "./data/alphabet.txt" \
--load "$LOAD" \
--load_train "$LOAD" \
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
@ -114,13 +61,20 @@ for LOAD in 'init' 'last' 'auto'; do
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--drop_source_layers 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load 'last' \
--load_train 'last' \
--train_files "${ru_csv}" --train_batch_size 1 \
--dev_files "${ru_csv}" --dev_batch_size 1 \
--test_files "${ru_csv}" --test_batch_size 1 \
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
--scorer_path '' \
--n_hidden 100 \
--epochs 10
# Test transfer learning checkpoint
python -u evaluate.py --noshow_progressbar \
--test_files "${ru_csv}" --test_batch_size 1 \
--alphabet_config_path "${ru_dir}/alphabet.ru" \
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
--scorer_path '' \
--n_hidden 100
done

View File

@ -16,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
from six.moves import zip
from .util.config import Config, initialize_globals
from .util.checkpoints import load_graph
from .util.checkpoints import load_graph_for_evaluation
from .util.evaluate_tools import calculate_and_print_report
from .util.feeding import create_dataset
from .util.flags import create_flags, FLAGS
@ -82,11 +82,7 @@ def evaluate(test_csvs, create_model):
num_processes = 1
with tfv1.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)
def run_test(init_op, dataset):
wav_filenames = []

View File

@ -30,7 +30,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
from .evaluate import evaluate
from six.moves import zip, range
from .util.config import Config, initialize_globals
from .util.checkpoints import load_or_init_graph_for_training, load_graph
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
from .util.flags import create_flags, FLAGS
from .util.helpers import check_ctcdecoder_version, ExceptionBox
@ -508,11 +508,7 @@ def train():
tfv1.get_default_graph().finalize()
# Load checkpoint or initialize variables
if FLAGS.load == 'auto':
method_order = ['best', 'last', 'init']
else:
method_order = [FLAGS.load]
load_or_init_graph_for_training(session, method_order)
load_or_init_graph_for_training(session)
def run_set(set_name, epoch, init_op, dataset=None):
is_train = set_name == 'train'
@ -773,11 +769,7 @@ def export():
with tf.Session() as session:
# Restore variables from checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)
output_filename = FLAGS.export_file_name + '.pb'
if FLAGS.remove_export:
@ -857,11 +849,7 @@ def do_single_file_inference(input_file_path):
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
# Restore variables from training checkpoint
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)
features, features_len = audiofile_to_features(input_file_path)
previous_state_c = np.zeros([1, Config.n_cell_dim])
@ -896,17 +884,26 @@ def do_single_file_inference(input_file_path):
print(decoded[0][1])
def early_checks():
def early_training_checks():
# Check for proper scorer early
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.scorer_path, Config.alphabet)
del scorer
if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir:
log_warn('WARNING: You specified different values for --load_checkpoint_dir '
'and --save_checkpoint_dir, but you are running training and testing '
'in a single invocation. The testing step will respect --load_checkpoint_dir, '
'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. '
'Train and test in two separate invocations, specifying the correct '
'--load_checkpoint_dir in both cases, or use the same location '
'for loading and saving.')
def main(_):
initialize_globals()
early_checks()
early_training_checks()
if FLAGS.train_files:
tfv1.reset_default_graph()

View File

@ -87,13 +87,7 @@ def _initialize_all_variables(session):
session.run(v.initializer)
def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True):
'''
Load variables from checkpoint or initialize variables following the method
order specified in the method_order parameter.
Valid methods are 'best', 'last' and 'init'.
'''
def _load_or_init_impl(session, method_order, allow_drop_layers):
for method in method_order:
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
if method == 'best':
@ -124,12 +118,29 @@ def load_or_init_graph_for_training(session, method_order, allow_drop_layers=Tru
sys.exit(1)
def load_graph(session, method_order):
def load_or_init_graph_for_training(session):
'''
Load variables from checkpoint. Initialization is not allowed. Follows the
method order specified in the method_order parameter.
Load variables from checkpoint or initialize variables. By default this will
try to load the best validating checkpoint, then try the last checkpoint,
and finally initialize the weights from scratch. This can be overriden with
the `--load_train` flag. See its documentation for more info.
'''
if FLAGS.load_train == 'auto':
methods = ['best', 'last', 'init']
else:
methods = [FLAGS.load_train]
_load_or_init_impl(session, methods, allow_drop_layers=True)
Valid methods are 'best' and 'last'.
def load_graph_for_evaluation(session):
'''
assert('init' not in method_order)
load_or_init_graph_for_training(session, method_order, allow_drop_layers=False)
Load variables from checkpoint. Initialization is not allowed. By default
this will try to load the best validating checkpoint, then try the last
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
documentation for more info.
'''
if FLAGS.load_evaluate == 'auto':
methods = ['best', 'last']
else:
methods = [FLAGS.load_evaluate]
_load_or_init_impl(session, methods, allow_drop_layers=False)

View File

@ -10,7 +10,7 @@ from xdg import BaseDirectory as xdg
from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error
from .logging import log_error, log_warn
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size
@ -45,8 +45,11 @@ def initialize_globals():
if not FLAGS.checkpoint_dir:
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
FLAGS.load = 'auto'
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
FLAGS.load_train = 'auto'
if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
FLAGS.load_evaluate = 'auto'
# Set default summary dir
if not FLAGS.summary_dir:

View File

@ -101,7 +101,8 @@ def create_flags():
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
f.DEFINE_string('load_train', 'auto', 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.')
f.DEFINE_string('load_evaluate', 'auto', 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.')
# Transfer Learning

View File

@ -29,7 +29,7 @@ def fail(message, code=1):
def transcribe_file(audio_path, tlog_path):
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
from deepspeech_training.util.checkpoints import load_graph
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
initialize_globals()
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
try:
@ -50,11 +50,7 @@ def transcribe_file(audio_path, tlog_path):
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
tf.train.get_or_create_global_step()
with tf.Session(config=Config.session_config) as session:
if FLAGS.load == 'auto':
method_order = ['best', 'last']
else:
method_order = [FLAGS.load]
load_graph(session, method_order)
load_graph_for_evaluation(session)
session.run(iterator.make_initializer(data_set))
transcripts = []
while True: