Merge pull request #2433 from lissyx/cli_lm
Expose beam width, lm_alpha and lm_beta in CLI args
This commit is contained in:
commit
818ea6a40f
@ -20,6 +20,12 @@ char* trie = NULL;
|
||||
|
||||
char* audio = NULL;
|
||||
|
||||
int beam_width = 500;
|
||||
|
||||
float lm_alpha = 0.75f;
|
||||
|
||||
float lm_beta = 1.85f;
|
||||
|
||||
bool load_without_trie = false;
|
||||
|
||||
bool show_times = false;
|
||||
@ -44,6 +50,9 @@ void PrintHelp(const char* bin)
|
||||
" --lm LM Path to the language model binary file\n"
|
||||
" --trie TRIE Path to the language model trie file created with native_client/generate_trie\n"
|
||||
" --audio AUDIO Path to the audio file to run (WAV format)\n"
|
||||
" --beam_width BEAM_WIDTH Value for decoder beam width (int)\n"
|
||||
" --lm_alpha LM_ALPHA Value for language model alpha param (float)\n"
|
||||
" --lm_beta LM_BETA Value for language model beta param (float)\n"
|
||||
" -t Run in benchmark mode, output mfcc & inference time\n"
|
||||
" --extended Output string from extended metadata\n"
|
||||
" --json Extended output, shows word timings as JSON\n"
|
||||
@ -56,13 +65,16 @@ void PrintHelp(const char* bin)
|
||||
|
||||
bool ProcessArgs(int argc, char** argv)
|
||||
{
|
||||
const char* const short_opts = "m:a:l:r:w:tehv";
|
||||
const char* const short_opts = "m:a:l:r:w:c:d:b:tehv";
|
||||
const option long_opts[] = {
|
||||
{"model", required_argument, nullptr, 'm'},
|
||||
{"alphabet", required_argument, nullptr, 'a'},
|
||||
{"lm", required_argument, nullptr, 'l'},
|
||||
{"trie", required_argument, nullptr, 'r'},
|
||||
{"audio", required_argument, nullptr, 'w'},
|
||||
{"beam_width", required_argument, nullptr, 'b'},
|
||||
{"lm_alpha", required_argument, nullptr, 'c'},
|
||||
{"lm_beta", required_argument, nullptr, 'd'},
|
||||
{"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999},
|
||||
{"t", no_argument, nullptr, 't'},
|
||||
{"extended", no_argument, nullptr, 'e'},
|
||||
@ -102,6 +114,18 @@ bool ProcessArgs(int argc, char** argv)
|
||||
audio = optarg;
|
||||
break;
|
||||
|
||||
case 'b':
|
||||
beam_width = atoi(optarg);
|
||||
break;
|
||||
|
||||
case 'c':
|
||||
lm_alpha = atof(optarg);
|
||||
break;
|
||||
|
||||
case 'd':
|
||||
lm_beta = atof(optarg);
|
||||
break;
|
||||
|
||||
case 999:
|
||||
load_without_trie = true;
|
||||
break;
|
||||
|
@ -33,10 +33,6 @@
|
||||
#include "deepspeech.h"
|
||||
#include "args.h"
|
||||
|
||||
#define BEAM_WIDTH 500
|
||||
#define LM_ALPHA 0.75f
|
||||
#define LM_BETA 1.85f
|
||||
|
||||
typedef struct {
|
||||
const char* string;
|
||||
double cpu_time_overall;
|
||||
@ -373,7 +369,7 @@ main(int argc, char **argv)
|
||||
|
||||
// Initialise DeepSpeech
|
||||
ModelState* ctx;
|
||||
int status = DS_CreateModel(model, alphabet, BEAM_WIDTH, &ctx);
|
||||
int status = DS_CreateModel(model, alphabet, beam_width, &ctx);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not create model.\n");
|
||||
return 1;
|
||||
@ -383,8 +379,8 @@ main(int argc, char **argv)
|
||||
int status = DS_EnableDecoderWithLM(ctx,
|
||||
lm,
|
||||
trie,
|
||||
LM_ALPHA,
|
||||
LM_BETA);
|
||||
lm_alpha,
|
||||
lm_beta);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
|
||||
return 1;
|
||||
|
@ -10,18 +10,6 @@ const Wav = require('node-wav');
|
||||
const Duplex = require('stream').Duplex;
|
||||
const util = require('util');
|
||||
|
||||
// These constants control the beam search decoder
|
||||
|
||||
// Beam width used in the CTC decoder when building candidate transcriptions
|
||||
const BEAM_WIDTH = 500;
|
||||
|
||||
// The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
const LM_ALPHA = 0.75;
|
||||
|
||||
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
const LM_BETA = 1.85;
|
||||
|
||||
|
||||
var VersionAction = function VersionAction(options) {
|
||||
options = options || {};
|
||||
options.nargs = 0;
|
||||
@ -45,6 +33,9 @@ parser.addArgument(['--alphabet'], {required: true, help: 'Path to the configura
|
||||
parser.addArgument(['--lm'], {help: 'Path to the language model binary file', nargs: '?'});
|
||||
parser.addArgument(['--trie'], {help: 'Path to the language model trie file created with native_client/generate_trie', nargs: '?'});
|
||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
|
||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha)', defaultValue: 0.75, type: 'float'});
|
||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta)', defaultValue: 1.85, type: 'float'});
|
||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
||||
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
|
||||
var args = parser.parseArgs();
|
||||
@ -64,7 +55,7 @@ function metadataToString(metadata) {
|
||||
|
||||
console.error('Loading model from file %s', args['model']);
|
||||
const model_load_start = process.hrtime();
|
||||
var model = new Ds.Model(args['model'], args['alphabet'], BEAM_WIDTH);
|
||||
var model = new Ds.Model(args['model'], args['alphabet'], args['beam_width']);
|
||||
const model_load_end = process.hrtime(model_load_start);
|
||||
console.error('Loaded model in %ds.', totalTime(model_load_end));
|
||||
|
||||
@ -73,7 +64,7 @@ var desired_sample_rate = model.sampleRate();
|
||||
if (args['lm'] && args['trie']) {
|
||||
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
|
||||
const lm_load_start = process.hrtime();
|
||||
model.enableDecoderWithLM(args['lm'], args['trie'], LM_ALPHA, LM_BETA);
|
||||
model.enableDecoderWithLM(args['lm'], args['trie'], args['lm_alpha'], args['lm_beta']);
|
||||
const lm_load_end = process.hrtime(lm_load_start);
|
||||
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
|
||||
}
|
||||
|
@ -17,18 +17,6 @@ try:
|
||||
except ImportError:
|
||||
from pipes import quote
|
||||
|
||||
# These constants control the beam search decoder
|
||||
|
||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||
BEAM_WIDTH = 500
|
||||
|
||||
# The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
LM_ALPHA = 0.75
|
||||
|
||||
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
LM_BETA = 1.85
|
||||
|
||||
|
||||
def convert_samplerate(audio_path, desired_sample_rate):
|
||||
sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate)
|
||||
try:
|
||||
@ -66,6 +54,12 @@ def main():
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--audio', required=True,
|
||||
help='Path to the audio file to run (WAV format)')
|
||||
parser.add_argument('--beam_width', type=int, default=500,
|
||||
help='Beam width for the CTC decoder')
|
||||
parser.add_argument('--lm_alpha', type=float, default=0.75,
|
||||
help='Language model weight (lm_alpha)')
|
||||
parser.add_argument('--lm_beta', type=float, default=1.85,
|
||||
help='Word insertion bonus (lm_beta)')
|
||||
parser.add_argument('--version', action=VersionAction,
|
||||
help='Print version and exits')
|
||||
parser.add_argument('--extended', required=False, action='store_true',
|
||||
@ -74,7 +68,7 @@ def main():
|
||||
|
||||
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
||||
model_load_start = timer()
|
||||
ds = Model(args.model, args.alphabet, BEAM_WIDTH)
|
||||
ds = Model(args.model, args.alphabet, args.beam_width)
|
||||
model_load_end = timer() - model_load_start
|
||||
print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr)
|
||||
|
||||
@ -83,7 +77,7 @@ def main():
|
||||
if args.lm and args.trie:
|
||||
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
|
||||
lm_load_start = timer()
|
||||
ds.enableDecoderWithLM(args.lm, args.trie, LM_ALPHA, LM_BETA)
|
||||
ds.enableDecoderWithLM(args.lm, args.trie, args.lm_alpha, args.lm_beta)
|
||||
lm_load_end = timer() - lm_load_start
|
||||
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user