diff --git a/DeepSpeech.py b/DeepSpeech.py index e6d3a929..48b8edb6 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -797,6 +797,7 @@ def export(): outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') + outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width') outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet') if FLAGS.export_language: diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 274ce41f..9f0ebe42 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -257,7 +257,6 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) int DS_CreateModel(const char* aModelPath, - unsigned int aBeamWidth, ModelState** retval) { *retval = nullptr; @@ -282,7 +281,7 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_FAIL_CREATE_MODEL; } - int err = model->init(aModelPath, aBeamWidth); + int err = model->init(aModelPath); if (err != DS_ERR_OK) { return err; } @@ -291,6 +290,19 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_OK; } +unsigned int +DS_GetModelBeamWidth(ModelState* aCtx) +{ + return aCtx->beam_width_; +} + +int +DS_SetModelBeamWidth(ModelState* aCtx, unsigned int aBeamWidth) +{ + aCtx->beam_width_ = aBeamWidth; + return 0; +} + int DS_GetModelSampleRate(ModelState* aCtx) { diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 4e017653..c5e330cb 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -78,18 +78,39 @@ enum DeepSpeech_Error_Codes * @brief An object providing an interface to a trained DeepSpeech model. * * @param aModelPath The path to the frozen model graph. - * @param aBeamWidth The beam width used by the decoder. A larger beam - * width generates better results at the cost of decoding - * time. * @param[out] retval a ModelState pointer * * @return Zero on success, non-zero on failure. */ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, - unsigned int aBeamWidth, ModelState** retval); +/** + * @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth} + * was not called before, will return the default value loaded from the + * model file. + * + * @param aCtx A ModelState pointer created with {@link DS_CreateModel}. + * + * @return Beam width value used by the model. + */ +DEEPSPEECH_EXPORT +unsigned int DS_GetModelBeamWidth(ModelState* aCtx); + +/** + * @brief Set beam width value used by the model. + * + * @param aCtx A ModelState pointer created with {@link DS_CreateModel}. + * @param aBeamWidth The beam width used by the model. A larger beam width value + * generates better results at the cost of decoding time. + * + * @return Zero on success, non-zero on failure. + */ +DEEPSPEECH_EXPORT +int DS_SetModelBeamWidth(ModelState* aCtx, + unsigned int aBeamWidth); + /** * @brief Return the sample rate expected by a model. * diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index c7fc46a0..4bc0e953 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,10 +24,8 @@ ModelState::~ModelState() } int -ModelState::init(const char* model_path, - unsigned int beam_width) +ModelState::init(const char* model_path) { - beam_width_ = beam_width; return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index d4f11c1c..c296c003 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -30,7 +30,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const char* model_path, unsigned int beam_width); + virtual int init(const char* model_path); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 2135a571..5e0c71f3 100644 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -91,10 +91,9 @@ TFLiteModelState::~TFLiteModelState() } int -TFLiteModelState::init(const char* model_path, - unsigned int beam_width) +TFLiteModelState::init(const char* model_path) { - int err = ModelState::init(model_path, beam_width); + int err = ModelState::init(model_path); if (err != DS_ERR_OK) { return err; } @@ -130,6 +129,7 @@ TFLiteModelState::init(const char* model_path, int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len"); int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step"); int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet"); + int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width"); std::vector metadata_exec_plan; metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]); @@ -137,6 +137,7 @@ TFLiteModelState::init(const char* model_path, metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]); + metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]); for (int i = 0; i < metadata_exec_plan.size(); ++i) { assert(metadata_exec_plan[i] > -1); @@ -207,9 +208,14 @@ TFLiteModelState::init(const char* model_path, return DS_ERR_INVALID_ALPHABET; } + int* const beam_width = interpreter_->typed_tensor(metadata_beam_width_idx); + beam_width_ = (unsigned int)(*beam_width); + assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); + assert(alphabet_.GetSize() > 0); + assert(beam_width_ > 0); TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims; diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index 77137751..11532e64 100644 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -30,8 +30,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const char* model_path, - unsigned int beam_width) override; + virtual int init(const char* model_path) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index c7c1a688..ab7cc136 100644 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -23,10 +23,9 @@ TFModelState::~TFModelState() } int -TFModelState::init(const char* model_path, - unsigned int beam_width) +TFModelState::init(const char* model_path) { - int err = ModelState::init(model_path, beam_width); + int err = ModelState::init(model_path); if (err != DS_ERR_OK) { return err; } @@ -104,6 +103,7 @@ TFModelState::init(const char* model_path, "metadata_feature_win_len", "metadata_feature_win_step", "metadata_alphabet", + "metadata_beam_width", }, {}, &metadata_outputs); if (!status.ok()) { std::cout << "Unable to fetch metadata: " << status << std::endl; @@ -122,9 +122,14 @@ TFModelState::init(const char* model_path, return DS_ERR_INVALID_ALPHABET; } + int beam_width = metadata_outputs[4].scalar()(); + beam_width_ = (unsigned int)(beam_width); + assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); + assert(alphabet_.GetSize() > 0); + assert(beam_width_ > 0); for (int i = 0; i < graph_def_.node_size(); ++i) { NodeDef node = graph_def_.node(i); diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index 2f6edf49..2a8db699 100644 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,8 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const char* model_path, - unsigned int beam_width) override; + virtual int init(const char* model_path) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames, diff --git a/util/flags.py b/util/flags.py index c3ed2af8..a465c9fc 100644 --- a/util/flags.py +++ b/util/flags.py @@ -111,6 +111,7 @@ def create_flags(): f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.') f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json') f.DEFINE_string('export_name', 'output_graph', 'name for the export model') + f.DEFINE_string('export_beam_width', 500, 'default beam width to embed into exported graph') # Reporting