Embed default beam width into exported graph and remove param from DS_CreateModel
This commit is contained in:
parent
5366f90375
commit
8e9b6ef7b3
@ -797,6 +797,7 @@ def export():
|
|||||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
|
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
|
@ -257,7 +257,6 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
|||||||
|
|
||||||
int
|
int
|
||||||
DS_CreateModel(const char* aModelPath,
|
DS_CreateModel(const char* aModelPath,
|
||||||
unsigned int aBeamWidth,
|
|
||||||
ModelState** retval)
|
ModelState** retval)
|
||||||
{
|
{
|
||||||
*retval = nullptr;
|
*retval = nullptr;
|
||||||
@ -282,7 +281,7 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
return DS_ERR_FAIL_CREATE_MODEL;
|
return DS_ERR_FAIL_CREATE_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = model->init(aModelPath, aBeamWidth);
|
int err = model->init(aModelPath);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
@ -291,6 +290,19 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned int
|
||||||
|
DS_GetModelBeamWidth(ModelState* aCtx)
|
||||||
|
{
|
||||||
|
return aCtx->beam_width_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
DS_SetModelBeamWidth(ModelState* aCtx, unsigned int aBeamWidth)
|
||||||
|
{
|
||||||
|
aCtx->beam_width_ = aBeamWidth;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
DS_GetModelSampleRate(ModelState* aCtx)
|
DS_GetModelSampleRate(ModelState* aCtx)
|
||||||
{
|
{
|
||||||
|
@ -78,18 +78,39 @@ enum DeepSpeech_Error_Codes
|
|||||||
* @brief An object providing an interface to a trained DeepSpeech model.
|
* @brief An object providing an interface to a trained DeepSpeech model.
|
||||||
*
|
*
|
||||||
* @param aModelPath The path to the frozen model graph.
|
* @param aModelPath The path to the frozen model graph.
|
||||||
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
|
||||||
* width generates better results at the cost of decoding
|
|
||||||
* time.
|
|
||||||
* @param[out] retval a ModelState pointer
|
* @param[out] retval a ModelState pointer
|
||||||
*
|
*
|
||||||
* @return Zero on success, non-zero on failure.
|
* @return Zero on success, non-zero on failure.
|
||||||
*/
|
*/
|
||||||
DEEPSPEECH_EXPORT
|
DEEPSPEECH_EXPORT
|
||||||
int DS_CreateModel(const char* aModelPath,
|
int DS_CreateModel(const char* aModelPath,
|
||||||
unsigned int aBeamWidth,
|
|
||||||
ModelState** retval);
|
ModelState** retval);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth}
|
||||||
|
* was not called before, will return the default value loaded from the
|
||||||
|
* model file.
|
||||||
|
*
|
||||||
|
* @param aCtx A ModelState pointer created with {@link DS_CreateModel}.
|
||||||
|
*
|
||||||
|
* @return Beam width value used by the model.
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
unsigned int DS_GetModelBeamWidth(ModelState* aCtx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set beam width value used by the model.
|
||||||
|
*
|
||||||
|
* @param aCtx A ModelState pointer created with {@link DS_CreateModel}.
|
||||||
|
* @param aBeamWidth The beam width used by the model. A larger beam width value
|
||||||
|
* generates better results at the cost of decoding time.
|
||||||
|
*
|
||||||
|
* @return Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
int DS_SetModelBeamWidth(ModelState* aCtx,
|
||||||
|
unsigned int aBeamWidth);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return the sample rate expected by a model.
|
* @brief Return the sample rate expected by a model.
|
||||||
*
|
*
|
||||||
|
@ -24,10 +24,8 @@ ModelState::~ModelState()
|
|||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
ModelState::init(const char* model_path,
|
ModelState::init(const char* model_path)
|
||||||
unsigned int beam_width)
|
|
||||||
{
|
{
|
||||||
beam_width_ = beam_width;
|
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ struct ModelState {
|
|||||||
ModelState();
|
ModelState();
|
||||||
virtual ~ModelState();
|
virtual ~ModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path, unsigned int beam_width);
|
virtual int init(const char* model_path);
|
||||||
|
|
||||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||||
|
|
||||||
|
@ -91,10 +91,9 @@ TFLiteModelState::~TFLiteModelState()
|
|||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
TFLiteModelState::init(const char* model_path,
|
TFLiteModelState::init(const char* model_path)
|
||||||
unsigned int beam_width)
|
|
||||||
{
|
{
|
||||||
int err = ModelState::init(model_path, beam_width);
|
int err = ModelState::init(model_path);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
@ -130,6 +129,7 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
||||||
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
||||||
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
|
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
|
||||||
|
int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width");
|
||||||
|
|
||||||
std::vector<int> metadata_exec_plan;
|
std::vector<int> metadata_exec_plan;
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
||||||
@ -137,6 +137,7 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
|
||||||
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]);
|
||||||
|
|
||||||
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
||||||
assert(metadata_exec_plan[i] > -1);
|
assert(metadata_exec_plan[i] > -1);
|
||||||
@ -207,9 +208,14 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int* const beam_width = interpreter_->typed_tensor<int>(metadata_beam_width_idx);
|
||||||
|
beam_width_ = (unsigned int)(*beam_width);
|
||||||
|
|
||||||
assert(sample_rate_ > 0);
|
assert(sample_rate_ > 0);
|
||||||
assert(audio_win_len_ > 0);
|
assert(audio_win_len_ > 0);
|
||||||
assert(audio_win_step_ > 0);
|
assert(audio_win_step_ > 0);
|
||||||
|
assert(alphabet_.GetSize() > 0);
|
||||||
|
assert(beam_width_ > 0);
|
||||||
|
|
||||||
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||||
|
|
||||||
|
@ -30,8 +30,7 @@ struct TFLiteModelState : public ModelState
|
|||||||
TFLiteModelState();
|
TFLiteModelState();
|
||||||
virtual ~TFLiteModelState();
|
virtual ~TFLiteModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path,
|
virtual int init(const char* model_path) override;
|
||||||
unsigned int beam_width) override;
|
|
||||||
|
|
||||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||||
std::vector<float>& mfcc_output) override;
|
std::vector<float>& mfcc_output) override;
|
||||||
|
@ -23,10 +23,9 @@ TFModelState::~TFModelState()
|
|||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
TFModelState::init(const char* model_path,
|
TFModelState::init(const char* model_path)
|
||||||
unsigned int beam_width)
|
|
||||||
{
|
{
|
||||||
int err = ModelState::init(model_path, beam_width);
|
int err = ModelState::init(model_path);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
@ -104,6 +103,7 @@ TFModelState::init(const char* model_path,
|
|||||||
"metadata_feature_win_len",
|
"metadata_feature_win_len",
|
||||||
"metadata_feature_win_step",
|
"metadata_feature_win_step",
|
||||||
"metadata_alphabet",
|
"metadata_alphabet",
|
||||||
|
"metadata_beam_width",
|
||||||
}, {}, &metadata_outputs);
|
}, {}, &metadata_outputs);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
||||||
@ -122,9 +122,14 @@ TFModelState::init(const char* model_path,
|
|||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int beam_width = metadata_outputs[4].scalar<int>()();
|
||||||
|
beam_width_ = (unsigned int)(beam_width);
|
||||||
|
|
||||||
assert(sample_rate_ > 0);
|
assert(sample_rate_ > 0);
|
||||||
assert(audio_win_len_ > 0);
|
assert(audio_win_len_ > 0);
|
||||||
assert(audio_win_step_ > 0);
|
assert(audio_win_step_ > 0);
|
||||||
|
assert(alphabet_.GetSize() > 0);
|
||||||
|
assert(beam_width_ > 0);
|
||||||
|
|
||||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||||
NodeDef node = graph_def_.node(i);
|
NodeDef node = graph_def_.node(i);
|
||||||
|
@ -18,8 +18,7 @@ struct TFModelState : public ModelState
|
|||||||
TFModelState();
|
TFModelState();
|
||||||
virtual ~TFModelState();
|
virtual ~TFModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path,
|
virtual int init(const char* model_path) override;
|
||||||
unsigned int beam_width) override;
|
|
||||||
|
|
||||||
virtual void infer(const std::vector<float>& mfcc,
|
virtual void infer(const std::vector<float>& mfcc,
|
||||||
unsigned int n_frames,
|
unsigned int n_frames,
|
||||||
|
@ -111,6 +111,7 @@ def create_flags():
|
|||||||
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
f.DEFINE_string('export_language', '', 'language the model was trained on e.g. "en" or "English". Gets embedded into exported model.')
|
||||||
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
|
f.DEFINE_boolean('export_zip', False, 'export a TFLite model and package with LM and info.json')
|
||||||
f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
|
f.DEFINE_string('export_name', 'output_graph', 'name for the export model')
|
||||||
|
f.DEFINE_string('export_beam_width', 500, 'default beam width to embed into exported graph')
|
||||||
|
|
||||||
# Reporting
|
# Reporting
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user