Merge pull request #1810 from mozilla/update-new-hparams

Update decoder hyperparameters
This commit is contained in:
Reuben Morais 2019-01-02 13:44:43 +00:00 committed by GitHub
commit fc46f4382d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 46 additions and 51 deletions

View File

@ -11,11 +11,10 @@ const util = require('util');
const BEAM_WIDTH = 1024; const BEAM_WIDTH = 1024;
// The alpha hyperparameter of the CTC decoder. Language Model weight // The alpha hyperparameter of the CTC decoder. Language Model weight
const LM_WEIGHT = 1.50; const LM_ALPHA = 0.75;
// Valid word insertion weight. This is used to lessen the word insertion penalty // The beta hyperparameter of the CTC decoder. Word insertion bonus.
// when the inserted word is part of the vocabulary const LM_BETA = 1.85;
const VALID_WORD_COUNT_WEIGHT = 2.25;
// These constants are tied to the shape of the graph used (changing them changes // These constants are tied to the shape of the graph used (changing them changes
// the geometry of the first layer), so make sure you use the same constants that // the geometry of the first layer), so make sure you use the same constants that
@ -63,7 +62,7 @@ if (args['lm'] && args['trie']) {
console.error('Loading language model from files %s %s', args['lm'], args['trie']); console.error('Loading language model from files %s %s', args['lm'], args['trie']);
const lm_load_start = process.hrtime(); const lm_load_start = process.hrtime();
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'], model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); LM_ALPHA, LM_BETA);
const lm_load_end = process.hrtime(lm_load_start); const lm_load_end = process.hrtime(lm_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end)); console.error('Loaded language model in %ds.', totalTime(lm_load_end));
} }

View File

@ -26,7 +26,7 @@ brew install portaudio
usage: mic_vad_streaming.py [-h] [-v VAD_AGGRESSIVENESS] [--nospinner] usage: mic_vad_streaming.py [-h] [-v VAD_AGGRESSIVENESS] [--nospinner]
[-w SAVEWAV] -m MODEL [-a ALPHABET] [-l LM] [-w SAVEWAV] -m MODEL [-a ALPHABET] [-l LM]
[-t TRIE] [-nf N_FEATURES] [-nc N_CONTEXT] [-t TRIE] [-nf N_FEATURES] [-nc N_CONTEXT]
[-lw LM_WEIGHT] [-vwcw VALID_WORD_COUNT_WEIGHT] [-la LM_ALPHA] [-lb LM_BETA]
[-bw BEAM_WIDTH] [-bw BEAM_WIDTH]
Stream from microphone to DeepSpeech using VAD Stream from microphone to DeepSpeech using VAD
@ -56,13 +56,12 @@ optional arguments:
-nc N_CONTEXT, --n_context N_CONTEXT -nc N_CONTEXT, --n_context N_CONTEXT
Size of the context window used for producing Size of the context window used for producing
timesteps in the input vector. Default: 9 timesteps in the input vector. Default: 9
-lw LM_WEIGHT, --lm_weight LM_WEIGHT -la LM_ALPHA, --lm_alpha LM_ALPHA
The alpha hyperparameter of the CTC decoder. Language The alpha hyperparameter of the CTC decoder. Language
Model weight. Default: 1.5 Model weight. Default: 0.75
-vwcw VALID_WORD_COUNT_WEIGHT, --valid_word_count_weight VALID_WORD_COUNT_WEIGHT -lb LM_BETA, --lm_beta LM_BETA
Valid word insertion weight. This is used to lessen The beta hyperparameter of the CTC decoder. Word insertion
the word insertion penalty when the inserted word is bonus. Default: 1.85
part of the vocabulary. Default: 2.1
-bw BEAM_WIDTH, --beam_width BEAM_WIDTH -bw BEAM_WIDTH, --beam_width BEAM_WIDTH
Beam width used in the CTC decoder when building Beam width used in the CTC decoder when building
candidate transcriptions. Default: 500 candidate transcriptions. Default: 500

View File

