enable hot-word boosting (#3297)

* enable hot-word boosting

* more consistent ordering of CLI arguments

* progress on review

* use map instead of set for hot-words, move string logic to client.cc

* typo bug

* pointer things?

* use map for hotwords, better string splitting

* add the boost, not multiply

* cleaning up

* cleaning whitespace

* remove <set> inclusion

* change typo set-->map

* rename boost_coefficient to boost

X-DeepSpeech: NOBUILD

* add hot_words to python bindings

* missing hot_words

* include map in swigwrapper.i

* add Map template to swigwrapper.i

* emacs intermediate file

* map things

* map-->unordered_map

* typu

* typu

* use dict() not None

* error out if hot_words without scorer

* two new functions: remove hot-word and clear all hot-words

* starting to work on better error messages

X-DeepSpeech: NOBUILD

* better error handling + .Net ERR codes

* allow for negative boosts:)

* adding TC test for hot-words

* add hot-words to python client, make TC test hot-words everywhere

* only run TC tests for C++ and Python

* fully expose API in python bindings

* expose API in Java (thanks spectie!)

* expose API in dotnet (thanks spectie!)

* expose API in javascript (thanks spectie!)

* java lol

* typo in javascript

* commenting

* java error codes from swig

* java docs from SWIG

* java and dotnet issues

* add hotword test to android tests

* dotnet fixes from carlos

* add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command

* make sure lm is on android for hotword test

* path to android model + nit

* path

* path
This commit is contained in:
Josh Meyer 2020-09-24 14:58:41 -04:00 committed by GitHub
parent d466fb09d4
commit 1eb155ed93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 400 additions and 11 deletions

View File

@ -38,6 +38,8 @@ int json_candidate_transcripts = 3;
int stream_size = 0; int stream_size = 0;
char* hot_words = NULL;
void PrintHelp(const char* bin) void PrintHelp(const char* bin)
{ {
std::cout << std::cout <<
@ -56,6 +58,7 @@ void PrintHelp(const char* bin)
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n" "\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n" "\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n" "\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
"\t--help\t\t\t\tShow help\n" "\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n"; "\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version(); char* version = DS_Version();
@ -66,7 +69,7 @@ void PrintHelp(const char* bin)
bool ProcessArgs(int argc, char** argv) bool ProcessArgs(int argc, char** argv)
{ {
const char* const short_opts = "m:l:a:b:c:d:tejs:vh"; const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh";
const option long_opts[] = { const option long_opts[] = {
{"model", required_argument, nullptr, 'm'}, {"model", required_argument, nullptr, 'm'},
{"scorer", required_argument, nullptr, 'l'}, {"scorer", required_argument, nullptr, 'l'},
@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv)
{"json", no_argument, nullptr, 'j'}, {"json", no_argument, nullptr, 'j'},
{"candidate_transcripts", required_argument, nullptr, 150}, {"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'}, {"stream", required_argument, nullptr, 's'},
{"hot_words", required_argument, nullptr, 'w'},
{"version", no_argument, nullptr, 'v'}, {"version", no_argument, nullptr, 'v'},
{"help", no_argument, nullptr, 'h'}, {"help", no_argument, nullptr, 'h'},
{nullptr, no_argument, nullptr, 0} {nullptr, no_argument, nullptr, 0}
@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv)
has_versions = true; has_versions = true;
break; break;
case 'w':
hot_words = optarg;
break;
case 'h': // -h or --help case 'h': // -h or --help
case '?': // Unrecognized option case '?': // Unrecognized option
default: default:

View File

@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
} }
} }
std::vector<std::string>
SplitStringOnDelim(std::string in_string, std::string delim)
{
std::vector<std::string> out_vector;
char * tmp_str = new char[in_string.size() + 1];
std::copy(in_string.begin(), in_string.end(), tmp_str);
tmp_str[in_string.size()] = '\0';
const char* token = strtok(tmp_str, delim.c_str());
while( token != NULL ) {
out_vector.push_back(token);
token = strtok(NULL, delim.c_str());
}
delete[] tmp_str;
return out_vector;
}
int int
main(int argc, char **argv) main(int argc, char **argv)
{ {
@ -432,6 +448,23 @@ main(int argc, char **argv)
} }
// sphinx-doc: c_ref_model_stop // sphinx-doc: c_ref_model_stop
if (hot_words) {
std::vector<std::string> hot_words_ = SplitStringOnDelim(hot_words, ",");
for ( std::string hot_word_ : hot_words_ ) {
std::vector<std::string> pair_ = SplitStringOnDelim(hot_word_, ":");
const char* word = (pair_[0]).c_str();
// the strtof function will return 0 in case of non numeric characters
// so, check the boost string before we turn it into a float
bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos);
float boost = strtof((pair_[1]).c_str(),0);
status = DS_AddHotWord(ctx, word, boost);
if (status != 0 || !boost_is_valid) {
fprintf(stderr, "Could not enable hot-word.\n");
return 1;
}
}
}
#ifndef NO_SOX #ifndef NO_SOX
// Initialise SOX // Initialise SOX
assert(sox_init() == SOX_SUCCESS); assert(sox_init() == SOX_SUCCESS);

