Check if train/dev/test files were passed in instead of having explicit flags

This commit is contained in:
Reuben Morais 2019-04-04 18:33:08 -03:00
parent 232df740db
commit ed15caf3c5
7 changed files with 15 additions and 28 deletions

View File

@ -376,7 +376,7 @@ def train():
# Make initialization ops for switching between the two sets
train_init_op = iterator.make_initializer(train_set)
if FLAGS.dev:
if FLAGS.dev_files:
dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','),
batch_size=FLAGS.dev_batch_size,
cache_path=FLAGS.dev_cached_features_path)
@ -508,7 +508,7 @@ def train():
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
if FLAGS.dev:
if FLAGS.dev_files:
# Validation
log_info('Validating epoch %d ...' % current_epoch)
dev_loss = run_set('dev', dev_init_op, dev_batches)
@ -794,12 +794,12 @@ def do_single_file_inference(input_file_path):
def main(_):
initialize_globals()
if FLAGS.train:
if FLAGS.train_files:
tf.reset_default_graph()
tf.set_random_seed(FLAGS.random_seed)
train()
if FLAGS.test:
if FLAGS.test_files:
tf.reset_default_graph()
test()
@ -807,7 +807,7 @@ def main(_):
tf.reset_default_graph()
export()
if len(FLAGS.one_shot_infer):
if FLAGS.one_shot_infer:
tf.reset_default_graph()
do_single_file_inference(FLAGS.one_shot_infer)

View File

@ -340,7 +340,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information
### Exporting a model for TFLite
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`.
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--nouse_seq_length --export_tflite --export_dir /model/export/destination`.
### Making a mmap-able model for inference

View File

@ -16,12 +16,10 @@ else
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
fi
python -u DeepSpeech.py --noshow_progressbar --nodev \
python -u DeepSpeech.py --noshow_progressbar \
--train_files data/ldc93s1/ldc93s1.csv \
--dev_files data/ldc93s1/ldc93s1.csv \
--test_files data/ldc93s1/ldc93s1.csv \
--train_batch_size 1 \
--dev_batch_size 1 \
--test_batch_size 1 \
--n_hidden 100 \
--epoch 200 \

View File

@ -20,13 +20,9 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie'
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--train_files ${ldc93s1_csv} --train_batch_size 1 \
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
--test_files ${ldc93s1_csv} --test_batch_size 1 \
--n_hidden 100 --epoch 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
--learning_rate 0.001 --dropout_rate 0.05 \
python -u DeepSpeech.py \
--n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--one_shot_infer 'data/smoke_test/LDC93S1.wav'

View File

@ -16,5 +16,4 @@ python -u DeepSpeech.py --noshow_progressbar \
--export_dir '/tmp/train_tflite' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--notrain --notest \
--export_tflite --nouse_seq_length

View File

@ -92,10 +92,7 @@ def initialize_globals():
# Units in the sixth layer = number of characters in the target language plus one
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
if len(FLAGS.one_shot_infer) > 0:
FLAGS.train = False
FLAGS.test = False
FLAGS.export_dir = ''
if FLAGS.one_shot_infer:
if not os.path.exists(FLAGS.one_shot_infer):
log_error('Path specified in --one_shot_infer is not a valid file.')
exit(1)

View File

@ -10,9 +10,9 @@ def create_flags():
# Importer
# ========
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
@ -22,9 +22,6 @@ def create_flags():
# Global Constants
# ================
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs')
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
@ -113,5 +110,5 @@ def create_flags():
# Inference mode
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.')
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')