@ -118,7 +118,7 @@ def main(ARGS):
if ARGS.lm and ARGS.trie: if ARGS.lm and ARGS.trie:
logging.info("ARGS.lm: %s", ARGS.lm) logging.info("ARGS.lm: %s", ARGS.lm)
logging.info("ARGS.trie: %s", ARGS.trie) logging.info("ARGS.trie: %s", ARGS.trie)
model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_weight, ARGS.valid_word_count_weight) model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_alpha, ARGS.lm_beta)
# Start audio with VAD # Start audio with VAD
vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness) vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness)
@ -148,8 +148,8 @@ def main(ARGS):
if __name__ == '__main__': if __name__ == '__main__':
BEAM_WIDTH = 500 BEAM_WIDTH = 500
LM_WEIGHT = 1.50 LM_ALPHA = 0.75
VALID_WORD_COUNT_WEIGHT = 2.10 LM_BETA = 1.85
N_FEATURES = 26 N_FEATURES = 26
N_CONTEXT = 9 N_CONTEXT = 9
@ -175,10 +175,10 @@ if __name__ == '__main__':
help=f"Number of MFCC features to use. Default: {N_FEATURES}") help=f"Number of MFCC features to use. Default: {N_FEATURES}")
parser.add_argument('-nc', '--n_context', type=int, default=N_CONTEXT, parser.add_argument('-nc', '--n_context', type=int, default=N_CONTEXT,
help=f"Size of the context window used for producing timesteps in the input vector. Default: {N_CONTEXT}") help=f"Size of the context window used for producing timesteps in the input vector. Default: {N_CONTEXT}")
parser.add_argument('-lw', '--lm_weight', type=float, default=LM_WEIGHT, parser.add_argument('-la', '--lm_alpha', type=float, default=LM_ALPHA,
help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_WEIGHT}") help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_ALPHA}")
parser.add_argument('-vwcw', '--valid_word_count_weight', type=float, default=VALID_WORD_COUNT_WEIGHT, parser.add_argument('-lb', '--lm_beta', type=float, default=LM_BETA,
help=f"Valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary. Default: {VALID_WORD_COUNT_WEIGHT}") help=f"The beta hyperparameter of the CTC decoder. Word insertion bonus. Default: {LM_BETA}")
parser.add_argument('-bw', '--beam_width', type=int, default=BEAM_WIDTH, parser.add_argument('-bw', '--beam_width', type=int, default=BEAM_WIDTH,
help=f"Beam width used in the CTC decoder when building candidate transcriptions. Default: {BEAM_WIDTH}") help=f"Beam width used in the CTC decoder when building candidate transcriptions. Default: {BEAM_WIDTH}")

View File

@ -38,8 +38,8 @@ namespace CSharpExamples
const uint N_CEP = 26; const uint N_CEP = 26;
const uint N_CONTEXT = 9; const uint N_CONTEXT = 9;
const uint BEAM_WIDTH = 200; const uint BEAM_WIDTH = 200;
const float LM_WEIGHT = 1.50f; const float LM_ALPHA = 0.75f;
const float VALID_WORD_COUNT_WEIGHT = 2.10f; const float LM_BETA = 1.85f;
Stopwatch stopwatch = new Stopwatch(); Stopwatch stopwatch = new Stopwatch();
@ -76,7 +76,7 @@ namespace CSharpExamples
alphabet ?? "alphabet.txt", alphabet ?? "alphabet.txt",
lm ?? "lm.binary", lm ?? "lm.binary",
trie ?? "trie", trie ?? "trie",
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); LM_ALPHA, LM_BETA);
} }
catch (IOException ex) catch (IOException ex)
{ {

View File

@ -24,8 +24,8 @@ namespace DeepSpeechWPF
private const uint N_CEP = 26; private const uint N_CEP = 26;
private const uint N_CONTEXT = 9; private const uint N_CONTEXT = 9;
private const uint BEAM_WIDTH = 500; private const uint BEAM_WIDTH = 500;
private const float LM_WEIGHT = 1.50f; private const float LM_ALPHA = 0.75f;
private const float VALID_WORD_COUNT_WEIGHT = 2.10f; private const float LM_BETA = 1.85f;
@ -160,7 +160,7 @@ namespace DeepSpeechWPF
{ {
try try
{ {
if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) != 0) if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_ALPHA, LM_BETA) != 0)
{ {
MessageBox.Show("Error loading LM."); MessageBox.Show("Error loading LM.");
Dispatcher.Invoke(() => btnEnableLM.IsEnabled = true); Dispatcher.Invoke(() => btnEnableLM.IsEnabled = true);

View File

@ -19,8 +19,8 @@ def load_model(models, alphabet, lm, trie):
N_FEATURES = 26 N_FEATURES = 26
N_CONTEXT = 9 N_CONTEXT = 9
BEAM_WIDTH = 500 BEAM_WIDTH = 500
LM_WEIGHT = 1.50 LM_ALPHA = 0.75
VALID_WORD_COUNT_WEIGHT = 2.10 LM_BETA = 1.85
model_load_start = timer() model_load_start = timer()
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH) ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
@ -28,7 +28,7 @@ def load_model(models, alphabet, lm, trie):
logging.debug("Loaded model in %0.3fs." % (model_load_end)) logging.debug("Loaded model in %0.3fs." % (model_load_end))
lm_load_start = timer() lm_load_start = timer()
ds.enableDecoderWithLM(alphabet, lm, trie, LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start lm_load_end = timer() - lm_load_start
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end)) logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))

View File

