Check if train/dev/test files were passed in instead of having explicit flags
This commit is contained in:
parent
232df740db
commit
ed15caf3c5
@ -376,7 +376,7 @@ def train():
|
||||
# Make initialization ops for switching between the two sets
|
||||
train_init_op = iterator.make_initializer(train_set)
|
||||
|
||||
if FLAGS.dev:
|
||||
if FLAGS.dev_files:
|
||||
dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','),
|
||||
batch_size=FLAGS.dev_batch_size,
|
||||
cache_path=FLAGS.dev_cached_features_path)
|
||||
@ -508,7 +508,7 @@ def train():
|
||||
log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss))
|
||||
checkpoint_saver.save(session, checkpoint_path, global_step=global_step)
|
||||
|
||||
if FLAGS.dev:
|
||||
if FLAGS.dev_files:
|
||||
# Validation
|
||||
log_info('Validating epoch %d ...' % current_epoch)
|
||||
dev_loss = run_set('dev', dev_init_op, dev_batches)
|
||||
@ -794,12 +794,12 @@ def do_single_file_inference(input_file_path):
|
||||
def main(_):
|
||||
initialize_globals()
|
||||
|
||||
if FLAGS.train:
|
||||
if FLAGS.train_files:
|
||||
tf.reset_default_graph()
|
||||
tf.set_random_seed(FLAGS.random_seed)
|
||||
train()
|
||||
|
||||
if FLAGS.test:
|
||||
if FLAGS.test_files:
|
||||
tf.reset_default_graph()
|
||||
test()
|
||||
|
||||
@ -807,7 +807,7 @@ def main(_):
|
||||
tf.reset_default_graph()
|
||||
export()
|
||||
|
||||
if len(FLAGS.one_shot_infer):
|
||||
if FLAGS.one_shot_infer:
|
||||
tf.reset_default_graph()
|
||||
do_single_file_inference(FLAGS.one_shot_infer)
|
||||
|
||||
|
||||
@ -340,7 +340,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information
|
||||
|
||||
### Exporting a model for TFLite
|
||||
|
||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`.
|
||||
If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--nouse_seq_length --export_tflite --export_dir /model/export/destination`.
|
||||
|
||||
### Making a mmap-able model for inference
|
||||
|
||||
|
||||
@ -16,12 +16,10 @@ else
|
||||
checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))')
|
||||
fi
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --nodev \
|
||||
python -u DeepSpeech.py --noshow_progressbar \
|
||||
--train_files data/ldc93s1/ldc93s1.csv \
|
||||
--dev_files data/ldc93s1/ldc93s1.csv \
|
||||
--test_files data/ldc93s1/ldc93s1.csv \
|
||||
--train_batch_size 1 \
|
||||
--dev_batch_size 1 \
|
||||
--test_batch_size 1 \
|
||||
--n_hidden 100 \
|
||||
--epoch 200 \
|
||||
|
||||
@ -20,13 +20,9 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie'
|
||||
|
||||
python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--train_files ${ldc93s1_csv} --train_batch_size 1 \
|
||||
--dev_files ${ldc93s1_csv} --dev_batch_size 1 \
|
||||
--test_files ${ldc93s1_csv} --test_batch_size 1 \
|
||||
--n_hidden 100 --epoch 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
python -u DeepSpeech.py \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--one_shot_infer 'data/smoke_test/LDC93S1.wav'
|
||||
|
||||
@ -16,5 +16,4 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--export_dir '/tmp/train_tflite' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--notrain --notest \
|
||||
--export_tflite --nouse_seq_length
|
||||
|
||||
@ -92,10 +92,7 @@ def initialize_globals():
|
||||
# Units in the sixth layer = number of characters in the target language plus one
|
||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||
|
||||
if len(FLAGS.one_shot_infer) > 0:
|
||||
FLAGS.train = False
|
||||
FLAGS.test = False
|
||||
FLAGS.export_dir = ''
|
||||
if FLAGS.one_shot_infer:
|
||||
if not os.path.exists(FLAGS.one_shot_infer):
|
||||
log_error('Path specified in --one_shot_infer is not a valid file.')
|
||||
exit(1)
|
||||
|
||||
@ -10,9 +10,9 @@ def create_flags():
|
||||
# Importer
|
||||
# ========
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||
tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.')
|
||||
tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.')
|
||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.')
|
||||
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
||||
|
||||
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||
@ -22,9 +22,6 @@ def create_flags():
|
||||
# Global Constants
|
||||
# ================
|
||||
|
||||
tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network')
|
||||
tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs')
|
||||
tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network')
|
||||
tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained')
|
||||
|
||||
tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers')
|
||||
@ -113,5 +110,5 @@ def create_flags():
|
||||
|
||||
# Inference mode
|
||||
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.')
|
||||
tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user