diff --git a/examples/ffmpeg_vad_streaming/index.js b/examples/ffmpeg_vad_streaming/index.js index 37f6d871..b209d3b9 100644 --- a/examples/ffmpeg_vad_streaming/index.js +++ b/examples/ffmpeg_vad_streaming/index.js @@ -11,11 +11,10 @@ const util = require('util'); const BEAM_WIDTH = 1024; // The alpha hyperparameter of the CTC decoder. Language Model weight -const LM_WEIGHT = 1.50; +const LM_ALPHA = 0.75; -// Valid word insertion weight. This is used to lessen the word insertion penalty -// when the inserted word is part of the vocabulary -const VALID_WORD_COUNT_WEIGHT = 2.25; +// The beta hyperparameter of the CTC decoder. Word insertion bonus. +const LM_BETA = 1.85; // These constants are tied to the shape of the graph used (changing them changes // the geometry of the first layer), so make sure you use the same constants that @@ -63,7 +62,7 @@ if (args['lm'] && args['trie']) { console.error('Loading language model from files %s %s', args['lm'], args['trie']); const lm_load_start = process.hrtime(); model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'], - LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); + LM_ALPHA, LM_BETA); const lm_load_end = process.hrtime(lm_load_start); console.error('Loaded language model in %ds.', totalTime(lm_load_end)); } diff --git a/examples/mic_vad_streaming/README.md b/examples/mic_vad_streaming/README.md index 211a9b61..2ac40be0 100644 --- a/examples/mic_vad_streaming/README.md +++ b/examples/mic_vad_streaming/README.md @@ -26,7 +26,7 @@ brew install portaudio usage: mic_vad_streaming.py [-h] [-v VAD_AGGRESSIVENESS] [--nospinner] [-w SAVEWAV] -m MODEL [-a ALPHABET] [-l LM] [-t TRIE] [-nf N_FEATURES] [-nc N_CONTEXT] - [-lw LM_WEIGHT] [-vwcw VALID_WORD_COUNT_WEIGHT] + [-la LM_ALPHA] [-lb LM_BETA] [-bw BEAM_WIDTH] Stream from microphone to DeepSpeech using VAD @@ -56,13 +56,12 @@ optional arguments: -nc N_CONTEXT, --n_context N_CONTEXT Size of the context window used for producing timesteps in the input vector. Default: 9 - -lw LM_WEIGHT, --lm_weight LM_WEIGHT + -la LM_ALPHA, --lm_alpha LM_ALPHA The alpha hyperparameter of the CTC decoder. Language - Model weight. Default: 1.5 - -vwcw VALID_WORD_COUNT_WEIGHT, --valid_word_count_weight VALID_WORD_COUNT_WEIGHT - Valid word insertion weight. This is used to lessen - the word insertion penalty when the inserted word is - part of the vocabulary. Default: 2.1 + Model weight. Default: 0.75 + -lb LM_BETA, --lm_beta LM_BETA + The beta hyperparameter of the CTC decoder. Word insertion + bonus. Default: 1.85 -bw BEAM_WIDTH, --beam_width BEAM_WIDTH Beam width used in the CTC decoder when building candidate transcriptions. Default: 500 diff --git a/examples/mic_vad_streaming/mic_vad_streaming.py b/examples/mic_vad_streaming/mic_vad_streaming.py index 4a7244c8..259d4fef 100644 --- a/examples/mic_vad_streaming/mic_vad_streaming.py +++ b/examples/mic_vad_streaming/mic_vad_streaming.py @@ -118,7 +118,7 @@ def main(ARGS): if ARGS.lm and ARGS.trie: logging.info("ARGS.lm: %s", ARGS.lm) logging.info("ARGS.trie: %s", ARGS.trie) - model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_weight, ARGS.valid_word_count_weight) + model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_alpha, ARGS.lm_beta) # Start audio with VAD vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness) @@ -148,8 +148,8 @@ def main(ARGS): if __name__ == '__main__': BEAM_WIDTH = 500 - LM_WEIGHT = 1.50 - VALID_WORD_COUNT_WEIGHT = 2.10 + LM_ALPHA = 0.75 + LM_BETA = 1.85 N_FEATURES = 26 N_CONTEXT = 9 @@ -175,10 +175,10 @@ if __name__ == '__main__': help=f"Number of MFCC features to use. Default: {N_FEATURES}") parser.add_argument('-nc', '--n_context', type=int, default=N_CONTEXT, help=f"Size of the context window used for producing timesteps in the input vector. Default: {N_CONTEXT}") - parser.add_argument('-lw', '--lm_weight', type=float, default=LM_WEIGHT, - help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_WEIGHT}") - parser.add_argument('-vwcw', '--valid_word_count_weight', type=float, default=VALID_WORD_COUNT_WEIGHT, - help=f"Valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary. Default: {VALID_WORD_COUNT_WEIGHT}") + parser.add_argument('-la', '--lm_alpha', type=float, default=LM_ALPHA, + help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_ALPHA}") + parser.add_argument('-lb', '--lm_beta', type=float, default=LM_BETA, + help=f"The beta hyperparameter of the CTC decoder. Word insertion bonus. Default: {LM_BETA}") parser.add_argument('-bw', '--beam_width', type=int, default=BEAM_WIDTH, help=f"Beam width used in the CTC decoder when building candidate transcriptions. Default: {BEAM_WIDTH}") diff --git a/examples/net_framework/CSharpExamples/DeepSpeechConsole/Program.cs b/examples/net_framework/CSharpExamples/DeepSpeechConsole/Program.cs index 6cfa554b..f7c0684c 100644 --- a/examples/net_framework/CSharpExamples/DeepSpeechConsole/Program.cs +++ b/examples/net_framework/CSharpExamples/DeepSpeechConsole/Program.cs @@ -38,8 +38,8 @@ namespace CSharpExamples const uint N_CEP = 26; const uint N_CONTEXT = 9; const uint BEAM_WIDTH = 200; - const float LM_WEIGHT = 1.50f; - const float VALID_WORD_COUNT_WEIGHT = 2.10f; + const float LM_ALPHA = 0.75f; + const float LM_BETA = 1.85f; Stopwatch stopwatch = new Stopwatch(); @@ -76,7 +76,7 @@ namespace CSharpExamples alphabet ?? "alphabet.txt", lm ?? "lm.binary", trie ?? "trie", - LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); + LM_ALPHA, LM_BETA); } catch (IOException ex) { diff --git a/examples/net_framework/CSharpExamples/DeepSpeechWPF/MainWindow.xaml.cs b/examples/net_framework/CSharpExamples/DeepSpeechWPF/MainWindow.xaml.cs index b59d34a9..b0e2a74e 100644 --- a/examples/net_framework/CSharpExamples/DeepSpeechWPF/MainWindow.xaml.cs +++ b/examples/net_framework/CSharpExamples/DeepSpeechWPF/MainWindow.xaml.cs @@ -24,8 +24,8 @@ namespace DeepSpeechWPF private const uint N_CEP = 26; private const uint N_CONTEXT = 9; private const uint BEAM_WIDTH = 500; - private const float LM_WEIGHT = 1.50f; - private const float VALID_WORD_COUNT_WEIGHT = 2.10f; + private const float LM_ALPHA = 0.75f; + private const float LM_BETA = 1.85f; @@ -160,7 +160,7 @@ namespace DeepSpeechWPF { try { - if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) != 0) + if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_ALPHA, LM_BETA) != 0) { MessageBox.Show("Error loading LM."); Dispatcher.Invoke(() => btnEnableLM.IsEnabled = true); diff --git a/examples/vad_transcriber/wavTranscriber.py b/examples/vad_transcriber/wavTranscriber.py index b3f0d272..2735879f 100644 --- a/examples/vad_transcriber/wavTranscriber.py +++ b/examples/vad_transcriber/wavTranscriber.py @@ -19,8 +19,8 @@ def load_model(models, alphabet, lm, trie): N_FEATURES = 26 N_CONTEXT = 9 BEAM_WIDTH = 500 - LM_WEIGHT = 1.50 - VALID_WORD_COUNT_WEIGHT = 2.10 + LM_ALPHA = 0.75 + LM_BETA = 1.85 model_load_start = timer() ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH) @@ -28,7 +28,7 @@ def load_model(models, alphabet, lm, trie): logging.debug("Loaded model in %0.3fs." % (model_load_end)) lm_load_start = timer() - ds.enableDecoderWithLM(alphabet, lm, trie, LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) + ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA) lm_load_end = timer() - lm_load_start logging.debug('Loaded language model in %0.3fs.' % (lm_load_end)) diff --git a/native_client/client.cc b/native_client/client.cc index ca8c9f94..f8be754d 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -24,8 +24,8 @@ #define N_CEP 26 #define N_CONTEXT 9 #define BEAM_WIDTH 500 -#define LM_WEIGHT 1.50f -#define VALID_WORD_COUNT_WEIGHT 2.10f +#define LM_ALPHA 0.75f +#define LM_BETA 1.85f typedef struct { const char* string; @@ -253,8 +253,8 @@ main(int argc, char **argv) alphabet, lm, trie, - LM_WEIGHT, - VALID_WORD_COUNT_WEIGHT); + LM_ALPHA, + LM_BETA); if (status != 0) { fprintf(stderr, "Could not enable CTC decoder with LM.\n"); return 1; diff --git a/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/DeepSpeechActivity.java b/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/DeepSpeechActivity.java index 0635f604..159403cd 100644 --- a/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/DeepSpeechActivity.java +++ b/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/DeepSpeechActivity.java @@ -38,8 +38,8 @@ public class DeepSpeechActivity extends AppCompatActivity { final int N_CEP = 26; final int N_CONTEXT = 9; final int BEAM_WIDTH = 50; - final float LM_WEIGHT = 1.50f; - final float VALID_WORD_COUNT_WEIGHT = 2.10f; + final float LM_ALPHA = 0.75f; + final float LM_BETA = 1.85f; private char readLEChar(RandomAccessFile f) throws IOException { byte b1 = f.readByte(); diff --git a/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/Model.java b/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/Model.java index 303f03d7..a3b147c6 100644 --- a/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/Model.java +++ b/native_client/java/app/src/main/java/deepspeech/mozilla/org/deepspeech/Model.java @@ -16,8 +16,8 @@ public class Model { impl.DestroyModel(this._msp); } - public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_weight, float valid_word_count_weight) { - impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_weight, valid_word_count_weight); + public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_alpha, float lm_beta) { + impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_alpha, lm_beta); } public String stt(short[] buffer, int buffer_size, int sample_rate) { diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index 1e92379f..a7a72de3 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -15,11 +15,10 @@ const util = require('util'); const BEAM_WIDTH = 500; // The alpha hyperparameter of the CTC decoder. Language Model weight -const LM_WEIGHT = 1.50; +const LM_ALPHA = 0.75; -// Valid word insertion weight. This is used to lessen the word insertion penalty -// when the inserted word is part of the vocabulary -const VALID_WORD_COUNT_WEIGHT = 2.10; +// The beta hyperparameter of the CTC decoder. Word insertion bonus. +const LM_BETA = 1.85; // These constants are tied to the shape of the graph used (changing them changes @@ -102,7 +101,7 @@ audioStream.on('finish', () => { console.error('Loading language model from files %s %s', args['lm'], args['trie']); const lm_load_start = process.hrtime(); model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'], - LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); + LM_ALPHA, LM_BETA); const lm_load_end = process.hrtime(lm_load_start); console.error('Loaded language model in %ds.', totalTime(lm_load_end)); } diff --git a/native_client/python/client.py b/native_client/python/client.py index 60e8097c..4db4c7a9 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -23,11 +23,10 @@ except ImportError: BEAM_WIDTH = 500 # The alpha hyperparameter of the CTC decoder. Language Model weight -LM_WEIGHT = 1.50 +LM_ALPHA = 0.75 -# Valid word insertion weight. This is used to lessen the word insertion penalty -# when the inserted word is part of the vocabulary -VALID_WORD_COUNT_WEIGHT = 2.10 +# The beta hyperparameter of the CTC decoder. Word insertion bonus. +LM_BETA = 1.85 # These constants are tied to the shape of the graph used (changing them changes @@ -85,8 +84,7 @@ def main(): if args.lm and args.trie: print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr) lm_load_start = timer() - ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_WEIGHT, - VALID_WORD_COUNT_WEIGHT) + ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA) lm_load_end = timer() - lm_load_start print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr) diff --git a/tc-tests-utils.sh b/tc-tests-utils.sh index c6685bfa..efb86859 100755 --- a/tc-tests-utils.sh +++ b/tc-tests-utils.sh @@ -151,12 +151,12 @@ assert_correct_multi_ldc93s1() assert_correct_ldc93s1_prodmodel() { - assert_correct_inference "$1" "she had a due in greasy wash water year" + assert_correct_inference "$1" "she had a due and greasy wash water year" } assert_correct_ldc93s1_prodmodel_stereo_44k() { - assert_correct_inference "$1" "she had a due in greasy wash water year" + assert_correct_inference "$1" "she had a due and greasy wash water year" } assert_correct_warning_upsampling() diff --git a/util/flags.py b/util/flags.py index 4925b733..dc8e6a3e 100644 --- a/util/flags.py +++ b/util/flags.py @@ -131,8 +131,8 @@ def create_flags(): tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') - tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') - tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') + tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') + tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') # Inference mode