@ -24,8 +24,8 @@
#define N_CEP 26 #define N_CEP 26
#define N_CONTEXT 9 #define N_CONTEXT 9
#define BEAM_WIDTH 500 #define BEAM_WIDTH 500
#define LM_WEIGHT 1.50f #define LM_ALPHA 0.75f
#define VALID_WORD_COUNT_WEIGHT 2.10f #define LM_BETA 1.85f
typedef struct { typedef struct {
const char* string; const char* string;
@ -253,8 +253,8 @@ main(int argc, char **argv)
alphabet, alphabet,
lm, lm,
trie, trie,
LM_WEIGHT, LM_ALPHA,
VALID_WORD_COUNT_WEIGHT); LM_BETA);
if (status != 0) { if (status != 0) {
fprintf(stderr, "Could not enable CTC decoder with LM.\n"); fprintf(stderr, "Could not enable CTC decoder with LM.\n");
return 1; return 1;

View File

@ -38,8 +38,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
final int N_CEP = 26; final int N_CEP = 26;
final int N_CONTEXT = 9; final int N_CONTEXT = 9;
final int BEAM_WIDTH = 50; final int BEAM_WIDTH = 50;
final float LM_WEIGHT = 1.50f; final float LM_ALPHA = 0.75f;
final float VALID_WORD_COUNT_WEIGHT = 2.10f; final float LM_BETA = 1.85f;
private char readLEChar(RandomAccessFile f) throws IOException { private char readLEChar(RandomAccessFile f) throws IOException {
byte b1 = f.readByte(); byte b1 = f.readByte();

View File

@ -16,8 +16,8 @@ public class Model {
impl.DestroyModel(this._msp); impl.DestroyModel(this._msp);
} }
public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_weight, float valid_word_count_weight) { public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_alpha, float lm_beta) {
impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_weight, valid_word_count_weight); impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_alpha, lm_beta);
} }
public String stt(short[] buffer, int buffer_size, int sample_rate) { public String stt(short[] buffer, int buffer_size, int sample_rate) {

View File

@ -15,11 +15,10 @@ const util = require('util');
const BEAM_WIDTH = 500; const BEAM_WIDTH = 500;
// The alpha hyperparameter of the CTC decoder. Language Model weight // The alpha hyperparameter of the CTC decoder. Language Model weight
const LM_WEIGHT = 1.50; const LM_ALPHA = 0.75;
// Valid word insertion weight. This is used to lessen the word insertion penalty // The beta hyperparameter of the CTC decoder. Word insertion bonus.
// when the inserted word is part of the vocabulary const LM_BETA = 1.85;
const VALID_WORD_COUNT_WEIGHT = 2.10;
// These constants are tied to the shape of the graph used (changing them changes // These constants are tied to the shape of the graph used (changing them changes
@ -102,7 +101,7 @@ audioStream.on('finish', () => {
console.error('Loading language model from files %s %s', args['lm'], args['trie']); console.error('Loading language model from files %s %s', args['lm'], args['trie']);
const lm_load_start = process.hrtime(); const lm_load_start = process.hrtime();
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'], model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT); LM_ALPHA, LM_BETA);
const lm_load_end = process.hrtime(lm_load_start); const lm_load_end = process.hrtime(lm_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end)); console.error('Loaded language model in %ds.', totalTime(lm_load_end));
} }

View File

@ -23,11 +23,10 @@ except ImportError:
BEAM_WIDTH = 500 BEAM_WIDTH = 500
# The alpha hyperparameter of the CTC decoder. Language Model weight # The alpha hyperparameter of the CTC decoder. Language Model weight
LM_WEIGHT = 1.50 LM_ALPHA = 0.75
# Valid word insertion weight. This is used to lessen the word insertion penalty # The beta hyperparameter of the CTC decoder. Word insertion bonus.
# when the inserted word is part of the vocabulary LM_BETA = 1.85
VALID_WORD_COUNT_WEIGHT = 2.10
# These constants are tied to the shape of the graph used (changing them changes # These constants are tied to the shape of the graph used (changing them changes
@ -85,8 +84,7 @@ def main():
if args.lm and args.trie: if args.lm and args.trie:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr) print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
lm_load_start = timer() lm_load_start = timer()
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_WEIGHT, ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA)
VALID_WORD_COUNT_WEIGHT)
lm_load_end = timer() - lm_load_start lm_load_end = timer() - lm_load_start
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr) print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)

View File

@ -151,12 +151,12 @@ assert_correct_multi_ldc93s1()
assert_correct_ldc93s1_prodmodel() assert_correct_ldc93s1_prodmodel()
{ {
assert_correct_inference "$1" "she had a due in greasy wash water year" assert_correct_inference "$1" "she had a due and greasy wash water year"
} }
assert_correct_ldc93s1_prodmodel_stereo_44k() assert_correct_ldc93s1_prodmodel_stereo_44k()
{ {
assert_correct_inference "$1" "she had a due in greasy wash water year" assert_correct_inference "$1" "she had a due and greasy wash water year"
} }
assert_correct_warning_upsampling() assert_correct_warning_upsampling()

View File

@ -131,8 +131,8 @@ def create_flags():
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM') tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie') tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions') tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.') tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.') tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
# Inference mode # Inference mode