View File

@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
scorer=None, scorer=None,
hot_words=dict(),
num_results=1): num_results=1):
"""Wrapper for the CTC Beam Search Decoder. """Wrapper for the CTC Beam Search Decoder.
@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word :param scorer: External scorer for partially decoded sentence, e.g. word
count or language model. count or language model.
:type scorer: Scorer :type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return. :param num_results: Number of beams to return.
:type num_results: int :type num_results: int
:return: List of tuples of confidence and sentence as decoding :return: List of tuples of confidence and sentence as decoding
@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq,
""" """
beam_results = swigwrapper.ctc_beam_search_decoder( beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n, probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer, num_results) scorer, hot_words, num_results)
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
return beam_results return beam_results
@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
cutoff_prob=1.0, cutoff_prob=1.0,
cutoff_top_n=40, cutoff_top_n=40,
scorer=None, scorer=None,
hot_words=dict(),
num_results=1): num_results=1):
"""Wrapper for the batched CTC beam search decoder. """Wrapper for the batched CTC beam search decoder.
@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word :param scorer: External scorer for partially decoded sentence, e.g. word
count or language model. count or language model.
:type scorer: Scorer :type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return. :param num_results: Number of beams to return.
:type num_results: int :type num_results: int
:return: List of tuples of confidence and sentence as decoding :return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence. results, in descending order of the confidence.
:rtype: list :rtype: list
""" """
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results) batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results)
batch_beam_results = [ batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results for beam_results in batch_beam_results

View File

@ -4,7 +4,7 @@
#include <cmath> #include <cmath>
#include <iostream> #include <iostream>
#include <limits> #include <limits>
#include <map> #include <unordered_map>
#include <utility> #include <utility>
#include "decoder_utils.h" #include "decoder_utils.h"
@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet,
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer) std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words)
{ {
// assign special ids // assign special ids
abs_time_step_ = 0; abs_time_step_ = 0;
@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet,
cutoff_prob_ = cutoff_prob; cutoff_prob_ = cutoff_prob;
cutoff_top_n_ = cutoff_top_n; cutoff_top_n_ = cutoff_top_n;
ext_scorer_ = ext_scorer; ext_scorer_ = ext_scorer;
hot_words_ = hot_words;
start_expanding_ = false; start_expanding_ = false;
// init prefixes' root // init prefixes' root
@ -160,8 +162,23 @@ DecoderState::next(const double *probs,
float score = 0.0; float score = 0.0;
std::vector<std::string> ngram; std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix_to_score); ngram = ext_scorer_->make_ngram(prefix_to_score);
float hot_boost = 0.0;
if (!hot_words_.empty()) {
std::unordered_map<std::string, float>::iterator iter;
// increase prob of prefix for every word
// that matches a word in the hot-words list
for (std::string word : ngram) {
iter = hot_words_.find(word);
if ( iter != hot_words_.end() ) {
// increase the log_cond_prob(prefix|LM)
hot_boost += iter->second;
}
}
}
bool bos = ngram.size() < ext_scorer_->get_max_order(); bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha; score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha;
log_p += score; log_p += score;
log_p += ext_scorer_->beta; log_p += ext_scorer_->beta;
} }
@ -256,11 +273,12 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer, std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results) size_t num_results)
{ {
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
DecoderState state; DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words);
state.next(probs, time_dim, class_dim); state.next(probs, time_dim, class_dim);
return state.decode(num_results); return state.decode(num_results);
} }
@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer, std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results) size_t num_results)
{ {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch(
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
ext_scorer, ext_scorer,
hot_words,
num_results)); num_results));
} }

