Expose beam width, lm_alpha and lm_beta in CLI args

This commit is contained in:
Alexandre Lissy 2019-10-14 13:57:56 +02:00
parent 65a53aeb4a
commit ef5ae5c0b4
4 changed files with 41 additions and 36 deletions

View File

@ -20,6 +20,12 @@ char* trie = NULL;
char* audio = NULL;
int beam_width = 500;
float lm_alpha = 0.75f;
float lm_beta = 1.85f;
bool load_without_trie = false;
bool show_times = false;
@ -44,6 +50,9 @@ void PrintHelp(const char* bin)
" --lm LM Path to the language model binary file\n"
" --trie TRIE Path to the language model trie file created with native_client/generate_trie\n"
" --audio AUDIO Path to the audio file to run (WAV format)\n"
" --beam_width BEAM_WIDTH Value for decoder beam width (int)\n"
" --lm_alpha LM_ALPHA Value for language model alpha param (float)\n"
" --lm_beta LM_BETA Value for language model beta param (float)\n"
" -t Run in benchmark mode, output mfcc & inference time\n"
" --extended Output string from extended metadata\n"
" --json Extended output, shows word timings as JSON\n"
@ -56,13 +65,16 @@ void PrintHelp(const char* bin)
bool ProcessArgs(int argc, char** argv)
{
const char* const short_opts = "m:a:l:r:w:tehv";
const char* const short_opts = "m:a:l:r:w:c:d:b:tehv";
const option long_opts[] = {
{"model", required_argument, nullptr, 'm'},
{"alphabet", required_argument, nullptr, 'a'},
{"lm", required_argument, nullptr, 'l'},
{"trie", required_argument, nullptr, 'r'},
{"audio", required_argument, nullptr, 'w'},
{"beam_width", required_argument, nullptr, 'b'},
{"lm_alpha", required_argument, nullptr, 'c'},
{"lm_beta", required_argument, nullptr, 'd'},
{"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999},
{"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'},
@ -102,6 +114,18 @@ bool ProcessArgs(int argc, char** argv)
audio = optarg;
break;
case 'b':
beam_width = atoi(optarg);
break;
case 'c':
lm_alpha = atof(optarg);
break;
case 'd':
lm_beta = atof(optarg);
break;
case 999:
load_without_trie = true;
break;

View File

@ -33,10 +33,6 @@
#include "deepspeech.h"
#include "args.h"
#define BEAM_WIDTH 500
#define LM_ALPHA 0.75f
#define LM_BETA 1.85f
typedef struct {
const char* string;
double cpu_time_overall;
@ -373,7 +369,7 @@ main(int argc, char **argv)
// Initialise DeepSpeech
ModelState* ctx;
int status = DS_CreateModel(model, alphabet, BEAM_WIDTH, &ctx);
int status = DS_CreateModel(model, alphabet, beam_width, &ctx);
if (status != 0) {
fprintf(stderr, "Could not create model.\n");
return 1;
@ -383,8 +379,8 @@ main(int argc, char **argv)
int status = DS_EnableDecoderWithLM(ctx,
lm,
trie,
LM_ALPHA,
LM_BETA);
lm_alpha,
lm_beta);
if (status != 0) {
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
return 1;

View File

@ -10,18 +10,6 @@ const Wav = require('node-wav');
const Duplex = require('stream').Duplex;
const util = require('util');
// These constants control the beam search decoder
// Beam width used in the CTC decoder when building candidate transcriptions
const BEAM_WIDTH = 500;
// The alpha hyperparameter of the CTC decoder. Language Model weight
const LM_ALPHA = 0.75;
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
const LM_BETA = 1.85;
var VersionAction = function VersionAction(options) {
options = options || {};
options.nargs = 0;
@ -45,6 +33,9 @@ parser.addArgument(['--alphabet'], {required: true, help: 'Path to the configura
parser.addArgument(['--lm'], {help: 'Path to the language model binary file', nargs: '?'});
parser.addArgument(['--trie'], {help: 'Path to the language model trie file created with native_client/generate_trie', nargs: '?'});
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha)', defaultValue: 0.75, type: 'float'});
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta)', defaultValue: 1.85, type: 'float'});
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
var args = parser.parseArgs();
@ -64,7 +55,7 @@ function metadataToString(metadata) {
console.error('Loading model from file %s', args['model']);
const model_load_start = process.hrtime();
var model = new Ds.Model(args['model'], args['alphabet'], BEAM_WIDTH);
var model = new Ds.Model(args['model'], args['alphabet'], args['beam_width']);
const model_load_end = process.hrtime(model_load_start);
console.error('Loaded model in %ds.', totalTime(model_load_end));
@ -73,7 +64,7 @@ var desired_sample_rate = model.sampleRate();
if (args['lm'] && args['trie']) {
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
const lm_load_start = process.hrtime();
model.enableDecoderWithLM(args['lm'], args['trie'], LM_ALPHA, LM_BETA);
model.enableDecoderWithLM(args['lm'], args['trie'], args['lm_alpha'], args['lm_beta']);
const lm_load_end = process.hrtime(lm_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
}

View File

@ -17,18 +17,6 @@ try:
except ImportError:
from pipes import quote
# These constants control the beam search decoder
# Beam width used in the CTC decoder when building candidate transcriptions
BEAM_WIDTH = 500
# The alpha hyperparameter of the CTC decoder. Language Model weight
LM_ALPHA = 0.75
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
LM_BETA = 1.85
def convert_samplerate(audio_path, desired_sample_rate):
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
try:
@ -66,6 +54,12 @@ def main():
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--audio', required=True,
help='Path to the audio file to run (WAV format)')
parser.add_argument('--beam_width', type=int, default=500,
help='Beam width for the CTC decoder')
parser.add_argument('--lm_alpha', type=float, default=0.75,
help='Language model weight (lm_alpha)')
parser.add_argument('--lm_beta', type=float, default=1.85,
help='Word insertion bonus (lm_beta)')
parser.add_argument('--version', action=VersionAction,
help='Print version and exits')
parser.add_argument('--extended', required=False, action='store_true',
@ -74,7 +68,7 @@ def main():
print('Loading model from file {}'.format(args.model), file=sys.stderr)
model_load_start = timer()
ds = Model(args.model, args.alphabet, BEAM_WIDTH)
ds = Model(args.model, args.alphabet, args.beam_width)
model_load_end = timer() - model_load_start
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
@ -83,7 +77,7 @@ def main():
if args.lm and args.trie:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
lm_load_start = timer()
ds.enableDecoderWithLM(args.lm, args.trie, LM_ALPHA, LM_BETA)
ds.enableDecoderWithLM(args.lm, args.trie, args.lm_alpha, args.lm_beta)
lm_load_end = timer() - lm_load_start
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)