STT/native_client/modelstate.h
Josh Meyer 1eb155ed93
enable hot-word boosting (#3297)
* enable hot-word boosting

* more consistent ordering of CLI arguments

* progress on review

* use map instead of set for hot-words, move string logic to client.cc

* typo bug

* pointer things?

* use map for hotwords, better string splitting

* add the boost, not multiply

* cleaning up

* cleaning whitespace

* remove <set> inclusion

* change typo set-->map

* rename boost_coefficient to boost

X-DeepSpeech: NOBUILD

* add hot_words to python bindings

* missing hot_words

* include map in swigwrapper.i

* add Map template to swigwrapper.i

* emacs intermediate file

* map things

* map-->unordered_map

* typu

* typu

* use dict() not None

* error out if hot_words without scorer

* two new functions: remove hot-word and clear all hot-words

* starting to work on better error messages

X-DeepSpeech: NOBUILD

* better error handling + .Net ERR codes

* allow for negative boosts:)

* adding TC test for hot-words

* add hot-words to python client, make TC test hot-words everywhere

* only run TC tests for C++ and Python

* fully expose API in python bindings

* expose API in Java (thanks spectie!)

* expose API in dotnet (thanks spectie!)

* expose API in javascript (thanks spectie!)

* java lol

* typo in javascript

* commenting

* java error codes from swig

* java docs from SWIG

* java and dotnet issues

* add hotword test to android tests

* dotnet fixes from carlos

* add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command

* make sure lm is on android for hotword test

* path to android model + nit

* path

* path
2020-09-24 14:58:41 -04:00

81 lines
2.4 KiB
C++

#ifndef MODELSTATE_H
#define MODELSTATE_H
#include <vector>
#include "deepspeech.h"
#include "alphabet.h"
#include "ctcdecode/scorer.h"
#include "ctcdecode/output.h"
class DecoderState;
struct ModelState {
//TODO: infer batch size from model/use dynamic batch size
static constexpr unsigned int BATCH_SIZE = 1;
Alphabet alphabet_;
std::shared_ptr<Scorer> scorer_;
std::unordered_map<std::string, float> hot_words_;
unsigned int beam_width_;
unsigned int n_steps_;
unsigned int n_context_;
unsigned int n_features_;
unsigned int mfcc_feats_per_timestep_;
unsigned int sample_rate_;
unsigned int audio_win_len_;
unsigned int audio_win_step_;
unsigned int state_size_;
ModelState();
virtual ~ModelState();
virtual int init(const char* model_path);
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
/**
* @brief Do a single inference step in the acoustic model, with:
* input=mfcc
* input_lengths=[n_frames]
*
* @param mfcc batch input data
* @param n_frames number of timesteps in the data
*
* @param[out] output_logits Where to store computed logits.
*/
virtual void infer(const std::vector<float>& mfcc,
unsigned int n_frames,
const std::vector<float>& previous_state_c,
const std::vector<float>& previous_state_h,
std::vector<float>& logits_output,
std::vector<float>& state_c_output,
std::vector<float>& state_h_output) = 0;
/**
* @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled
*
* @param state Decoder state to use when decoding.
*
* @return String representing the decoded text.
*/
virtual char* decode(const DecoderState& state) const;
/**
* @brief Return character-level metadata including letter timings.
*
* @param state Decoder state to use when decoding.
* @param num_results Maximum number of candidate results to return.
*
* @return A Metadata struct containing CandidateTranscript structs.
* Each represents an candidate transcript, with the first ranked most probable.
* The user is responsible for freeing Result by calling DS_FreeMetadata().
*/
virtual Metadata* decode_metadata(const DecoderState& state,
size_t num_results);
};
#endif // MODELSTATE_H