Update decoder hyperparameters
This commit is contained in:
parent
e549f8978e
commit
fa7cb1a983
@ -11,11 +11,10 @@ const util = require('util');
|
||||
const BEAM_WIDTH = 1024;
|
||||
|
||||
// The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
const LM_WEIGHT = 1.50;
|
||||
const LM_ALPHA = 0.75;
|
||||
|
||||
// Valid word insertion weight. This is used to lessen the word insertion penalty
|
||||
// when the inserted word is part of the vocabulary
|
||||
const VALID_WORD_COUNT_WEIGHT = 2.25;
|
||||
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
const LM_BETA = 1.85;
|
||||
|
||||
// These constants are tied to the shape of the graph used (changing them changes
|
||||
// the geometry of the first layer), so make sure you use the same constants that
|
||||
@ -63,7 +62,7 @@ if (args['lm'] && args['trie']) {
|
||||
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
|
||||
const lm_load_start = process.hrtime();
|
||||
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
|
||||
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
|
||||
LM_ALPHA, LM_BETA);
|
||||
const lm_load_end = process.hrtime(lm_load_start);
|
||||
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
|
||||
}
|
||||
|
@ -26,7 +26,7 @@ brew install portaudio
|
||||
usage: mic_vad_streaming.py [-h] [-v VAD_AGGRESSIVENESS] [--nospinner]
|
||||
[-w SAVEWAV] -m MODEL [-a ALPHABET] [-l LM]
|
||||
[-t TRIE] [-nf N_FEATURES] [-nc N_CONTEXT]
|
||||
[-lw LM_WEIGHT] [-vwcw VALID_WORD_COUNT_WEIGHT]
|
||||
[-la LM_ALPHA] [-lb LM_BETA]
|
||||
[-bw BEAM_WIDTH]
|
||||
|
||||
Stream from microphone to DeepSpeech using VAD
|
||||
@ -56,13 +56,12 @@ optional arguments:
|
||||
-nc N_CONTEXT, --n_context N_CONTEXT
|
||||
Size of the context window used for producing
|
||||
timesteps in the input vector. Default: 9
|
||||
-lw LM_WEIGHT, --lm_weight LM_WEIGHT
|
||||
-la LM_ALPHA, --lm_alpha LM_ALPHA
|
||||
The alpha hyperparameter of the CTC decoder. Language
|
||||
Model weight. Default: 1.5
|
||||
-vwcw VALID_WORD_COUNT_WEIGHT, --valid_word_count_weight VALID_WORD_COUNT_WEIGHT
|
||||
Valid word insertion weight. This is used to lessen
|
||||
the word insertion penalty when the inserted word is
|
||||
part of the vocabulary. Default: 2.1
|
||||
Model weight. Default: 0.75
|
||||
-lb LM_BETA, --lm_beta LM_BETA
|
||||
The beta hyperparameter of the CTC decoder. Word insertion
|
||||
bonus. Default: 1.85
|
||||
-bw BEAM_WIDTH, --beam_width BEAM_WIDTH
|
||||
Beam width used in the CTC decoder when building
|
||||
candidate transcriptions. Default: 500
|
||||
|
@ -118,7 +118,7 @@ def main(ARGS):
|
||||
if ARGS.lm and ARGS.trie:
|
||||
logging.info("ARGS.lm: %s", ARGS.lm)
|
||||
logging.info("ARGS.trie: %s", ARGS.trie)
|
||||
model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_weight, ARGS.valid_word_count_weight)
|
||||
model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_alpha, ARGS.lm_beta)
|
||||
|
||||
# Start audio with VAD
|
||||
vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness)
|
||||
@ -148,8 +148,8 @@ def main(ARGS):
|
||||
|
||||
if __name__ == '__main__':
|
||||
BEAM_WIDTH = 500
|
||||
LM_WEIGHT = 1.50
|
||||
VALID_WORD_COUNT_WEIGHT = 2.10
|
||||
LM_ALPHA = 0.75
|
||||
LM_BETA = 1.85
|
||||
N_FEATURES = 26
|
||||
N_CONTEXT = 9
|
||||
|
||||
@ -175,10 +175,10 @@ if __name__ == '__main__':
|
||||
help=f"Number of MFCC features to use. Default: {N_FEATURES}")
|
||||
parser.add_argument('-nc', '--n_context', type=int, default=N_CONTEXT,
|
||||
help=f"Size of the context window used for producing timesteps in the input vector. Default: {N_CONTEXT}")
|
||||
parser.add_argument('-lw', '--lm_weight', type=float, default=LM_WEIGHT,
|
||||
help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_WEIGHT}")
|
||||
parser.add_argument('-vwcw', '--valid_word_count_weight', type=float, default=VALID_WORD_COUNT_WEIGHT,
|
||||
help=f"Valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary. Default: {VALID_WORD_COUNT_WEIGHT}")
|
||||
parser.add_argument('-la', '--lm_alpha', type=float, default=LM_ALPHA,
|
||||
help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_ALPHA}")
|
||||
parser.add_argument('-lb', '--lm_beta', type=float, default=LM_BETA,
|
||||
help=f"The beta hyperparameter of the CTC decoder. Word insertion bonus. Default: {LM_BETA}")
|
||||
parser.add_argument('-bw', '--beam_width', type=int, default=BEAM_WIDTH,
|
||||
help=f"Beam width used in the CTC decoder when building candidate transcriptions. Default: {BEAM_WIDTH}")
|
||||
|
||||
|
@ -38,8 +38,8 @@ namespace CSharpExamples
|
||||
const uint N_CEP = 26;
|
||||
const uint N_CONTEXT = 9;
|
||||
const uint BEAM_WIDTH = 200;
|
||||
const float LM_WEIGHT = 1.50f;
|
||||
const float VALID_WORD_COUNT_WEIGHT = 2.10f;
|
||||
const float LM_ALPHA = 0.75f;
|
||||
const float LM_BETA = 1.85f;
|
||||
|
||||
Stopwatch stopwatch = new Stopwatch();
|
||||
|
||||
@ -76,7 +76,7 @@ namespace CSharpExamples
|
||||
alphabet ?? "alphabet.txt",
|
||||
lm ?? "lm.binary",
|
||||
trie ?? "trie",
|
||||
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
|
||||
LM_ALPHA, LM_BETA);
|
||||
}
|
||||
catch (IOException ex)
|
||||
{
|
||||
|
@ -24,8 +24,8 @@ namespace DeepSpeechWPF
|
||||
private const uint N_CEP = 26;
|
||||
private const uint N_CONTEXT = 9;
|
||||
private const uint BEAM_WIDTH = 500;
|
||||
private const float LM_WEIGHT = 1.50f;
|
||||
private const float VALID_WORD_COUNT_WEIGHT = 2.10f;
|
||||
private const float LM_ALPHA = 0.75f;
|
||||
private const float LM_BETA = 1.85f;
|
||||
|
||||
|
||||
|
||||
@ -160,7 +160,7 @@ namespace DeepSpeechWPF
|
||||
{
|
||||
try
|
||||
{
|
||||
if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) != 0)
|
||||
if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_ALPHA, LM_BETA) != 0)
|
||||
{
|
||||
MessageBox.Show("Error loading LM.");
|
||||
Dispatcher.Invoke(() => btnEnableLM.IsEnabled = true);
|
||||
|
@ -19,8 +19,8 @@ def load_model(models, alphabet, lm, trie):
|
||||
N_FEATURES = 26
|
||||
N_CONTEXT = 9
|
||||
BEAM_WIDTH = 500
|
||||
LM_WEIGHT = 1.50
|
||||
VALID_WORD_COUNT_WEIGHT = 2.10
|
||||
LM_ALPHA = 0.75
|
||||
LM_BETA = 1.85
|
||||
|
||||
model_load_start = timer()
|
||||
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
|
||||
@ -28,7 +28,7 @@ def load_model(models, alphabet, lm, trie):
|
||||
logging.debug("Loaded model in %0.3fs." % (model_load_end))
|
||||
|
||||
lm_load_start = timer()
|
||||
ds.enableDecoderWithLM(alphabet, lm, trie, LM_WEIGHT, VALID_WORD_COUNT_WEIGHT)
|
||||
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
|
||||
lm_load_end = timer() - lm_load_start
|
||||
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))
|
||||
|
||||
|
@ -24,8 +24,8 @@
|
||||
#define N_CEP 26
|
||||
#define N_CONTEXT 9
|
||||
#define BEAM_WIDTH 500
|
||||
#define LM_WEIGHT 1.50f
|
||||
#define VALID_WORD_COUNT_WEIGHT 2.10f
|
||||
#define LM_ALPHA 0.75f
|
||||
#define LM_BETA 1.85f
|
||||
|
||||
typedef struct {
|
||||
const char* string;
|
||||
@ -253,8 +253,8 @@ main(int argc, char **argv)
|
||||
alphabet,
|
||||
lm,
|
||||
trie,
|
||||
LM_WEIGHT,
|
||||
VALID_WORD_COUNT_WEIGHT);
|
||||
LM_ALPHA,
|
||||
LM_BETA);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
|
||||
return 1;
|
||||
|
@ -38,8 +38,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||
final int N_CEP = 26;
|
||||
final int N_CONTEXT = 9;
|
||||
final int BEAM_WIDTH = 50;
|
||||
final float LM_WEIGHT = 1.50f;
|
||||
final float VALID_WORD_COUNT_WEIGHT = 2.10f;
|
||||
final float LM_ALPHA = 0.75f;
|
||||
final float LM_BETA = 1.85f;
|
||||
|
||||
private char readLEChar(RandomAccessFile f) throws IOException {
|
||||
byte b1 = f.readByte();
|
||||
|
@ -16,8 +16,8 @@ public class Model {
|
||||
impl.DestroyModel(this._msp);
|
||||
}
|
||||
|
||||
public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_weight, float valid_word_count_weight) {
|
||||
impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_weight, valid_word_count_weight);
|
||||
public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_alpha, float lm_beta) {
|
||||
impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_alpha, lm_beta);
|
||||
}
|
||||
|
||||
public String stt(short[] buffer, int buffer_size, int sample_rate) {
|
||||
|
@ -15,11 +15,10 @@ const util = require('util');
|
||||
const BEAM_WIDTH = 500;
|
||||
|
||||
// The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
const LM_WEIGHT = 1.50;
|
||||
const LM_ALPHA = 0.75;
|
||||
|
||||
// Valid word insertion weight. This is used to lessen the word insertion penalty
|
||||
// when the inserted word is part of the vocabulary
|
||||
const VALID_WORD_COUNT_WEIGHT = 2.10;
|
||||
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
const LM_BETA = 1.85;
|
||||
|
||||
|
||||
// These constants are tied to the shape of the graph used (changing them changes
|
||||
@ -102,7 +101,7 @@ audioStream.on('finish', () => {
|
||||
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
|
||||
const lm_load_start = process.hrtime();
|
||||
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
|
||||
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
|
||||
LM_ALPHA, LM_BETA);
|
||||
const lm_load_end = process.hrtime(lm_load_start);
|
||||
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
|
||||
}
|
||||
|
@ -23,11 +23,10 @@ except ImportError:
|
||||
BEAM_WIDTH = 500
|
||||
|
||||
# The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
LM_WEIGHT = 1.50
|
||||
LM_ALPHA = 0.75
|
||||
|
||||
# Valid word insertion weight. This is used to lessen the word insertion penalty
|
||||
# when the inserted word is part of the vocabulary
|
||||
VALID_WORD_COUNT_WEIGHT = 2.10
|
||||
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
LM_BETA = 1.85
|
||||
|
||||
|
||||
# These constants are tied to the shape of the graph used (changing them changes
|
||||
@ -85,8 +84,7 @@ def main():
|
||||
if args.lm and args.trie:
|
||||
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
|
||||
lm_load_start = timer()
|
||||
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_WEIGHT,
|
||||
VALID_WORD_COUNT_WEIGHT)
|
||||
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA)
|
||||
lm_load_end = timer() - lm_load_start
|
||||
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)
|
||||
|
||||
|
@ -131,8 +131,8 @@ def create_flags():
|
||||
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
|
||||
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
|
||||
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||
tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||
tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
|
||||
|
||||
# Inference mode
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user