STT/native_client/ctcdecode/ctc_beam_search_decoder.h
Josh Meyer 1eb155ed93
enable hot-word boosting (#3297)
* enable hot-word boosting

* more consistent ordering of CLI arguments

* progress on review

* use map instead of set for hot-words, move string logic to client.cc

* typo bug

* pointer things?

* use map for hotwords, better string splitting

* add the boost, not multiply

* cleaning up

* cleaning whitespace

* remove <set> inclusion

* change typo set-->map

* rename boost_coefficient to boost

X-DeepSpeech: NOBUILD

* add hot_words to python bindings

* missing hot_words

* include map in swigwrapper.i

* add Map template to swigwrapper.i

* emacs intermediate file

* map things

* map-->unordered_map

* typu

* typu

* use dict() not None

* error out if hot_words without scorer

* two new functions: remove hot-word and clear all hot-words

* starting to work on better error messages

X-DeepSpeech: NOBUILD

* better error handling + .Net ERR codes

* allow for negative boosts:)

* adding TC test for hot-words

* add hot-words to python client, make TC test hot-words everywhere

* only run TC tests for C++ and Python

* fully expose API in python bindings

* expose API in Java (thanks spectie!)

* expose API in dotnet (thanks spectie!)

* expose API in javascript (thanks spectie!)

* java lol

* typo in javascript

* commenting

* java error codes from swig

* java docs from SWIG

* java and dotnet issues

* add hotword test to android tests

* dotnet fixes from carlos

* add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command

* make sure lm is on android for hotword test

* path to android model + nit

* path

* path
2020-09-24 14:58:41 -04:00

150 lines
4.9 KiB
C++

#ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_
#include <memory>
#include <string>
#include <vector>
#include "scorer.h"
#include "output.h"
#include "alphabet.h"
class DecoderState {
int abs_time_step_;
int space_id_;
int blank_id_;
size_t beam_size_;
double cutoff_prob_;
size_t cutoff_top_n_;
bool start_expanding_;
std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0};
std::unordered_map<std::string, float> hot_words_;
public:
DecoderState() = default;
~DecoderState() = default;
// Disallow copying
DecoderState(const DecoderState&) = delete;
DecoderState& operator=(DecoderState&) = delete;
/* Initialize CTC beam search decoder
*
* Parameters:
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* Return:
* Zero on success, non-zero on failure.
*/
int init(const Alphabet& alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words);
/* Send data to the decoder
*
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Number of classes (alphabet length + 1 for space character).
*/
void next(const double *probs,
int time_dim,
int class_dim);
/* Get up to num_results transcriptions from current decoder state.
*
* Parameters:
* num_results: Number of beams to return.
*
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decode(size_t num_results=1) const;
};
/* CTC Beam Search Decoder
* Parameters:
* probs: 2-D vector where each element is a vector of probabilities
* over alphabet of one time step.
* time_dim: Number of timesteps.
* class_dim: Alphabet length (plus 1 for space character).
* alphabet: The alphabet.
* beam_size: The width of beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> ctc_beam_search_decoder(
const double* probs,
int time_dim,
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);
/* CTC Beam Search Decoder for batch data
* Parameters:
* probs: 3-D vector where each element is a 2-D vector that can be used
* by ctc_beam_search_decoder().
* alphabet: The alphabet.
* beam_size: The width of beam search.
* num_processes: Number of threads for beam search.
* cutoff_prob: Cutoff probability for pruning.
* cutoff_top_n: Cutoff number for pruning.
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A 2-D vector where each element is a vector of beam search decoding
* result for one audio sample.
*/
std::vector<std::vector<Output>>
ctc_beam_search_decoder_batch(
const double* probs,
int batch_size,
int time_dim,
int class_dim,
const int* seq_lengths,
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);
#endif // CTC_BEAM_SEARCH_DECODER_H_