Split --load into two to avoid unexpected behavior at evaluation time
This commit is contained in:
parent
cc7a0ada46
commit
0c6e90868e
@ -31,7 +31,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
@ -45,60 +45,7 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
echo "#################################################################################"
|
||||
echo "#### Transfer Russian model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "#################################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
done
|
||||
|
||||
echo "#######################################################"
|
||||
echo "##### Train ENGLISH model and transfer to RUSSIAN #####"
|
||||
echo "##### while iterating over loading logic #####"
|
||||
echo "#######################################################"
|
||||
|
||||
for LOAD in 'init' 'last' 'auto'; do
|
||||
echo "########################################################"
|
||||
echo "#### Train ENGLISH model with just --checkpoint_dir ####"
|
||||
echo "########################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
--checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
|
||||
echo "##############################################################################"
|
||||
echo "#### Train ENGLISH model with --save_checkpoint_dir --load_checkpoint_dir ####"
|
||||
echo "##############################################################################"
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--alphabet_config_path "./data/alphabet.txt" \
|
||||
--load "$LOAD" \
|
||||
--load_train "$LOAD" \
|
||||
--train_files "${ldc93s1_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ldc93s1_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ldc93s1_csv}" --test_batch_size 1 \
|
||||
@ -114,13 +61,20 @@ for LOAD in 'init' 'last' 'auto'; do
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--drop_source_layers 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load 'last' \
|
||||
--load_train 'last' \
|
||||
--train_files "${ru_csv}" --train_batch_size 1 \
|
||||
--dev_files "${ru_csv}" --dev_batch_size 1 \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--save_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/eng' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100 \
|
||||
--epochs 10
|
||||
|
||||
# Test transfer learning checkpoint
|
||||
python -u evaluate.py --noshow_progressbar \
|
||||
--test_files "${ru_csv}" --test_batch_size 1 \
|
||||
--alphabet_config_path "${ru_dir}/alphabet.ru" \
|
||||
--load_checkpoint_dir '/tmp/ckpt/transfer/ru' \
|
||||
--scorer_path '' \
|
||||
--n_hidden 100
|
||||
done
|
||||
|
@ -16,7 +16,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder_batch, Scorer
|
||||
from six.moves import zip
|
||||
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_graph
|
||||
from .util.checkpoints import load_graph_for_evaluation
|
||||
from .util.evaluate_tools import calculate_and_print_report
|
||||
from .util.feeding import create_dataset
|
||||
from .util.flags import create_flags, FLAGS
|
||||
@ -82,11 +82,7 @@ def evaluate(test_csvs, create_model):
|
||||
num_processes = 1
|
||||
|
||||
with tfv1.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_graph(session, method_order)
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
def run_test(init_op, dataset):
|
||||
wav_filenames = []
|
||||
|
@ -30,7 +30,7 @@ from ds_ctcdecoder import ctc_beam_search_decoder, Scorer
|
||||
from .evaluate import evaluate
|
||||
from six.moves import zip, range
|
||||
from .util.config import Config, initialize_globals
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph
|
||||
from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation
|
||||
from .util.feeding import create_dataset, samples_to_mfccs, audiofile_to_features
|
||||
from .util.flags import create_flags, FLAGS
|
||||
from .util.helpers import check_ctcdecoder_version, ExceptionBox
|
||||
@ -508,11 +508,7 @@ def train():
|
||||
tfv1.get_default_graph().finalize()
|
||||
|
||||
# Load checkpoint or initialize variables
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last', 'init']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_or_init_graph_for_training(session, method_order)
|
||||
load_or_init_graph_for_training(session)
|
||||
|
||||
def run_set(set_name, epoch, init_op, dataset=None):
|
||||
is_train = set_name == 'train'
|
||||
@ -773,11 +769,7 @@ def export():
|
||||
|
||||
with tf.Session() as session:
|
||||
# Restore variables from checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_graph(session, method_order)
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
output_filename = FLAGS.export_file_name + '.pb'
|
||||
if FLAGS.remove_export:
|
||||
@ -857,11 +849,7 @@ def do_single_file_inference(input_file_path):
|
||||
inputs, outputs, _ = create_inference_graph(batch_size=1, n_steps=-1)
|
||||
|
||||
# Restore variables from training checkpoint
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_graph(session, method_order)
|
||||
load_graph_for_evaluation(session)
|
||||
|
||||
features, features_len = audiofile_to_features(input_file_path)
|
||||
previous_state_c = np.zeros([1, Config.n_cell_dim])
|
||||
@ -896,17 +884,26 @@ def do_single_file_inference(input_file_path):
|
||||
print(decoded[0][1])
|
||||
|
||||
|
||||
def early_checks():
|
||||
def early_training_checks():
|
||||
# Check for proper scorer early
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
del scorer
|
||||
|
||||
if FLAGS.train_files and FLAGS.test_files and FLAGS.load_checkpoint_dir != FLAGS.save_checkpoint_dir:
|
||||
log_warn('WARNING: You specified different values for --load_checkpoint_dir '
|
||||
'and --save_checkpoint_dir, but you are running training and testing '
|
||||
'in a single invocation. The testing step will respect --load_checkpoint_dir, '
|
||||
'and thus WILL NOT TEST THE CHECKPOINT CREATED BY THE TRAINING STEP. '
|
||||
'Train and test in two separate invocations, specifying the correct '
|
||||
'--load_checkpoint_dir in both cases, or use the same location '
|
||||
'for loading and saving.')
|
||||
|
||||
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
early_checks()
|
||||
early_training_checks()
|
||||
|
||||
if FLAGS.train_files:
|
||||
tfv1.reset_default_graph()
|
||||
|
@ -87,13 +87,7 @@ def _initialize_all_variables(session):
|
||||
session.run(v.initializer)
|
||||
|
||||
|
||||
def load_or_init_graph_for_training(session, method_order, allow_drop_layers=True):
|
||||
'''
|
||||
Load variables from checkpoint or initialize variables following the method
|
||||
order specified in the method_order parameter.
|
||||
|
||||
Valid methods are 'best', 'last' and 'init'.
|
||||
'''
|
||||
def _load_or_init_impl(session, method_order, allow_drop_layers):
|
||||
for method in method_order:
|
||||
# Load best validating checkpoint, saved in checkpoint file 'best_dev_checkpoint'
|
||||
if method == 'best':
|
||||
@ -124,12 +118,29 @@ def load_or_init_graph_for_training(session, method_order, allow_drop_layers=Tru
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def load_graph(session, method_order):
|
||||
def load_or_init_graph_for_training(session):
|
||||
'''
|
||||
Load variables from checkpoint. Initialization is not allowed. Follows the
|
||||
method order specified in the method_order parameter.
|
||||
Load variables from checkpoint or initialize variables. By default this will
|
||||
try to load the best validating checkpoint, then try the last checkpoint,
|
||||
and finally initialize the weights from scratch. This can be overriden with
|
||||
the `--load_train` flag. See its documentation for more info.
|
||||
'''
|
||||
if FLAGS.load_train == 'auto':
|
||||
methods = ['best', 'last', 'init']
|
||||
else:
|
||||
methods = [FLAGS.load_train]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=True)
|
||||
|
||||
Valid methods are 'best' and 'last'.
|
||||
|
||||
def load_graph_for_evaluation(session):
|
||||
'''
|
||||
assert('init' not in method_order)
|
||||
load_or_init_graph_for_training(session, method_order, allow_drop_layers=False)
|
||||
Load variables from checkpoint. Initialization is not allowed. By default
|
||||
this will try to load the best validating checkpoint, then try the last
|
||||
checkpoint. This can be overriden with the `--load_evaluate` flag. See its
|
||||
documentation for more info.
|
||||
'''
|
||||
if FLAGS.load_evaluate == 'auto':
|
||||
methods = ['best', 'last']
|
||||
else:
|
||||
methods = [FLAGS.load_evaluate]
|
||||
_load_or_init_impl(session, methods, allow_drop_layers=False)
|
||||
|
@ -10,7 +10,7 @@ from xdg import BaseDirectory as xdg
|
||||
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .logging import log_error
|
||||
from .logging import log_error, log_warn
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
|
||||
@ -45,8 +45,11 @@ def initialize_globals():
|
||||
if not FLAGS.checkpoint_dir:
|
||||
FLAGS.checkpoint_dir = xdg.save_data_path(os.path.join('deepspeech', 'checkpoints'))
|
||||
|
||||
if FLAGS.load not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load = 'auto'
|
||||
if FLAGS.load_train not in ['last', 'best', 'init', 'auto']:
|
||||
FLAGS.load_train = 'auto'
|
||||
|
||||
if FLAGS.load_evaluate not in ['last', 'best', 'auto']:
|
||||
FLAGS.load_evaluate = 'auto'
|
||||
|
||||
# Set default summary dir
|
||||
if not FLAGS.summary_dir:
|
||||
|
@ -101,7 +101,8 @@ def create_flags():
|
||||
f.DEFINE_string('save_checkpoint_dir', '', 'directory to which checkpoints are saved - defaults to directory "deepspeech/checkpoints" within user\'s data home specified by the XDG Base Directory Specification')
|
||||
f.DEFINE_integer('checkpoint_secs', 600, 'checkpoint saving interval in seconds')
|
||||
f.DEFINE_integer('max_to_keep', 5, 'number of checkpoint files to keep - default value is 5')
|
||||
f.DEFINE_string('load', 'auto', '"last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a fresh model, "transfer" for transfer learning, "auto" for trying several options.')
|
||||
f.DEFINE_string('load_train', 'auto', 'what checkpoint to load before starting the training process. "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "init" for initializing a new checkpoint, "auto" for trying several options.')
|
||||
f.DEFINE_string('load_evaluate', 'auto', 'what checkpoint to load for evaluation tasks (test epochs, model export, single file inference, etc). "last" for loading most recent epoch checkpoint, "best" for loading best validation loss checkpoint, "auto" for trying several options.')
|
||||
|
||||
# Transfer Learning
|
||||
|
||||
|
@ -29,7 +29,7 @@ def fail(message, code=1):
|
||||
|
||||
def transcribe_file(audio_path, tlog_path):
|
||||
from deepspeech_training.train import create_model # pylint: disable=cyclic-import,import-outside-toplevel
|
||||
from deepspeech_training.util.checkpoints import load_graph
|
||||
from deepspeech_training.util.checkpoints import load_graph_for_evaluation
|
||||
initialize_globals()
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
|
||||
try:
|
||||
@ -50,11 +50,7 @@ def transcribe_file(audio_path, tlog_path):
|
||||
transposed = tf.nn.softmax(tf.transpose(logits, [1, 0, 2]))
|
||||
tf.train.get_or_create_global_step()
|
||||
with tf.Session(config=Config.session_config) as session:
|
||||
if FLAGS.load == 'auto':
|
||||
method_order = ['best', 'last']
|
||||
else:
|
||||
method_order = [FLAGS.load]
|
||||
load_graph(session, method_order)
|
||||
load_graph_for_evaluation(session)
|
||||
session.run(iterator.make_initializer(data_set))
|
||||
transcripts = []
|
||||
while True:
|
||||
|
Loading…
x
Reference in New Issue
Block a user