View File

@ -22,6 +22,7 @@ class DecoderState {
std::vector<PathTrie*> prefixes_; std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_; std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0}; TimestepTreeNode timestep_tree_root_{nullptr, 0};
std::unordered_map<std::string, float> hot_words_;
public: public:
DecoderState() = default; DecoderState() = default;
@ -48,7 +49,8 @@ public:
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer); std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words);
/* Send data to the decoder /* Send data to the decoder
* *
@ -88,6 +90,8 @@ public:
* ext_scorer: External scorer to evaluate a prefix, which consists of * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return. * num_results: Number of beams to return.
* Return: * Return:
* A vector where each element is a pair of score and decoding result, * A vector where each element is a pair of score and decoding result,
@ -103,6 +107,7 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer, std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1); size_t num_results=1);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
@ -117,6 +122,8 @@ std::vector<Output> ctc_beam_search_decoder(
* ext_scorer: External scorer to evaluate a prefix, which consists of * ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term. * n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer. * Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return. * num_results: Number of beams to return.
* Return: * Return:
* A 2-D vector where each element is a vector of beam search decoding * A 2-D vector where each element is a vector of beam search decoding
@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer, std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1); size_t num_results=1);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

View File

@ -11,6 +11,7 @@
%include <std_string.i> %include <std_string.i>
%include <std_vector.i> %include <std_vector.i>
%include <std_shared_ptr.i> %include <std_shared_ptr.i>
%include <std_unordered_map.i>
%include "numpy.i" %include "numpy.i"
%init %{ %init %{
@ -22,6 +23,7 @@ namespace std {
%template(UnsignedIntVector) vector<unsigned int>; %template(UnsignedIntVector) vector<unsigned int>;
%template(OutputVector) vector<Output>; %template(OutputVector) vector<Output>;
%template(OutputVectorVector) vector<vector<Output>>; %template(OutputVectorVector) vector<vector<Output>>;
%template(Map) unordered_map<string, float>;
} }
%shared_ptr(Scorer); %shared_ptr(Scorer);

View File

@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx,
return DS_ERR_OK; return DS_ERR_OK;
} }
int
DS_AddHotWord(ModelState* aCtx,
const char* word,
float boost)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
aCtx->hot_words_.insert( std::pair<std::string,float> (word, boost) );
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_INSERT_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int
DS_EraseHotWord(ModelState* aCtx,
const char* word)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
int err = aCtx->hot_words_.erase(word);
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_ERASE_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int
DS_ClearHotWords(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->hot_words_.clear();
const int size_after = aCtx->hot_words_.size();
if (size_after != 0) {
return DS_ERR_FAIL_CLEAR_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int int
DS_DisableExternalScorer(ModelState* aCtx) DS_DisableExternalScorer(ModelState* aCtx)
{ {
@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx,
aCtx->beam_width_, aCtx->beam_width_,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
aCtx->scorer_); aCtx->scorer_,
aCtx->hot_words_);
*retval = ctx.release(); *retval = ctx.release();
return DS_ERR_OK; return DS_ERR_OK;

View File

@ -81,7 +81,10 @@ typedef struct Metadata {
APPLY(DS_ERR_FAIL_CREATE_STREAM, 0x3004, "Error creating the stream.") \ APPLY(DS_ERR_FAIL_CREATE_STREAM, 0x3004, "Error creating the stream.") \
APPLY(DS_ERR_FAIL_READ_PROTOBUF, 0x3005, "Error reading the proto buffer model file.") \ APPLY(DS_ERR_FAIL_READ_PROTOBUF, 0x3005, "Error reading the proto buffer model file.") \
APPLY(DS_ERR_FAIL_CREATE_SESS, 0x3006, "Failed to create session.") \ APPLY(DS_ERR_FAIL_CREATE_SESS, 0x3006, "Failed to create session.") \
APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.") APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.") \
APPLY(DS_ERR_FAIL_INSERT_HOTWORD, 0x3008, "Could not insert hot-word.") \
APPLY(DS_ERR_FAIL_CLEAR_HOTWORD, 0x3009, "Could not clear hot-words.") \
APPLY(DS_ERR_FAIL_ERASE_HOTWORD, 0x3010, "Could not erase hot-word.")
// sphinx-doc: error_code_listing_end // sphinx-doc: error_code_listing_end
@ -157,6 +160,42 @@ DEEPSPEECH_EXPORT
int DS_EnableExternalScorer(ModelState* aCtx, int DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath); const char* aScorerPath);
/**
* @brief Add a hot-word and its boost.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param word The hot-word.
* @param boost The boost.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
DEEPSPEECH_EXPORT
int DS_AddHotWord(ModelState* aCtx,
const char* word,
float boost);
/**
* @brief Remove entry for a hot-word from the hot-words map.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param word The hot-word.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
DEEPSPEECH_EXPORT
int DS_EraseHotWord(ModelState* aCtx,
const char* word);
/**
* @brief Removes all elements from the hot-words map.
*
* @param aCtx The ModelState pointer for the model being changed.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
DEEPSPEECH_EXPORT
int DS_ClearHotWords(ModelState* aCtx);
/** /**
* @brief Disable decoding using an external scorer. * @brief Disable decoding using an external scorer.
* *

View File

@ -74,6 +74,39 @@ namespace DeepSpeechClient
EvaluateResultCode(resultCode); EvaluateResultCode(resultCode);
} }
/// <summary>
/// Add a hot-word.
/// </summary>
/// <param name="aWord">Some word</param>
/// <param name="aBoost">Some boost</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
public unsafe void AddHotWord(string aWord, float aBoost)
{
var resultCode = NativeImp.DS_AddHotWord(_modelStatePP, aWord, aBoost);
EvaluateResultCode(resultCode);
}
/// <summary>
/// Erase entry for a hot-word.
/// </summary>
/// <param name="aWord">Some word</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
public unsafe void EraseHotWord(string aWord)
{
var resultCode = NativeImp.DS_EraseHotWord(_modelStatePP, aWord);
EvaluateResultCode(resultCode);
}
/// <summary>
/// Clear all hot-words.
/// </summary>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
public unsafe void ClearHotWords()
{
var resultCode = NativeImp.DS_ClearHotWords(_modelStatePP);
EvaluateResultCode(resultCode);
}
/// <summary> /// <summary>
/// Return the sample rate expected by the model. /// Return the sample rate expected by the model.
/// </summary> /// </summary>

View File

@ -26,5 +26,8 @@
DS_ERR_FAIL_CREATE_STREAM = 0x3004, DS_ERR_FAIL_CREATE_STREAM = 0x3004,
DS_ERR_FAIL_READ_PROTOBUF = 0x3005, DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
DS_ERR_FAIL_CREATE_SESS = 0x3006, DS_ERR_FAIL_CREATE_SESS = 0x3006,
DS_ERR_FAIL_INSERT_HOTWORD = 0x3008,
DS_ERR_FAIL_CLEAR_HOTWORD = 0x3009,
DS_ERR_FAIL_ERASE_HOTWORD = 0x3010
} }
} }

View File

@ -44,6 +44,27 @@ namespace DeepSpeechClient.Interfaces
/// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception> /// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
unsafe void EnableExternalScorer(string aScorerPath); unsafe void EnableExternalScorer(string aScorerPath);
/// <summary>
/// Add a hot-word.
/// </summary>
/// <param name="aWord">Some word</param>
/// <param name="aBoost">Some boost</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
unsafe void AddHotWord(string aWord, float aBoost);
/// <summary>
/// Erase entry for a hot-word.
/// </summary>
/// <param name="aWord">Some word</param>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
unsafe void EraseHotWord(string aWord);
/// <summary>
/// Clear all hot-words.
/// </summary>
/// <exception cref="ArgumentException">Thrown on failure.</exception>
unsafe void ClearHotWords();
/// <summary> /// <summary>
/// Disable decoding using an external scorer. /// Disable decoding using an external scorer.
/// </summary> /// </summary>

View File

@ -41,6 +41,18 @@ namespace DeepSpeechClient
internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx, internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx,
string aScorerPath); string aScorerPath);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_AddHotWord(IntPtr** aCtx,
string aWord,
float aBoost);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_EraseHotWord(IntPtr** aCtx,
string aWord);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_ClearHotWords(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx); internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx);

View File

@ -215,4 +215,36 @@ public class DeepSpeechModel {
public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) { public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) {
return impl.FinishStreamWithMetadata(ctx.get(), num_results); return impl.FinishStreamWithMetadata(ctx.get(), num_results);
} }
/**
* @brief Add a hot-word
*
* @param word
* @param boost
*
* @throws RuntimeException on failure.
*
*/
public void addHotWord(String word, float boost) {
evaluateErrorCode(impl.AddHotWord(this._msp, word, boost));
}
/**
* @brief Erase a hot-word
*
* @param word
*
* @throws RuntimeException on failure.
*
*/
public void eraseHotWord(String word) {
evaluateErrorCode(impl.EraseHotWord(this._msp, word));
}
/**
* @brief Clear all hot-words.
*
* @throws RuntimeException on failure.
*
*/
public void clearHotWords() {
evaluateErrorCode(impl.ClearHotWords(this._msp));
}
} }

