From 1eb155ed93d04eb7d38feb01b68b9e0a86b8eba5 Mon Sep 17 00:00:00 2001 From: Josh Meyer Date: Thu, 24 Sep 2020 14:58:41 -0400 Subject: [PATCH] enable hot-word boosting (#3297) * enable hot-word boosting * more consistent ordering of CLI arguments * progress on review * use map instead of set for hot-words, move string logic to client.cc * typo bug * pointer things? * use map for hotwords, better string splitting * add the boost, not multiply * cleaning up * cleaning whitespace * remove inclusion * change typo set-->map * rename boost_coefficient to boost X-DeepSpeech: NOBUILD * add hot_words to python bindings * missing hot_words * include map in swigwrapper.i * add Map template to swigwrapper.i * emacs intermediate file * map things * map-->unordered_map * typu * typu * use dict() not None * error out if hot_words without scorer * two new functions: remove hot-word and clear all hot-words * starting to work on better error messages X-DeepSpeech: NOBUILD * better error handling + .Net ERR codes * allow for negative boosts:) * adding TC test for hot-words * add hot-words to python client, make TC test hot-words everywhere * only run TC tests for C++ and Python * fully expose API in python bindings * expose API in Java (thanks spectie!) * expose API in dotnet (thanks spectie!) * expose API in javascript (thanks spectie!) * java lol * typo in javascript * commenting * java error codes from swig * java docs from SWIG * java and dotnet issues * add hotword test to android tests * dotnet fixes from carlos * add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command * make sure lm is on android for hotword test * path to android model + nit * path * path --- native_client/args.h | 10 +++- native_client/client.cc | 33 ++++++++++++ native_client/ctcdecode/__init__.py | 10 +++- .../ctcdecode/ctc_beam_search_decoder.cpp | 28 +++++++++-- .../ctcdecode/ctc_beam_search_decoder.h | 10 +++- native_client/ctcdecode/swigwrapper.i | 2 + native_client/deepspeech.cc | 50 ++++++++++++++++++- native_client/deepspeech.h | 41 ++++++++++++++- .../dotnet/DeepSpeechClient/DeepSpeech.cs | 33 ++++++++++++ .../DeepSpeechClient/Enums/ErrorCodes.cs | 3 ++ .../Interfaces/IDeepSpeech.cs | 21 ++++++++ .../dotnet/DeepSpeechClient/NativeImp.cs | 12 +++++ .../libdeepspeech/DeepSpeechModel.java | 32 ++++++++++++ .../DeepSpeech_Error_Codes.java | 5 +- native_client/javascript/index.ts | 41 +++++++++++++++ native_client/modelstate.h | 1 + native_client/python/__init__.py | 39 +++++++++++++++ native_client/python/client.py | 8 +++ taskcluster/tc-android-ds-tests.sh | 2 + taskcluster/tc-android-utils.sh | 4 ++ taskcluster/tc-asserts.sh | 18 +++++++ taskcluster/tc-cpp-ds-tests.sh | 2 + taskcluster/tc-cpp_tflite-ds-tests.sh | 2 + taskcluster/tc-python-tests.sh | 2 + taskcluster/tc-python_tflite-tests.sh | 2 + 25 files changed, 400 insertions(+), 11 deletions(-) diff --git a/native_client/args.h b/native_client/args.h index baa9b7ff..856988dd 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -38,6 +38,8 @@ int json_candidate_transcripts = 3; int stream_size = 0; +char* hot_words = NULL; + void PrintHelp(const char* bin) { std::cout << @@ -56,6 +58,7 @@ void PrintHelp(const char* bin) "\t--json\t\t\t\tExtended output, shows word timings as JSON\n" "\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n" "\t--stream size\t\t\tRun in stream mode, output intermediate results\n" + "\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n" "\t--help\t\t\t\tShow help\n" "\t--version\t\t\tPrint version and exits\n"; char* version = DS_Version(); @@ -66,7 +69,7 @@ void PrintHelp(const char* bin) bool ProcessArgs(int argc, char** argv) { - const char* const short_opts = "m:l:a:b:c:d:tejs:vh"; + const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh"; const option long_opts[] = { {"model", required_argument, nullptr, 'm'}, {"scorer", required_argument, nullptr, 'l'}, @@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv) {"json", no_argument, nullptr, 'j'}, {"candidate_transcripts", required_argument, nullptr, 150}, {"stream", required_argument, nullptr, 's'}, + {"hot_words", required_argument, nullptr, 'w'}, {"version", no_argument, nullptr, 'v'}, {"help", no_argument, nullptr, 'h'}, {nullptr, no_argument, nullptr, 0} @@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv) has_versions = true; break; + case 'w': + hot_words = optarg; + break; + case 'h': // -h or --help case '?': // Unrecognized option default: diff --git a/native_client/client.cc b/native_client/client.cc index 46a16115..96e1ff39 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times) } } +std::vector +SplitStringOnDelim(std::string in_string, std::string delim) +{ + std::vector out_vector; + char * tmp_str = new char[in_string.size() + 1]; + std::copy(in_string.begin(), in_string.end(), tmp_str); + tmp_str[in_string.size()] = '\0'; + const char* token = strtok(tmp_str, delim.c_str()); + while( token != NULL ) { + out_vector.push_back(token); + token = strtok(NULL, delim.c_str()); + } + delete[] tmp_str; + return out_vector; +} + int main(int argc, char **argv) { @@ -432,6 +448,23 @@ main(int argc, char **argv) } // sphinx-doc: c_ref_model_stop + if (hot_words) { + std::vector hot_words_ = SplitStringOnDelim(hot_words, ","); + for ( std::string hot_word_ : hot_words_ ) { + std::vector pair_ = SplitStringOnDelim(hot_word_, ":"); + const char* word = (pair_[0]).c_str(); + // the strtof function will return 0 in case of non numeric characters + // so, check the boost string before we turn it into a float + bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos); + float boost = strtof((pair_[1]).c_str(),0); + status = DS_AddHotWord(ctx, word, boost); + if (status != 0 || !boost_is_valid) { + fprintf(stderr, "Could not enable hot-word.\n"); + return 1; + } + } + } + #ifndef NO_SOX // Initialise SOX assert(sox_init() == SOX_SUCCESS); diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index fd897b3b..94e03b15 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq, cutoff_prob=1.0, cutoff_top_n=40, scorer=None, + hot_words=dict(), num_results=1): """Wrapper for the CTC Beam Search Decoder. @@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq, :param scorer: External scorer for partially decoded sentence, e.g. word count or language model. :type scorer: Scorer + :param hot_words: Map of words (keys) to their assigned boosts (values) + :type hot_words: map{string:float} :param num_results: Number of beams to return. :type num_results: int :return: List of tuples of confidence and sentence as decoding @@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq, """ beam_results = swigwrapper.ctc_beam_search_decoder( probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n, - scorer, num_results) + scorer, hot_words, num_results) beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] return beam_results @@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq, cutoff_prob=1.0, cutoff_top_n=40, scorer=None, + hot_words=dict(), num_results=1): """Wrapper for the batched CTC beam search decoder. @@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq, :param scorer: External scorer for partially decoded sentence, e.g. word count or language model. :type scorer: Scorer + :param hot_words: Map of words (keys) to their assigned boosts (values) + :type hot_words: map{string:float} :param num_results: Number of beams to return. :type num_results: int :return: List of tuples of confidence and sentence as decoding results, in descending order of the confidence. :rtype: list """ - batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results) + batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results) batch_beam_results = [ [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results] for beam_results in batch_beam_results diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 580ee51c..2f6dd17a 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include "decoder_utils.h" @@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet, size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - std::shared_ptr ext_scorer) + std::shared_ptr ext_scorer, + std::unordered_map hot_words) { // assign special ids abs_time_step_ = 0; @@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet, cutoff_prob_ = cutoff_prob; cutoff_top_n_ = cutoff_top_n; ext_scorer_ = ext_scorer; + hot_words_ = hot_words; start_expanding_ = false; // init prefixes' root @@ -160,8 +162,23 @@ DecoderState::next(const double *probs, float score = 0.0; std::vector ngram; ngram = ext_scorer_->make_ngram(prefix_to_score); + + float hot_boost = 0.0; + if (!hot_words_.empty()) { + std::unordered_map::iterator iter; + // increase prob of prefix for every word + // that matches a word in the hot-words list + for (std::string word : ngram) { + iter = hot_words_.find(word); + if ( iter != hot_words_.end() ) { + // increase the log_cond_prob(prefix|LM) + hot_boost += iter->second; + } + } + } + bool bos = ngram.size() < ext_scorer_->get_max_order(); - score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha; + score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha; log_p += score; log_p += ext_scorer_->beta; } @@ -256,11 +273,12 @@ std::vector ctc_beam_search_decoder( double cutoff_prob, size_t cutoff_top_n, std::shared_ptr ext_scorer, + std::unordered_map hot_words, size_t num_results) { VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); DecoderState state; - state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); + state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words); state.next(probs, time_dim, class_dim); return state.decode(num_results); } @@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch( double cutoff_prob, size_t cutoff_top_n, std::shared_ptr ext_scorer, + std::unordered_map hot_words, size_t num_results) { VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); @@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch( cutoff_prob, cutoff_top_n, ext_scorer, + hot_words, num_results)); } diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index 65e7497d..dc19555c 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -22,6 +22,7 @@ class DecoderState { std::vector prefixes_; std::unique_ptr prefix_root_; TimestepTreeNode timestep_tree_root_{nullptr, 0}; + std::unordered_map hot_words_; public: DecoderState() = default; @@ -48,7 +49,8 @@ public: size_t beam_size, double cutoff_prob, size_t cutoff_top_n, - std::shared_ptr ext_scorer); + std::shared_ptr ext_scorer, + std::unordered_map hot_words); /* Send data to the decoder * @@ -88,6 +90,8 @@ public: * ext_scorer: External scorer to evaluate a prefix, which consists of * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. + * hot_words: A map of hot-words and their corresponding boosts + * The hot-word is a string and the boost is a float. * num_results: Number of beams to return. * Return: * A vector where each element is a pair of score and decoding result, @@ -103,6 +107,7 @@ std::vector ctc_beam_search_decoder( double cutoff_prob, size_t cutoff_top_n, std::shared_ptr ext_scorer, + std::unordered_map hot_words, size_t num_results=1); /* CTC Beam Search Decoder for batch data @@ -117,6 +122,8 @@ std::vector ctc_beam_search_decoder( * ext_scorer: External scorer to evaluate a prefix, which consists of * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. + * hot_words: A map of hot-words and their corresponding boosts + * The hot-word is a string and the boost is a float. * num_results: Number of beams to return. * Return: * A 2-D vector where each element is a vector of beam search decoding @@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch( double cutoff_prob, size_t cutoff_top_n, std::shared_ptr ext_scorer, + std::unordered_map hot_words, size_t num_results=1); #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/native_client/ctcdecode/swigwrapper.i b/native_client/ctcdecode/swigwrapper.i index dbe67c68..683a3426 100644 --- a/native_client/ctcdecode/swigwrapper.i +++ b/native_client/ctcdecode/swigwrapper.i @@ -11,6 +11,7 @@ %include %include %include +%include %include "numpy.i" %init %{ @@ -22,6 +23,7 @@ namespace std { %template(UnsignedIntVector) vector; %template(OutputVector) vector; %template(OutputVectorVector) vector>; + %template(Map) unordered_map; } %shared_ptr(Scorer); diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 38868d4b..57f77ba1 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx, return DS_ERR_OK; } +int +DS_AddHotWord(ModelState* aCtx, + const char* word, + float boost) +{ + if (aCtx->scorer_) { + const int size_before = aCtx->hot_words_.size(); + aCtx->hot_words_.insert( std::pair (word, boost) ); + const int size_after = aCtx->hot_words_.size(); + if (size_before == size_after) { + return DS_ERR_FAIL_INSERT_HOTWORD; + } + return DS_ERR_OK; + } + return DS_ERR_SCORER_NOT_ENABLED; +} + +int +DS_EraseHotWord(ModelState* aCtx, + const char* word) +{ + if (aCtx->scorer_) { + const int size_before = aCtx->hot_words_.size(); + int err = aCtx->hot_words_.erase(word); + const int size_after = aCtx->hot_words_.size(); + if (size_before == size_after) { + return DS_ERR_FAIL_ERASE_HOTWORD; + } + return DS_ERR_OK; + } + return DS_ERR_SCORER_NOT_ENABLED; +} + +int +DS_ClearHotWords(ModelState* aCtx) +{ + if (aCtx->scorer_) { + aCtx->hot_words_.clear(); + const int size_after = aCtx->hot_words_.size(); + if (size_after != 0) { + return DS_ERR_FAIL_CLEAR_HOTWORD; + } + return DS_ERR_OK; + } + return DS_ERR_SCORER_NOT_ENABLED; +} + int DS_DisableExternalScorer(ModelState* aCtx) { @@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx, aCtx->beam_width_, cutoff_prob, cutoff_top_n, - aCtx->scorer_); + aCtx->scorer_, + aCtx->hot_words_); *retval = ctx.release(); return DS_ERR_OK; diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 1df3cf2e..35e9289a 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -81,7 +81,10 @@ typedef struct Metadata { APPLY(DS_ERR_FAIL_CREATE_STREAM, 0x3004, "Error creating the stream.") \ APPLY(DS_ERR_FAIL_READ_PROTOBUF, 0x3005, "Error reading the proto buffer model file.") \ APPLY(DS_ERR_FAIL_CREATE_SESS, 0x3006, "Failed to create session.") \ - APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.") + APPLY(DS_ERR_FAIL_CREATE_MODEL, 0x3007, "Could not allocate model state.") \ + APPLY(DS_ERR_FAIL_INSERT_HOTWORD, 0x3008, "Could not insert hot-word.") \ + APPLY(DS_ERR_FAIL_CLEAR_HOTWORD, 0x3009, "Could not clear hot-words.") \ + APPLY(DS_ERR_FAIL_ERASE_HOTWORD, 0x3010, "Could not erase hot-word.") // sphinx-doc: error_code_listing_end @@ -157,6 +160,42 @@ DEEPSPEECH_EXPORT int DS_EnableExternalScorer(ModelState* aCtx, const char* aScorerPath); +/** + * @brief Add a hot-word and its boost. + * + * @param aCtx The ModelState pointer for the model being changed. + * @param word The hot-word. + * @param boost The boost. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_AddHotWord(ModelState* aCtx, + const char* word, + float boost); + +/** + * @brief Remove entry for a hot-word from the hot-words map. + * + * @param aCtx The ModelState pointer for the model being changed. + * @param word The hot-word. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_EraseHotWord(ModelState* aCtx, + const char* word); + +/** + * @brief Removes all elements from the hot-words map. + * + * @param aCtx The ModelState pointer for the model being changed. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_ClearHotWords(ModelState* aCtx); + /** * @brief Disable decoding using an external scorer. * diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs index 08a3808b..b9b8e237 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs @@ -74,6 +74,39 @@ namespace DeepSpeechClient EvaluateResultCode(resultCode); } + /// + /// Add a hot-word. + /// + /// Some word + /// Some boost + /// Thrown on failure. + public unsafe void AddHotWord(string aWord, float aBoost) + { + var resultCode = NativeImp.DS_AddHotWord(_modelStatePP, aWord, aBoost); + EvaluateResultCode(resultCode); + } + + /// + /// Erase entry for a hot-word. + /// + /// Some word + /// Thrown on failure. + public unsafe void EraseHotWord(string aWord) + { + var resultCode = NativeImp.DS_EraseHotWord(_modelStatePP, aWord); + EvaluateResultCode(resultCode); + } + + /// + /// Clear all hot-words. + /// + /// Thrown on failure. + public unsafe void ClearHotWords() + { + var resultCode = NativeImp.DS_ClearHotWords(_modelStatePP); + EvaluateResultCode(resultCode); + } + /// /// Return the sample rate expected by the model. /// diff --git a/native_client/dotnet/DeepSpeechClient/Enums/ErrorCodes.cs b/native_client/dotnet/DeepSpeechClient/Enums/ErrorCodes.cs index 30660add..cbcb8f43 100644 --- a/native_client/dotnet/DeepSpeechClient/Enums/ErrorCodes.cs +++ b/native_client/dotnet/DeepSpeechClient/Enums/ErrorCodes.cs @@ -26,5 +26,8 @@ DS_ERR_FAIL_CREATE_STREAM = 0x3004, DS_ERR_FAIL_READ_PROTOBUF = 0x3005, DS_ERR_FAIL_CREATE_SESS = 0x3006, + DS_ERR_FAIL_INSERT_HOTWORD = 0x3008, + DS_ERR_FAIL_CLEAR_HOTWORD = 0x3009, + DS_ERR_FAIL_ERASE_HOTWORD = 0x3010 } } diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs index e1ed9cad..344b758e 100644 --- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs @@ -44,6 +44,27 @@ namespace DeepSpeechClient.Interfaces /// Thrown when cannot find the scorer file. unsafe void EnableExternalScorer(string aScorerPath); + /// + /// Add a hot-word. + /// + /// Some word + /// Some boost + /// Thrown on failure. + unsafe void AddHotWord(string aWord, float aBoost); + + /// + /// Erase entry for a hot-word. + /// + /// Some word + /// Thrown on failure. + unsafe void EraseHotWord(string aWord); + + /// + /// Clear all hot-words. + /// + /// Thrown on failure. + unsafe void ClearHotWords(); + /// /// Disable decoding using an external scorer. /// diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs index bc77cf1b..1a7dacac 100644 --- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs +++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs @@ -41,6 +41,18 @@ namespace DeepSpeechClient internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx, string aScorerPath); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal static unsafe extern ErrorCodes DS_AddHotWord(IntPtr** aCtx, + string aWord, + float aBoost); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal static unsafe extern ErrorCodes DS_EraseHotWord(IntPtr** aCtx, + string aWord); + + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal static unsafe extern ErrorCodes DS_ClearHotWords(IntPtr** aCtx); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx); diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java index eafa11e2..ce313d20 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java @@ -215,4 +215,36 @@ public class DeepSpeechModel { public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) { return impl.FinishStreamWithMetadata(ctx.get(), num_results); } + /** + * @brief Add a hot-word + * + * @param word + * @param boost + * + * @throws RuntimeException on failure. + * + */ + public void addHotWord(String word, float boost) { + evaluateErrorCode(impl.AddHotWord(this._msp, word, boost)); + } + /** + * @brief Erase a hot-word + * + * @param word + * + * @throws RuntimeException on failure. + * + */ + public void eraseHotWord(String word) { + evaluateErrorCode(impl.EraseHotWord(this._msp, word)); + } + /** + * @brief Clear all hot-words. + * + * @throws RuntimeException on failure. + * + */ + public void clearHotWords() { + evaluateErrorCode(impl.ClearHotWords(this._msp)); + } } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java index f93f3e8c..3fad4553 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java @@ -28,7 +28,10 @@ public enum DeepSpeech_Error_Codes { ERR_FAIL_CREATE_STREAM(0x3004), ERR_FAIL_READ_PROTOBUF(0x3005), ERR_FAIL_CREATE_SESS(0x3006), - ERR_FAIL_CREATE_MODEL(0x3007); + ERR_FAIL_CREATE_MODEL(0x3007), + ERR_FAIL_INSERT_HOTWORD(0x3008), + ERR_FAIL_CLEAR_HOTWORD(0x3009), + ERR_FAIL_ERASE_HOTWORD(0x3010); public final int swigValue() { return swigValue; diff --git a/native_client/javascript/index.ts b/native_client/javascript/index.ts index 988cbfd5..ec7f5686 100644 --- a/native_client/javascript/index.ts +++ b/native_client/javascript/index.ts @@ -182,6 +182,47 @@ export class Model { } } + /** + * Add a hot-word and its boost + * + * @param aWord word + * @param aBoost boost + * + * @throws on error + */ + addHotWord(aWord: string, aBoost: number): void { + const status = binding.addHotWord(this._impl, aWord, aBoost); + if (status !== 0) { + throw `addHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`; + } + } + + /** + * Erase entry for hot-word + * + * @param aWord word + * + * @throws on error + */ + eraseHotWord(aWord: string): void { + const status = binding.eraseHotWord(this._impl, aWord); + if (status !== 0) { + throw `eraseHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`; + } + } + + /** + * Clear all hot-word entries + * + * @throws on error + */ + clearHotWords(): void { + const status = binding.clearHotWords(this._impl); + if (status !== 0) { + throw `clearHotWord failed: ${binding.ErrorCodeToErrorMessage(status)} (0x${status.toString(16)})`; + } + } + /** * Return the sample rate expected by the model. * diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 0dbe108a..4beb78b4 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -17,6 +17,7 @@ struct ModelState { Alphabet alphabet_; std::shared_ptr scorer_; + std::unordered_map hot_words_; unsigned int beam_width_; unsigned int n_steps_; unsigned int n_context_; diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index 8dec3f0c..4b55da5a 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -95,6 +95,45 @@ class Model(object): """ return deepspeech.impl.DisableExternalScorer(self._impl) + def addHotWord(self, word, boost): + """ + Add a word and its boost for decoding. + + :param word: the hot-word + :type word: str + + :param word: the boost + :type word: float + + :throws: RuntimeError on error + """ + status = deepspeech.impl.AddHotWord(self._impl, word, boost) + if status != 0: + raise RuntimeError("AddHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status)) + + def eraseHotWord(self, word): + """ + Remove entry for word from hot-words dict. + + :param word: the hot-word + :type word: str + + :throws: RuntimeError on error + """ + status = deepspeech.impl.EraseHotWord(self._impl, word) + if status != 0: + raise RuntimeError("EraseHotWord failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status)) + + def clearHotWords(self): + """ + Remove all entries from hot-words dict. + + :throws: RuntimeError on error + """ + status = deepspeech.impl.ClearHotWords(self._impl) + if status != 0: + raise RuntimeError("ClearHotWords failed with '{}' (0x{:X})".format(deepspeech.impl.ErrorCodeToErrorMessage(status),status)) + def setScorerAlphaBeta(self, alpha, beta): """ Set hyperparameters alpha and beta of the external scorer. diff --git a/native_client/python/client.py b/native_client/python/client.py index 6ebf7bcd..ca1c8e92 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -109,6 +109,8 @@ def main(): help='Output json from metadata with timestamp of each word') parser.add_argument('--candidate_transcripts', type=int, default=3, help='Number of candidate transcripts to include in JSON output') + parser.add_argument('--hot_words', type=str, + help='Hot-words and their boosts.') args = parser.parse_args() print('Loading model from file {}'.format(args.model), file=sys.stderr) @@ -134,6 +136,12 @@ def main(): if args.lm_alpha and args.lm_beta: ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta) + if args.hot_words: + print('Adding hot-words', file=sys.stderr) + for word_boost in args.hot_words.split(','): + word,boost = word_boost.split(':') + ds.addHotWord(word,float(boost)) + fin = wave.open(args.audio, 'rb') fs_orig = fin.getframerate() if fs_orig != desired_sample_rate: diff --git a/taskcluster/tc-android-ds-tests.sh b/taskcluster/tc-android-ds-tests.sh index 257a9496..c0f2d36e 100755 --- a/taskcluster/tc-android-ds-tests.sh +++ b/taskcluster/tc-android-ds-tests.sh @@ -30,4 +30,6 @@ android_setup_ndk_data run_tflite_basic_inference_tests +run_android_hotword_tests + android_stop_emulator diff --git a/taskcluster/tc-android-utils.sh b/taskcluster/tc-android-utils.sh index 3bf66927..752d5765 100755 --- a/taskcluster/tc-android-utils.sh +++ b/taskcluster/tc-android-utils.sh @@ -206,6 +206,10 @@ android_setup_ndk_data() ${TASKCLUSTER_TMP_DIR}/${model_name} \ ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \ ${ANDROID_TMP_DIR}/ds/ + + if [ -f "${TASKCLUSTER_TMP_DIR}/kenlm.scorer" ]; then + adb push ${TASKCLUSTER_TMP_DIR}/kenlm.scorer ${ANDROID_TMP_DIR}/ds/ + fi } android_setup_apk_data() diff --git a/taskcluster/tc-asserts.sh b/taskcluster/tc-asserts.sh index 7a164b07..d485846e 100755 --- a/taskcluster/tc-asserts.sh +++ b/taskcluster/tc-asserts.sh @@ -524,6 +524,24 @@ run_multi_inference_tests() assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status" } +run_hotword_tests() +{ + set +e + hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr) + status=$? + set -e + assert_correct_ldc93s1_lm "${hotwords_decode}" "$status" +} + +run_android_hotword_tests() +{ + set +e + hotwords_decode=$(${DS_BINARY_PREFIX}deepspeech --model ${DATA_TMP_DIR}/${model_name} --scorer ${DATA_TMP_DIR}/kenlm.scorer --audio ${DATA_TMP_DIR}/${ldc93s1_sample_filename} --hot_words "foo:0.0,bar:-0.1" 2>${TASKCLUSTER_TMP_DIR}/stderr) + status=$? + set -e + assert_correct_ldc93s1_lm "${hotwords_decode}" "$status" +} + run_cpp_only_inference_tests() { set +e diff --git a/taskcluster/tc-cpp-ds-tests.sh b/taskcluster/tc-cpp-ds-tests.sh index 67d5d92f..eabfcfa8 100644 --- a/taskcluster/tc-cpp-ds-tests.sh +++ b/taskcluster/tc-cpp-ds-tests.sh @@ -18,3 +18,5 @@ run_all_inference_tests run_multi_inference_tests run_cpp_only_inference_tests + +run_hotword_tests diff --git a/taskcluster/tc-cpp_tflite-ds-tests.sh b/taskcluster/tc-cpp_tflite-ds-tests.sh index 313475ef..6e7f9d8c 100644 --- a/taskcluster/tc-cpp_tflite-ds-tests.sh +++ b/taskcluster/tc-cpp_tflite-ds-tests.sh @@ -23,3 +23,5 @@ run_all_inference_tests run_multi_inference_tests run_cpp_only_inference_tests + +run_hotword_tests diff --git a/taskcluster/tc-python-tests.sh b/taskcluster/tc-python-tests.sh index d55a3097..4e6b2840 100644 --- a/taskcluster/tc-python-tests.sh +++ b/taskcluster/tc-python-tests.sh @@ -28,4 +28,6 @@ ensure_cuda_usage "$3" run_all_inference_tests +run_hotword_tests + virtualenv_deactivate "${pyalias}" "deepspeech" diff --git a/taskcluster/tc-python_tflite-tests.sh b/taskcluster/tc-python_tflite-tests.sh index a95adf40..7f00ea12 100644 --- a/taskcluster/tc-python_tflite-tests.sh +++ b/taskcluster/tc-python_tflite-tests.sh @@ -33,4 +33,6 @@ deepspeech --version run_all_inference_tests +run_hotword_tests + virtualenv_deactivate "${pyalias}" "deepspeech"