diff --git a/native_client/args.h b/native_client/args.h index 3af3f54b..549f6419 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -20,6 +20,12 @@ char* trie = NULL; char* audio = NULL; +int beam_width = 500; + +float lm_alpha = 0.75f; + +float lm_beta = 1.85f; + bool load_without_trie = false; bool show_times = false; @@ -44,6 +50,9 @@ void PrintHelp(const char* bin) " --lm LM Path to the language model binary file\n" " --trie TRIE Path to the language model trie file created with native_client/generate_trie\n" " --audio AUDIO Path to the audio file to run (WAV format)\n" + " --beam_width BEAM_WIDTH Value for decoder beam width (int)\n" + " --lm_alpha LM_ALPHA Value for language model alpha param (float)\n" + " --lm_beta LM_BETA Value for language model beta param (float)\n" " -t Run in benchmark mode, output mfcc & inference time\n" " --extended Output string from extended metadata\n" " --json Extended output, shows word timings as JSON\n" @@ -56,13 +65,16 @@ void PrintHelp(const char* bin) bool ProcessArgs(int argc, char** argv) { - const char* const short_opts = "m:a:l:r:w:tehv"; + const char* const short_opts = "m:a:l:r:w:c:d:b:tehv"; const option long_opts[] = { {"model", required_argument, nullptr, 'm'}, {"alphabet", required_argument, nullptr, 'a'}, {"lm", required_argument, nullptr, 'l'}, {"trie", required_argument, nullptr, 'r'}, {"audio", required_argument, nullptr, 'w'}, + {"beam_width", required_argument, nullptr, 'b'}, + {"lm_alpha", required_argument, nullptr, 'c'}, + {"lm_beta", required_argument, nullptr, 'd'}, {"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999}, {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, @@ -102,6 +114,18 @@ bool ProcessArgs(int argc, char** argv) audio = optarg; break; + case 'b': + beam_width = atoi(optarg); + break; + + case 'c': + lm_alpha = atof(optarg); + break; + + case 'd': + lm_beta = atof(optarg); + break; + case 999: load_without_trie = true; break; diff --git a/native_client/client.cc b/native_client/client.cc index 80663fe6..ec2d7162 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -33,10 +33,6 @@ #include "deepspeech.h" #include "args.h" -#define BEAM_WIDTH 500 -#define LM_ALPHA 0.75f -#define LM_BETA 1.85f - typedef struct { const char* string; double cpu_time_overall; @@ -373,7 +369,7 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; - int status = DS_CreateModel(model, alphabet, BEAM_WIDTH, &ctx); + int status = DS_CreateModel(model, alphabet, beam_width, &ctx); if (status != 0) { fprintf(stderr, "Could not create model.\n"); return 1; @@ -383,8 +379,8 @@ main(int argc, char **argv) int status = DS_EnableDecoderWithLM(ctx, lm, trie, - LM_ALPHA, - LM_BETA); + lm_alpha, + lm_beta); if (status != 0) { fprintf(stderr, "Could not enable CTC decoder with LM.\n"); return 1; diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index 8bbdce12..e4aecea8 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -10,18 +10,6 @@ const Wav = require('node-wav'); const Duplex = require('stream').Duplex; const util = require('util'); -// These constants control the beam search decoder - -// Beam width used in the CTC decoder when building candidate transcriptions -const BEAM_WIDTH = 500; - -// The alpha hyperparameter of the CTC decoder. Language Model weight -const LM_ALPHA = 0.75; - -// The beta hyperparameter of the CTC decoder. Word insertion bonus. -const LM_BETA = 1.85; - - var VersionAction = function VersionAction(options) { options = options || {}; options.nargs = 0; @@ -45,6 +33,9 @@ parser.addArgument(['--alphabet'], {required: true, help: 'Path to the configura parser.addArgument(['--lm'], {help: 'Path to the language model binary file', nargs: '?'}); parser.addArgument(['--trie'], {help: 'Path to the language model trie file created with native_client/generate_trie', nargs: '?'}); parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'}); +parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'}); +parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha)', defaultValue: 0.75, type: 'float'}); +parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta)', defaultValue: 1.85, type: 'float'}); parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'}); parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'}); var args = parser.parseArgs(); @@ -64,7 +55,7 @@ function metadataToString(metadata) { console.error('Loading model from file %s', args['model']); const model_load_start = process.hrtime(); -var model = new Ds.Model(args['model'], args['alphabet'], BEAM_WIDTH); +var model = new Ds.Model(args['model'], args['alphabet'], args['beam_width']); const model_load_end = process.hrtime(model_load_start); console.error('Loaded model in %ds.', totalTime(model_load_end)); @@ -73,7 +64,7 @@ var desired_sample_rate = model.sampleRate(); if (args['lm'] && args['trie']) { console.error('Loading language model from files %s %s', args['lm'], args['trie']); const lm_load_start = process.hrtime(); - model.enableDecoderWithLM(args['lm'], args['trie'], LM_ALPHA, LM_BETA); + model.enableDecoderWithLM(args['lm'], args['trie'], args['lm_alpha'], args['lm_beta']); const lm_load_end = process.hrtime(lm_load_start); console.error('Loaded language model in %ds.', totalTime(lm_load_end)); } diff --git a/native_client/python/client.py b/native_client/python/client.py index 3792a406..12b3aebc 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -17,18 +17,6 @@ try: except ImportError: from pipes import quote -# These constants control the beam search decoder - -# Beam width used in the CTC decoder when building candidate transcriptions -BEAM_WIDTH = 500 - -# The alpha hyperparameter of the CTC decoder. Language Model weight -LM_ALPHA = 0.75 - -# The beta hyperparameter of the CTC decoder. Word insertion bonus. -LM_BETA = 1.85 - - def convert_samplerate(audio_path, desired_sample_rate): sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) try: @@ -66,6 +54,12 @@ def main(): help='Path to the language model trie file created with native_client/generate_trie') parser.add_argument('--audio', required=True, help='Path to the audio file to run (WAV format)') + parser.add_argument('--beam_width', type=int, default=500, + help='Beam width for the CTC decoder') + parser.add_argument('--lm_alpha', type=float, default=0.75, + help='Language model weight (lm_alpha)') + parser.add_argument('--lm_beta', type=float, default=1.85, + help='Word insertion bonus (lm_beta)') parser.add_argument('--version', action=VersionAction, help='Print version and exits') parser.add_argument('--extended', required=False, action='store_true', @@ -74,7 +68,7 @@ def main(): print('Loading model from file {}'.format(args.model), file=sys.stderr) model_load_start = timer() - ds = Model(args.model, args.alphabet, BEAM_WIDTH) + ds = Model(args.model, args.alphabet, args.beam_width) model_load_end = timer() - model_load_start print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) @@ -83,7 +77,7 @@ def main(): if args.lm and args.trie: print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr) lm_load_start = timer() - ds.enableDecoderWithLM(args.lm, args.trie, LM_ALPHA, LM_BETA) + ds.enableDecoderWithLM(args.lm, args.trie, args.lm_alpha, args.lm_beta) lm_load_end = timer() - lm_load_start print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)