enable hot-word boosting (#3297)
* enable hot-word boosting * more consistent ordering of CLI arguments * progress on review * use map instead of set for hot-words, move string logic to client.cc * typo bug * pointer things? * use map for hotwords, better string splitting * add the boost, not multiply * cleaning up * cleaning whitespace * remove <set> inclusion * change typo set-->map * rename boost_coefficient to boost X-DeepSpeech: NOBUILD * add hot_words to python bindings * missing hot_words * include map in swigwrapper.i * add Map template to swigwrapper.i * emacs intermediate file * map things * map-->unordered_map * typu * typu * use dict() not None * error out if hot_words without scorer * two new functions: remove hot-word and clear all hot-words * starting to work on better error messages X-DeepSpeech: NOBUILD * better error handling + .Net ERR codes * allow for negative boosts:) * adding TC test for hot-words * add hot-words to python client, make TC test hot-words everywhere * only run TC tests for C++ and Python * fully expose API in python bindings * expose API in Java (thanks spectie!) * expose API in dotnet (thanks spectie!) * expose API in javascript (thanks spectie!) * java lol * typo in javascript * commenting * java error codes from swig * java docs from SWIG * java and dotnet issues * add hotword test to android tests * dotnet fixes from carlos * add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command * make sure lm is on android for hotword test * path to android model + nit * path * path
This commit is contained in:
parent
d466fb09d4
commit
1eb155ed93
|
@ -38,6 +38,8 @@ int json_candidate_transcripts = 3;
|
|||
|
||||
int stream_size = 0;
|
||||
|
||||
char* hot_words = NULL;
|
||||
|
||||
void PrintHelp(const char* bin)
|
||||
{
|
||||
std::cout <<
|
||||
|
@ -56,6 +58,7 @@ void PrintHelp(const char* bin)
|
|||
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
|
||||
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
|
||||
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
|
||||
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
|
||||
"\t--help\t\t\t\tShow help\n"
|
||||
"\t--version\t\t\tPrint version and exits\n";
|
||||
char* version = DS_Version();
|
||||
|
@ -66,7 +69,7 @@ void PrintHelp(const char* bin)
|
|||
|
||||
bool ProcessArgs(int argc, char** argv)
|
||||
{
|
||||
const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
|
||||
const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh";
|
||||
const option long_opts[] = {
|
||||
{"model", required_argument, nullptr, 'm'},
|
||||
{"scorer", required_argument, nullptr, 'l'},
|
||||
|
@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv)
|
|||
{"json", no_argument, nullptr, 'j'},
|
||||
{"candidate_transcripts", required_argument, nullptr, 150},
|
||||
{"stream", required_argument, nullptr, 's'},
|
||||
{"hot_words", required_argument, nullptr, 'w'},
|
||||
{"version", no_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, 'h'},
|
||||
{nullptr, no_argument, nullptr, 0}
|
||||
|
@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv)
|
|||
has_versions = true;
|
||||
break;
|
||||
|
||||
case 'w':
|
||||
hot_words = optarg;
|
||||
break;
|
||||
|
||||
case 'h': // -h or --help
|
||||
case '?': // Unrecognized option
|
||||
default:
|
||||
|
|
|
@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::string>
|
||||
SplitStringOnDelim(std::string in_string, std::string delim)
|
||||
{
|
||||
std::vector<std::string> out_vector;
|
||||
char * tmp_str = new char[in_string.size() + 1];
|
||||
std::copy(in_string.begin(), in_string.end(), tmp_str);
|
||||
tmp_str[in_string.size()] = '\0';
|
||||
const char* token = strtok(tmp_str, delim.c_str());
|
||||
while( token != NULL ) {
|
||||
out_vector.push_back(token);
|
||||
token = strtok(NULL, delim.c_str());
|
||||
}
|
||||
delete[] tmp_str;
|
||||
return out_vector;
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char **argv)
|
||||
{
|
||||
|
@ -432,6 +448,23 @@ main(int argc, char **argv)
|
|||
}
|
||||
// sphinx-doc: c_ref_model_stop
|
||||
|
||||
if (hot_words) {
|
||||
std::vector<std::string> hot_words_ = SplitStringOnDelim(hot_words, ",");
|
||||
for ( std::string hot_word_ : hot_words_ ) {
|
||||
std::vector<std::string> pair_ = SplitStringOnDelim(hot_word_, ":");
|
||||
const char* word = (pair_[0]).c_str();
|
||||
// the strtof function will return 0 in case of non numeric characters
|
||||
// so, check the boost string before we turn it into a float
|
||||
bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos);
|
||||
float boost = strtof((pair_[1]).c_str(),0);
|
||||
status = DS_AddHotWord(ctx, word, boost);
|
||||
if (status != 0 || !boost_is_valid) {
|
||||
fprintf(stderr, "Could not enable hot-word.\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NO_SOX
|
||||
// Initialise SOX
|
||||
assert(sox_init() == SOX_SUCCESS);
|
||||
|
|
|
@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
scorer=None,
|
||||
hot_words=dict(),
|
||||
num_results=1):
|
||||
"""Wrapper for the CTC Beam Search Decoder.
|
||||
|
||||
|
@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq,
|
|||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||
:type hot_words: map{string:float}
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
|
@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|||
"""
|
||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||
scorer, num_results)
|
||||
scorer, hot_words, num_results)
|
||||
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
return beam_results
|
||||
|
||||
|
@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
scorer=None,
|
||||
hot_words=dict(),
|
||||
num_results=1):
|
||||
"""Wrapper for the batched CTC beam search decoder.
|
||||
|
||||
|
@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||
:param scorer: External scorer for partially decoded sentence, e.g. word
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||
:type hot_words: map{string:float}
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
results, in descending order of the confidence.
|
||||
:rtype: list
|
||||
"""
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results)
|
||||
batch_beam_results = [
|
||||
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
for beam_results in batch_beam_results
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet,
|
|||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words)
|
||||
{
|
||||
// assign special ids
|
||||
abs_time_step_ = 0;
|
||||
|
@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet,
|
|||
cutoff_prob_ = cutoff_prob;
|
||||
cutoff_top_n_ = cutoff_top_n;
|
||||
ext_scorer_ = ext_scorer;
|
||||
hot_words_ = hot_words;
|
||||
start_expanding_ = false;
|
||||
|
||||
// init prefixes' root
|
||||
|
@ -160,8 +162,23 @@ DecoderState::next(const double *probs,
|
|||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||
|
||||
float hot_boost = 0.0;
|
||||
if (!hot_words_.empty()) {
|
||||
std::unordered_map<std::string, float>::iterator iter;
|
||||
// increase prob of prefix for every word
|
||||
// that matches a word in the hot-words list
|
||||
for (std::string word : ngram) {
|
||||
iter = hot_words_.find(word);
|
||||
if ( iter != hot_words_.end() ) {
|
||||
// increase the log_cond_prob(prefix|LM)
|
||||
hot_boost += iter->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool bos = ngram.size() < ext_scorer_->get_max_order();
|
||||
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
|
||||
score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer_->beta;
|
||||
}
|
||||
|
@ -256,11 +273,12 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words,
|
||||
size_t num_results)
|
||||
{
|
||||
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
|
||||
DecoderState state;
|
||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words);
|
||||
state.next(probs, time_dim, class_dim);
|
||||
return state.decode(num_results);
|
||||
}
|
||||
|
@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch(
|
|||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words,
|
||||
size_t num_results)
|
||||
{
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
|
@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch(
|
|||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
ext_scorer,
|
||||
hot_words,
|
||||
num_results));
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ class DecoderState {
|
|||
std::vector<PathTrie*> prefixes_;
|
||||
std::unique_ptr<PathTrie> prefix_root_;
|
||||
TimestepTreeNode timestep_tree_root_{nullptr, 0};
|
||||
std::unordered_map<std::string, float> hot_words_;
|
||||
|
||||
public:
|
||||
DecoderState() = default;
|
||||
|
@ -48,7 +49,8 @@ public:
|
|||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words);
|
||||
|
||||
/* Send data to the decoder
|
||||
*
|
||||
|
@ -88,6 +90,8 @@ public:
|
|||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* hot_words: A map of hot-words and their corresponding boosts
|
||||
* The hot-word is a string and the boost is a float.
|
||||
* num_results: Number of beams to return.
|
||||
* Return:
|
||||
* A vector where each element is a pair of score and decoding result,
|
||||
|
@ -103,6 +107,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words,
|
||||
size_t num_results=1);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
|
@ -117,6 +122,8 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* hot_words: A map of hot-words and their corresponding boosts
|
||||
* The hot-word is a string and the boost is a float.
|
||||
* num_results: Number of beams to return.
|
||||
* Return:
|
||||
* A 2-D vector where each element is a vector of beam search decoding
|
||||
|
@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch(
|
|||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
std::shared_ptr<Scorer> ext_scorer,
|
||||
std::unordered_map<std::string, float> hot_words,
|
||||
size_t num_results=1);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
%include <std_string.i>
|
||||
%include <std_vector.i>
|
||||
%include <std_shared_ptr.i>
|
||||
%include <std_unordered_map.i>
|
||||
%include "numpy.i"
|
||||
|
||||
%init %{
|
||||
|
@ -22,6 +23,7 @@ namespace std {
|
|||
%template(UnsignedIntVector) vector<unsigned int>;
|
||||
%template(OutputVector) vector<Output>;
|
||||
%template(OutputVectorVector) vector<vector<Output>>;
|
||||
%template(Map) unordered_map<string, float>;
|
||||
}
|
||||
|
||||
%shared_ptr(Scorer);
|
||||
|
|
|
@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx,
|
|||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
DS_AddHotWord(ModelState* aCtx,
|
||||
const char* word,
|
||||
float boost)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
const int size_before = aCtx->hot_words_.size();
|
||||
aCtx->hot_words_.insert( std::pair<std::string,float> (word, boost) );
|
||||
const int size_after = aCtx->hot_words_.size();
|
||||
if (size_before == size_after) {
|
||||
return DS_ERR_FAIL_INSERT_HOTWORD;
|
||||
}
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int
|
||||
DS_EraseHotWord(ModelState* aCtx,
|
||||
const char* word)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
const int size_before = aCtx->hot_words_.size();
|
||||
int err = aCtx->hot_words_.erase(word);
|
||||
const int size_after = aCtx->hot_words_.size();
|
||||
if (size_before == size_after) {
|
||||
return DS_ERR_FAIL_ERASE_HOTWORD;
|
||||
}
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int
|
||||
DS_ClearHotWords(ModelState* aCtx)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
aCtx->hot_words_.clear();
|
||||
const int size_after = aCtx->hot_words_.size();
|
||||
if (size_after != 0) {
|
||||
return DS_ERR_FAIL_CLEAR_HOTWORD;
|
||||
}
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int
|
||||
DS_DisableExternalScorer(ModelState* aCtx)
|
||||
{
|
||||
|
@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx,
|
|||
aCtx->beam_width_,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
aCtx->scorer_);
|
||||
aCtx->scorer_,
|
||||
aCtx->hot_words_);
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
|
|
|
@ -81,7 +81,10 @@ typedef struct Metadata {
|
|||
APPLY(DS_ERR_FAIL_CREATE_STREAM, 0x3004, "Error creating the stream.") \
|
||||
APPLY(DS_ERR_FAIL_READ_PROTOBUF, 0x3005, "Error reading the proto buffer model file.") \
|
||||
APPLY(DS_ERR_FAIL_CREATE_SESS, 0x3006, "Failed to create session.") \
|
||||
APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.")
|
||||
APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.") \
|
||||
APPLY(DS_ERR_FAIL_INSERT_HOTWORD, 0x3008, "Could not insert hot-word.") \
|
||||
APPLY(DS_ERR_FAIL_CLEAR_HOTWORD, 0x3009, "Could not clear hot-words.") \
|
||||
APPLY(DS_ERR_FAIL_ERASE_HOTWORD, 0x3010, "Could not erase hot-word.")
|
||||
|
||||
// sphinx-doc: error_code_listing_end
|
||||
|
||||
|
@ -157,6 +160,42 @@ DEEPSPEECH_EXPORT
|
|||
int DS_EnableExternalScorer(ModelState* aCtx,
|
||||
const char* aScorerPath);
|
||||
|
||||
/**
|
||||
* @brief Add a hot-word and its boost.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param word The hot-word.
|
||||
* @param boost The boost.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_AddHotWord(ModelState* aCtx,
|
||||
const char* word,
|
||||
float boost);
|
||||
|
||||
/**
|
||||
* @brief Remove entry for a hot-word from the hot-words map.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param word The hot-word.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_EraseHotWord(ModelState* aCtx,
|
||||
const char* word);
|
||||
|
||||
/**
|
||||
* @brief Removes all elements from the hot-words map.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_ClearHotWords(ModelState* aCtx);
|
||||
|
||||
/**
|
||||
* @brief Disable decoding using an external scorer.
|
||||
*
|
||||
|
|
|
@ -74,6 +74,39 @@ namespace DeepSpeechClient
|
|||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Add a hot-word.
|
||||
/// </summary>
|
||||
/// <param name="aWord">Some word</param>
|
||||
/// <param name="aBoost">Some boost</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
public unsafe void AddHotWord(string aWord, float aBoost)
|
||||
{
|
||||
var resultCode = NativeImp.DS_AddHotWord(_modelStatePP, aWord, aBoost);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Erase entry for a hot-word.
|
||||
/// </summary>
|
||||
/// <param name="aWord">Some word</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
public unsafe void EraseHotWord(string aWord)
|
||||
{
|
||||
var resultCode = NativeImp.DS_EraseHotWord(_modelStatePP, aWord);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Clear all hot-words.
|
||||
/// </summary>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
public unsafe void ClearHotWords()
|
||||
{
|
||||
var resultCode = NativeImp.DS_ClearHotWords(_modelStatePP);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Return the sample rate expected by the model.
|
||||
/// </summary>
|
||||
|
|
|
@ -26,5 +26,8 @@
|
|||
DS_ERR_FAIL_CREATE_STREAM = 0x3004,
|
||||
DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
|
||||
DS_ERR_FAIL_CREATE_SESS = 0x3006,
|
||||
DS_ERR_FAIL_INSERT_HOTWORD = 0x3008,
|
||||
DS_ERR_FAIL_CLEAR_HOTWORD = 0x3009,
|
||||
DS_ERR_FAIL_ERASE_HOTWORD = 0x3010
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,27 @@ namespace DeepSpeechClient.Interfaces
|
|||
/// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
|
||||
unsafe void EnableExternalScorer(string aScorerPath);
|
||||
|
||||
/// <summary>
|
||||
/// Add a hot-word.
|
||||
/// </summary>
|
||||
/// <param name="aWord">Some word</param>
|
||||
/// <param name="aBoost">Some boost</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
unsafe void AddHotWord(string aWord, float aBoost);
|
||||
|
||||
/// <summary>
|
||||
/// Erase entry for a hot-word.
|
||||
/// </summary>
|
||||
/// <param name="aWord">Some word</param>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
unsafe void EraseHotWord(string aWord);
|
||||
|
||||
/// <summary>
|
||||
/// Clear all hot-words.
|
||||
/// </summary>
|
||||
/// <exception cref="ArgumentException">Thrown on failure.</exception>
|
||||
unsafe void ClearHotWords();
|
||||
|
||||
/// <summary>
|
||||
/// Disable decoding using an external scorer.
|
||||
/// </summary>
|
||||
|
|
|
@ -41,6 +41,18 @@ namespace DeepSpeechClient
|
|||
internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx,
|
||||
string aScorerPath);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_AddHotWord(IntPtr** aCtx,
|
||||
string aWord,
|
||||
float aBoost);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_EraseHotWord(IntPtr** aCtx,
|
||||
string aWord);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_ClearHotWords(IntPtr** aCtx);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx);
|
||||
|
||||
|
|
|
@ -215,4 +215,36 @@ public class DeepSpeechModel {
|
|||
public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) {
|
||||
return impl.FinishStreamWithMetadata(ctx.get(), num_results);
|
||||
}
|
||||
/**
|
||||
* @brief Add a hot-word
|
||||
*
|
||||
* @param word
|
||||
* @param boost
|
||||
*
|
||||
* @throws RuntimeException on failure.
|
||||
*
|
||||
*/
|
||||
public void addHotWord(String word, float boost) {
|
||||
evaluateErrorCode(impl.AddHotWord(this._msp, word, boost));
|
||||
}
|
||||
/**
|
||||
* @brief Erase a hot-word
|
||||
*
|
||||
* @param word
|
||||
*
|
||||
* @throws RuntimeException on failure.
|
||||
*
|
||||
*/
|
||||
public void eraseHotWord(String word) {
|
||||
evaluateErrorCode(impl.EraseHotWord(this._msp, word));
|
||||
}
|
||||
/**
|
||||
* @brief Clear all hot-words.
|
||||
*
|
||||
* @throws RuntimeException on failure.
|
||||
*
|
||||
*/
|
||||
public void clearHotWords() {
|
||||
evaluateErrorCode(impl.ClearHotWords(this._msp));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,7 +28,10 @@ public enum DeepSpeech_Error_Codes {
|
|||
ERR_FAIL_CREATE_STREAM(0x3004),
|
||||
ERR_FAIL_READ_PROTOBUF(0x3005),
|
||||
ERR_FAIL_CREATE_SESS(0x3006),
|
||||
ERR_FAIL_CREATE_MODEL(0x3007);
|
||||
ERR_FAIL_CREATE_MODEL(0x3007),
|
||||
ERR_FAIL_INSERT_HOTWORD(0x3008),
|
||||
ERR_FAIL_CLEAR_HOTWORD(0x3009),
|
||||
ERR_FAIL_ERASE_HOTWORD(0x3010);
|
||||
|
||||
public final int swigValue() {
|
||||
return swigValue;
|
||||
|
|
|
@ -182,6 +182,47 @@ export class Model {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a hot-word and its boost
|
||||
*
|
||||
* @param aWord word
|
||||
* @param aBoost boost
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
addHotWord(aWord: string, aBoost: number): void {
|
||||
const status = binding.addHotWord(this._impl, aWord, aBoost);
|
||||
if (status !== 0) {
|
||||
throw `addHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Erase entry for hot-word
|
||||
*
|
||||
* @param aWord word
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
eraseHotWord(aWord: string): void {
|
||||
const status = binding.eraseHotWord(this._impl, aWord);
|
||||
if (status !== 0) {
|
||||
throw `eraseHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all hot-word entries
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
clearHotWords(): void {
|
||||
const status = binding.clearHotWords(this._impl);
|
||||
if (status !== 0) {
|
||||
throw `clearHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the sample rate expected by the model.
|
||||
*
|
||||
|
|
|
@ -17,6 +17,7 @@ struct ModelState {
|
|||
|
||||
Alphabet alphabet_;
|
||||
std::shared_ptr<Scorer> scorer_;
|
||||
std::unordered_map<std::string, float> hot_words_;
|
||||
unsigned int beam_width_;
|
||||
unsigned int n_steps_;
|
||||
unsigned int n_context_;
|
||||
|
|
|
@ -95,6 +95,45 @@ class Model(object):
|
|||
"""
|
||||
return deepspeech.impl.DisableExternalScorer(self._impl)
|
||||
|
||||
def addHotWord(self, word, boost):
|
||||
"""
|
||||
Add a word and its boost for decoding.
|
||||
|
||||
:param word: the hot-word
|
||||
:type word: str
|
||||
|
||||
:param word: the boost
|
||||
:type word: float
|
||||
|
||||
:throws: RuntimeError on error
|
||||
"""
|
||||
status = deepspeech.impl.AddHotWord(self._impl, word, boost)
|
||||
if status != 0:
|
||||
raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
|
||||
|
||||
def eraseHotWord(self, word):
|
||||
"""
|
||||
Remove entry for word from hot-words dict.
|
||||
|
||||
:param word: the hot-word
|
||||
:type word: str
|
||||
|
||||
:throws: RuntimeError on error
|
||||
"""
|
||||
status = deepspeech.impl.EraseHotWord(self._impl, word)
|
||||
if status != 0:
|
||||
raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
|
||||
|
||||
def clearHotWords(self):
|
||||
"""
|
||||
Remove all entries from hot-words dict.
|
||||
|
||||
:throws: RuntimeError on error
|
||||
"""
|
||||
status = deepspeech.impl.ClearHotWords(self._impl)
|
||||
if status != 0:
|
||||
raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
|
||||
|
||||
def setScorerAlphaBeta(self, alpha, beta):
|
||||
"""
|
||||
Set hyperparameters alpha and beta of the external scorer.
|
||||
|
|
|
@ -109,6 +109,8 @@ def main():
|
|||
help='Output json from metadata with timestamp of each word')
|
||||
parser.add_argument('--candidate_transcripts', type=int, default=3,
|
||||
help='Number of candidate transcripts to include in JSON output')
|
||||
parser.add_argument('--hot_words', type=str,
|
||||
help='Hot-words and their boosts.')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('Loading model from file {}'.format(args.model), file=sys.stderr)
|
||||
|
@ -134,6 +136,12 @@ def main():
|
|||
if args.lm_alpha and args.lm_beta:
|
||||
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
|
||||
|
||||
if args.hot_words:
|
||||
print('Adding hot-words', file=sys.stderr)
|
||||
for word_boost in args.hot_words.split(','):
|
||||
word,boost = word_boost.split(':')
|
||||
ds.addHotWord(word,float(boost))
|
||||
|
||||
fin = wave.open(args.audio, 'rb')
|
||||
fs_orig = fin.getframerate()
|
||||
if fs_orig != desired_sample_rate:
|
||||
|
|
|
@ -30,4 +30,6 @@ android_setup_ndk_data
|
|||
|
||||
run_tflite_basic_inference_tests
|
||||
|
||||
run_android_hotword_tests
|
||||
|
||||
android_stop_emulator
|
||||
|
|
|
@ -206,6 +206,10 @@ android_setup_ndk_data()
|
|||
${TASKCLUSTER_TMP_DIR}/${model_name} \
|
||||
${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \
|
||||
${ANDROID_TMP_DIR}/ds/
|
||||
|
||||
if [ -f "${TASKCLUSTER_TMP_DIR}/kenlm.scorer" ]; then
|
||||
adb push ${TASKCLUSTER_TMP_DIR}/kenlm.scorer ${ANDROID_TMP_DIR}/ds/
|
||||
fi
|
||||
}
|
||||
|
||||
android_setup_apk_data()
|
||||
|
|
|
@ -524,6 +524,24 @@ run_multi_inference_tests()
|
|||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
|
||||
}
|
||||
|
||||
run_hotword_tests()
|
||||
{
|
||||
set +e
|
||||
hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${hotwords_decode}" "$status"
|
||||
}
|
||||
|
||||
run_android_hotword_tests()
|
||||
{
|
||||
set +e
|
||||
hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${DATA_TMP_DIR}/${model_name} --scorer ${DATA_TMP_DIR}/kenlm.scorer --audio ${DATA_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${hotwords_decode}" "$status"
|
||||
}
|
||||
|
||||
run_cpp_only_inference_tests()
|
||||
{
|
||||
set +e
|
||||
|
|
|
@ -18,3 +18,5 @@ run_all_inference_tests
|
|||
run_multi_inference_tests
|
||||
|
||||
run_cpp_only_inference_tests
|
||||
|
||||
run_hotword_tests
|
||||
|
|
|
@ -23,3 +23,5 @@ run_all_inference_tests
|
|||
run_multi_inference_tests
|
||||
|
||||
run_cpp_only_inference_tests
|
||||
|
||||
run_hotword_tests
|
||||
|
|
|
@ -28,4 +28,6 @@ ensure_cuda_usage "$3"
|
|||
|
||||
run_all_inference_tests
|
||||
|
||||
run_hotword_tests
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "deepspeech"
|
||||
|
|
|
@ -33,4 +33,6 @@ deepspeech --version
|
|||
|
||||
run_all_inference_tests
|
||||
|
||||
run_hotword_tests
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "deepspeech"
|
||||
|
|
Loading…
Reference in New Issue