diff --git a/DeepSpeech.py b/DeepSpeech.py index 28af8de7..2bd24231 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -376,7 +376,7 @@ def train(): # Make initialization ops for switching between the two sets train_init_op = iterator.make_initializer(train_set) - if FLAGS.dev: + if FLAGS.dev_files: dev_set, dev_batches = create_dataset(FLAGS.dev_files.split(','), batch_size=FLAGS.dev_batch_size, cache_path=FLAGS.dev_cached_features_path) @@ -508,7 +508,7 @@ def train(): log_info('Finished training epoch %d - loss: %f' % (current_epoch, train_loss)) checkpoint_saver.save(session, checkpoint_path, global_step=global_step) - if FLAGS.dev: + if FLAGS.dev_files: # Validation log_info('Validating epoch %d ...' % current_epoch) dev_loss = run_set('dev', dev_init_op, dev_batches) @@ -794,12 +794,12 @@ def do_single_file_inference(input_file_path): def main(_): initialize_globals() - if FLAGS.train: + if FLAGS.train_files: tf.reset_default_graph() tf.set_random_seed(FLAGS.random_seed) train() - if FLAGS.test: + if FLAGS.test_files: tf.reset_default_graph() test() @@ -807,7 +807,7 @@ def main(_): tf.reset_default_graph() export() - if len(FLAGS.one_shot_infer): + if FLAGS.one_shot_infer: tf.reset_default_graph() do_single_file_inference(FLAGS.one_shot_infer) diff --git a/README.md b/README.md index cc8f5ed8..ab62fe26 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ Refer to the corresponding [README.md](native_client/README.md) for information ### Exporting a model for TFLite -If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--notrain --notest --nouse_seq_length --export_tflite --export_dir /model/export/destination`. +If you want to experiment with the TF Lite engine, you need to export a model that is compatible with it, then use the `--nouse_seq_length --export_tflite` flags. If you already have a trained model, you can re-export it for TFLite by running `DeepSpeech.py` again and specifying the same `checkpoint_dir` that you used for training, as well as passing `--nouse_seq_length --export_tflite --export_dir /model/export/destination`. ### Making a mmap-able model for inference diff --git a/bin/run-ldc93s1.sh b/bin/run-ldc93s1.sh index a50d1dfb..4a6527e8 100755 --- a/bin/run-ldc93s1.sh +++ b/bin/run-ldc93s1.sh @@ -16,12 +16,10 @@ else checkpoint_dir=$(python -c 'from xdg import BaseDirectory as xdg; print(xdg.save_data_path("deepspeech/ldc93s1"))') fi -python -u DeepSpeech.py --noshow_progressbar --nodev \ +python -u DeepSpeech.py --noshow_progressbar \ --train_files data/ldc93s1/ldc93s1.csv \ - --dev_files data/ldc93s1/ldc93s1.csv \ --test_files data/ldc93s1/ldc93s1.csv \ --train_batch_size 1 \ - --dev_batch_size 1 \ --test_batch_size 1 \ --n_hidden 100 \ --epoch 200 \ diff --git a/bin/run-tc-ldc93s1_singleshotinference.sh b/bin/run-tc-ldc93s1_singleshotinference.sh index 25e04f7d..07a4aab1 100755 --- a/bin/run-tc-ldc93s1_singleshotinference.sh +++ b/bin/run-tc-ldc93s1_singleshotinference.sh @@ -20,13 +20,9 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ --lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --lm_trie_path 'data/smoke_test/vocab.trie' -python -u DeepSpeech.py --noshow_progressbar --noearly_stop \ - --train_files ${ldc93s1_csv} --train_batch_size 1 \ - --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ - --test_files ${ldc93s1_csv} --test_batch_size 1 \ - --n_hidden 100 --epoch 1 \ - --max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \ - --learning_rate 0.001 --dropout_rate 0.05 \ +python -u DeepSpeech.py \ + --n_hidden 100 \ + --checkpoint_dir '/tmp/ckpt' \ --lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --lm_trie_path 'data/smoke_test/vocab.trie' \ --one_shot_infer 'data/smoke_test/LDC93S1.wav' diff --git a/bin/run-tc-ldc93s1_tflite.sh b/bin/run-tc-ldc93s1_tflite.sh index d99b4b59..e2e8fb61 100755 --- a/bin/run-tc-ldc93s1_tflite.sh +++ b/bin/run-tc-ldc93s1_tflite.sh @@ -16,5 +16,4 @@ python -u DeepSpeech.py --noshow_progressbar \ --export_dir '/tmp/train_tflite' \ --lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --lm_trie_path 'data/smoke_test/vocab.trie' \ - --notrain --notest \ --export_tflite --nouse_seq_length diff --git a/util/config.py b/util/config.py index 378bb8f2..b7765451 100644 --- a/util/config.py +++ b/util/config.py @@ -92,10 +92,7 @@ def initialize_globals(): # Units in the sixth layer = number of characters in the target language plus one c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label - if len(FLAGS.one_shot_infer) > 0: - FLAGS.train = False - FLAGS.test = False - FLAGS.export_dir = '' + if FLAGS.one_shot_infer: if not os.path.exists(FLAGS.one_shot_infer): log_error('Path specified in --one_shot_infer is not a valid file.') exit(1) diff --git a/util/flags.py b/util/flags.py index 25cc9b8d..0dcc7ab7 100644 --- a/util/flags.py +++ b/util/flags.py @@ -10,9 +10,9 @@ def create_flags(): # Importer # ======== - tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') - tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged') - tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') + tf.app.flags.DEFINE_string ('train_files', '', 'comma separated list of files specifying the dataset used for training. Multiple files will get merged. If empty, training will not be run.') + tf.app.flags.DEFINE_string ('dev_files', '', 'comma separated list of files specifying the dataset used for validation. Multiple files will get merged. If empty, validation will not be run.') + tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. Multiple files will get merged. If empty, the model will not be tested.') tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training') tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged') @@ -22,9 +22,6 @@ def create_flags(): # Global Constants # ================ - tf.app.flags.DEFINE_boolean ('train', True, 'whether to train the network') - tf.app.flags.DEFINE_boolean ('dev', True, 'whether to run validation epochs') - tf.app.flags.DEFINE_boolean ('test', True, 'whether to test the network') tf.app.flags.DEFINE_integer ('epoch', 75, 'target epoch to train - if negative, the absolute number of additional epochs will be trained') tf.app.flags.DEFINE_float ('dropout_rate', 0.05, 'dropout rate for feedforward layers') @@ -113,5 +110,5 @@ def create_flags(): # Inference mode - tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it. Disables training, testing and exporting.') + tf.app.flags.DEFINE_string ('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')