* enable hot-word boosting * more consistent ordering of CLI arguments * progress on review * use map instead of set for hot-words, move string logic to client.cc * typo bug * pointer things? * use map for hotwords, better string splitting * add the boost, not multiply * cleaning up * cleaning whitespace * remove <set> inclusion * change typo set-->map * rename boost_coefficient to boost X-DeepSpeech: NOBUILD * add hot_words to python bindings * missing hot_words * include map in swigwrapper.i * add Map template to swigwrapper.i * emacs intermediate file * map things * map-->unordered_map * typu * typu * use dict() not None * error out if hot_words without scorer * two new functions: remove hot-word and clear all hot-words * starting to work on better error messages X-DeepSpeech: NOBUILD * better error handling + .Net ERR codes * allow for negative boosts:) * adding TC test for hot-words * add hot-words to python client, make TC test hot-words everywhere * only run TC tests for C++ and Python * fully expose API in python bindings * expose API in Java (thanks spectie!) * expose API in dotnet (thanks spectie!) * expose API in javascript (thanks spectie!) * java lol * typo in javascript * commenting * java error codes from swig * java docs from SWIG * java and dotnet issues * add hotword test to android tests * dotnet fixes from carlos * add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command * make sure lm is on android for hotword test * path to android model + nit * path * path
81 lines
2.4 KiB
C++
81 lines
2.4 KiB
C++
#ifndef MODELSTATE_H
|
|
#define MODELSTATE_H
|
|
|
|
#include <vector>
|
|
|
|
#include "deepspeech.h"
|
|
#include "alphabet.h"
|
|
|
|
#include "ctcdecode/scorer.h"
|
|
#include "ctcdecode/output.h"
|
|
|
|
class DecoderState;
|
|
|
|
struct ModelState {
|
|
//TODO: infer batch size from model/use dynamic batch size
|
|
static constexpr unsigned int BATCH_SIZE = 1;
|
|
|
|
Alphabet alphabet_;
|
|
std::shared_ptr<Scorer> scorer_;
|
|
std::unordered_map<std::string, float> hot_words_;
|
|
unsigned int beam_width_;
|
|
unsigned int n_steps_;
|
|
unsigned int n_context_;
|
|
unsigned int n_features_;
|
|
unsigned int mfcc_feats_per_timestep_;
|
|
unsigned int sample_rate_;
|
|
unsigned int audio_win_len_;
|
|
unsigned int audio_win_step_;
|
|
unsigned int state_size_;
|
|
|
|
ModelState();
|
|
virtual ~ModelState();
|
|
|
|
virtual int init(const char* model_path);
|
|
|
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
|
|
|
/**
|
|
* @brief Do a single inference step in the acoustic model, with:
|
|
* input=mfcc
|
|
* input_lengths=[n_frames]
|
|
*
|
|
* @param mfcc batch input data
|
|
* @param n_frames number of timesteps in the data
|
|
*
|
|
* @param[out] output_logits Where to store computed logits.
|
|
*/
|
|
virtual void infer(const std::vector<float>& mfcc,
|
|
unsigned int n_frames,
|
|
const std::vector<float>& previous_state_c,
|
|
const std::vector<float>& previous_state_h,
|
|
std::vector<float>& logits_output,
|
|
std::vector<float>& state_c_output,
|
|
std::vector<float>& state_h_output) = 0;
|
|
|
|
/**
|
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
|
* CTC decoder with KenLM enabled
|
|
*
|
|
* @param state Decoder state to use when decoding.
|
|
*
|
|
* @return String representing the decoded text.
|
|
*/
|
|
virtual char* decode(const DecoderState& state) const;
|
|
|
|
/**
|
|
* @brief Return character-level metadata including letter timings.
|
|
*
|
|
* @param state Decoder state to use when decoding.
|
|
* @param num_results Maximum number of candidate results to return.
|
|
*
|
|
* @return A Metadata struct containing CandidateTranscript structs.
|
|
* Each represents an candidate transcript, with the first ranked most probable.
|
|
* The user is responsible for freeing Result by calling DS_FreeMetadata().
|
|
*/
|
|
virtual Metadata* decode_metadata(const DecoderState& state,
|
|
size_t num_results);
|
|
};
|
|
|
|
#endif // MODELSTATE_H
|