diff --git a/DeepSpeech.py b/DeepSpeech.py index e6d3a929..d25c192b 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -797,6 +797,7 @@ def export(): outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') + outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet') if FLAGS.export_language: @@ -875,11 +876,6 @@ def package_zip(): with open(os.path.join(export_dir, 'info.json'), 'w') as f: json.dump({ 'name': FLAGS.export_language, - 'parameters': { - 'beamWidth': FLAGS.beam_width, - 'lmAlpha': FLAGS.lm_alpha, - 'lmBeta': FLAGS.lm_beta - } }, f) shutil.copy(FLAGS.scorer_path, export_dir) diff --git a/evaluate_tflite.py b/evaluate_tflite.py index aba6fb68..f105702f 100644 --- a/evaluate_tflite.py +++ b/evaluate_tflite.py @@ -30,15 +30,9 @@ This module should be self-contained: Then run with a TF Lite model, a scorer and a CSV test file ''' -BEAM_WIDTH = 500 -LM_ALPHA = 0.75 -LM_BETA = 1.85 - def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) - ds = Model(model, BEAM_WIDTH) - ds.enableExternalScorer(scorer) - ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA) + ds = Model(model) while True: try: diff --git a/native_client/args.h b/native_client/args.h index d5a0f869..2e7306c7 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -16,7 +16,9 @@ char* scorer = NULL; char* audio = NULL; -int beam_width = 500; +bool set_beamwidth = false; + +int beam_width = 0; bool set_alphabeta = false; @@ -98,6 +100,7 @@ bool ProcessArgs(int argc, char** argv) break; case 'b': + set_beamwidth = true; beam_width = atoi(optarg); break; diff --git a/native_client/client.cc b/native_client/client.cc index 718fba75..abcadd8d 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -368,14 +368,22 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; - int status = DS_CreateModel(model, beam_width, &ctx); + int status = DS_CreateModel(model, &ctx); if (status != 0) { fprintf(stderr, "Could not create model.\n"); return 1; } + if (set_beamwidth) { + status = DS_SetModelBeamWidth(ctx, beam_width); + if (status != 0) { + fprintf(stderr, "Could not set model beam width.\n"); + return 1; + } + } + if (scorer) { - int status = DS_EnableExternalScorer(ctx, scorer); + status = DS_EnableExternalScorer(ctx, scorer); if (status != 0) { fprintf(stderr, "Could not enable external scorer.\n"); return 1; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 274ce41f..9f0ebe42 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -257,7 +257,6 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) int DS_CreateModel(const char* aModelPath, - unsigned int aBeamWidth, ModelState** retval) { *retval = nullptr; @@ -282,7 +281,7 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_FAIL_CREATE_MODEL; } - int err = model->init(aModelPath, aBeamWidth); + int err = model->init(aModelPath); if (err != DS_ERR_OK) { return err; } @@ -291,6 +290,19 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_OK; } +unsigned int +DS_GetModelBeamWidth(ModelState* aCtx) +{ + return aCtx->beam_width_; +} + +int +DS_SetModelBeamWidth(ModelState* aCtx, unsigned int aBeamWidth) +{ + aCtx->beam_width_ = aBeamWidth; + return 0; +} + int DS_GetModelSampleRate(ModelState* aCtx) { diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 4e017653..c5e330cb 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -78,18 +78,39 @@ enum DeepSpeech_Error_Codes * @brief An object providing an interface to a trained DeepSpeech model. * * @param aModelPath The path to the frozen model graph. - * @param aBeamWidth The beam width used by the decoder. A larger beam - * width generates better results at the cost of decoding - * time. * @param[out] retval a ModelState pointer * * @return Zero on success, non-zero on failure. */ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, - unsigned int aBeamWidth, ModelState** retval); +/** + * @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth} + * was not called before, will return the default value loaded from the + * model file. + * + * @param aCtx A ModelState pointer created with {@link DS_CreateModel}. + * + * @return Beam width value used by the model. + */ +DEEPSPEECH_EXPORT +unsigned int DS_GetModelBeamWidth(ModelState* aCtx); + +/** + * @brief Set beam width value used by the model. + * + * @param aCtx A ModelState pointer created with {@link DS_CreateModel}. + * @param aBeamWidth The beam width used by the model. A larger beam width value + * generates better results at the cost of decoding time. + * + * @return Zero on success, non-zero on failure. + */ +DEEPSPEECH_EXPORT +int DS_SetModelBeamWidth(ModelState* aCtx, + unsigned int aBeamWidth); + /** * @brief Return the sample rate expected by a model. * diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs index e5e33370..15c2212c 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs @@ -19,11 +19,10 @@ namespace DeepSpeechClient /// Initializes a new instance of class and creates a new acoustic model. /// /// The path to the frozen model graph. - /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. /// Thrown when the native binary failed to create the model. - public DeepSpeech(string aModelPath, uint aBeamWidth) + public DeepSpeech(string aModelPath) { - CreateModel(aModelPath, aBeamWidth); + CreateModel(aModelPath); } #region IDeepSpeech @@ -32,10 +31,8 @@ namespace DeepSpeechClient /// Create an object providing an interface to a trained DeepSpeech model. /// /// The path to the frozen model graph. - /// The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. /// Thrown when the native binary failed to create the model. - private unsafe void CreateModel(string aModelPath, - uint aBeamWidth) + private unsafe void CreateModel(string aModelPath) { string exceptionMessage = null; if (string.IsNullOrWhiteSpace(aModelPath)) @@ -52,11 +49,31 @@ namespace DeepSpeechClient throw new FileNotFoundException(exceptionMessage); } var resultCode = NativeImp.DS_CreateModel(aModelPath, - aBeamWidth, ref _modelStatePP); EvaluateResultCode(resultCode); } + /// + /// Get beam width value used by the model. If SetModelBeamWidth was not + /// called before, will return the default value loaded from the model file. + /// + /// Beam width value used by the model. + public unsafe uint GetModelBeamWidth() + { + return NativeImp.DS_GetModelBeamWidth(_modelStatePP); + } + + /// + /// Set beam width value used by the model. + /// + /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time. + /// Thrown on failure. + public unsafe void SetModelBeamWidth(uint aBeamWidth) + { + var resultCode = NativeImp.DS_SetModelBeamWidth(_modelStatePP, aBeamWidth); + EvaluateResultCode(resultCode); + } + /// /// Return the sample rate expected by the model. /// diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs index ecbfb7e9..f00c188d 100644 --- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs @@ -20,6 +20,21 @@ namespace DeepSpeechClient.Interfaces /// Sample rate. unsafe int GetModelSampleRate(); + /// + /// Get beam width value used by the model. If SetModelBeamWidth was not + /// called before, will return the default value loaded from the model + /// file. + /// + /// Beam width value used by the model. + unsafe uint GetModelBeamWidth(); + + /// + /// Set beam width value used by the model. + /// + /// The beam width used by the decoder. A larger beam width value generates better results at the cost of decoding time. + /// Thrown on failure. + unsafe void SetModelBeamWidth(uint aBeamWidth); + /// /// Enable decoding using an external scorer. /// diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs index 1c49feec..af28618c 100644 --- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs +++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs @@ -14,6 +14,17 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static extern void DS_PrintVersions(); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, + ref IntPtr** pint); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx, + uint aBeamWidth); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, uint aBeamWidth, diff --git a/native_client/dotnet/DeepSpeechConsole/Program.cs b/native_client/dotnet/DeepSpeechConsole/Program.cs index 1f6e299b..b35c7046 100644 --- a/native_client/dotnet/DeepSpeechConsole/Program.cs +++ b/native_client/dotnet/DeepSpeechConsole/Program.cs @@ -46,15 +46,12 @@ namespace CSharpExamples extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended")); } - const uint BEAM_WIDTH = 500; - Stopwatch stopwatch = new Stopwatch(); try { Console.WriteLine("Loading model..."); stopwatch.Start(); - using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm", - BEAM_WIDTH)) + using (IDeepSpeech sttClient = new DeepSpeech(model ?? "output_graph.pbmm")) { stopwatch.Stop(); diff --git a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java index 12e758df..f9d9a11e 100644 --- a/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java +++ b/native_client/java/app/src/main/java/org/mozilla/deepspeech/DeepSpeechActivity.java @@ -49,7 +49,8 @@ public class DeepSpeechActivity extends AppCompatActivity { private void newModel(String tfliteModel) { this._tfliteStatus.setText("Creating model"); if (this._m == null) { - this._m = new DeepSpeechModel(tfliteModel, BEAM_WIDTH); + this._m = new DeepSpeechModel(tfliteModel); + this._m.setBeamWidth(BEAM_WIDTH); } } diff --git a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java index bb6bbe42..2957b2e7 100644 --- a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java +++ b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java @@ -33,8 +33,6 @@ public class BasicTest { public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer"; public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav"; - public static final int BEAM_WIDTH = 50; - private char readLEChar(RandomAccessFile f) throws IOException { byte b1 = f.readByte(); byte b2 = f.readByte(); @@ -59,7 +57,7 @@ public class BasicTest { @Test public void loadDeepSpeech_basic() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); m.freeModel(); } @@ -116,7 +114,7 @@ public class BasicTest { @Test public void loadDeepSpeech_stt_noLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); String decoded = doSTT(m, false); assertEquals("she had your dark suit in greasy wash water all year", decoded); @@ -125,7 +123,7 @@ public class BasicTest { @Test public void loadDeepSpeech_stt_withLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); m.enableExternalScorer(scorerFile); String decoded = doSTT(m, false); @@ -135,7 +133,7 @@ public class BasicTest { @Test public void loadDeepSpeech_sttWithMetadata_noLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); String decoded = doSTT(m, true); assertEquals("she had your dark suit in greasy wash water all year", decoded); @@ -144,7 +142,7 @@ public class BasicTest { @Test public void loadDeepSpeech_sttWithMetadata_withLM() { - DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); + DeepSpeechModel m = new DeepSpeechModel(modelFile); m.enableExternalScorer(scorerFile); String decoded = doSTT(m, true); diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java index 0438ac10..6d0a316b 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java @@ -20,16 +20,35 @@ public class DeepSpeechModel { * @constructor * * @param modelPath The path to the frozen model graph. - * @param beam_width The beam width used by the decoder. A larger beam - * width generates better results at the cost of decoding - * time. */ - public DeepSpeechModel(String modelPath, int beam_width) { + public DeepSpeechModel(String modelPath) { this._mspp = impl.new_modelstatep(); - impl.CreateModel(modelPath, beam_width, this._mspp); + impl.CreateModel(modelPath, this._mspp); this._msp = impl.modelstatep_value(this._mspp); } + /** + * @brief Get beam width value used by the model. If setModelBeamWidth was not + * called before, will return the default value loaded from the model file. + * + * @return Beam width value used by the model. + */ + public long beamWidth() { + return impl.GetModelBeamWidth(this._msp); + } + + /** + * @brief Set beam width value used by the model. + * + * @param aBeamWidth The beam width used by the model. A larger beam width value + * generates better results at the cost of decoding time. + * + * @return Zero on success, non-zero on failure. + */ + public int setBeamWidth(long beamWidth) { + return impl.SetModelBeamWidth(this._msp, beamWidth); + } + /** * @brief Return the sample rate expected by the model. * diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index 7266b85d..09406ccc 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -31,7 +31,7 @@ var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running D parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'}); parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'}); parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'}); -parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'}); +parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', type: 'int'}); parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'}); parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'}); @@ -53,10 +53,14 @@ function metadataToString(metadata) { console.error('Loading model from file %s', args['model']); const model_load_start = process.hrtime(); -var model = new Ds.Model(args['model'], args['beam_width']); +var model = new Ds.Model(args['model']); const model_load_end = process.hrtime(model_load_start); console.error('Loaded model in %ds.', totalTime(model_load_end)); +if (args['beam_width']) { + model.setBeamWidth(args['beam_width']); +} + var desired_sample_rate = model.sampleRate(); if (args['scorer']) { diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js index 772b1a82..58697033 100644 --- a/native_client/javascript/index.js +++ b/native_client/javascript/index.js @@ -25,14 +25,13 @@ if (process.platform === 'win32') { * An object providing an interface to a trained DeepSpeech model. * * @param {string} aModelPath The path to the frozen model graph. - * @param {number} aBeamWidth The beam width used by the decoder. A larger beam width generates better results at the cost of decoding time. * * @throws on error */ -function Model() { +function Model(aModelPath) { this._impl = null; - const rets = binding.CreateModel.apply(null, arguments); + const rets = binding.CreateModel(aModelPath); const status = rets[0]; const impl = rets[1]; if (status !== 0) { @@ -42,6 +41,27 @@ function Model() { this._impl = impl; } +/** + * Get beam width value used by the model. If :js:func:Model.setBeamWidth was + * not called before, will return the default value loaded from the model file. + * + * @return {number} Beam width value used by the model. + */ +Model.prototype.beamWidth = function() { + return binding.GetModelBeamWidth(this._impl); +} + +/** + * Set beam width value used by the model. + * + * @param {number} The beam width used by the model. A larger beam width value generates better results at the cost of decoding time. + * + * @return {number} Zero on success, non-zero on failure. + */ +Model.prototype.setBeamWidth = function(aBeamWidth) { + return binding.SetModelBeamWidth(this._impl, aBeamWidth); +} + /** * Return the sample rate expected by the model. * diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index c7fc46a0..4bc0e953 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,10 +24,8 @@ ModelState::~ModelState() } int -ModelState::init(const char* model_path, - unsigned int beam_width) +ModelState::init(const char* model_path) { - beam_width_ = beam_width; return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index d4f11c1c..c296c003 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -30,7 +30,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const char* model_path, unsigned int beam_width); + virtual int init(const char* model_path); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index ccb53fc4..960305be 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -28,15 +28,12 @@ class Model(object): :param aModelPath: Path to model file to load :type aModelPath: str - - :param aBeamWidth: Decoder beam width - :type aBeamWidth: int """ - def __init__(self, *args, **kwargs): + def __init__(self, model_path): # make sure the attribute is there if CreateModel fails self._impl = None - status, impl = deepspeech.impl.CreateModel(*args, **kwargs) + status, impl = deepspeech.impl.CreateModel(model_path) if status != 0: raise RuntimeError("CreateModel failed with error code {}".format(status)) self._impl = impl @@ -46,6 +43,28 @@ class Model(object): deepspeech.impl.FreeModel(self._impl) self._impl = None + def beamWidth(self): + """ + Get beam width value used by the model. If setModelBeamWidth was not + called before, will return the default value loaded from the model file. + + :return: Beam width value used by the model. + :type: int + """ + return deepspeech.impl.GetModelBeamWidth(self._impl) + + def setBeamWidth(self, beam_width): + """ + Set beam width value used by the model. + + :param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time. + :type beam_width: int + + :return: Zero on success, non-zero on failure. + :type: int + """ + return deepspeech.impl.SetModelBeamWidth(self._impl, beam_width) + def sampleRate(self): """ Return the sample rate expected by the model. diff --git a/native_client/python/client.py b/native_client/python/client.py index 2ef88caf..26db1e00 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -92,7 +92,7 @@ def main(): help='Path to the external scorer file') parser.add_argument('--audio', required=True, help='Path to the audio file to run (WAV format)') - parser.add_argument('--beam_width', type=int, default=500, + parser.add_argument('--beam_width', type=int, help='Beam width for the CTC decoder') parser.add_argument('--lm_alpha', type=float, help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') @@ -108,10 +108,13 @@ def main(): print('Loading model from file {}'.format(args.model), file=sys.stderr) model_load_start = timer() - ds = Model(args.model, args.beam_width) + ds = Model(args.model) model_load_end = timer() - model_load_start print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) + if args.beam_width: + ds.setModelBeamWidth(args.beam_width) + desired_sample_rate = ds.sampleRate() if args.scorer: diff --git a/native_client/test/concurrent_streams.py b/native_client/test/concurrent_streams.py index d799de36..e435b43f 100644 --- a/native_client/test/concurrent_streams.py +++ b/native_client/test/concurrent_streams.py @@ -9,12 +9,6 @@ import wave from deepspeech import Model -# These constants control the beam search decoder - -# Beam width used in the CTC decoder when building candidate transcriptions -BEAM_WIDTH = 500 - - def main(): parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') parser.add_argument('--model', required=True, @@ -27,7 +21,7 @@ def main(): help='Second audio file to use in interleaved streams') args = parser.parse_args() - ds = Model(args.model, BEAM_WIDTH) + ds = Model(args.model) if args.scorer: ds.enableExternalScorer(args.scorer) diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 2135a571..4836ed0b 100644 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -91,10 +91,9 @@ TFLiteModelState::~TFLiteModelState() } int -TFLiteModelState::init(const char* model_path, - unsigned int beam_width) +TFLiteModelState::init(const char* model_path) { - int err = ModelState::init(model_path, beam_width); + int err = ModelState::init(model_path); if (err != DS_ERR_OK) { return err; } @@ -129,6 +128,7 @@ TFLiteModelState::init(const char* model_path, int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate"); int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len"); int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step"); + int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width"); int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet"); std::vector metadata_exec_plan; @@ -136,6 +136,7 @@ TFLiteModelState::init(const char* model_path, metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]); + metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]); for (int i = 0; i < metadata_exec_plan.size(); ++i) { @@ -201,6 +202,9 @@ TFLiteModelState::init(const char* model_path, audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0); audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0); + int* const beam_width = interpreter_->typed_tensor(metadata_beam_width_idx); + beam_width_ = (unsigned int)(*beam_width); + tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0); err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len); if (err != 0) { @@ -210,6 +214,8 @@ TFLiteModelState::init(const char* model_path, assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); + assert(beam_width_ > 0); + assert(alphabet_.GetSize() > 0); TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims; diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index 77137751..11532e64 100644 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -30,8 +30,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const char* model_path, - unsigned int beam_width) override; + virtual int init(const char* model_path) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index c7c1a688..5b1e1675 100644 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -23,10 +23,9 @@ TFModelState::~TFModelState() } int -TFModelState::init(const char* model_path, - unsigned int beam_width) +TFModelState::init(const char* model_path) { - int err = ModelState::init(model_path, beam_width); + int err = ModelState::init(model_path); if (err != DS_ERR_OK) { return err; } @@ -103,6 +102,7 @@ TFModelState::init(const char* model_path, "metadata_sample_rate", "metadata_feature_win_len", "metadata_feature_win_step", + "metadata_beam_width", "metadata_alphabet", }, {}, &metadata_outputs); if (!status.ok()) { @@ -115,8 +115,10 @@ TFModelState::init(const char* model_path, int win_step_ms = metadata_outputs[2].scalar()(); audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0); audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0); + int beam_width = metadata_outputs[3].scalar()(); + beam_width_ = (unsigned int)(beam_width); - string serialized_alphabet = metadata_outputs[3].scalar()(); + string serialized_alphabet = metadata_outputs[4].scalar()(); err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size()); if (err != 0) { return DS_ERR_INVALID_ALPHABET; @@ -125,6 +127,8 @@ TFModelState::init(const char* model_path, assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); + assert(beam_width_ > 0); + assert(alphabet_.GetSize() > 0); for (int i = 0; i < graph_def_.node_size(); ++i) { NodeDef node = graph_def_.node(i); diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index 2f6edf49..2a8db699 100644 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,8 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const char* model_path, - unsigned int beam_width) override; + virtual int init(const char* model_path) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames, diff --git a/taskcluster/examples-base.tyml b/taskcluster/examples-base.tyml index 2af1c1f1..381e9284 100644 --- a/taskcluster/examples-base.tyml +++ b/taskcluster/examples-base.tyml @@ -30,11 +30,11 @@ then: image: ${build.docker_image} env: - DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models.tar.gz" + DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models_beam_width.tar.gz" DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz" PIP_DEFAULT_TIMEOUT: "60" EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples" - EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a" + EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width" command: - "/bin/bash" diff --git a/taskcluster/win-opt-base.tyml b/taskcluster/win-opt-base.tyml index 6bcc0acd..e892ec70 100644 --- a/taskcluster/win-opt-base.tyml +++ b/taskcluster/win-opt-base.tyml @@ -44,7 +44,7 @@ payload: MSYS: 'winsymlinks:nativestrict' TENSORFLOW_BUILD_ARTIFACT: ${build.tensorflow} EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples" - EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a" + EXAMPLES_CHECKOUT_TARGET: "embedded-beam-width" command: - >- diff --git a/util/flags.py b/util/flags.py index c3ed2af8..9f31aae4 100644 --- a/util/flags.py +++ b/util/flags.py @@ -111,6 +111,7 @@ def create_flags(): f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') f.DEFINE_string('export_name', 'output_graph', 'name for the export model') + f.DEFINE_integer('export_beam_width', 500, 'default beam width to embed into exported graph') # Reporting