View File

@ -28,7 +28,10 @@ public enum DeepSpeech_Error_Codes {
ERR_FAIL_CREATE_STREAM(0x3004), ERR_FAIL_CREATE_STREAM(0x3004),
ERR_FAIL_READ_PROTOBUF(0x3005), ERR_FAIL_READ_PROTOBUF(0x3005),
ERR_FAIL_CREATE_SESS(0x3006), ERR_FAIL_CREATE_SESS(0x3006),
ERR_FAIL_CREATE_MODEL(0x3007); ERR_FAIL_CREATE_MODEL(0x3007),
ERR_FAIL_INSERT_HOTWORD(0x3008),
ERR_FAIL_CLEAR_HOTWORD(0x3009),
ERR_FAIL_ERASE_HOTWORD(0x3010);
public final int swigValue() { public final int swigValue() {
return swigValue; return swigValue;

View File

@ -182,6 +182,47 @@ export class Model {
} }
} }
/**
* Add a hot-word and its boost
*
* @param aWord word
* @param aBoost boost
*
* @throws on error
*/
addHotWord(aWord: string, aBoost: number): void {
const status = binding.addHotWord(this._impl, aWord, aBoost);
if (status !== 0) {
throw `addHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
}
}
/**
* Erase entry for hot-word
*
* @param aWord word
*
* @throws on error
*/
eraseHotWord(aWord: string): void {
const status = binding.eraseHotWord(this._impl, aWord);
if (status !== 0) {
throw `eraseHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
}
}
/**
* Clear all hot-word entries
*
* @throws on error
*/
clearHotWords(): void {
const status = binding.clearHotWords(this._impl);
if (status !== 0) {
throw `clearHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`;
}
}
/** /**
* Return the sample rate expected by the model. * Return the sample rate expected by the model.
* *

View File

@ -17,6 +17,7 @@ struct ModelState {
Alphabet alphabet_; Alphabet alphabet_;
std::shared_ptr<Scorer> scorer_; std::shared_ptr<Scorer> scorer_;
std::unordered_map<std::string, float> hot_words_;
unsigned int beam_width_; unsigned int beam_width_;
unsigned int n_steps_; unsigned int n_steps_;
unsigned int n_context_; unsigned int n_context_;

View File

@ -95,6 +95,45 @@ class Model(object):
""" """
return deepspeech.impl.DisableExternalScorer(self._impl) return deepspeech.impl.DisableExternalScorer(self._impl)
def addHotWord(self, word, boost):
"""
Add a word and its boost for decoding.
:param word: the hot-word
:type word: str
:param word: the boost
:type word: float
:throws: RuntimeError on error
"""
status = deepspeech.impl.AddHotWord(self._impl, word, boost)
if status != 0:
raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def eraseHotWord(self, word):
"""
Remove entry for word from hot-words dict.
:param word: the hot-word
:type word: str
:throws: RuntimeError on error
"""
status = deepspeech.impl.EraseHotWord(self._impl, word)
if status != 0:
raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def clearHotWords(self):
"""
Remove all entries from hot-words dict.
:throws: RuntimeError on error
"""
status = deepspeech.impl.ClearHotWords(self._impl)
if status != 0:
raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status))
def setScorerAlphaBeta(self, alpha, beta): def setScorerAlphaBeta(self, alpha, beta):
""" """
Set hyperparameters alpha and beta of the external scorer. Set hyperparameters alpha and beta of the external scorer.

View File

@ -109,6 +109,8 @@ def main():
help='Output json from metadata with timestamp of each word') help='Output json from metadata with timestamp of each word')
parser.add_argument('--candidate_transcripts', type=int, default=3, parser.add_argument('--candidate_transcripts', type=int, default=3,
help='Number of candidate transcripts to include in JSON output') help='Number of candidate transcripts to include in JSON output')
parser.add_argument('--hot_words', type=str,
help='Hot-words and their boosts.')
args = parser.parse_args() args = parser.parse_args()
print('Loading model from file {}'.format(args.model), file=sys.stderr) print('Loading model from file {}'.format(args.model), file=sys.stderr)
@ -134,6 +136,12 @@ def main():
if args.lm_alpha and args.lm_beta: if args.lm_alpha and args.lm_beta:
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta) ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
if args.hot_words:
print('Adding hot-words', file=sys.stderr)
for word_boost in args.hot_words.split(','):
word,boost = word_boost.split(':')
ds.addHotWord(word,float(boost))
fin = wave.open(args.audio, 'rb') fin = wave.open(args.audio, 'rb')
fs_orig = fin.getframerate() fs_orig = fin.getframerate()
if fs_orig != desired_sample_rate: if fs_orig != desired_sample_rate:

View File

@ -30,4 +30,6 @@ android_setup_ndk_data
run_tflite_basic_inference_tests run_tflite_basic_inference_tests
run_android_hotword_tests
android_stop_emulator android_stop_emulator

View File

@ -206,6 +206,10 @@ android_setup_ndk_data()
${TASKCLUSTER_TMP_DIR}/${model_name} \ ${TASKCLUSTER_TMP_DIR}/${model_name} \
${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \ ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \
${ANDROID_TMP_DIR}/ds/ ${ANDROID_TMP_DIR}/ds/
if [ -f "${TASKCLUSTER_TMP_DIR}/kenlm.scorer" ]; then
adb push ${TASKCLUSTER_TMP_DIR}/kenlm.scorer ${ANDROID_TMP_DIR}/ds/
fi
} }
android_setup_apk_data() android_setup_apk_data()

View File

@ -524,6 +524,24 @@ run_multi_inference_tests()
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status" assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
} }
run_hotword_tests()
{
set +e
hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$?
set -e
assert_correct_ldc93s1_lm "${hotwords_decode}" "$status"
}
run_android_hotword_tests()
{
set +e
hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${DATA_TMP_DIR}/${model_name} --scorer ${DATA_TMP_DIR}/kenlm.scorer --audio ${DATA_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$?
set -e
assert_correct_ldc93s1_lm "${hotwords_decode}" "$status"
}
run_cpp_only_inference_tests() run_cpp_only_inference_tests()
{ {
set +e set +e

View File

@ -18,3 +18,5 @@ run_all_inference_tests
run_multi_inference_tests run_multi_inference_tests
run_cpp_only_inference_tests run_cpp_only_inference_tests
run_hotword_tests

View File

@ -23,3 +23,5 @@ run_all_inference_tests
run_multi_inference_tests run_multi_inference_tests
run_cpp_only_inference_tests run_cpp_only_inference_tests
run_hotword_tests

View File

@ -28,4 +28,6 @@ ensure_cuda_usage "$3"
run_all_inference_tests run_all_inference_tests
run_hotword_tests
virtualenv_deactivate "${pyalias}" "deepspeech" virtualenv_deactivate "${pyalias}" "deepspeech"

View File

@ -33,4 +33,6 @@ deepspeech --version
run_all_inference_tests run_all_inference_tests
run_hotword_tests
virtualenv_deactivate "${pyalias}" "deepspeech" virtualenv_deactivate "${pyalias}" "deepspeech"