From 755fb81a627f7cc455f959bc2de48b011016e7ef Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Tue, 7 Sep 2021 18:10:06 +0200 Subject: [PATCH 1/4] Expose Flashlight LexiconDecoder/LexiconFreeDecoder --- native_client/BUILD | 76 +++- native_client/alphabet.cc | 76 ++-- native_client/alphabet.h | 30 +- native_client/coqui-stt.h | 43 ++- native_client/ctcdecode/__init__.py | 76 ++++ native_client/ctcdecode/build_archive.py | 21 +- .../ctcdecode/ctc_beam_search_decoder.cpp | 282 +++++++++++++++ .../ctcdecode/ctc_beam_search_decoder.h | 170 ++++++++- native_client/ctcdecode/output.h | 8 + native_client/ctcdecode/scorer.cpp | 70 +++- native_client/ctcdecode/scorer.h | 33 +- native_client/ctcdecode/setup.py | 2 +- native_client/ctcdecode/swigwrapper.i | 14 +- .../ctcdecode/third_party/flashlight/LICENSE | 21 ++ .../flashlight/lib/common/String.cpp | 115 ++++++ .../flashlight/flashlight/lib/common/String.h | 129 +++++++ .../flashlight/lib/common/System.cpp | 250 +++++++++++++ .../flashlight/flashlight/lib/common/System.h | 96 +++++ .../flashlight/lib/text/decoder/Decoder.h | 77 ++++ .../lib/text/decoder/LexiconDecoder.cpp | 328 ++++++++++++++++++ .../lib/text/decoder/LexiconDecoder.h | 187 ++++++++++ .../lib/text/decoder/LexiconFreeDecoder.cpp | 207 +++++++++++ .../lib/text/decoder/LexiconFreeDecoder.h | 160 +++++++++ .../decoder/LexiconFreeSeq2SeqDecoder.cpp | 179 ++++++++++ .../text/decoder/LexiconFreeSeq2SeqDecoder.h | 141 ++++++++ .../text/decoder/LexiconSeq2SeqDecoder.cpp | 243 +++++++++++++ .../lib/text/decoder/LexiconSeq2SeqDecoder.h | 165 +++++++++ .../flashlight/lib/text/decoder/Trie.cpp | 103 ++++++ .../flashlight/lib/text/decoder/Trie.h | 95 +++++ .../flashlight/lib/text/decoder/Utils.cpp | 15 + .../flashlight/lib/text/decoder/Utils.h | 275 +++++++++++++++ .../flashlight/lib/text/decoder/lm/ConvLM.cpp | 239 +++++++++++++ .../flashlight/lib/text/decoder/lm/ConvLM.h | 73 ++++ .../flashlight/lib/text/decoder/lm/KenLM.cpp | 74 ++++ .../flashlight/lib/text/decoder/lm/KenLM.h | 70 ++++ .../flashlight/lib/text/decoder/lm/LM.h | 90 +++++ .../flashlight/lib/text/decoder/lm/ZeroLM.cpp | 31 ++ .../flashlight/lib/text/decoder/lm/ZeroLM.h | 32 ++ .../flashlight/lib/text/dictionary/Defines.h | 21 ++ .../lib/text/dictionary/Dictionary.cpp | 152 ++++++++ .../lib/text/dictionary/Dictionary.h | 66 ++++ .../flashlight/lib/text/dictionary/Utils.cpp | 147 ++++++++ .../flashlight/lib/text/dictionary/Utils.h | 52 +++ .../coqui_stt_training/deepspeech_model.py | 2 +- .../coqui_stt_training/evaluate_flashlight.py | 201 +++++++++++ .../training_graph_inference_flashlight.py | 96 +++++ training/coqui_stt_training/util/config.py | 7 + 47 files changed, 4925 insertions(+), 115 deletions(-) create mode 100644 native_client/ctcdecode/third_party/flashlight/LICENSE create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp create mode 100644 native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h create mode 100644 training/coqui_stt_training/evaluate_flashlight.py create mode 100644 training/coqui_stt_training/training_graph_inference_flashlight.py diff --git a/native_client/BUILD b/native_client/BUILD index 54a5d993..334934a5 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -86,11 +86,9 @@ cc_binary( "kenlm/*/*test.cc", "kenlm/*/*main.cc", ],), - copts = [ - "-std=c++11" - ] + select({ - "//tensorflow:windows": [], - "//conditions:default": ["-fvisibility=hidden"], + copts = select({ + "//tensorflow:windows": ["/std:c++14"], + "//conditions:default": ["-std=c++14", "-fwrapv", "-fvisibility=hidden"], }), defines = ["KENLM_MAX_ORDER=6"], includes = ["kenlm"], @@ -110,24 +108,62 @@ cc_binary( ) cc_library( - name = "kenlm", + name="kenlm", hdrs = glob([ "kenlm/lm/*.hh", "kenlm/util/*.hh", ]), srcs = [":libkenlm.so"], - copts = ["-std=c++11"], + copts = ["-std=c++14"], defines = ["KENLM_MAX_ORDER=6"], - includes = ["kenlm"], + includes = [".", "kenlm"], +) + +cc_library( + name = "flashlight", + hdrs = [ + "ctcdecode/third_party/flashlight/flashlight/lib/common/String.h", + "ctcdecode/third_party/flashlight/flashlight/lib/common/System.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h", + "ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h", + ], + srcs = [ + "ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp", + "ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp", + ], + includes = ["ctcdecode/third_party/flashlight"], + deps = [":kenlm"], ) cc_library( name = "decoder", srcs = DECODER_SOURCES, includes = DECODER_INCLUDES, - deps = [":kenlm"], + deps = [":kenlm", ":flashlight"], linkopts = DECODER_LINKOPTS, - copts = ["-fexceptions"], + copts = select({ + "//tensorflow:windows": ["/std:c++14"], + "//conditions:default": ["-std=c++14", "-fexceptions", "-fwrapv"], + }), ) cc_library( @@ -195,10 +231,12 @@ cc_library( ] + DECODER_SOURCES, copts = select({ # -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default - "//tensorflow:windows": ["/w"], + "//tensorflow:windows": ["/std:c++14", "/w"], # -Wno-sign-compare to silent a lot of warnings from tensorflow itself, # which makes it harder to see our own warnings "//conditions:default": [ + "-std=c++14", + "-fwrapv", "-Wno-sign-compare", "-fvisibility=hidden", ], @@ -220,7 +258,7 @@ cc_library( "//conditions:default": [], }) + DECODER_LINKOPTS, includes = DECODER_INCLUDES, - deps = [":kenlm", ":tflite", ":tflitedelegates"], + deps = [":kenlm", ":tflite", ":tflitedelegates", ":flashlight"], ) cc_binary( @@ -264,8 +302,8 @@ cc_binary( "stt_errors.cc", ], copts = select({ - "//tensorflow:windows": [], - "//conditions:default": ["-std=c++11"], + "//tensorflow:windows": ["/std:c++14"], + "//conditions:default": ["-std=c++14"], }), deps = [ ":decoder", @@ -297,7 +335,10 @@ cc_binary( "enumerate_kenlm_vocabulary.cpp", ], deps = [":kenlm"], - copts = ["-std=c++11"], + copts = select({ + "//tensorflow:windows": ["/std:c++14"], + "//conditions:default": ["-std=c++14"], + }), ) cc_binary( @@ -305,6 +346,9 @@ cc_binary( srcs = [ "trie_load.cc", ] + DECODER_SOURCES, - copts = ["-std=c++11"], + copts = select({ + "//tensorflow:windows": ["/std:c++14"], + "//conditions:default": ["-std=c++14"], + }), linkopts = DECODER_LINKOPTS, ) diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc index a2a00dc0..4193eef2 100644 --- a/native_client/alphabet.cc +++ b/native_client/alphabet.cc @@ -45,8 +45,8 @@ Alphabet::init(const char *config_file) if (!in) { return 1; } - unsigned int label = 0; - space_label_ = -2; + int index = 0; + space_index_ = -2; for (std::string line; getline_crossplatform(in, line);) { if (line.size() == 2 && line[0] == '\\' && line[1] == '#') { line = '#'; @@ -55,16 +55,14 @@ Alphabet::init(const char *config_file) } //TODO: we should probably do something more i18n-aware here if (line == " ") { - space_label_ = label; + space_index_ = index; } if (line.length() == 0) { continue; } - label_to_str_[label] = line; - str_to_label_[line] = label; - ++label; + addEntry(line, index); + ++index; } - size_ = label; in.close(); return 0; } @@ -72,15 +70,13 @@ Alphabet::init(const char *config_file) void Alphabet::InitFromLabels(const std::vector& labels) { - space_label_ = -2; - size_ = labels.size(); - for (int i = 0; i < size_; ++i) { - const std::string& label = labels[i]; + space_index_ = -2; + for (int idx = 0; idx < labels.size(); ++idx) { + const std::string& label = labels[idx]; if (label == " ") { - space_label_ = i; + space_index_ = idx; } - label_to_str_[i] = label; - str_to_label_[label] = i; + addEntry(label, idx); } } @@ -90,12 +86,12 @@ Alphabet::SerializeText() std::stringstream out; out << "# Each line in this file represents the Unicode codepoint (UTF-8 encoded)\n" - << "# associated with a numeric label.\n" + << "# associated with a numeric index.\n" << "# A line that starts with # is a comment. You can escape it with \\# if you wish\n" - << "# to use '#' as a label.\n"; + << "# to use '#' in the Alphabet.\n"; - for (int label = 0; label < size_; ++label) { - out << label_to_str_[label] << "\n"; + for (int idx = 0; idx < entrySize(); ++idx) { + out << getEntry(idx) << "\n"; } out << "# The last (non-comment) line needs to end with a newline.\n"; @@ -105,18 +101,22 @@ Alphabet::SerializeText() std::string Alphabet::Serialize() { + // Should always be true in our usage, but this method will crash if for some + // mystical reason it doesn't hold, so defensively assert it here. + assert(isContiguous()); + // Serialization format is a sequence of (key, value) pairs, where key is // a uint16_t and value is a uint16_t length followed by `length` UTF-8 // encoded bytes with the label. std::stringstream out; // We start by writing the number of pairs in the buffer as uint16_t. - uint16_t size = size_; + uint16_t size = entrySize(); out.write(reinterpret_cast(&size), sizeof(size)); - for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) { - uint16_t key = it->first; - string str = it->second; + for (int i = 0; i < GetSize(); ++i) { + uint16_t key = i; + string str = DecodeSingle(i); uint16_t len = str.length(); // Then we write the key as uint16_t, followed by the length of the value // as uint16_t, followed by `length` bytes (the value itself). @@ -138,7 +138,6 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size) } uint16_t size = *(uint16_t*)(buffer + offset); offset += sizeof(uint16_t); - size_ = size; for (int i = 0; i < size; ++i) { if (buffer_size - offset < sizeof(uint16_t)) { @@ -159,22 +158,26 @@ Alphabet::Deserialize(const char* buffer, const int buffer_size) std::string val(buffer+offset, val_len); offset += val_len; - label_to_str_[label] = val; - str_to_label_[val] = label; + addEntry(val, label); if (val == " ") { - space_label_ = label; + space_index_ = label; } } return 0; } +size_t +Alphabet::GetSize() const +{ + return entrySize(); +} + bool Alphabet::CanEncodeSingle(const std::string& input) const { - auto it = str_to_label_.find(input); - return it != str_to_label_.end(); + return contains(input); } bool @@ -191,25 +194,14 @@ Alphabet::CanEncode(const std::string& input) const std::string Alphabet::DecodeSingle(unsigned int label) const { - auto it = label_to_str_.find(label); - if (it != label_to_str_.end()) { - return it->second; - } else { - std::cerr << "Invalid label " << label << std::endl; - abort(); - } + assert(label <= INT_MAX); + return getEntry(label); } unsigned int Alphabet::EncodeSingle(const std::string& string) const { - auto it = str_to_label_.find(string); - if (it != str_to_label_.end()) { - return it->second; - } else { - std::cerr << "Invalid string " << string << std::endl; - abort(); - } + return getIndex(string); } std::string diff --git a/native_client/alphabet.h b/native_client/alphabet.h index ad75dfc1..166e1042 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -5,12 +5,15 @@ #include #include +#include "flashlight/lib/text/dictionary/Dictionary.h" + /* * Loads a text file describing a mapping of labels to strings, one string per * line. This is used by the decoder, client and Python scripts to convert the * output of the decoder to a human-readable string and vice-versa. */ -class Alphabet { +class Alphabet : public fl::lib::text::Dictionary +{ public: Alphabet() = default; Alphabet(const Alphabet&) = default; @@ -31,16 +34,14 @@ public: // Deserialize alphabet from a binary buffer. int Deserialize(const char* buffer, const int buffer_size); - size_t GetSize() const { - return size_; - } + size_t GetSize() const; bool IsSpace(unsigned int label) const { - return label == space_label_; + return label == space_index_; } unsigned int GetSpaceLabel() const { - return space_label_; + return space_index_; } // Returns true if the single character/output class has a corresponding label @@ -72,23 +73,20 @@ public: virtual std::vector Encode(const std::string& input) const; protected: - size_t size_; - unsigned int space_label_; - std::unordered_map label_to_str_; - std::unordered_map str_to_label_; + unsigned int space_index_; }; class UTF8Alphabet : public Alphabet { public: UTF8Alphabet() { - size_ = 255; - space_label_ = ' ' - 1; - for (size_t i = 0; i < size_; ++i) { - std::string val(1, i+1); - label_to_str_[i] = val; - str_to_label_[val] = i; + // 255 byte values, index n -> byte value n+1 + // because NUL is never used, we don't use up an index in the maps for it + for (int idx = 0; idx < 255; ++idx) { + std::string val(1, idx+1); + addEntry(val, idx); } + space_index_ = ' ' - 1; } int init(const char*) override { diff --git a/native_client/coqui-stt.h b/native_client/coqui-stt.h index 7794bc79..7e187108 100644 --- a/native_client/coqui-stt.h +++ b/native_client/coqui-stt.h @@ -15,6 +15,10 @@ extern "C" { #define STT_EXPORT #endif +// For the decoder package we include this header but should only expose +// the error info, so guard all the other definitions out. +#ifndef SWIG_ERRORS_ONLY + typedef struct ModelState ModelState; typedef struct StreamingState StreamingState; @@ -59,6 +63,8 @@ typedef struct Metadata { const unsigned int num_transcripts; } Metadata; +#endif /* SWIG_ERRORS_ONLY */ + // sphinx-doc: error_code_listing_start #define STT_FOR_EACH_ERROR(APPLY) \ @@ -95,6 +101,8 @@ STT_FOR_EACH_ERROR(DEFINE) #undef DEFINE }; +#ifndef SWIG_ERRORS_ONLY + /** * @brief An object providing an interface to a trained Coqui STT model. * @@ -105,7 +113,7 @@ STT_FOR_EACH_ERROR(DEFINE) */ STT_EXPORT int STT_CreateModel(const char* aModelPath, - ModelState** retval); + ModelState** retval); /** * @brief Get beam width value used by the model. If {@link STT_SetModelBeamWidth} @@ -130,7 +138,7 @@ unsigned int STT_GetModelBeamWidth(const ModelState* aCtx); */ STT_EXPORT int STT_SetModelBeamWidth(ModelState* aCtx, - unsigned int aBeamWidth); + unsigned int aBeamWidth); /** * @brief Return the sample rate expected by a model. @@ -158,7 +166,7 @@ void STT_FreeModel(ModelState* ctx); */ STT_EXPORT int STT_EnableExternalScorer(ModelState* aCtx, - const char* aScorerPath); + const char* aScorerPath); /** * @brief Add a hot-word and its boost. @@ -173,8 +181,8 @@ int STT_EnableExternalScorer(ModelState* aCtx, */ STT_EXPORT int STT_AddHotWord(ModelState* aCtx, - const char* word, - float boost); + const char* word, + float boost); /** * @brief Remove entry for a hot-word from the hot-words map. @@ -186,7 +194,7 @@ int STT_AddHotWord(ModelState* aCtx, */ STT_EXPORT int STT_EraseHotWord(ModelState* aCtx, - const char* word); + const char* word); /** * @brief Removes all elements from the hot-words map. @@ -219,8 +227,8 @@ int STT_DisableExternalScorer(ModelState* aCtx); */ STT_EXPORT int STT_SetScorerAlphaBeta(ModelState* aCtx, - float aAlpha, - float aBeta); + float aAlpha, + float aBeta); /** * @brief Use the Coqui STT model to convert speech to text. @@ -235,8 +243,8 @@ int STT_SetScorerAlphaBeta(ModelState* aCtx, */ STT_EXPORT char* STT_SpeechToText(ModelState* aCtx, - const short* aBuffer, - unsigned int aBufferSize); + const short* aBuffer, + unsigned int aBufferSize); /** * @brief Use the Coqui STT model to convert speech to text and output results @@ -255,9 +263,9 @@ char* STT_SpeechToText(ModelState* aCtx, */ STT_EXPORT Metadata* STT_SpeechToTextWithMetadata(ModelState* aCtx, - const short* aBuffer, - unsigned int aBufferSize, - unsigned int aNumResults); + const short* aBuffer, + unsigned int aBufferSize, + unsigned int aNumResults); /** * @brief Create a new streaming inference state. The streaming state returned @@ -284,8 +292,8 @@ int STT_CreateStream(ModelState* aCtx, */ STT_EXPORT void STT_FeedAudioContent(StreamingState* aSctx, - const short* aBuffer, - unsigned int aBufferSize); + const short* aBuffer, + unsigned int aBufferSize); /** * @brief Compute the intermediate decoding of an ongoing streaming inference. @@ -312,7 +320,7 @@ char* STT_IntermediateDecode(const StreamingState* aSctx); */ STT_EXPORT Metadata* STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx, - unsigned int aNumResults); + unsigned int aNumResults); /** * @brief Compute the final decoding of an ongoing streaming inference and return @@ -345,7 +353,7 @@ char* STT_FinishStream(StreamingState* aSctx); */ STT_EXPORT Metadata* STT_FinishStreamWithMetadata(StreamingState* aSctx, - unsigned int aNumResults); + unsigned int aNumResults); /** * @brief Destroy a streaming state without decoding the computed logits. This @@ -389,6 +397,7 @@ char* STT_Version(); STT_EXPORT char* STT_ErrorCodeToErrorMessage(int aErrorCode); +#endif /* SWIG_ERRORS_ONLY */ #undef STT_EXPORT #ifdef __cplusplus diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 82cdd308..7676e6ad 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -13,6 +13,10 @@ for symbol in dir(swigwrapper): globals()[symbol] = getattr(swigwrapper, symbol) +class FlashlightDecoderState(swigwrapper.FlashlightDecoderState): + pass + + class Scorer(swigwrapper.Scorer): """Wrapper for Scorer. @@ -265,3 +269,75 @@ def ctc_beam_search_decoder_batch( for beam_results in batch_beam_results ] return batch_beam_results + + +def flashlight_beam_search_decoder( + probs_seq, + alphabet, + beam_size, + decoder_type, + token_type, + lm_tokens, + scorer=None, + beam_threshold=25.0, + cutoff_top_n=40, + silence_score=0.0, + merge_with_log_add=False, + criterion_type=swigwrapper.FlashlightDecoderState.CTC, + transitions=[], + num_results=1, +): + return swigwrapper.flashlight_beam_search_decoder( + probs_seq, + alphabet, + beam_size, + beam_threshold, + cutoff_top_n, + scorer, + token_type, + lm_tokens, + decoder_type, + silence_score, + merge_with_log_add, + criterion_type, + transitions, + num_results, + ) + + +def flashlight_beam_search_decoder_batch( + probs_seq, + seq_lengths, + alphabet, + beam_size, + decoder_type, + token_type, + lm_tokens, + num_processes, + scorer=None, + beam_threshold=25.0, + cutoff_top_n=40, + silence_score=0.0, + merge_with_log_add=False, + criterion_type=swigwrapper.FlashlightDecoderState.CTC, + transitions=[], + num_results=1, +): + return swigwrapper.flashlight_beam_search_decoder_batch( + probs_seq, + seq_lengths, + alphabet, + beam_size, + beam_threshold, + cutoff_top_n, + scorer, + token_type, + lm_tokens, + decoder_type, + silence_score, + merge_with_log_add, + criterion_type, + transitions, + num_results, + num_processes, + ) diff --git a/native_client/ctcdecode/build_archive.py b/native_client/ctcdecode/build_archive.py index a4f13c1a..ecf86dd6 100644 --- a/native_client/ctcdecode/build_archive.py +++ b/native_client/ctcdecode/build_archive.py @@ -17,7 +17,7 @@ else: ARGS = [ "-fPIC", "-DKENLM_MAX_ORDER=6", - "-std=c++11", + "-std=c++14", "-Wno-unused-local-typedefs", "-Wno-sign-compare", ] @@ -32,6 +32,7 @@ INCLUDES = [ OPENFST_DIR + "/src/include", "third_party/ThreadPool", "third_party/object_pool", + "third_party/flashlight", ] KENLM_FILES = ( @@ -40,7 +41,7 @@ KENLM_FILES = ( + glob.glob("../kenlm/util/double-conversion/*.cc") ) -KENLM_FILES += glob.glob(OPENFST_DIR + "/src/lib/*.cc") +OPENFST_FILES = glob.glob(OPENFST_DIR + "/src/lib/*.cc") KENLM_FILES = [ fn @@ -50,6 +51,22 @@ KENLM_FILES = [ ) ] +FLASHLIGHT_FILES = [ + "third_party/flashlight/flashlight/lib/common/String.cpp", + "third_party/flashlight/flashlight/lib/common/System.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp", + "third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp", + "third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp", + "third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp", +] + +THIRD_PARTY_FILES = KENLM_FILES + OPENFST_FILES + FLASHLIGHT_FILES + CTC_DECODER_FILES = [ "ctc_beam_search_decoder.cpp", "scorer.cpp", diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 2f6dd17a..af029b9d 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -12,6 +12,12 @@ #include "fst/fstlib.h" #include "path_trie.h" +#include "flashlight/lib/text/dictionary/Dictionary.h" +#include "flashlight/lib/text/decoder/Trie.h" +#include "flashlight/lib/text/decoder/LexiconDecoder.h" +#include "flashlight/lib/text/decoder/LexiconFreeDecoder.h" + +namespace flt = fl::lib::text; int DecoderState::init(const Alphabet& alphabet, @@ -264,6 +270,181 @@ DecoderState::decode(size_t num_results) const return outputs; } +int +FlashlightDecoderState::init( + const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + flt::Dictionary lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions) +{ + // Lexicon-free decoder must use single-token based LM + if (decoder_type == LexiconFree) { + assert(token_type == Single); + } + + // Build lexicon index to LM index map + if (!lm_tokens.contains("")) { + lm_tokens.addEntry(""); + } + ext_scorer->load_words(lm_tokens); + lm_tokens_ = lm_tokens; + + // Convert our criterion type to Flashlight type + flt::CriterionType flt_criterion; + switch (criterion_type) { + case ASG: flt_criterion = flt::CriterionType::ASG; break; + case CTC: flt_criterion = flt::CriterionType::CTC; break; + case S2S: flt_criterion = flt::CriterionType::S2S; break; + default: assert(false); + } + + // Build Trie + std::shared_ptr trie = nullptr; + auto startState = ext_scorer->start(false); + if (token_type == Aggregate || decoder_type == LexiconBased) { + trie = std::make_shared(lm_tokens.indexSize(), alphabet.GetSpaceLabel()); + for (int i = 0; i < lm_tokens.entrySize(); ++i) { + const std::string entry = lm_tokens.getEntry(i); + if (entry[0] == '<') { // don't insert , and + continue; + } + float score = -1; + if (token_type == Aggregate) { + flt::LMStatePtr dummyState; + std::tie(dummyState, score) = ext_scorer->score(startState, i); + } + std::vector encoded = alphabet.Encode(entry); + std::vector encoded_s(encoded.begin(), encoded.end()); + trie->insert(encoded_s, i, score); + } + + // Smear trie + trie->smear(flt::SmearingMode::MAX); + } + + // Query unknown token score + int unknown_word_index = lm_tokens.getIndex(""); + float unknown_score = -std::numeric_limits::infinity(); + if (token_type == Aggregate) { + std::tie(std::ignore, unknown_score) = + ext_scorer->score(startState, unknown_word_index); + } + + // Make sure conversions from uint to int below don't trip us + assert(beam_size < INT_MAX); + assert(cutoff_top_n < INT_MAX); + + if (decoder_type == LexiconBased) { + flt::LexiconDecoderOptions opts; + opts.beamSize = static_cast(beam_size); + opts.beamSizeToken = static_cast(cutoff_top_n); + opts.beamThreshold = beam_threshold; + opts.lmWeight = ext_scorer->alpha; + opts.wordScore = ext_scorer->beta; + opts.unkScore = unknown_score; + opts.silScore = silence_score; + opts.logAdd = merge_with_log_add; + opts.criterionType = flt_criterion; + decoder_impl_.reset(new flt::LexiconDecoder( + opts, + trie, + ext_scorer, + alphabet.GetSpaceLabel(), // silence index + alphabet.GetSize(), // blank index + unknown_word_index, + transitions, + token_type == Single) + ); + } else { + flt::LexiconFreeDecoderOptions opts; + opts.beamSize = static_cast(beam_size); + opts.beamSizeToken = static_cast(cutoff_top_n); + opts.beamThreshold = beam_threshold; + opts.lmWeight = ext_scorer->alpha; + opts.silScore = silence_score; + opts.logAdd = merge_with_log_add; + opts.criterionType = flt_criterion; + decoder_impl_.reset(new flt::LexiconFreeDecoder( + opts, + ext_scorer, + alphabet.GetSpaceLabel(), // silence index + alphabet.GetSize(), // blank index + transitions) + ); + } + + // Init decoder for stream + decoder_impl_->decodeBegin(); + + return 0; +} + +void +FlashlightDecoderState::next( + const double *probs, + int time_dim, + int class_dim) +{ + std::vector probs_f(probs, probs + (time_dim * class_dim) + 1); + decoder_impl_->decodeStep(probs_f.data(), time_dim, class_dim); +} + +FlashlightOutput +FlashlightDecoderState::intermediate(bool prune) +{ + flt::DecodeResult result = decoder_impl_->getBestHypothesis(); + std::vector valid_words; + for (int w : result.words) { + if (w != -1) { + valid_words.push_back(w); + } + } + FlashlightOutput ret { + .aggregate_score = result.score, + .acoustic_model_score = result.amScore, + .language_model_score = result.lmScore, + .words = lm_tokens_.mapIndicesToEntries(valid_words), // how does this interact with token-based decoding? + .tokens = result.tokens + }; + if (prune) { + decoder_impl_->prune(); + } + return ret; +} + +std::vector +FlashlightDecoderState::decode(size_t num_results) +{ + decoder_impl_->decodeEnd(); + std::vector flt_results = decoder_impl_->getAllFinalHypothesis(); + std::vector ret; + for (auto result : flt_results) { + std::vector valid_words; + for (int w : result.words) { + if (w != -1) { + valid_words.push_back(w); + } + } + ret.push_back({ + .aggregate_score = result.score, + .acoustic_model_score = result.amScore, + .language_model_score = result.lmScore, + .words = lm_tokens_.mapIndicesToEntries(valid_words), // how does this interact with token-based decoding? + .tokens = result.tokens + }); + } + decoder_impl_.reset(nullptr); + return ret; +} + std::vector ctc_beam_search_decoder( const double *probs, int time_dim, @@ -328,3 +509,104 @@ ctc_beam_search_decoder_batch( } return batch_results; } + +std::vector +flashlight_beam_search_decoder( + const double* probs, + int time_dim, + int class_dim, + const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + const std::vector& lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions, + size_t num_results) +{ + VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model."); + flt::Dictionary tokens_dict; + for (auto str : lm_tokens) { + tokens_dict.addEntry(str); + } + FlashlightDecoderState state; + state.init( + alphabet, + beam_size, + beam_threshold, + cutoff_top_n, + ext_scorer, + token_type, + tokens_dict, + decoder_type, + silence_score, + merge_with_log_add, + criterion_type, + transitions); + state.next(probs, time_dim, class_dim); + return state.decode(num_results); +} + +std::vector> +flashlight_beam_search_decoder_batch( + const double *probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, + const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + const std::vector& lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions, + size_t num_processes, + size_t num_results) +{ + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element"); + + ThreadPool pool(num_processes); + + // enqueue the tasks of decoding + std::vector>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(flashlight_beam_search_decoder, + &probs[i*time_dim*class_dim], + seq_lengths[i], + class_dim, + alphabet, + beam_size, + beam_threshold, + cutoff_top_n, + ext_scorer, + token_type, + lm_tokens, + decoder_type, + silence_score, + merge_with_log_add, + criterion_type, + transitions, + num_results)); + } + + // get decoding results + std::vector> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + + return batch_results; +} diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index dc19555c..2176565c 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -9,7 +9,10 @@ #include "output.h" #include "alphabet.h" -class DecoderState { +#include "flashlight/lib/text/decoder/Decoder.h" + +class DecoderState +{ int abs_time_step_; int space_id_; int blank_id_; @@ -76,6 +79,89 @@ public: std::vector decode(size_t num_results=1) const; }; +class FlashlightDecoderState +{ +public: + FlashlightDecoderState() = default; + ~FlashlightDecoderState() = default; + + // Disallow copying + FlashlightDecoderState(const FlashlightDecoderState&) = delete; + FlashlightDecoderState& operator=(FlashlightDecoderState&) = delete; + + enum LMTokenType { + Single // LM units == AM units (character/byte LM) + ,Aggregate // LM units != AM units (word LM) + }; + + enum DecoderType { + LexiconBased + ,LexiconFree + }; + + enum CriterionType { + ASG = 0 + ,CTC = 1 + ,S2S = 2 + }; + + /* Initialize beam search decoder + * + * Parameters: + * alphabet: The alphabet. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * Return: + * Zero on success, non-zero on failure. + */ + int init(const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + fl::lib::text::Dictionary lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions); + + /* Send data to the decoder + * + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * time_dim: Number of timesteps. + * class_dim: Number of classes (alphabet length + 1 for space character). + */ + void next(const double *probs, + int time_dim, + int class_dim); + + /* Return current best hypothesis, optinoally pruning hypothesis space */ + FlashlightOutput intermediate(bool prune = true); + + /* Get up to num_results transcriptions from current decoder state. + * + * Parameters: + * num_results: Number of hypotheses to return. + * + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. + */ + std::vector decode(size_t num_results = 1); + +private: + fl::lib::text::Dictionary lm_tokens_; + std::unique_ptr decoder_impl_; +}; + /* CTC Beam Search Decoder * Parameters: @@ -146,4 +232,86 @@ ctc_beam_search_decoder_batch( std::unordered_map hot_words, size_t num_results=1); +/* Flashlight Beam Search Decoder + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + * time_dim: Number of timesteps. + * class_dim: Alphabet length (plus 1 for space character). + * alphabet: The alphabet. + * beam_size: The width of beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * hot_words: A map of hot-words and their corresponding boosts + * The hot-word is a string and the boost is a float. + * num_results: Number of beams to return. + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. +*/ + +std::vector +flashlight_beam_search_decoder( + const double* probs, + int time_dim, + int class_dim, + const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + const std::vector& lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions, + size_t num_results); + +/* Flashlight Beam Search Decoder for batch data + * Parameters: + * probs: 3-D vector where each element is a 2-D vector that can be used + * by flashlight_beam_search_decoder(). + * alphabet: The alphabet. + * beam_size: The width of beam search. + * num_processes: Number of threads for beam search. + * cutoff_prob: Cutoff probability for pruning. + * cutoff_top_n: Cutoff number for pruning. + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * hot_words: A map of hot-words and their corresponding boosts + * The hot-word is a string and the boost is a float. + * num_results: Number of beams to return. + * Return: + * A 2-D vector where each element is a vector of beam search decoding + * result for one audio sample. +*/ +std::vector> +flashlight_beam_search_decoder_batch( + const double* probs, + int batch_size, + int time_dim, + int class_dim, + const int* seq_lengths, + int seq_lengths_size, + const Alphabet& alphabet, + size_t beam_size, + double beam_threshold, + size_t cutoff_top_n, + std::shared_ptr ext_scorer, + FlashlightDecoderState::LMTokenType token_type, + const std::vector& lm_tokens, + FlashlightDecoderState::DecoderType decoder_type, + double silence_score, + bool merge_with_log_add, + FlashlightDecoderState::CriterionType criterion_type, + std::vector transitions, + size_t num_results, + size_t num_processes); + #endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/native_client/ctcdecode/output.h b/native_client/ctcdecode/output.h index bdfc8ee9..a9967d16 100644 --- a/native_client/ctcdecode/output.h +++ b/native_client/ctcdecode/output.h @@ -12,4 +12,12 @@ struct Output { std::vector timesteps; }; +struct FlashlightOutput { + double aggregate_score; + double acoustic_model_score; + double language_model_score; + std::vector words; + std::vector tokens; +}; + #endif // OUTPUT_H_ diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index eb059fdf..d5ca6cbc 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -17,16 +17,27 @@ #include #include -#include "lm/config.hh" -#include "lm/model.hh" -#include "lm/state.hh" -#include "util/string_piece.hh" +#include "kenlm/lm/config.hh" +#include "kenlm/lm/model.hh" +#include "kenlm/lm/state.hh" +#include "kenlm/lm/word_index.hh" +#include "kenlm/util/string_piece.hh" #include "decoder_utils.h" +using namespace fl::lib::text; + static const int32_t MAGIC = 'TRIE'; static const int32_t FILE_VERSION = 6; +Scorer::Scorer() +{ +} + +Scorer::~Scorer() +{ +} + int Scorer::init(const std::string& lm_path, const Alphabet& alphabet) @@ -347,3 +358,54 @@ void Scorer::fill_dictionary(const std::unordered_set& vocabulary) std::unique_ptr converted(new FstType(*new_dict)); this->dictionary = std::move(converted); } + +LMStatePtr +Scorer::start(bool startWithNothing) +{ + auto outState = std::make_shared(); + if (startWithNothing) { + language_model_->NullContextWrite(outState->ken()); + } else { + language_model_->BeginSentenceWrite(outState->ken()); + } + + return outState; +} + +std::pair +Scorer::score(const LMStatePtr& state, + const int usrTokenIdx) +{ + if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) { + throw std::runtime_error( + "[Scorer] Invalid user token index: " + std::to_string(usrTokenIdx)); + } + auto inState = std::static_pointer_cast(state); + auto outState = inState->child(usrTokenIdx); + float score = language_model_->BaseScore( + inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken()); + return std::make_pair(std::move(outState), score); +} + +std::pair +Scorer::finish(const LMStatePtr& state) +{ + auto inState = std::static_pointer_cast(state); + auto outState = inState->child(-1); + float score = language_model_->BaseScore( + inState->ken(), + language_model_->BaseVocabulary().EndSentence(), + outState->ken() + ); + return std::make_pair(std::move(outState), score); +} + +void +Scorer::load_words(const Dictionary& word_dict) +{ + const auto& vocab = language_model_->BaseVocabulary(); + usrToLmIdxMap_.resize(word_dict.indexSize()); + for (int i = 0; i < word_dict.indexSize(); ++i) { + usrToLmIdxMap_[i] = vocab.Index(word_dict.getEntry(i)); + } +} diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 67ea96d3..eaf789db 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -7,9 +7,7 @@ #include #include -#include "lm/virtual_interface.hh" -#include "lm/word_index.hh" -#include "util/string_piece.hh" +#include "flashlight/lib/text/decoder/lm/KenLM.h" #include "path_trie.h" #include "alphabet.h" @@ -27,12 +25,12 @@ const std::string END_TOKEN = ""; * Scorer scorer(alpha, beta, "path_of_language_model"); * scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" }); */ -class Scorer { +class Scorer : public fl::lib::text::LM { public: using FstType = PathTrie::FstType; - Scorer() = default; - ~Scorer() = default; + Scorer(); + ~Scorer(); // disallow copying Scorer(const Scorer&) = delete; @@ -94,6 +92,29 @@ public: // pointer to the dictionary of FST std::unique_ptr dictionary; + // --------------- + // fl::lib::text::LM methods + + /* Initialize or reset language model state */ + fl::lib::text::LMStatePtr start(bool startWithNothing); + + /** + * Query the language model given input state and a specific token, return a + * new language model state and score. + */ + std::pair score( + const fl::lib::text::LMStatePtr& state, + const int usrTokenIdx); + + /* Query the language model and finish decoding. */ + std::pair finish(const fl::lib::text::LMStatePtr& state); + + // --------------- + // fl::lib::text helper + + // Must be called before use of this Scorer with Flashlight APIs. + void load_words(const fl::lib::text::Dictionary& word_dict); + protected: // necessary setup after setting alphabet void setup_char_map(); diff --git a/native_client/ctcdecode/setup.py b/native_client/ctcdecode/setup.py index 6b987b7b..5131f365 100644 --- a/native_client/ctcdecode/setup.py +++ b/native_client/ctcdecode/setup.py @@ -70,7 +70,7 @@ third_party_build = "third_party.{}".format(archive_ext) ctc_decoder_build = "first_party.{}".format(archive_ext) -maybe_rebuild(KENLM_FILES, third_party_build, build_dir) +maybe_rebuild(THIRD_PARTY_FILES, third_party_build, build_dir) maybe_rebuild(CTC_DECODER_FILES, ctc_decoder_build, build_dir) decoder_module = Extension( diff --git a/native_client/ctcdecode/swigwrapper.i b/native_client/ctcdecode/swigwrapper.i index facc83eb..fa176342 100644 --- a/native_client/ctcdecode/swigwrapper.i +++ b/native_client/ctcdecode/swigwrapper.i @@ -20,9 +20,12 @@ import_array(); namespace std { %template(StringVector) vector; + %template(FloatVector) vector; %template(UnsignedIntVector) vector; %template(OutputVector) vector; %template(OutputVectorVector) vector>; + %template(FlashlightOutputVector) vector; + %template(FlashlightOutputVectorVector) vector>; %template(Map) unordered_map; } @@ -36,6 +39,7 @@ namespace std { %ignore Scorer::dictionary; +%include "third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h" %include "../alphabet.h" %include "output.h" %include "scorer.h" @@ -45,13 +49,5 @@ namespace std { %constant const char* __git_version__ = ds_git_version(); // Import only the error code enum definitions from coqui-stt.h -// We can't just do |%ignore "";| here because it affects this file globally (even -// files %include'd above). That causes SWIG to lose destructor information and -// leads to leaks of the wrapper objects. -// Instead we ignore functions and classes (structs), which are the only other -// things in coqui-stt.h. If we add some new construct to coqui-stt.h we need -// to update the ignore rules here to avoid exposing unwanted APIs in the decoder -// package. -%rename("$ignore", %$isfunction) ""; -%rename("$ignore", %$isclass) ""; +#define SWIG_ERRORS_ONLY %include "../coqui-stt.h" diff --git a/native_client/ctcdecode/third_party/flashlight/LICENSE b/native_client/ctcdecode/third_party/flashlight/LICENSE new file mode 100644 index 00000000..b96dcb04 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp new file mode 100644 index 00000000..ebfbc2bd --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.cpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/common/String.h" + +#include + +#include +#include +#include +#include + +static constexpr const char* kSpaceChars = "\t\n\v\f\r "; + +namespace fl { +namespace lib { + +std::string trim(const std::string& str) { + auto i = str.find_first_not_of(kSpaceChars); + if (i == std::string::npos) { + return ""; + } + auto j = str.find_last_not_of(kSpaceChars); + if (j == std::string::npos || i > j) { + return ""; + } + return str.substr(i, j - i + 1); +} + +void replaceAll( + std::string& str, + const std::string& from, + const std::string& repl) { + if (from.empty()) { + return; + } + size_t pos = 0; + while ((pos = str.find(from, pos)) != std::string::npos) { + str.replace(pos, from.length(), repl); + pos += repl.length(); + } +} + +bool startsWith(const std::string& input, const std::string& pattern) { + return (input.find(pattern) == 0); +} + +bool endsWith(const std::string& input, const std::string& pattern) { + if (pattern.size() > input.size()) { + return false; + } + return std::equal(pattern.rbegin(), pattern.rend(), input.rbegin()); +} + +template +static std::vector splitImpl( + const Delim& delim, + std::string::size_type delimSize, + const std::string& input, + bool ignoreEmpty = false) { + std::vector result; + std::string::size_type i = 0; + while (true) { + auto j = Any ? input.find_first_of(delim, i) : input.find(delim, i); + if (j == std::string::npos) { + break; + } + if (!(ignoreEmpty && i == j)) { + result.emplace_back(input.begin() + i, input.begin() + j); + } + i = j + delimSize; + } + if (!(ignoreEmpty && i == input.size())) { + result.emplace_back(input.begin() + i, input.end()); + } + return result; +} + +std::vector +split(char delim, const std::string& input, bool ignoreEmpty) { + return splitImpl(delim, 1, input, ignoreEmpty); +} + +std::vector +split(const std::string& delim, const std::string& input, bool ignoreEmpty) { + if (delim.empty()) { + throw std::invalid_argument("delimiter is empty string"); + } + return splitImpl(delim, delim.size(), input, ignoreEmpty); +} + +std::vector splitOnAnyOf( + const std::string& delim, + const std::string& input, + bool ignoreEmpty) { + return splitImpl(delim, 1, input, ignoreEmpty); +} + +std::vector splitOnWhitespace( + const std::string& input, + bool ignoreEmpty) { + return splitOnAnyOf(kSpaceChars, input, ignoreEmpty); +} + +std::string join( + const std::string& delim, + const std::vector& vec) { + return join(delim, vec.begin(), vec.end()); +} +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h new file mode 100644 index 00000000..492c710a --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h @@ -0,0 +1,129 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fl { +namespace lib { + +// ============================ Types and Templates ============================ + +template +using DecayDereference = + typename std::decay())>::type; + +template +using EnableIfSame = typename std::enable_if::value>::type; + +// ================================== Functions +// ================================== + +std::string trim(const std::string& str); + +void replaceAll( + std::string& str, + const std::string& from, + const std::string& repl); + +bool startsWith(const std::string& input, const std::string& pattern); +bool endsWith(const std::string& input, const std::string& pattern); + +std::vector +split(char delim, const std::string& input, bool ignoreEmpty = false); + +std::vector split( + const std::string& delim, + const std::string& input, + bool ignoreEmpty = false); + +std::vector splitOnAnyOf( + const std::string& delim, + const std::string& input, + bool ignoreEmpty = false); + +std::vector splitOnWhitespace( + const std::string& input, + bool ignoreEmpty = false); + +/** + * Join a vector of `std::string` inserting `delim` in between. + */ +std::string join(const std::string& delim, const std::vector& vec); + +/** + * Join a range of `std::string` specified by iterators. + */ +template < + typename FwdIt, + typename = EnableIfSame, std::string>> +std::string join(const std::string& delim, FwdIt begin, FwdIt end) { + if (begin == end) { + return ""; + } + + size_t totalSize = begin->size(); + for (auto it = std::next(begin); it != end; ++it) { + totalSize += delim.size() + it->size(); + } + + std::string result; + result.reserve(totalSize); + + result.append(*begin); + for (auto it = std::next(begin); it != end; ++it) { + result.append(delim); + result.append(*it); + } + return result; +} + +/** + * Create an output string using a `printf`-style format string and arguments. + * Safer than `sprintf` which is vulnerable to buffer overflow. + */ +template +std::string format(const char* fmt, Args&&... args) { + auto res = std::snprintf(nullptr, 0, fmt, std::forward(args)...); + if (res < 0) { + throw std::runtime_error(std::strerror(errno)); + } + std::string buf(res, '\0'); + // the size here is fine -- it's legal to write '\0' to buf[res] + auto res2 = std::snprintf(&buf[0], res + 1, fmt, std::forward(args)...); + if (res2 < 0) { + throw std::runtime_error(std::strerror(errno)); + } + + if (res2 != res) { + throw std::runtime_error( + "The size of the formated string is not equal to what it is expected."); + } + return buf; +} + +/** + * Dedup the elements in a vector. + */ +template +void dedup(std::vector& in) { + if (in.empty()) { + return; + } + auto it = std::unique(in.begin(), in.end()); + in.resize(std::distance(in.begin(), it)); +} +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp new file mode 100644 index 00000000..fd89ac58 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp @@ -0,0 +1,250 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/common/System.h" + +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "flashlight/lib/common/String.h" + +namespace fl { +namespace lib { + +size_t getProcessId() { +#ifdef _WIN32 + return GetCurrentProcessId(); +#else + return ::getpid(); +#endif +} + +size_t getThreadId() { +#ifdef _WIN32 + return GetCurrentThreadId(); +#else + return std::hash()(std::this_thread::get_id()); +#endif +} + +std::string pathSeperator() { +#ifdef _WIN32 + return "\\"; +#else + return "/"; +#endif +} + +std::string pathsConcat(const std::string& p1, const std::string& p2) { + if (!p1.empty() && p1[p1.length() - 1] != pathSeperator()[0]) { + return ( + trim(p1) + pathSeperator() + trim(p2)); // Need to add a path separator + } else { + return (trim(p1) + trim(p2)); + } +} + +namespace { + +/** + * @path contains directories separated by path separator. + * Returns a vector with the directores in the original order. Vector with a + * Special cases: a vector with a single entry containing the input is returned + * when path is one of the following special cases: empty, “.”, “..” and “/” + */ +std::vector getDirsOnPath(const std::string& path) { + const std::string trimPath = trim(path); + + if (trimPath.empty() || trimPath == pathSeperator() || trimPath == "." || + trimPath == "..") { + return {trimPath}; + } + const std::vector tokens = split(pathSeperator(), trimPath); + std::vector dirs; + for (const std::string& token : tokens) { + const std::string dir = trim(token); + if (!dir.empty()) { + dirs.push_back(dir); + } + } + return dirs; +} + +} // namespace + +std::string dirname(const std::string& path) { + std::vector dirsOnPath = getDirsOnPath(path); + if (dirsOnPath.size() < 2) { + return "."; + } else { + dirsOnPath.pop_back(); + const std::string root = + ((trim(path))[0] == pathSeperator()[0]) ? pathSeperator() : ""; + return root + join(pathSeperator(), dirsOnPath); + } +} + +std::string basename(const std::string& path) { + std::vector dirsOnPath = getDirsOnPath(path); + if (dirsOnPath.empty()) { + return ""; + } else { + return dirsOnPath.back(); + } +} + +bool dirExists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; + } else if (info.st_mode & S_IFDIR) { + return true; + } else { + return false; + } +} + +void dirCreate(const std::string& path) { + if (dirExists(path)) { + return; + } + mode_t nMode = 0755; + int nError = 0; +#ifdef _WIN32 + nError = _mkdir(path.c_str()); +#else + nError = mkdir(path.c_str(), nMode); +#endif + if (nError != 0) { + throw std::runtime_error( + std::string() + "Unable to create directory - " + path); + } +} + +void dirCreateRecursive(const std::string& path) { + if (dirExists(path)) { + return; + } + std::vector dirsOnPath = getDirsOnPath(path); + std::string pathFromStart; + if (path[0] == pathSeperator()[0]) { + pathFromStart = pathSeperator(); + } + for (std::string& dir : dirsOnPath) { + if (pathFromStart.empty()) { + pathFromStart = dir; + } else { + pathFromStart = pathsConcat(pathFromStart, dir); + } + + if (!dirExists(pathFromStart)) { + dirCreate(pathFromStart); + } + } +} + +bool fileExists(const std::string& path) { + std::ifstream fs(path, std::ifstream::in); + return fs.good(); +} + +std::string getEnvVar( + const std::string& key, + const std::string& dflt /*= "" */) { + char* val = getenv(key.c_str()); + return val ? std::string(val) : dflt; +} + +std::string getCurrentDate() { + time_t now = time(nullptr); + struct tm tmbuf; + struct tm* tstruct; + tstruct = localtime_r(&now, &tmbuf); + + std::array buf; + strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct); + return std::string(buf.data()); +} + +std::string getCurrentTime() { + time_t now = time(nullptr); + struct tm tmbuf; + struct tm* tstruct; + tstruct = localtime_r(&now, &tmbuf); + + std::array buf; + strftime(buf.data(), buf.size(), "%X", tstruct); + return std::string(buf.data()); +} + +std::string getTmpPath(const std::string& filename) { + std::string tmpDir = "/tmp"; + auto getTmpDir = [&tmpDir](const std::string& env) { + char* dir = std::getenv(env.c_str()); + if (dir != nullptr) { + tmpDir = std::string(dir); + } + }; + getTmpDir("TMPDIR"); + getTmpDir("TEMP"); + getTmpDir("TMP"); + return tmpDir + "/fl_tmp_" + getEnvVar("USER", "unknown") + "_" + filename; +} + +std::vector getFileContent(const std::string& file) { + std::vector data; + std::ifstream in = createInputStream(file); + std::string str; + while (std::getline(in, str)) { + data.emplace_back(str); + } + in.close(); + return data; +} + +std::vector fileGlob(const std::string& pat) { + glob_t result; + glob(pat.c_str(), GLOB_TILDE, nullptr, &result); + std::vector ret; + for (unsigned int i = 0; i < result.gl_pathc; ++i) { + ret.push_back(std::string(result.gl_pathv[i])); + } + globfree(&result); + return ret; +} + +std::ifstream createInputStream(const std::string& filename) { + std::ifstream file(filename); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file for reading: " + filename); + } + return file; +} + +std::ofstream createOutputStream( + const std::string& filename, + std::ios_base::openmode mode) { + std::ofstream file(filename, mode); + if (!file.is_open()) { + throw std::runtime_error("Failed to open file for writing: " + filename); + } + return file; +} + +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h new file mode 100644 index 00000000..c63ed1bb --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fl { +namespace lib { + +size_t getProcessId(); + +size_t getThreadId(); + +std::string pathsConcat(const std::string& p1, const std::string& p2); + +std::string pathSeperator(); + +std::string dirname(const std::string& path); + +std::string basename(const std::string& path); + +bool dirExists(const std::string& path); + +void dirCreate(const std::string& path); + +void dirCreateRecursive(const std::string& path); + +bool fileExists(const std::string& path); + +std::string getEnvVar(const std::string& key, const std::string& dflt = ""); + +std::string getCurrentDate(); + +std::string getCurrentTime(); + +std::string getTmpPath(const std::string& filename); + +std::vector getFileContent(const std::string& file); + +std::vector fileGlob(const std::string& pat); + +std::ifstream createInputStream(const std::string& filename); + +std::ofstream createOutputStream( + const std::string& filename, + std::ios_base::openmode mode = std::ios_base::out); + +/** + * Calls `f(args...)` repeatedly, retrying if an exception is thrown. + * Supports sleeps between retries, with duration starting at `initial` and + * multiplying by `factor` each retry. At most `maxIters` calls are made. + */ +template +typename std::result_of::type retryWithBackoff( + std::chrono::duration initial, + double factor, + int64_t maxIters, + Fn&& f, + Args&&... args) { + if (!(initial.count() >= 0.0)) { + throw std::invalid_argument("retryWithBackoff: bad initial"); + } else if (!(factor >= 0.0)) { + throw std::invalid_argument("retryWithBackoff: bad factor"); + } else if (maxIters <= 0) { + throw std::invalid_argument("retryWithBackoff: bad maxIters"); + } + auto sleepSecs = initial.count(); + for (int64_t i = 0; i < maxIters; ++i) { + try { + return f(std::forward(args)...); + } catch (...) { + if (i >= maxIters - 1) { + throw; + } + } + if (sleepSecs > 0.0) { + /* sleep override */ + std::this_thread::sleep_for( + std::chrono::duration(std::min(1e7, sleepSecs))); + } + sleepSecs *= factor; + } + throw std::logic_error("retryWithBackoff: hit unreachable"); +} +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h new file mode 100644 index 00000000..e495e282 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Decoder.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "flashlight/lib/text/decoder/Utils.h" + +namespace fl { +namespace lib { +namespace text { + +enum class CriterionType { ASG = 0, CTC = 1, S2S = 2 }; + +/** + * Decoder support two typical use cases: + * Offline manner: + * decoder.decode(someData) [returns all hypothesis (transcription)] + * + * Online manner: + * decoder.decodeBegin() [called only at the beginning of the stream] + * while (stream) + * decoder.decodeStep(someData) [one or more calls] + * decoder.getBestHypothesis() [returns the best hypothesis (transcription)] + * decoder.prune() [prunes the hypothesis space] + * decoder.decodeEnd() [called only at the end of the stream] + * + * Note: function decoder.prune() deletes hypothesis up until time when called + * to supports online decoding. It will also add a offset to the scores in beam + * to avoid underflow/overflow. + * + */ +class Decoder { + public: + Decoder() = default; + virtual ~Decoder() = default; + + /* Initialize decoder before starting consume emissions */ + virtual void decodeBegin() {} + + /* Consume emissions in T x N chunks and increase the hypothesis space */ + virtual void decodeStep(const float* emissions, int T, int N) = 0; + + /* Finish up decoding after consuming all emissions */ + virtual void decodeEnd() {} + + /* Offline decode function, which consume all emissions at once */ + virtual std::vector + decode(const float* emissions, int T, int N) { + decodeBegin(); + decodeStep(emissions, T, N); + decodeEnd(); + return getAllFinalHypothesis(); + } + + /* Prune the hypothesis space */ + virtual void prune(int lookBack = 0) = 0; + + /* Get the number of decoded frame in buffer */ + virtual int nDecodedFramesInBuffer() const = 0; + + /* + * Get the best completed hypothesis which is `lookBack` frames ahead the last + * one in buffer. For lexicon requiredd LMs, completed hypothesis means no + * partial word appears at the end. + */ + virtual DecodeResult getBestHypothesis(int lookBack = 0) const = 0; + + /* Get all the final hypothesis */ + virtual std::vector getAllFinalHypothesis() const = 0; +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp new file mode 100644 index 00000000..8d8576ff --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.cpp @@ -0,0 +1,328 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/LexiconDecoder.h" + +namespace fl { +namespace lib { +namespace text { + +void LexiconDecoder::decodeBegin() { + hyp_.clear(); + hyp_.emplace(0, std::vector()); + + /* note: the lm reset itself with :start() */ + hyp_[0].emplace_back( + 0.0, lm_->start(0), lexicon_->getRoot(), nullptr, sil_, -1); + nDecodedFrames_ = 0; + nPrunedFrames_ = 0; +} + +void LexiconDecoder::decodeStep(const float* emissions, int T, int N) { + int startFrame = nDecodedFrames_ - nPrunedFrames_; + // Extend hyp_ buffer + if (hyp_.size() < startFrame + T + 2) { + for (int i = hyp_.size(); i < startFrame + T + 2; i++) { + hyp_.emplace(i, std::vector()); + } + } + + std::vector idx(N); + for (int t = 0; t < T; t++) { + std::iota(idx.begin(), idx.end(), 0); + if (N > opt_.beamSizeToken) { + std::partial_sort( + idx.begin(), + idx.begin() + opt_.beamSizeToken, + idx.end(), + [&t, &N, &emissions](const size_t& l, const size_t& r) { + return emissions[t * N + l] > emissions[t * N + r]; + }); + } + + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) { + const TrieNode* prevLex = prevHyp.lex; + const int prevIdx = prevHyp.token; + const float lexMaxScore = + prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore; + + /* (1) Try children */ + for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) { + int n = idx[r]; + auto iter = prevLex->children.find(n); + if (iter == prevLex->children.end()) { + continue; + } + const TrieNodePtr& lex = iter->second; + double amScore = emissions[t * N + n]; + if (nDecodedFrames_ + t > 0 && + opt_.criterionType == CriterionType::ASG) { + amScore += transitions_[n * N + prevIdx]; + } + double score = prevHyp.score + amScore; + if (n == sil_) { + score += opt_.silScore; + } + + LMStatePtr lmState; + double lmScore = 0.; + + if (isLmToken_) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, n); + lmState = lmStateScorePair.first; + lmScore = lmStateScorePair.second; + } + + // We eat-up a new token + if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank || + n != prevIdx) { + if (!lex->children.empty()) { + if (!isLmToken_) { + lmState = prevHyp.lmState; + lmScore = lex->maxScore - lexMaxScore; + } + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score + opt_.lmWeight * lmScore, + lmState, + lex.get(), + &prevHyp, + n, + -1, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } + } + + // If we got a true word + for (auto label : lex->labels) { + if (prevLex == lexicon_->getRoot() && prevHyp.token == n) { + // This is to avoid an situation that, when there is word with + // single token spelling (e.g. X -> x) in the lexicon and token `x` + // is predicted in several consecutive frames, multiple word `X` + // will be emitted. This violates the property of CTC, where + // there must be an blank token in between to predict 2 identical + // tokens consecutively. + continue; + } + + if (!isLmToken_) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, label); + lmState = lmStateScorePair.first; + lmScore = lmStateScorePair.second - lexMaxScore; + } + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score + opt_.lmWeight * lmScore + opt_.wordScore, + lmState, + lexicon_->getRoot(), + &prevHyp, + n, + label, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } + + // If we got an unknown word + if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) { + if (!isLmToken_) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_); + lmState = lmStateScorePair.first; + lmScore = lmStateScorePair.second - lexMaxScore; + } + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score + opt_.lmWeight * lmScore + opt_.unkScore, + lmState, + lexicon_->getRoot(), + &prevHyp, + n, + unk_, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } + } + + /* (2) Try same lexicon node */ + if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank || + prevLex == lexicon_->getRoot()) { + int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx; + double amScore = emissions[t * N + n]; + if (nDecodedFrames_ + t > 0 && + opt_.criterionType == CriterionType::ASG) { + amScore += transitions_[n * N + prevIdx]; + } + double score = prevHyp.score + amScore; + if (n == sil_) { + score += opt_.silScore; + } + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score, + prevHyp.lmState, + prevLex, + &prevHyp, + n, + -1, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore); + } + + /* (3) CTC only, try blank */ + if (opt_.criterionType == CriterionType::CTC) { + int n = blank_; + double amScore = emissions[t * N + n]; + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore, + prevHyp.lmState, + prevLex, + &prevHyp, + n, + -1, + true, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore); + } + // finish proposing + } + + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[startFrame + t + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + false); + updateLMCache(lm_, hyp_[startFrame + t + 1]); + } + + nDecodedFrames_ += T; +} + +void LexiconDecoder::decodeEnd() { + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + bool hasNiceEnding = false; + for (const LexiconDecoderState& prevHyp : + hyp_[nDecodedFrames_ - nPrunedFrames_]) { + if (prevHyp.lex == lexicon_->getRoot()) { + hasNiceEnding = true; + break; + } + } + for (const LexiconDecoderState& prevHyp : + hyp_[nDecodedFrames_ - nPrunedFrames_]) { + const TrieNode* prevLex = prevHyp.lex; + const LMStatePtr& prevLmState = prevHyp.lmState; + + if (!hasNiceEnding || prevHyp.lex == lexicon_->getRoot()) { + auto lmStateScorePair = lm_->finish(prevLmState); + auto lmScore = lmStateScorePair.second; + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + opt_.lmWeight * lmScore, + lmStateScorePair.first, + prevLex, + &prevHyp, + sil_, + -1, + false, // prevBlank + prevHyp.amScore, + prevHyp.lmScore + lmScore); + } + } + + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[nDecodedFrames_ - nPrunedFrames_ + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + true); + ++nDecodedFrames_; +} + +std::vector LexiconDecoder::getAllFinalHypothesis() const { + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + if (finalFrame < 1) { + return std::vector{}; + } + + return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame); +} + +DecodeResult LexiconDecoder::getBestHypothesis(int lookBack) const { + if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) { + return DecodeResult(); + } + + const LexiconDecoderState* bestNode = findBestAncestor( + hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack); + return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack); +} + +int LexiconDecoder::nHypothesis() const { + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + return hyp_.find(finalFrame)->second.size(); +} + +int LexiconDecoder::nDecodedFramesInBuffer() const { + return nDecodedFrames_ - nPrunedFrames_ + 1; +} + +void LexiconDecoder::prune(int lookBack) { + if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) { + return; // Not enough decoded frames to prune + } + + /* (1) Find the last emitted word in the best path */ + const LexiconDecoderState* bestNode = findBestAncestor( + hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack); + if (!bestNode) { + return; // Not enough decoded frames to prune + } + + int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack; + if (startFrame < 1) { + return; // Not enough decoded frames to prune + } + + /* (2) Move things from back of hyp_ to front and normalize scores */ + pruneAndNormalize(hyp_, startFrame, lookBack); + + nPrunedFrames_ = nDecodedFrames_ - lookBack; +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h new file mode 100644 index 00000000..c4b35807 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconDecoder.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "flashlight/lib/text/decoder/Decoder.h" +#include "flashlight/lib/text/decoder/Trie.h" +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +struct LexiconDecoderOptions { + int beamSize; // Maximum number of hypothesis we hold after each step + int beamSizeToken; // Maximum number of tokens we consider at each step + double beamThreshold; // Threshold to prune hypothesis + double lmWeight; // Weight of lm + double wordScore; // Word insertion score + double unkScore; // Unknown word insertion score + double silScore; // Silence insertion score + bool logAdd; // If or not use logadd when merging hypothesis + CriterionType criterionType; // CTC or ASG +}; + +/** + * LexiconDecoderState stores information for each hypothesis in the beam. + */ +struct LexiconDecoderState { + double score; // Accumulated total score so far + LMStatePtr lmState; // Language model state + const TrieNode* lex; // Trie node in the lexicon + const LexiconDecoderState* parent; // Parent hypothesis + int token; // Label of token + int word; // Label of word (-1 if incomplete) + bool prevBlank; // If previous hypothesis is blank (for CTC only) + + double amScore; // Accumulated AM score so far + double lmScore; // Accumulated LM score so far + + LexiconDecoderState( + const double score, + const LMStatePtr& lmState, + const TrieNode* lex, + const LexiconDecoderState* parent, + const int token, + const int word, + const bool prevBlank = false, + const double amScore = 0, + const double lmScore = 0) + : score(score), + lmState(lmState), + lex(lex), + parent(parent), + token(token), + word(word), + prevBlank(prevBlank), + amScore(amScore), + lmScore(lmScore) {} + + LexiconDecoderState() + : score(0.), + lmState(nullptr), + lex(nullptr), + parent(nullptr), + token(-1), + word(-1), + prevBlank(false), + amScore(0.), + lmScore(0.) {} + + int compareNoScoreStates(const LexiconDecoderState* node) const { + int lmCmp = lmState->compare(node->lmState); + if (lmCmp != 0) { + return lmCmp > 0 ? 1 : -1; + } else if (lex != node->lex) { + return lex > node->lex ? 1 : -1; + } else if (token != node->token) { + return token > node->token ? 1 : -1; + } else if (prevBlank != node->prevBlank) { + return prevBlank > node->prevBlank ? 1 : -1; + } + return 0; + } + + int getWord() const { + return word; + } + + bool isComplete() const { + return !parent || parent->word >= 0; + } +}; + +/** + * Decoder implements a beam seach decoder that finds the word transcription + * W maximizing: + * + * AM(W) + lmWeight_ * log(P_{lm}(W)) + wordScore_ * |W_known| + unkScore_ * + * |W_unknown| + silScore_ * |{i| pi_i = }| + * + * where P_{lm}(W) is the language model score, pi_i is the value for the i-th + * frame in the path leading to W and AM(W) is the (unnormalized) acoustic model + * score of the transcription W. Note that the lexicon is used to limit the + * search space and all candidate words are generated from it if unkScore is + * -inf, otherwise will be generated for OOVs. + */ +class LexiconDecoder : public Decoder { + public: + LexiconDecoder( + LexiconDecoderOptions opt, + const TriePtr& lexicon, + const LMPtr& lm, + const int sil, + const int blank, + const int unk, + const std::vector& transitions, + const bool isLmToken) + : opt_(std::move(opt)), + lexicon_(lexicon), + lm_(lm), + sil_(sil), + blank_(blank), + unk_(unk), + transitions_(transitions), + isLmToken_(isLmToken) {} + + void decodeBegin() override; + + void decodeStep(const float* emissions, int T, int N) override; + + void decodeEnd() override; + + int nHypothesis() const; + + void prune(int lookBack = 0) override; + + int nDecodedFramesInBuffer() const override; + + DecodeResult getBestHypothesis(int lookBack = 0) const override; + + std::vector getAllFinalHypothesis() const override; + + protected: + LexiconDecoderOptions opt_; + // Lexicon trie to restrict beam-search decoder + TriePtr lexicon_; + LMPtr lm_; + // Index of silence label + int sil_; + // Index of blank label (for CTC) + int blank_; + // Index of unknown word + int unk_; + // matrix of transitions (for ASG criterion) + std::vector transitions_; + // if LM is token-level (operates on the same level as acoustic model) + // or it is word-level (in case of false) + bool isLmToken_; + + // All the hypothesis new candidates (can be larger than beamsize) proposed + // based on the ones from previous frame + std::vector candidates_; + + // This vector is designed for efficient sorting and merging the candidates_, + // so instead of moving around objects, we only need to sort pointers + std::vector candidatePtrs_; + + // Best candidate score of current frame + double candidatesBestScore_; + + // Vector of hypothesis for all the frames so far + std::unordered_map> hyp_; + + // These 2 variables are used for online decoding, for hypothesis pruning + int nDecodedFrames_; // Total number of decoded frames. + int nPrunedFrames_; // Total number of pruned frames from hyp_. +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp new file mode 100644 index 00000000..01becd47 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp @@ -0,0 +1,207 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/LexiconFreeDecoder.h" + +namespace fl { +namespace lib { +namespace text { + +void LexiconFreeDecoder::decodeBegin() { + hyp_.clear(); + hyp_.emplace(0, std::vector()); + + /* note: the lm reset itself with :start() */ + hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_); + nDecodedFrames_ = 0; + nPrunedFrames_ = 0; +} + +void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) { + int startFrame = nDecodedFrames_ - nPrunedFrames_; + // Extend hyp_ buffer + if (hyp_.size() < startFrame + T + 2) { + for (int i = hyp_.size(); i < startFrame + T + 2; i++) { + hyp_.emplace(i, std::vector()); + } + } + + std::vector idx(N); + // Looping over all the frames + for (int t = 0; t < T; t++) { + std::iota(idx.begin(), idx.end(), 0); + if (N > opt_.beamSizeToken) { + std::partial_sort( + idx.begin(), + idx.begin() + opt_.beamSizeToken, + idx.end(), + [&t, &N, &emissions](const size_t& l, const size_t& r) { + return emissions[t * N + l] > emissions[t * N + r]; + }); + } + + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) { + const int prevIdx = prevHyp.token; + + for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) { + int n = idx[r]; + double amScore = emissions[t * N + n]; + if (nDecodedFrames_ + t > 0 && + opt_.criterionType == CriterionType::ASG) { + amScore += transitions_[n * N + prevIdx]; + } + double score = prevHyp.score + emissions[t * N + n]; + if (n == sil_) { + score += opt_.silScore; + } + + if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) || + (opt_.criterionType == CriterionType::CTC && n != blank_ && + (n != prevIdx || prevHyp.prevBlank))) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, n); + auto lmScore = lmStateScorePair.second; + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + n, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } else if (opt_.criterionType == CriterionType::CTC && n == blank_) { + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score, + prevHyp.lmState, + &prevHyp, + n, + true, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore); + } else { + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + score, + prevHyp.lmState, + &prevHyp, + n, + false, // prevBlank + prevHyp.amScore + amScore, + prevHyp.lmScore); + } + } + } + + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[startFrame + t + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + false); + updateLMCache(lm_, hyp_[startFrame + t + 1]); + } + nDecodedFrames_ += T; +} + +void LexiconFreeDecoder::decodeEnd() { + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + for (const LexiconFreeDecoderState& prevHyp : + hyp_[nDecodedFrames_ - nPrunedFrames_]) { + const LMStatePtr& prevLmState = prevHyp.lmState; + + auto lmStateScorePair = lm_->finish(prevLmState); + auto lmScore = lmStateScorePair.second; + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + sil_, + false, // prevBlank + prevHyp.amScore, + prevHyp.lmScore + lmScore); + } + + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[nDecodedFrames_ - nPrunedFrames_ + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + true); + ++nDecodedFrames_; +} + +std::vector LexiconFreeDecoder::getAllFinalHypothesis() const { + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame); +} + +DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const { + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + const LexiconFreeDecoderState* bestNode = + findBestAncestor(hyp_.find(finalFrame)->second, lookBack); + + return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack); +} + +int LexiconFreeDecoder::nHypothesis() const { + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + return hyp_.find(finalFrame)->second.size(); +} + +int LexiconFreeDecoder::nDecodedFramesInBuffer() const { + return nDecodedFrames_ - nPrunedFrames_ + 1; +} + +void LexiconFreeDecoder::prune(int lookBack) { + if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) { + return; // Not enough decoded frames to prune + } + + /* (1) Find the last emitted word in the best path */ + int finalFrame = nDecodedFrames_ - nPrunedFrames_; + const LexiconFreeDecoderState* bestNode = + findBestAncestor(hyp_.find(finalFrame)->second, lookBack); + if (!bestNode) { + return; // Not enough decoded frames to prune + } + + int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack; + if (startFrame < 1) { + return; // Not enough decoded frames to prune + } + + /* (2) Move things from back of hyp_ to front and normalize scores */ + pruneAndNormalize(hyp_, startFrame, lookBack); + + nPrunedFrames_ = nDecodedFrames_ - lookBack; +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h new file mode 100644 index 00000000..62985812 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeDecoder.h @@ -0,0 +1,160 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "flashlight/lib/text/decoder/Decoder.h" +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +struct LexiconFreeDecoderOptions { + int beamSize; // Maximum number of hypothesis we hold after each step + int beamSizeToken; // Maximum number of tokens we consider at each step + double beamThreshold; // Threshold to prune hypothesis + double lmWeight; // Weight of lm + double silScore; // Silence insertion score + bool logAdd; + CriterionType criterionType; // CTC or ASG +}; + +/** + * LexiconFreeDecoderState stores information for each hypothesis in the beam. + */ +struct LexiconFreeDecoderState { + double score; // Accumulated total score so far + LMStatePtr lmState; // Language model state + const LexiconFreeDecoderState* parent; // Parent hypothesis + int token; // Label of token + bool prevBlank; // If previous hypothesis is blank (for CTC only) + + double amScore; // Accumulated AM score so far + double lmScore; // Accumulated LM score so far + + LexiconFreeDecoderState( + const double score, + const LMStatePtr& lmState, + const LexiconFreeDecoderState* parent, + const int token, + const bool prevBlank = false, + const double amScore = 0, + const double lmScore = 0) + : score(score), + lmState(lmState), + parent(parent), + token(token), + prevBlank(prevBlank), + amScore(amScore), + lmScore(lmScore) {} + + LexiconFreeDecoderState() + : score(0), + lmState(nullptr), + parent(nullptr), + token(-1), + prevBlank(false), + amScore(0.), + lmScore(0.) {} + + int compareNoScoreStates(const LexiconFreeDecoderState* node) const { + int lmCmp = lmState->compare(node->lmState); + if (lmCmp != 0) { + return lmCmp > 0 ? 1 : -1; + } else if (token != node->token) { + return token > node->token ? 1 : -1; + } else if (prevBlank != node->prevBlank) { + return prevBlank > node->prevBlank ? 1 : -1; + } + return 0; + } + + int getWord() const { + return -1; + } + + bool isComplete() const { + return true; + } +}; + +/** + * Decoder implements a beam seach decoder that finds the word transcription + * W maximizing: + * + * AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = }| + * + * where P_{lm}(W) is the language model score, pi_i is the value for the i-th + * frame in the path leading to W and AM(W) is the (unnormalized) acoustic model + * score of the transcription W. We are allowed to generate words from all the + * possible combination of tokens. + */ +class LexiconFreeDecoder : public Decoder { + public: + LexiconFreeDecoder( + LexiconFreeDecoderOptions opt, + const LMPtr& lm, + const int sil, + const int blank, + const std::vector& transitions) + : opt_(std::move(opt)), + lm_(lm), + transitions_(transitions), + sil_(sil), + blank_(blank) {} + + void decodeBegin() override; + + void decodeStep(const float* emissions, int T, int N) override; + + void decodeEnd() override; + + int nHypothesis() const; + + void prune(int lookBack = 0) override; + + int nDecodedFramesInBuffer() const override; + + DecodeResult getBestHypothesis(int lookBack = 0) const override; + + std::vector getAllFinalHypothesis() const override; + + protected: + LexiconFreeDecoderOptions opt_; + LMPtr lm_; + std::vector transitions_; + + // All the hypothesis new candidates (can be larger than beamsize) proposed + // based on the ones from previous frame + std::vector candidates_; + + // This vector is designed for efficient sorting and merging the candidates_, + // so instead of moving around objects, we only need to sort pointers + std::vector candidatePtrs_; + + // Best candidate score of current frame + double candidatesBestScore_; + + // Index of silence label + int sil_; + + // Index of blank label (for CTC) + int blank_; + + // Vector of hypothesis for all the frames so far + std::unordered_map> hyp_; + + // These 2 variables are used for online decoding, for hypothesis pruning + int nDecodedFrames_; // Total number of decoded frames. + int nPrunedFrames_; // Total number of pruned frames from hyp_. +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp new file mode 100644 index 00000000..0cda9688 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h" + +namespace fl { +namespace lib { +namespace text { + +void LexiconFreeSeq2SeqDecoder::decodeStep( + const float* emissions, + int T, + int N) { + // Extend hyp_ buffer + if (hyp_.size() < maxOutputLength_ + 2) { + for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) { + hyp_.emplace(i, std::vector()); + } + } + + // Start from here. + hyp_[0].clear(); + hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, -1, nullptr); + + // Decode frame by frame + int t = 0; + for (; t < maxOutputLength_; t++) { + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + + // Batch forwarding + rawY_.clear(); + rawPrevStates_.clear(); + for (const LexiconFreeSeq2SeqDecoderState& prevHyp : hyp_[t]) { + const AMStatePtr& prevState = prevHyp.amState; + if (prevHyp.token == eos_) { + continue; + } + rawY_.push_back(prevHyp.token); + rawPrevStates_.push_back(prevState); + } + if (rawY_.size() == 0) { + break; + } + + std::vector> amScores; + std::vector outStates; + + std::tie(amScores, outStates) = + amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t); + + std::vector idx(amScores.back().size()); + + // Generate new hypothesis + for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) { + const LexiconFreeSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; + // Change nothing for completed hypothesis + if (prevHyp.token == eos_) { + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score, + prevHyp.lmState, + &prevHyp, + eos_, + nullptr, + prevHyp.amScore, + prevHyp.lmScore); + continue; + } + + const AMStatePtr& outState = outStates[validHypo]; + if (!outState) { + validHypo++; + continue; + } + + std::iota(idx.begin(), idx.end(), 0); + if (amScores[validHypo].size() > opt_.beamSizeToken) { + std::partial_sort( + idx.begin(), + idx.begin() + opt_.beamSizeToken, + idx.end(), + [&amScores, &validHypo](const size_t& l, const size_t& r) { + return amScores[validHypo][l] > amScores[validHypo][r]; + }); + } + + for (int r = 0; + r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken); + r++) { + int n = idx[r]; + double amScore = amScores[validHypo][n]; + + if (n == eos_) { /* (1) Try eos */ + auto lmStateScorePair = lm_->finish(prevHyp.lmState); + auto lmScore = lmStateScorePair.second; + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + n, + nullptr, + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } else { /* (2) Try normal token */ + auto lmStateScorePair = lm_->score(prevHyp.lmState, n); + auto lmScore = lmStateScorePair.second; + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore + opt_.lmWeight * lmScore, + lmStateScorePair.first, + &prevHyp, + n, + outState, + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } + } + validHypo++; + } + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[t + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + true); + updateLMCache(lm_, hyp_[t + 1]); + } // End of decoding + + while (t > 0 && hyp_[t].empty()) { + --t; + } + hyp_[maxOutputLength_ + 1].resize(hyp_[t].size()); + for (int i = 0; i < hyp_[t].size(); i++) { + hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]); + } +} + +std::vector LexiconFreeSeq2SeqDecoder::getAllFinalHypothesis() + const { + return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size()); +} + +DecodeResult LexiconFreeSeq2SeqDecoder::getBestHypothesis( + int /* unused */) const { + return getHypothesis( + hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size()); +} + +void LexiconFreeSeq2SeqDecoder::prune(int /* unused */) { + return; +} + +int LexiconFreeSeq2SeqDecoder::nDecodedFramesInBuffer() const { + /* unused function */ + return -1; +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h new file mode 100644 index 00000000..c10c1deb --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.h @@ -0,0 +1,141 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include "flashlight/lib/text/decoder/Decoder.h" +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +using AMStatePtr = std::shared_ptr; +using AMUpdateFunc = std::function< + std::pair>, std::vector>( + const float*, + const int, + const int, + const std::vector&, + const std::vector&, + int&)>; + +struct LexiconFreeSeq2SeqDecoderOptions { + int beamSize; // Maximum number of hypothesis we hold after each step + int beamSizeToken; // Maximum number of tokens we consider at each step + double beamThreshold; // Threshold to prune hypothesis + double lmWeight; // Weight of lm + double eosScore; // Score for inserting an EOS + bool logAdd; // If or not use logadd when merging hypothesis +}; + +/** + * LexiconFreeSeq2SeqDecoderState stores information for each hypothesis in the + * beam. + */ +struct LexiconFreeSeq2SeqDecoderState { + double score; // Accumulated total score so far + LMStatePtr lmState; // Language model state + const LexiconFreeSeq2SeqDecoderState* parent; // Parent hypothesis + int token; // Label of token + AMStatePtr amState; // Acoustic model state + + double amScore; // Accumulated AM score so far + double lmScore; // Accumulated LM score so far + + LexiconFreeSeq2SeqDecoderState( + const double score, + const LMStatePtr& lmState, + const LexiconFreeSeq2SeqDecoderState* parent, + const int token, + const AMStatePtr& amState = nullptr, + const double amScore = 0, + const double lmScore = 0) + : score(score), + lmState(lmState), + parent(parent), + token(token), + amState(amState), + amScore(amScore), + lmScore(lmScore) {} + + LexiconFreeSeq2SeqDecoderState() + : score(0), + lmState(nullptr), + parent(nullptr), + token(-1), + amState(nullptr), + amScore(0.), + lmScore(0.) {} + + int compareNoScoreStates(const LexiconFreeSeq2SeqDecoderState* node) const { + return lmState->compare(node->lmState); + } + + int getWord() const { + return -1; + } +}; + +/** + * Decoder implements a beam seach decoder that finds the token transcription + * W maximizing: + * + * AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS| + * + * where P_{lm}(W) is the language model score. The sequence of tokens is not + * constrained by a lexicon, and thus the language model must operate at + * token-level. + * + * TODO: Doesn't support online decoding now. + * + */ +class LexiconFreeSeq2SeqDecoder : public Decoder { + public: + LexiconFreeSeq2SeqDecoder( + LexiconFreeSeq2SeqDecoderOptions opt, + const LMPtr& lm, + const int eos, + AMUpdateFunc amUpdateFunc, + const int maxOutputLength) + : opt_(std::move(opt)), + lm_(lm), + eos_(eos), + amUpdateFunc_(amUpdateFunc), + maxOutputLength_(maxOutputLength) {} + + void decodeStep(const float* emissions, int T, int N) override; + + void prune(int lookBack = 0) override; + + int nDecodedFramesInBuffer() const override; + + DecodeResult getBestHypothesis(int lookBack = 0) const override; + + std::vector getAllFinalHypothesis() const override; + + protected: + LexiconFreeSeq2SeqDecoderOptions opt_; + LMPtr lm_; + int eos_; + AMUpdateFunc amUpdateFunc_; + std::vector rawY_; + std::vector rawPrevStates_; + int maxOutputLength_; + + std::vector candidates_; + std::vector candidatePtrs_; + double candidatesBestScore_; + + std::unordered_map> hyp_; +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.cpp new file mode 100644 index 00000000..fa332d24 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.cpp @@ -0,0 +1,243 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h" + +namespace fl { +namespace lib { +namespace text { + +void LexiconSeq2SeqDecoder::decodeStep(const float* emissions, int T, int N) { + // Extend hyp_ buffer + if (hyp_.size() < maxOutputLength_ + 2) { + for (int i = hyp_.size(); i < maxOutputLength_ + 2; i++) { + hyp_.emplace(i, std::vector()); + } + } + + // Start from here. + hyp_[0].clear(); + hyp_[0].emplace_back( + 0.0, lm_->start(0), lexicon_->getRoot(), nullptr, -1, -1, nullptr); + + auto compare = [](const LexiconSeq2SeqDecoderState& n1, + const LexiconSeq2SeqDecoderState& n2) { + return n1.score > n2.score; + }; + + // Decode frame by frame + int t = 0; + for (; t < maxOutputLength_; t++) { + candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_); + + // Batch forwarding + rawY_.clear(); + rawPrevStates_.clear(); + for (const LexiconSeq2SeqDecoderState& prevHyp : hyp_[t]) { + const AMStatePtr& prevState = prevHyp.amState; + if (prevHyp.token == eos_) { + continue; + } + rawY_.push_back(prevHyp.token); + rawPrevStates_.push_back(prevState); + } + if (rawY_.size() == 0) { + break; + } + + std::vector> amScores; + std::vector outStates; + + std::tie(amScores, outStates) = + amUpdateFunc_(emissions, N, T, rawY_, rawPrevStates_, t); + + std::vector idx(amScores.back().size()); + + // Generate new hypothesis + for (int hypo = 0, validHypo = 0; hypo < hyp_[t].size(); hypo++) { + const LexiconSeq2SeqDecoderState& prevHyp = hyp_[t][hypo]; + // Change nothing for completed hypothesis + if (prevHyp.token == eos_) { + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score, + prevHyp.lmState, + prevHyp.lex, + &prevHyp, + eos_, + -1, + nullptr, + prevHyp.amScore, + prevHyp.lmScore); + continue; + } + + const AMStatePtr& outState = outStates[validHypo]; + if (!outState) { + validHypo++; + continue; + } + + const TrieNode* prevLex = prevHyp.lex; + const float lexMaxScore = + prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore; + + std::iota(idx.begin(), idx.end(), 0); + if (amScores[validHypo].size() > opt_.beamSizeToken) { + std::partial_sort( + idx.begin(), + idx.begin() + opt_.beamSizeToken, + idx.end(), + [&amScores, &validHypo](const size_t& l, const size_t& r) { + return amScores[validHypo][l] > amScores[validHypo][r]; + }); + } + + for (int r = 0; + r < std::min(amScores[validHypo].size(), (size_t)opt_.beamSizeToken); + r++) { + int n = idx[r]; + double amScore = amScores[validHypo][n]; + + /* (1) Try eos */ + if (n == eos_ && (prevLex == lexicon_->getRoot())) { + auto lmStateScorePair = lm_->finish(prevHyp.lmState); + LMStatePtr lmState = lmStateScorePair.first; + double lmScore; + if (isLmToken_) { + lmScore = lmStateScorePair.second; + } else { + lmScore = lmStateScorePair.second - lexMaxScore; + } + + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore + opt_.eosScore + opt_.lmWeight * lmScore, + lmState, + lexicon_->getRoot(), + &prevHyp, + n, + -1, + nullptr, + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + } + + /* (2) Try normal token */ + if (n != eos_) { + auto searchLex = prevLex->children.find(n); + if (searchLex != prevLex->children.end()) { + auto lex = searchLex->second; + LMStatePtr lmState; + double lmScore; + if (isLmToken_) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, n); + lmState = lmStateScorePair.first; + lmScore = lmStateScorePair.second; + } else { + // smearing + lmState = prevHyp.lmState; + lmScore = lex->maxScore - lexMaxScore; + } + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore + opt_.lmWeight * lmScore, + lmState, + lex.get(), + &prevHyp, + n, + -1, + outState, + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + + // If we got a true word + if (lex->labels.size() > 0) { + for (auto word : lex->labels) { + if (!isLmToken_) { + auto lmStateScorePair = lm_->score(prevHyp.lmState, word); + lmState = lmStateScorePair.first; + lmScore = lmStateScorePair.second - lexMaxScore; + } + candidatesAdd( + candidates_, + candidatesBestScore_, + opt_.beamThreshold, + prevHyp.score + amScore + opt_.wordScore + + opt_.lmWeight * lmScore, + lmState, + lexicon_->getRoot(), + &prevHyp, + n, + word, + outState, + prevHyp.amScore + amScore, + prevHyp.lmScore + lmScore); + if (isLmToken_) { + break; + } + } + } + } + } + } + validHypo++; + } + candidatesStore( + candidates_, + candidatePtrs_, + hyp_[t + 1], + opt_.beamSize, + candidatesBestScore_ - opt_.beamThreshold, + opt_.logAdd, + true); + updateLMCache(lm_, hyp_[t + 1]); + } // End of decoding + + while (t > 0 && hyp_[t].empty()) { + --t; + } + hyp_[maxOutputLength_ + 1].resize(hyp_[t].size()); + for (int i = 0; i < hyp_[t].size(); i++) { + hyp_[maxOutputLength_ + 1][i] = std::move(hyp_[t][i]); + } +} + +std::vector LexiconSeq2SeqDecoder::getAllFinalHypothesis() const { + return getAllHypothesis(hyp_.find(maxOutputLength_ + 1)->second, hyp_.size()); +} + +DecodeResult LexiconSeq2SeqDecoder::getBestHypothesis(int /* unused */) const { + return getHypothesis( + hyp_.find(maxOutputLength_ + 1)->second.data(), hyp_.size()); +} + +void LexiconSeq2SeqDecoder::prune(int /* unused */) { + return; +} + +int LexiconSeq2SeqDecoder::nDecodedFramesInBuffer() const { + /* unused function */ + return -1; +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h new file mode 100644 index 00000000..3e8e2d8f --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.h @@ -0,0 +1,165 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include "flashlight/lib/text/decoder/Decoder.h" +#include "flashlight/lib/text/decoder/Trie.h" +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +using AMStatePtr = std::shared_ptr; +using AMUpdateFunc = std::function< + std::pair>, std::vector>( + const float*, + const int, + const int, + const std::vector&, + const std::vector&, + int&)>; + +struct LexiconSeq2SeqDecoderOptions { + int beamSize; // Maximum number of hypothesis we hold after each step + int beamSizeToken; // Maximum number of tokens we consider at each step + double beamThreshold; // Threshold to prune hypothesis + double lmWeight; // Weight of lm + double wordScore; // Word insertion score + double eosScore; // Score for inserting an EOS + bool logAdd; // If or not use logadd when merging hypothesis +}; + +/** + * LexiconSeq2SeqDecoderState stores information for each hypothesis in the + * beam. + */ +struct LexiconSeq2SeqDecoderState { + double score; // Accumulated total score so far + LMStatePtr lmState; // Language model state + const TrieNode* lex; + const LexiconSeq2SeqDecoderState* parent; // Parent hypothesis + int token; // Label of token + int word; + AMStatePtr amState; // Acoustic model state + + double amScore; // Accumulated AM score so far + double lmScore; // Accumulated LM score so far + + LexiconSeq2SeqDecoderState( + const double score, + const LMStatePtr& lmState, + const TrieNode* lex, + const LexiconSeq2SeqDecoderState* parent, + const int token, + const int word, + const AMStatePtr& amState, + const double amScore = 0, + const double lmScore = 0) + : score(score), + lmState(lmState), + lex(lex), + parent(parent), + token(token), + word(word), + amState(amState), + amScore(amScore), + lmScore(lmScore) {} + + LexiconSeq2SeqDecoderState() + : score(0), + lmState(nullptr), + lex(nullptr), + parent(nullptr), + token(-1), + word(-1), + amState(nullptr), + amScore(0.), + lmScore(0.) {} + + int compareNoScoreStates(const LexiconSeq2SeqDecoderState* node) const { + int lmCmp = lmState->compare(node->lmState); + if (lmCmp != 0) { + return lmCmp > 0 ? 1 : -1; + } else if (lex != node->lex) { + return lex > node->lex ? 1 : -1; + } else if (token != node->token) { + return token > node->token ? 1 : -1; + } + return 0; + } + + int getWord() const { + return word; + } +}; + +/** + * Decoder implements a beam seach decoder that finds the token transcription + * W maximizing: + * + * AM(W) + lmWeight_ * log(P_{lm}(W)) + eosScore_ * |W_last == EOS| + * + * where P_{lm}(W) is the language model score. The transcription W is + * constrained by a lexicon. The language model may operate at word-level + * (isLmToken=false) or token-level (isLmToken=true). + * + * TODO: Doesn't support online decoding now. + * + */ +class LexiconSeq2SeqDecoder : public Decoder { + public: + LexiconSeq2SeqDecoder( + LexiconSeq2SeqDecoderOptions opt, + const TriePtr& lexicon, + const LMPtr& lm, + const int eos, + AMUpdateFunc amUpdateFunc, + const int maxOutputLength, + const bool isLmToken) + : opt_(std::move(opt)), + lm_(lm), + lexicon_(lexicon), + eos_(eos), + amUpdateFunc_(amUpdateFunc), + maxOutputLength_(maxOutputLength), + isLmToken_(isLmToken) {} + + void decodeStep(const float* emissions, int T, int N) override; + + void prune(int lookBack = 0) override; + + int nDecodedFramesInBuffer() const override; + + DecodeResult getBestHypothesis(int lookBack = 0) const override; + + std::vector getAllFinalHypothesis() const override; + + protected: + LexiconSeq2SeqDecoderOptions opt_; + LMPtr lm_; + TriePtr lexicon_; + int eos_; + AMUpdateFunc amUpdateFunc_; + std::vector rawY_; + std::vector rawPrevStates_; + int maxOutputLength_; + bool isLmToken_; + + std::vector candidates_; + std::vector candidatePtrs_; + double candidatesBestScore_; + + std::unordered_map> hyp_; +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp new file mode 100644 index 00000000..df037e1b --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/Trie.h" + +namespace fl { +namespace lib { +namespace text { + +const double kMinusLogThreshold = -39.14; + +const TrieNode* Trie::getRoot() const { + return root_.get(); +} + +TrieNodePtr +Trie::insert(const std::vector& indices, int label, float score) { + TrieNodePtr node = root_; + for (int i = 0; i < indices.size(); i++) { + int idx = indices[i]; + if (idx < 0 || idx >= maxChildren_) { + throw std::out_of_range( + "[Trie] Invalid letter index: " + std::to_string(idx)); + } + if (node->children.find(idx) == node->children.end()) { + node->children[idx] = std::make_shared(idx); + } + node = node->children[idx]; + } + if (node->labels.size() < kTrieMaxLabel) { + node->labels.push_back(label); + node->scores.push_back(score); + } else { + std::cerr << "[Trie] Trie label number reached limit: " << kTrieMaxLabel + << "\n"; + } + return node; +} + +TrieNodePtr Trie::search(const std::vector& indices) { + TrieNodePtr node = root_; + for (auto idx : indices) { + if (idx < 0 || idx >= maxChildren_) { + throw std::out_of_range( + "[Trie] Invalid letter index: " + std::to_string(idx)); + } + if (node->children.find(idx) == node->children.end()) { + return nullptr; + } + node = node->children[idx]; + } + return node; +} + +/* logadd */ +double TrieLogAdd(double log_a, double log_b) { + double minusdif; + if (log_a < log_b) { + std::swap(log_a, log_b); + } + minusdif = log_b - log_a; + if (minusdif < kMinusLogThreshold) { + return log_a; + } else { + return log_a + log1p(exp(minusdif)); + } +} + +void smearNode(TrieNodePtr node, SmearingMode smearMode) { + node->maxScore = -std::numeric_limits::infinity(); + for (auto score : node->scores) { + node->maxScore = TrieLogAdd(node->maxScore, score); + } + for (auto child : node->children) { + auto childNode = child.second; + smearNode(childNode, smearMode); + if (smearMode == SmearingMode::LOGADD) { + node->maxScore = TrieLogAdd(node->maxScore, childNode->maxScore); + } else if ( + smearMode == SmearingMode::MAX && + childNode->maxScore > node->maxScore) { + node->maxScore = childNode->maxScore; + } + } +} + +void Trie::smear(SmearingMode smearMode) { + if (smearMode != SmearingMode::NONE) { + smearNode(root_, smearMode); + } +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h new file mode 100644 index 00000000..e26100e7 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.h @@ -0,0 +1,95 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +namespace fl { +namespace lib { +namespace text { + +constexpr int kTrieMaxLabel = 6; + +enum class SmearingMode { + NONE = 0, + MAX = 1, + LOGADD = 2, +}; + +/** + * TrieNode is the trie node structure in Trie. + */ +struct TrieNode { + explicit TrieNode(int idx) + : children(std::unordered_map>()), + idx(idx), + maxScore(0) { + labels.reserve(kTrieMaxLabel); + scores.reserve(kTrieMaxLabel); + } + + // Pointers to the children of a node + std::unordered_map> children; + + // Node index + int idx; + + // Labels of words that are constructed from the given path. Note that + // `labels` is nonempty only if the current node represents a completed token. + std::vector labels; + + // Scores (`scores` should have the same size as `labels`) + std::vector scores; + + // Maximum score of all the labels if this node is a leaf, + // otherwise it will be the value after trie smearing. + float maxScore; +}; + +using TrieNodePtr = std::shared_ptr; + +/** + * Trie is used to store the lexicon in langiage model. We use it to limit + * the search space in deocder and quickly look up scores for a given token + * (completed word) or make prediction for incompleted ones based on smearing. + */ +class Trie { + public: + Trie(int maxChildren, int rootIdx) + : root_(std::make_shared(rootIdx)), maxChildren_(maxChildren) {} + + /* Return the root node pointer */ + const TrieNode* getRoot() const; + + /* Insert a token into trie with label */ + TrieNodePtr insert(const std::vector& indices, int label, float score); + + /* Get the labels for a given token */ + TrieNodePtr search(const std::vector& indices); + + /** + * Smearing the trie using the valid labels inserted in the trie so as to get + * score on each node (incompleted token). + * For example, if smear_mode is MAX, then for node "a" in path "c"->"a", we + * will select the maximum score from all its children like "c"->"a"->"t", + * "c"->"a"->"n", "c"->"a"->"r"->"e" and so on. + * This process will be carry out recusively on all the nodes. + */ + void smear(const SmearingMode smear_mode); + + private: + TrieNodePtr root_; + int maxChildren_; // The maximum number of childern for each node. It is + // usually the size of letters or phonmes. +}; + +using TriePtr = std::shared_ptr; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp new file mode 100644 index 00000000..d6b8d554 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.cpp @@ -0,0 +1,15 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +namespace fl { +namespace lib { +namespace text { + +// Place holder +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h new file mode 100644 index 00000000..d8d50457 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Utils.h @@ -0,0 +1,275 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +/* ===================== Definitions ===================== */ + +const double kNegativeInfinity = -std::numeric_limits::infinity(); +const int kLookBackLimit = 100; + +struct DecodeResult { + double score; + double amScore; + double lmScore; + std::vector words; + std::vector tokens; + + explicit DecodeResult(int length = 0) + : score(0), words(length, -1), tokens(length, -1) {} +}; + +/* ===================== Candidate-related operations ===================== */ + +template +void candidatesReset( + double& candidatesBestScore, + std::vector& candidates, + std::vector& candidatePtrs) { + candidatesBestScore = kNegativeInfinity; + candidates.clear(); + candidatePtrs.clear(); +} + +template +void candidatesAdd( + std::vector& candidates, + double& candidatesBestScore, + const double beamThreshold, + const double score, + const Args&... args) { + if (score >= candidatesBestScore) { + candidatesBestScore = score; + } + if (score >= candidatesBestScore - beamThreshold) { + candidates.emplace_back(score, args...); + } +} + +template +void candidatesStore( + std::vector& candidates, + std::vector& candidatePtrs, + std::vector& outputs, + const int beamSize, + const double threshold, + const bool logAdd, + const bool returnSorted) { + outputs.clear(); + if (candidates.empty()) { + return; + } + + /* 1. Select valid candidates */ + for (auto& candidate : candidates) { + if (candidate.score >= threshold) { + candidatePtrs.emplace_back(&candidate); + } + } + + /* 2. Merge candidates */ + std::sort( + candidatePtrs.begin(), + candidatePtrs.end(), + [](const DecoderState* node1, const DecoderState* node2) { + int cmp = node1->compareNoScoreStates(node2); + return cmp == 0 ? node1->score > node2->score : cmp > 0; + }); + + int nHypAfterMerging = 1; + for (int i = 1; i < candidatePtrs.size(); i++) { + if (candidatePtrs[i]->compareNoScoreStates( + candidatePtrs[nHypAfterMerging - 1]) != 0) { + // Distinct candidate + candidatePtrs[nHypAfterMerging] = candidatePtrs[i]; + nHypAfterMerging++; + } else { + // Same candidate + double maxScore = std::max( + candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score); + if (logAdd) { + double minScore = std::min( + candidatePtrs[nHypAfterMerging - 1]->score, + candidatePtrs[i]->score); + candidatePtrs[nHypAfterMerging - 1]->score = + maxScore + std::log1p(std::exp(minScore - maxScore)); + } else { + candidatePtrs[nHypAfterMerging - 1]->score = maxScore; + } + } + } + candidatePtrs.resize(nHypAfterMerging); + + /* 3. Sort and prune */ + auto compareNodeScore = [](const DecoderState* node1, + const DecoderState* node2) { + return node1->score > node2->score; + }; + + int nValidHyp = candidatePtrs.size(); + int finalSize = std::min(nValidHyp, beamSize); + if (!returnSorted && nValidHyp > beamSize) { + std::nth_element( + candidatePtrs.begin(), + candidatePtrs.begin() + finalSize, + candidatePtrs.begin() + nValidHyp, + compareNodeScore); + } else if (returnSorted) { + std::partial_sort( + candidatePtrs.begin(), + candidatePtrs.begin() + finalSize, + candidatePtrs.begin() + nValidHyp, + compareNodeScore); + } + + for (int i = 0; i < finalSize; i++) { + outputs.emplace_back(std::move(*candidatePtrs[i])); + } +} + +/* ===================== Result-related operations ===================== */ + +template +DecodeResult getHypothesis(const DecoderState* node, const int finalFrame) { + const DecoderState* node_ = node; + if (!node_) { + return DecodeResult(); + } + + DecodeResult res(finalFrame + 1); + res.score = node_->score; + res.amScore = node_->amScore; + res.lmScore = node_->lmScore; + + int i = 0; + while (node_) { + res.words[finalFrame - i] = node_->getWord(); + res.tokens[finalFrame - i] = node_->token; + node_ = node_->parent; + i++; + } + + return res; +} + +template +std::vector getAllHypothesis( + const std::vector& finalHyps, + const int finalFrame) { + int nHyp = finalHyps.size(); + + std::vector res(nHyp); + + for (int r = 0; r < nHyp; r++) { + const DecoderState* node = &finalHyps[r]; + res[r] = getHypothesis(node, finalFrame); + } + + return res; +} + +template +const DecoderState* findBestAncestor( + const std::vector& finalHyps, + int& lookBack) { + int nHyp = finalHyps.size(); + if (nHyp == 0) { + return nullptr; + } + + double bestScore = finalHyps.front().score; + const DecoderState* bestNode = finalHyps.data(); + for (int r = 1; r < nHyp; r++) { + const DecoderState* node = &finalHyps[r]; + if (node->score > bestScore) { + bestScore = node->score; + bestNode = node; + } + } + + int n = 0; + while (bestNode && n < lookBack) { + n++; + bestNode = bestNode->parent; + } + + const int maxLookBack = lookBack + kLookBackLimit; + while (bestNode) { + // Check for first emitted word. + if (bestNode->isComplete()) { + break; + } + + n++; + bestNode = bestNode->parent; + + if (n == maxLookBack) { + break; + } + } + + lookBack = n; + return bestNode; +} + +template +void pruneAndNormalize( + std::unordered_map>& hypothesis, + const int startFrame, + const int lookBack) { + /* 1. Move things from back of hypothesis to front. */ + for (int i = 0; i < hypothesis.size(); i++) { + if (i <= lookBack) { + hypothesis[i].swap(hypothesis[i + startFrame]); + } else { + hypothesis[i].clear(); + } + } + + /* 2. Avoid further back-tracking */ + for (DecoderState& hyp : hypothesis[0]) { + hyp.parent = nullptr; + } + + /* 3. Avoid score underflow/overflow. */ + double largestScore = hypothesis[lookBack].front().score; + for (int i = 1; i < hypothesis[lookBack].size(); i++) { + if (largestScore < hypothesis[lookBack][i].score) { + largestScore = hypothesis[lookBack][i].score; + } + } + + for (int i = 0; i < hypothesis[lookBack].size(); i++) { + hypothesis[lookBack][i].score -= largestScore; + } +} + +/* ===================== LM-related operations ===================== */ + +template +void updateLMCache(const LMPtr& lm, std::vector& hypothesis) { + // For ConvLM update cache + std::vector states; + for (const auto& hyp : hypothesis) { + states.emplace_back(hyp.lmState); + } + lm->updateCache(states); +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp new file mode 100644 index 00000000..e313668c --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.cpp @@ -0,0 +1,239 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ + +#include +#include +#include + +#include "flashlight/lib/text/decoder/lm/ConvLM.h" + +namespace fl { +namespace lib { +namespace text { + +ConvLM::ConvLM( + const GetConvLmScoreFunc& getConvLmScoreFunc, + const std::string& tokenVocabPath, + const Dictionary& usrTknDict, + int lmMemory, + int beamSize, + int historySize) + : lmMemory_(lmMemory), + beamSize_(beamSize), + getConvLmScoreFunc_(getConvLmScoreFunc), + maxHistorySize_(historySize) { + if (historySize < 1) { + throw std::invalid_argument("[ConvLM] History size is too small."); + } + + /* Load token vocabulary */ + // Note: fairseq vocab should start with: + // - 0 - 1, - 2, - 3 + std::cerr << "[ConvLM]: Loading vocabulary from " << tokenVocabPath << "\n"; + vocab_ = Dictionary(tokenVocabPath); + vocab_.setDefaultIndex(vocab_.getIndex(kUnkToken)); + vocabSize_ = vocab_.indexSize(); + std::cerr << "[ConvLM]: vocabulary size of convLM " << vocabSize_ << "\n"; + + /* Create index map */ + usrToLmIdxMap_.resize(usrTknDict.indexSize()); + for (int i = 0; i < usrTknDict.indexSize(); i++) { + auto token = usrTknDict.getEntry(i); + int lmIdx = vocab_.getIndex(token.c_str()); + usrToLmIdxMap_[i] = lmIdx; + } + + /* Refresh cache */ + cacheIndices_.reserve(beamSize_); + cache_.resize(beamSize_, std::vector(vocabSize_)); + slot_.reserve(beamSize_); + batchedTokens_.resize(beamSize_ * maxHistorySize_); +} + +LMStatePtr ConvLM::start(bool startWithNothing) { + cacheIndices_.clear(); + auto outState = std::make_shared(1); + if (!startWithNothing) { + outState->length = 1; + outState->tokens[0] = vocab_.getIndex(kEosToken); + } else { + throw std::invalid_argument( + "[ConvLM] Only support using EOS to start the sentence"); + } + return outState; +} + +std::pair ConvLM::scoreWithLmIdx( + const LMStatePtr& state, + const int tokenIdx) { + auto rawInState = std::static_pointer_cast(state).get(); + int inStateLength = rawInState->length; + std::shared_ptr outState; + + // Prepare output state + if (inStateLength == maxHistorySize_) { + outState = std::make_shared(maxHistorySize_); + std::copy( + rawInState->tokens.begin() + 1, + rawInState->tokens.end(), + outState->tokens.begin()); + outState->tokens[maxHistorySize_ - 1] = tokenIdx; + } else { + outState = std::make_shared(inStateLength + 1); + std::copy( + rawInState->tokens.begin(), + rawInState->tokens.end(), + outState->tokens.begin()); + outState->tokens[inStateLength] = tokenIdx; + } + + // Prepare score + float score = 0; + if (tokenIdx < 0 || tokenIdx >= vocabSize_) { + throw std::out_of_range( + "[ConvLM] Invalid query word: " + std::to_string(tokenIdx)); + } + + if (cacheIndices_.find(rawInState) != cacheIndices_.end()) { + // Cache hit + auto cacheInd = cacheIndices_[rawInState]; + if (cacheInd < 0 || cacheInd >= beamSize_) { + throw std::logic_error( + "[ConvLM] Invalid cache access: " + std::to_string(cacheInd)); + } + score = cache_[cacheInd][tokenIdx]; + } else { + // Cache miss + if (cacheIndices_.size() == beamSize_) { + cacheIndices_.clear(); + } + int newIdx = cacheIndices_.size(); + cacheIndices_[rawInState] = newIdx; + + std::vector lastTokenPositions = {rawInState->length - 1}; + cache_[newIdx] = + getConvLmScoreFunc_(rawInState->tokens, lastTokenPositions, -1, 1); + score = cache_[newIdx][tokenIdx]; + } + if (std::isnan(score) || !std::isfinite(score)) { + throw std::runtime_error( + "[ConvLM] Bad scoring from ConvLM: " + std::to_string(score)); + } + return std::make_pair(std::move(outState), score); +} + +std::pair ConvLM::score( + const LMStatePtr& state, + const int usrTokenIdx) { + if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) { + throw std::out_of_range( + "[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx)); + } + return scoreWithLmIdx(state, usrToLmIdxMap_[usrTokenIdx]); +} + +std::pair ConvLM::finish(const LMStatePtr& state) { + return scoreWithLmIdx(state, vocab_.getIndex(kEosToken)); +} + +void ConvLM::updateCache(std::vector states) { + int longestHistory = -1, nStates = states.size(); + if (nStates > beamSize_) { + throw std::invalid_argument( + "[ConvLM] Cache size too small (consider larger than beam size)."); + } + + // Refresh cache, store LM states that did not changed + slot_.clear(); + slot_.resize(beamSize_, nullptr); + for (const auto& state : states) { + auto rawState = std::static_pointer_cast(state).get(); + if (cacheIndices_.find(rawState) != cacheIndices_.end()) { + slot_[cacheIndices_[rawState]] = rawState; + } else if (rawState->length > longestHistory) { + // prepare intest history only for those which should be predicted + longestHistory = rawState->length; + } + } + cacheIndices_.clear(); + int cacheSize = 0; + for (int i = 0; i < beamSize_; i++) { + if (!slot_[i]) { + continue; + } + cache_[cacheSize] = cache_[i]; + cacheIndices_[slot_[i]] = cacheSize; + ++cacheSize; + } + + // Determine batchsize + if (longestHistory <= 0) { + return; + } + // batchSize * longestHistory = cacheSize; + int maxBatchSize = lmMemory_ / longestHistory; + if (maxBatchSize > nStates) { + maxBatchSize = nStates; + } + + // Run batch forward + int batchStart = 0; + while (batchStart < nStates) { + // Select batch + int nBatchStates = 0; + std::vector lastTokenPositions; + for (int i = batchStart; (nBatchStates < maxBatchSize) && (i < nStates); + i++, batchStart++) { + auto rawState = std::static_pointer_cast(states[i]).get(); + if (cacheIndices_.find(rawState) != cacheIndices_.end()) { + continue; + } + cacheIndices_[rawState] = cacheSize + nBatchStates; + int start = nBatchStates * longestHistory; + + for (int j = 0; j < rawState->length; j++) { + batchedTokens_[start + j] = rawState->tokens[j]; + } + start += rawState->length; + for (int j = 0; j < longestHistory - rawState->length; j++) { + batchedTokens_[start + j] = vocab_.getIndex(kPadToken); + } + lastTokenPositions.push_back(rawState->length - 1); + ++nBatchStates; + } + if (nBatchStates == 0 && batchStart >= nStates) { + // if all states were skipped + break; + } + + // Feed forward + if (nBatchStates < 1 || longestHistory < 1) { + throw std::logic_error( + "[ConvLM] Invalid batch: [" + std::to_string(nBatchStates) + " x " + + std::to_string(longestHistory) + "]"); + } + auto batchedProb = getConvLmScoreFunc_( + batchedTokens_, lastTokenPositions, longestHistory, nBatchStates); + + if (batchedProb.size() != vocabSize_ * nBatchStates) { + throw std::logic_error( + "[ConvLM] Batch X Vocab size " + std::to_string(batchedProb.size()) + + " mismatch with " + std::to_string(vocabSize_ * nBatchStates)); + } + // Place probabilities in cache + for (int i = 0; i < nBatchStates; i++, cacheSize++) { + std::memcpy( + cache_[cacheSize].data(), + batchedProb.data() + vocabSize_ * i, + vocabSize_ * sizeof(float)); + } + } +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h new file mode 100644 index 00000000..ae627412 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ConvLM.h @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. An additional grant + * of patent rights can be found in the PATENTS file in the same directory. + */ +#pragma once + +#include + +#include "flashlight/lib/text/decoder/lm/LM.h" +#include "flashlight/lib/text/dictionary/Defines.h" +#include "flashlight/lib/text/dictionary/Dictionary.h" + +namespace fl { +namespace lib { +namespace text { + +using GetConvLmScoreFunc = std::function(const std::vector&, const std::vector&, int, int)>; + +struct ConvLMState : LMState { + std::vector tokens; + int length; + + ConvLMState() : length(0) {} + explicit ConvLMState(int size) + : tokens(std::vector(size)), length(size) {} +}; + +class ConvLM : public LM { + public: + ConvLM( + const GetConvLmScoreFunc& getConvLmScoreFunc, + const std::string& tokenVocabPath, + const Dictionary& usrTknDict, + int lmMemory = 10000, + int beamSize = 2500, + int historySize = 49); + + LMStatePtr start(bool startWithNothing) override; + + std::pair score( + const LMStatePtr& state, + const int usrTokenIdx) override; + + std::pair finish(const LMStatePtr& state) override; + + void updateCache(std::vector states) override; + + private: + // This cache is also not thread-safe! + int lmMemory_; + int beamSize_; + std::unordered_map cacheIndices_; + std::vector> cache_; + std::vector slot_; + std::vector batchedTokens_; + + Dictionary vocab_; + GetConvLmScoreFunc getConvLmScoreFunc_; + + int vocabSize_; + int maxHistorySize_; + + std::pair scoreWithLmIdx( + const LMStatePtr& state, + const int tokenIdx); +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp new file mode 100644 index 00000000..e89f2131 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/text/decoder/lm/KenLM.h" + +#include + +#include + +namespace fl { +namespace lib { +namespace text { + +KenLMState::KenLMState() : ken_(std::make_unique()) {} + +KenLM::KenLM(const std::string& path, const Dictionary& usrTknDict) { + // Load LM + model_.reset(lm::ngram::LoadVirtual(path.c_str())); + if (!model_) { + throw std::runtime_error("[KenLM] LM loading failed."); + } + vocab_ = &model_->BaseVocabulary(); + if (!vocab_) { + throw std::runtime_error("[KenLM] LM vocabulary loading failed."); + } + + // Create index map + usrToLmIdxMap_.resize(usrTknDict.indexSize()); + for (int i = 0; i < usrTknDict.indexSize(); i++) { + auto token = usrTknDict.getEntry(i); + int lmIdx = vocab_->Index(token.c_str()); + usrToLmIdxMap_[i] = lmIdx; + } +} + +LMStatePtr KenLM::start(bool startWithNothing) { + auto outState = std::make_shared(); + if (startWithNothing) { + model_->NullContextWrite(outState->ken()); + } else { + model_->BeginSentenceWrite(outState->ken()); + } + + return outState; +} + +std::pair KenLM::score( + const LMStatePtr& state, + const int usrTokenIdx) { + if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) { + throw std::runtime_error( + "[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx)); + } + auto inState = std::static_pointer_cast(state); + auto outState = inState->child(usrTokenIdx); + float score = model_->BaseScore( + inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken()); + return std::make_pair(std::move(outState), score); +} + +std::pair KenLM::finish(const LMStatePtr& state) { + auto inState = std::static_pointer_cast(state); + auto outState = inState->child(-1); + float score = + model_->BaseScore(inState->ken(), vocab_->EndSentence(), outState->ken()); + return std::make_pair(std::move(outState), score); +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h new file mode 100644 index 00000000..8b667ecb --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/KenLM.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "flashlight/lib/text/decoder/lm/LM.h" +#include "flashlight/lib/text/dictionary/Dictionary.h" + +// Forward declarations to avoid including KenLM headers +namespace lm { +namespace base { + +struct Vocabulary; +struct Model; + +} // namespace base +namespace ngram { + +struct State; + +} // namespace ngram +} // namespace lm + +namespace fl { +namespace lib { +namespace text { + +/** + * KenLMState is a state object from KenLM, which contains context length, + * indicies and compare functions + * https://github.com/kpu/kenlm/blob/master/lm/state.hh. + */ +struct KenLMState : LMState { + KenLMState(); + std::unique_ptr ken_; + lm::ngram::State* ken() { + return ken_.get(); + } +}; + +/** + * KenLM extends LM by using the toolkit https://kheafield.com/code/kenlm/. + */ +class KenLM : public LM { + public: + KenLM(const std::string& path, const Dictionary& usrTknDict); + + LMStatePtr start(bool startWithNothing) override; + + std::pair score( + const LMStatePtr& state, + const int usrTokenIdx) override; + + std::pair finish(const LMStatePtr& state) override; + + private: + std::shared_ptr model_; + const lm::base::Vocabulary* vocab_; +}; + +using KenLMPtr = std::shared_ptr; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h new file mode 100644 index 00000000..a993839f --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/LM.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fl { +namespace lib { +namespace text { + +struct LMState { + std::unordered_map> children; + + template + std::shared_ptr child(int usrIdx) { + auto s = children.find(usrIdx); + if (s == children.end()) { + auto state = std::make_shared(); + children[usrIdx] = state; + return state; + } else { + return std::static_pointer_cast(s->second); + } + } + + /* Compare two language model states. */ + int compare(const std::shared_ptr& state) const { + LMState* inState = state.get(); + if (!state) { + throw std::runtime_error("a state is null"); + } + if (this == inState) { + return 0; + } else if (this < inState) { + return -1; + } else { + return 1; + } + }; +}; + +/** + * LMStatePtr is a shared LMState* tracking LM states generated during decoding. + */ +using LMStatePtr = std::shared_ptr; + +/** + * LM is a thin wrapper for laguage models. We abstrct several common methods + * here which can be shared for KenLM, ConvLM, RNNLM, etc. + */ +class LM { + public: + /* Initialize or reset language model */ + virtual LMStatePtr start(bool startWithNothing) = 0; + + /** + * Query the language model given input language model state and a specific + * token, return a new language model state and score. + */ + virtual std::pair score( + const LMStatePtr& state, + const int usrTokenIdx) = 0; + + /* Query the language model and finish decoding. */ + virtual std::pair finish(const LMStatePtr& state) = 0; + + /* Update LM caches (optional) given a bunch of new states generated */ + virtual void updateCache(std::vector stateIdices) {} + + virtual ~LM() = default; + + protected: + /* Map indices from acoustic model to LM for each valid token. */ + std::vector usrToLmIdxMap_; +}; + +using LMPtr = std::shared_ptr; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp new file mode 100644 index 00000000..5edbfddb --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/text/decoder/lm/ZeroLM.h" + +#include + +namespace fl { +namespace lib { +namespace text { + +LMStatePtr ZeroLM::start(bool /* unused */) { + return std::make_shared(); +} + +std::pair ZeroLM::score( + const LMStatePtr& state /* unused */, + const int usrTokenIdx) { + return std::make_pair(state->child(usrTokenIdx), 0.0); +} + +std::pair ZeroLM::finish(const LMStatePtr& state) { + return std::make_pair(state, 0.0); +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h new file mode 100644 index 00000000..36509a65 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/lm/ZeroLM.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include "flashlight/lib/text/decoder/lm/LM.h" + +namespace fl { +namespace lib { +namespace text { + +/** + * ZeroLM is a dummy language model class, which mimics the behavious of a + * uni-gram language model but always returns 0 as score. + */ +class ZeroLM : public LM { + public: + LMStatePtr start(bool startWithNothing) override; + + std::pair score( + const LMStatePtr& state, + const int usrTokenIdx) override; + + std::pair finish(const LMStatePtr& state) override; +}; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h new file mode 100644 index 00000000..084457a8 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Defines.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +namespace fl { +namespace lib { +namespace text { + +constexpr const char* kUnkToken = ""; +constexpr const char* kEosToken = ""; +constexpr const char* kPadToken = ""; +constexpr const char* kMaskToken = ""; + +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp new file mode 100644 index 00000000..d90de960 --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.cpp @@ -0,0 +1,152 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "flashlight/lib/common/String.h" +#include "flashlight/lib/common/System.h" +#include "flashlight/lib/text/dictionary/Dictionary.h" +#include "flashlight/lib/text/dictionary/Utils.h" + +namespace fl { +namespace lib { +namespace text { + +Dictionary::Dictionary(std::istream& stream) { + createFromStream(stream); +} + +Dictionary::Dictionary(const std::string& filename) { + std::ifstream stream = createInputStream(filename); + createFromStream(stream); +} + +void Dictionary::createFromStream(std::istream& stream) { + if (!stream) { + throw std::runtime_error("Unable to open dictionary input stream."); + } + std::string line; + while (std::getline(stream, line)) { + if (line.empty()) { + continue; + } + auto tkns = splitOnWhitespace(line, true); + auto idx = idx2entry_.size(); + // All entries on the same line map to the same index + for (const auto& tkn : tkns) { + addEntry(tkn, idx); + } + } + if (!isContiguous()) { + throw std::runtime_error("Invalid dictionary format - not contiguous"); + } +} + +void Dictionary::addEntry(const std::string& entry, int idx) { + if (entry2idx_.find(entry) != entry2idx_.end()) { + throw std::invalid_argument( + "Duplicate entry name in dictionary '" + entry + "'"); + } + entry2idx_[entry] = idx; + if (idx2entry_.find(idx) == idx2entry_.end()) { + idx2entry_[idx] = entry; + } +} + +void Dictionary::addEntry(const std::string& entry) { + // Check if the entry already exists in the dictionary + if (entry2idx_.find(entry) != entry2idx_.end()) { + throw std::invalid_argument( + "Duplicate entry in dictionary '" + entry + "'"); + } + int idx = idx2entry_.size(); + // Find first available index. + while (idx2entry_.find(idx) != idx2entry_.end()) { + ++idx; + } + addEntry(entry, idx); +} + +std::string Dictionary::getEntry(int idx) const { + auto iter = idx2entry_.find(idx); + if (iter == idx2entry_.end()) { + throw std::invalid_argument( + "Unknown index in dictionary '" + std::to_string(idx) + "'"); + } + return iter->second; +} + +void Dictionary::setDefaultIndex(int idx) { + defaultIndex_ = idx; +} + +int Dictionary::getIndex(const std::string& entry) const { + auto iter = entry2idx_.find(entry); + if (iter == entry2idx_.end()) { + if (defaultIndex_ < 0) { + throw std::invalid_argument( + "Unknown entry in dictionary: '" + entry + "'"); + } else { + return defaultIndex_; + } + } + return iter->second; +} + +bool Dictionary::contains(const std::string& entry) const { + auto iter = entry2idx_.find(entry); + if (iter == entry2idx_.end()) { + return false; + } + return true; +} + +size_t Dictionary::entrySize() const { + return entry2idx_.size(); +} + +bool Dictionary::isContiguous() const { + for (size_t i = 0; i < indexSize(); ++i) { + if (idx2entry_.find(i) == idx2entry_.end()) { + return false; + } + } + for (const auto& tknidx : entry2idx_) { + if (idx2entry_.find(tknidx.second) == idx2entry_.end()) { + return false; + } + } + return true; +} + +std::vector Dictionary::mapEntriesToIndices( + const std::vector& entries) const { + std::vector indices; + indices.reserve(entries.size()); + for (const auto& tkn : entries) { + indices.emplace_back(getIndex(tkn)); + } + return indices; +} + +std::vector Dictionary::mapIndicesToEntries( + const std::vector& indices) const { + std::vector entries; + entries.reserve(indices.size()); + for (const auto& idx : indices) { + entries.emplace_back(getEntry(idx)); + } + return entries; +} + +size_t Dictionary::indexSize() const { + return idx2entry_.size(); +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h new file mode 100644 index 00000000..a37c61ee --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Dictionary.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fl { +namespace lib { +namespace text { +// A simple dictionary class which holds a bidirectional map +// entry (strings) <--> integer indices. Not thread-safe ! +class Dictionary { + public: + // Creates an empty dictionary + Dictionary() {} + + explicit Dictionary(std::istream& stream); + + explicit Dictionary(const std::string& filename); + + size_t entrySize() const; + + size_t indexSize() const; + + void addEntry(const std::string& entry, int idx); + + void addEntry(const std::string& entry); + + std::string getEntry(int idx) const; + + void setDefaultIndex(int idx); + + int getIndex(const std::string& entry) const; + + bool contains(const std::string& entry) const; + + // checks if all the indices are contiguous + bool isContiguous() const; + + std::vector mapEntriesToIndices( + const std::vector& entries) const; + + std::vector mapIndicesToEntries( + const std::vector& indices) const; + + private: + // Creates a dictionary from an input stream + void createFromStream(std::istream& stream); + + std::unordered_map entry2idx_; + std::unordered_map idx2entry_; + int defaultIndex_ = -1; +}; + +typedef std::unordered_map DictionaryMap; +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp new file mode 100644 index 00000000..375a81ca --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/text/dictionary/Utils.h" +#include "flashlight/lib/common/String.h" +#include "flashlight/lib/common/System.h" +#include "flashlight/lib/text/dictionary/Defines.h" + +namespace fl { +namespace lib { +namespace text { + +Dictionary createWordDict(const LexiconMap& lexicon) { + Dictionary dict; + for (const auto& it : lexicon) { + dict.addEntry(it.first); + } + dict.setDefaultIndex(dict.getIndex(kUnkToken)); + return dict; +} + +LexiconMap loadWords(const std::string& filename, int maxWords) { + LexiconMap lexicon; + + std::string line; + std::ifstream infile = createInputStream(filename); + + // Add at most `maxWords` words into the lexicon. + // If `maxWords` is negative then no limit is applied. + while (maxWords != lexicon.size() && std::getline(infile, line)) { + // Parse the line into two strings: word and spelling. + auto fields = splitOnWhitespace(line, true); + if (fields.size() < 2) { + throw std::runtime_error("[loadWords] Invalid line: " + line); + } + const std::string& word = fields[0]; + std::vector spelling(fields.size() - 1); + std::copy(fields.begin() + 1, fields.end(), spelling.begin()); + + // Add the word into the dictionary. + if (lexicon.find(word) == lexicon.end()) { + lexicon[word] = {}; + } + + // Add the current spelling of the words to the list of spellings. + lexicon[word].push_back(spelling); + } + + // Insert unknown word. + lexicon[kUnkToken] = {}; + return lexicon; +} + +std::vector splitWrd(const std::string& word) { + std::vector tokens; + tokens.reserve(word.size()); + int len = word.length(); + for (int i = 0; i < len;) { + auto c = static_cast(word[i]); + int curTknBytes = -1; + // UTF-8 checks, works for ASCII automatically + if ((c & 0x80) == 0) { + curTknBytes = 1; + } else if ((c & 0xE0) == 0xC0) { + curTknBytes = 2; + } else if ((c & 0xF0) == 0xE0) { + curTknBytes = 3; + } else if ((c & 0xF8) == 0xF0) { + curTknBytes = 4; + } + if (curTknBytes == -1 || i + curTknBytes > len) { + throw std::runtime_error("splitWrd: invalid UTF-8 : " + word); + } + tokens.emplace_back(word.begin() + i, word.begin() + i + curTknBytes); + i += curTknBytes; + } + return tokens; +} + +std::vector packReplabels( + const std::vector& tokens, + const Dictionary& dict, + int maxReps) { + if (tokens.empty() || maxReps <= 0) { + return tokens; + } + + std::vector replabelValueToIdx(maxReps + 1); + for (int i = 1; i <= maxReps; ++i) { + replabelValueToIdx[i] = dict.getIndex("<" + std::to_string(i) + ">"); + } + + std::vector result; + int prevToken = -1; + int numReps = 0; + for (int token : tokens) { + if (token == prevToken && numReps < maxReps) { + numReps++; + } else { + if (numReps > 0) { + result.push_back(replabelValueToIdx[numReps]); + numReps = 0; + } + result.push_back(token); + prevToken = token; + } + } + if (numReps > 0) { + result.push_back(replabelValueToIdx[numReps]); + } + return result; +} + +std::vector unpackReplabels( + const std::vector& tokens, + const Dictionary& dict, + int maxReps) { + if (tokens.empty() || maxReps <= 0) { + return tokens; + } + + std::unordered_map replabelIdxToValue; + for (int i = 1; i <= maxReps; ++i) { + replabelIdxToValue.emplace(dict.getIndex("<" + std::to_string(i) + ">"), i); + } + + std::vector result; + int prevToken = -1; + for (int token : tokens) { + auto it = replabelIdxToValue.find(token); + if (it == replabelIdxToValue.end()) { + result.push_back(token); + prevToken = token; + } else if (prevToken != -1) { + result.insert(result.end(), it->second, prevToken); + prevToken = -1; + } + } + return result; +} +} // namespace text +} // namespace lib +} // namespace fl diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h new file mode 100644 index 00000000..6309fe1d --- /dev/null +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/dictionary/Utils.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include "flashlight/lib/text/dictionary/Dictionary.h" + +namespace fl { +namespace lib { +namespace text { + +using LexiconMap = + std::unordered_map>>; + +Dictionary createWordDict(const LexiconMap& lexicon); + +LexiconMap loadWords(const std::string& filename, int maxWords = -1); + +// split word into tokens abc -> {"a", "b", "c"} +// Works with ASCII, UTF-8 encodings +std::vector splitWrd(const std::string& word); + +/** + * Pack a token sequence by replacing consecutive repeats with replabels, + * e.g. "abbccc" -> "ab1c2". The tokens "1", "2", ..., `to_string(maxReps)` + * must already be in `dict`. + */ +std::vector packReplabels( + const std::vector& tokens, + const Dictionary& dict, + int maxReps); + +/** + * Unpack a token sequence by replacing replabels with repeated tokens, + * e.g. "ab1c2" -> "abbccc". The tokens "1", "2", ..., `to_string(maxReps)` + * must already be in `dict`. + */ +std::vector unpackReplabels( + const std::vector& tokens, + const Dictionary& dict, + int maxReps); +} // namespace text +} // namespace lib +} // namespace fl diff --git a/training/coqui_stt_training/deepspeech_model.py b/training/coqui_stt_training/deepspeech_model.py index c0579f63..4f4ce701 100644 --- a/training/coqui_stt_training/deepspeech_model.py +++ b/training/coqui_stt_training/deepspeech_model.py @@ -387,7 +387,7 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False): "input_samples": input_samples, } - if not Config.export_tflite: + if not tflite: inputs["input_lengths"] = seq_length outputs = { diff --git a/training/coqui_stt_training/evaluate_flashlight.py b/training/coqui_stt_training/evaluate_flashlight.py new file mode 100644 index 00000000..6ac96ce0 --- /dev/null +++ b/training/coqui_stt_training/evaluate_flashlight.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function + +import json +import sys +from multiprocessing import cpu_count + +import progressbar +import tensorflow.compat.v1 as tfv1 +from coqui_stt_ctcdecoder import ( + Scorer, + flashlight_beam_search_decoder_batch, + FlashlightDecoderState, +) +from six.moves import zip + +import tensorflow as tf + +from .deepspeech_model import create_model +from .util.augmentations import NormalizeSampleRate +from .util.checkpoints import load_graph_for_evaluation +from .util.config import ( + Config, + create_progressbar, + initialize_globals_from_cli, + log_error, + log_progress, +) +from .util.evaluate_tools import calculate_and_print_report, save_samples_json +from .util.feeding import create_dataset +from .util.helpers import check_ctcdecoder_version + + +def sparse_tensor_value_to_texts(value, alphabet): + r""" + Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings + representing its values, converting tokens to strings using ``alphabet``. + """ + return sparse_tuple_to_texts( + (value.indices, value.values, value.dense_shape), alphabet + ) + + +def sparse_tuple_to_texts(sp_tuple, alphabet): + indices = sp_tuple[0] + values = sp_tuple[1] + results = [[] for _ in range(sp_tuple[2][0])] + for i, index in enumerate(indices): + results[index[0]].append(values[i]) + # List of strings + return [alphabet.Decode(res) for res in results] + + +def evaluate(test_csvs, create_model): + if Config.scorer_path: + scorer = Scorer( + Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet + ) + else: + scorer = None + + test_sets = [ + create_dataset( + [csv], + batch_size=Config.test_batch_size, + train_phase=False, + augmentations=[NormalizeSampleRate(Config.audio_sample_rate)], + reverse=Config.reverse_test, + limit=Config.limit_test, + ) + for csv in test_csvs + ] + iterator = tfv1.data.Iterator.from_structure( + tfv1.data.get_output_types(test_sets[0]), + tfv1.data.get_output_shapes(test_sets[0]), + output_classes=tfv1.data.get_output_classes(test_sets[0]), + ) + test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] + + batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() + + # One rate per layer + no_dropout = [None] * 6 + logits, _ = create_model( + batch_x=batch_x, seq_length=batch_x_len, dropout=no_dropout + ) + + # Transpose to batch major and apply softmax for decoder + transposed = tf.nn.softmax(tf.transpose(a=logits, perm=[1, 0, 2])) + + loss = tfv1.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_x_len) + + tfv1.train.get_or_create_global_step() + + # Get number of accessible CPU cores for this process + try: + num_processes = cpu_count() + except NotImplementedError: + num_processes = 1 + + with open(Config.vocab_file) as fin: + vocab = [l.strip().encode("utf-8") for l in fin] + + with tfv1.Session(config=Config.session_config) as session: + load_graph_for_evaluation(session) + + def run_test(init_op, dataset): + wav_filenames = [] + losses = [] + predictions = [] + ground_truths = [] + + bar = create_progressbar( + prefix="Test epoch | ", + widgets=["Steps: ", progressbar.Counter(), " | ", progressbar.Timer()], + ).start() + log_progress("Test epoch...") + + step_count = 0 + + # Initialize iterator to the appropriate dataset + session.run(init_op) + + # First pass, compute losses and transposed logits for decoding + while True: + try: + ( + batch_wav_filenames, + batch_logits, + batch_loss, + batch_lengths, + batch_transcripts, + ) = session.run( + [batch_wav_filename, transposed, loss, batch_x_len, batch_y] + ) + except tf.errors.OutOfRangeError: + break + + decoded = flashlight_beam_search_decoder_batch( + batch_logits, + batch_lengths, + Config.alphabet, + beam_size=Config.export_beam_width, + decoder_type=FlashlightDecoderState.LexiconBased, + token_type=FlashlightDecoderState.Aggregate, + lm_tokens=vocab, + num_processes=num_processes, + scorer=scorer, + cutoff_top_n=Config.cutoff_top_n, + ) + predictions.extend(" ".join(d[0].words) for d in decoded) + ground_truths.extend( + sparse_tensor_value_to_texts(batch_transcripts, Config.alphabet) + ) + wav_filenames.extend( + wav_filename.decode("UTF-8") for wav_filename in batch_wav_filenames + ) + losses.extend(batch_loss) + + step_count += 1 + bar.update(step_count) + + bar.finish() + + # Print test summary + test_samples = calculate_and_print_report( + wav_filenames, ground_truths, predictions, losses, dataset + ) + return test_samples + + samples = [] + for csv, init_op in zip(test_csvs, test_init_ops): + print("Testing model on {}".format(csv)) + samples.extend(run_test(init_op, dataset=csv)) + return samples + + +def test(): + tfv1.reset_default_graph() + + samples = evaluate(Config.test_files, create_model) + if Config.test_output_file: + save_samples_json(samples, Config.test_output_file) + + +def main(): + initialize_globals_from_cli() + check_ctcdecoder_version() + + if not Config.test_files: + raise RuntimeError( + "You need to specify what files to use for evaluation via " + "the --test_files flag." + ) + + test() + + +if __name__ == "__main__": + main() diff --git a/training/coqui_stt_training/training_graph_inference_flashlight.py b/training/coqui_stt_training/training_graph_inference_flashlight.py new file mode 100644 index 00000000..cb3ab273 --- /dev/null +++ b/training/coqui_stt_training/training_graph_inference_flashlight.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os +import sys + +LOG_LEVEL_INDEX = sys.argv.index("--log_level") + 1 if "--log_level" in sys.argv else 0 +DESIRED_LOG_LEVEL = ( + sys.argv[LOG_LEVEL_INDEX] if 0 < LOG_LEVEL_INDEX < len(sys.argv) else "3" +) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = DESIRED_LOG_LEVEL + +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tfv1 + +from coqui_stt_ctcdecoder import ( + flashlight_beam_search_decoder, + Scorer, + FlashlightDecoderState, +) +from .deepspeech_model import create_inference_graph, create_overlapping_windows +from .util.checkpoints import load_graph_for_evaluation +from .util.config import Config, initialize_globals_from_cli, log_error +from .util.feeding import audiofile_to_features + + +def do_single_file_inference(input_file_path): + tfv1.reset_default_graph() + + with open(Config.vocab_file) as fin: + vocab = [w.encode("utf-8") for w in [l.strip() for l in fin]] + + with tfv1.Session(config=Config.session_config) as session: + inputs, outputs, layers = create_inference_graph(batch_size=1, n_steps=-1) + + # Restore variables from training checkpoint + load_graph_for_evaluation(session) + + features, features_len = audiofile_to_features(input_file_path) + previous_state_c = np.zeros([1, Config.n_cell_dim]) + previous_state_h = np.zeros([1, Config.n_cell_dim]) + + # Add batch dimension + features = tf.expand_dims(features, 0) + features_len = tf.expand_dims(features_len, 0) + + # Evaluate + features = create_overlapping_windows(features).eval(session=session) + features_len = features_len.eval(session=session) + + probs = layers["raw_logits"].eval( + feed_dict={ + inputs["input"]: features, + inputs["input_lengths"]: features_len, + inputs["previous_state_c"]: previous_state_c, + inputs["previous_state_h"]: previous_state_h, + }, + session=session, + ) + + probs = np.squeeze(probs) + + if Config.scorer_path: + scorer = Scorer( + Config.lm_alpha, Config.lm_beta, Config.scorer_path, Config.alphabet + ) + else: + scorer = None + decoded = flashlight_beam_search_decoder( + probs, + Config.alphabet, + beam_size=Config.export_beam_width, + decoder_type=FlashlightDecoderState.LexiconBased, + token_type=FlashlightDecoderState.Aggregate, + lm_tokens=vocab, + scorer=scorer, + cutoff_top_n=Config.cutoff_top_n, + ) + # Print highest probability result + print(" ".join(d.decode("utf-8") for d in decoded[0].words)) + + +def main(): + initialize_globals_from_cli() + + if Config.one_shot_infer: + tfv1.reset_default_graph() + do_single_file_inference(Config.one_shot_infer) + else: + raise RuntimeError( + "Calling training_graph_inference script directly but no --one_shot_infer input audio file specified" + ) + + +if __name__ == "__main__": + main() diff --git a/training/coqui_stt_training/util/config.py b/training/coqui_stt_training/util/config.py index 4ddc62b2..008953a8 100644 --- a/training/coqui_stt_training/util/config.py +++ b/training/coqui_stt_training/util/config.py @@ -322,6 +322,13 @@ class _SttConfig(Coqpit): ), ) + vocab_file: str = field( + default="", + metadata=dict( + help="For use with evaluate_flashlight - text file containing vocabulary of scorer, one word per line." + ), + ) + read_buffer: str = field( default="1MB", metadata=dict( From 04f62ac9f7f56f9953b335db9b5e30085e7f5f6b Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Sat, 30 Oct 2021 14:59:22 +0200 Subject: [PATCH 2/4] Exercise training graph inference/Flashlight decoder in extra training tests --- bin/run-ci-ldc93s1_singleshotinference.sh | 12 ++++++++++-- ci_scripts/train-extra-tests.sh | 3 +++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/bin/run-ci-ldc93s1_singleshotinference.sh b/bin/run-ci-ldc93s1_singleshotinference.sh index 699b09cb..3589166b 100755 --- a/bin/run-ci-ldc93s1_singleshotinference.sh +++ b/bin/run-ci-ldc93s1_singleshotinference.sh @@ -14,7 +14,8 @@ fi; # and when trying to run on multiple devices (like GPUs), this will break export CUDA_VISIBLE_DEVICES=0 -python -u train.py --alphabet_config_path "data/alphabet.txt" \ +python -m coqui_stt_training.train \ + --alphabet_config_path "data/alphabet.txt" \ --show_progressbar false --early_stop false \ --train_files ${ldc93s1_csv} --train_batch_size 1 \ --dev_files ${ldc93s1_csv} --dev_batch_size 1 \ @@ -24,8 +25,15 @@ python -u train.py --alphabet_config_path "data/alphabet.txt" \ --learning_rate 0.001 --dropout_rate 0.05 \ --scorer_path 'data/smoke_test/pruned_lm.scorer' -python -u train.py --alphabet_config_path "data/alphabet.txt" \ +python -m coqui_stt_training.training_graph_inference \ --n_hidden 100 \ --checkpoint_dir '/tmp/ckpt' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \ --one_shot_infer 'data/smoke_test/LDC93S1.wav' + +python -m coqui_stt_training.training_graph_inference_flashlight \ + --n_hidden 100 \ + --checkpoint_dir '/tmp/ckpt' \ + --scorer_path 'data/smoke_test/pruned_lm.scorer' \ + --vocab_file 'data/smoke_test/vocab.pruned.txt' \ + --one_shot_infer 'data/smoke_test/LDC93S1.wav' diff --git a/ci_scripts/train-extra-tests.sh b/ci_scripts/train-extra-tests.sh index 50265afc..f538110d 100755 --- a/ci_scripts/train-extra-tests.sh +++ b/ci_scripts/train-extra-tests.sh @@ -69,3 +69,6 @@ time ./bin/run-ci-ldc93s1_checkpoint_bytes.sh # Training with args set via initialize_globals_from_args() time python ./bin/run-ldc93s1.py + +# Training graph inference +time ./bin/run-ci-ldc93s1_singleshotinference.sh From 391036643c723dce42c1dc2861e5dbdf8664ba2d Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Sat, 30 Oct 2021 15:07:05 +0200 Subject: [PATCH 3/4] debug --- .github/workflows/build-and-test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 82f83a5d..0b6ccf1a 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -1713,6 +1713,9 @@ jobs: make -C native_client/ctcdecode/ \ NUM_PROCESSES=$(nproc) \ bindings + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + if: failure() - uses: actions/upload-artifact@v2 with: name: "coqui_stt_ctcdecoder-windows-test.whl" From a61180aeaed7aed2c5f15e5672e9f4804bada21e Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Sat, 30 Oct 2021 14:48:32 +0200 Subject: [PATCH 4/4] Fix Flashlight multiplatform build --- .../ctcdecode/ctc_beam_search_decoder.cpp | 27 ++++--- native_client/ctcdecode/scorer.cpp | 1 + .../flashlight/flashlight/lib/common/String.h | 1 + .../flashlight/lib/common/System.cpp | 73 ------------------- .../flashlight/flashlight/lib/common/System.h | 10 --- .../flashlight/lib/text/decoder/Trie.cpp | 1 + 6 files changed, 16 insertions(+), 97 deletions(-) diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index af029b9d..179ec467 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -407,13 +407,12 @@ FlashlightDecoderState::intermediate(bool prune) valid_words.push_back(w); } } - FlashlightOutput ret { - .aggregate_score = result.score, - .acoustic_model_score = result.amScore, - .language_model_score = result.lmScore, - .words = lm_tokens_.mapIndicesToEntries(valid_words), // how does this interact with token-based decoding? - .tokens = result.tokens - }; + FlashlightOutput ret; + ret.aggregate_score = result.score; + ret.acoustic_model_score = result.amScore; + ret.language_model_score = result.lmScore; + ret.words = lm_tokens_.mapIndicesToEntries(valid_words); // how does this interact with token-based decoding + ret.tokens = result.tokens; if (prune) { decoder_impl_->prune(); } @@ -433,13 +432,13 @@ FlashlightDecoderState::decode(size_t num_results) valid_words.push_back(w); } } - ret.push_back({ - .aggregate_score = result.score, - .acoustic_model_score = result.amScore, - .language_model_score = result.lmScore, - .words = lm_tokens_.mapIndicesToEntries(valid_words), // how does this interact with token-based decoding? - .tokens = result.tokens - }); + FlashlightOutput out; + out.aggregate_score = result.score; + out.acoustic_model_score = result.amScore; + out.language_model_score = result.lmScore; + out.words = lm_tokens_.mapIndicesToEntries(valid_words); // how does this interact with token-based decoding + out.tokens = result.tokens; + ret.push_back(out); } decoder_impl_.reset(nullptr); return ret; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index d5ca6cbc..34ad90fb 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -1,6 +1,7 @@ #ifdef _MSC_VER #include #include + #define NOMINMAX #include #define R_OK 4 /* Read permission. */ diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h index 492c710a..9ce89949 100644 --- a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/String.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp index fd89ac58..1afd13ac 100644 --- a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.cpp @@ -7,7 +7,6 @@ #include "flashlight/lib/common/System.h" -#include #include #include #include @@ -119,45 +118,6 @@ bool dirExists(const std::string& path) { } } -void dirCreate(const std::string& path) { - if (dirExists(path)) { - return; - } - mode_t nMode = 0755; - int nError = 0; -#ifdef _WIN32 - nError = _mkdir(path.c_str()); -#else - nError = mkdir(path.c_str(), nMode); -#endif - if (nError != 0) { - throw std::runtime_error( - std::string() + "Unable to create directory - " + path); - } -} - -void dirCreateRecursive(const std::string& path) { - if (dirExists(path)) { - return; - } - std::vector dirsOnPath = getDirsOnPath(path); - std::string pathFromStart; - if (path[0] == pathSeperator()[0]) { - pathFromStart = pathSeperator(); - } - for (std::string& dir : dirsOnPath) { - if (pathFromStart.empty()) { - pathFromStart = dir; - } else { - pathFromStart = pathsConcat(pathFromStart, dir); - } - - if (!dirExists(pathFromStart)) { - dirCreate(pathFromStart); - } - } -} - bool fileExists(const std::string& path) { std::ifstream fs(path, std::ifstream::in); return fs.good(); @@ -170,28 +130,6 @@ std::string getEnvVar( return val ? std::string(val) : dflt; } -std::string getCurrentDate() { - time_t now = time(nullptr); - struct tm tmbuf; - struct tm* tstruct; - tstruct = localtime_r(&now, &tmbuf); - - std::array buf; - strftime(buf.data(), buf.size(), "%Y-%m-%d", tstruct); - return std::string(buf.data()); -} - -std::string getCurrentTime() { - time_t now = time(nullptr); - struct tm tmbuf; - struct tm* tstruct; - tstruct = localtime_r(&now, &tmbuf); - - std::array buf; - strftime(buf.data(), buf.size(), "%X", tstruct); - return std::string(buf.data()); -} - std::string getTmpPath(const std::string& filename) { std::string tmpDir = "/tmp"; auto getTmpDir = [&tmpDir](const std::string& env) { @@ -217,17 +155,6 @@ std::vector getFileContent(const std::string& file) { return data; } -std::vector fileGlob(const std::string& pat) { - glob_t result; - glob(pat.c_str(), GLOB_TILDE, nullptr, &result); - std::vector ret; - for (unsigned int i = 0; i < result.gl_pathc; ++i) { - ret.push_back(std::string(result.gl_pathv[i])); - } - globfree(&result); - return ret; -} - std::ifstream createInputStream(const std::string& filename) { std::ifstream file(filename); if (!file.is_open()) { diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h index c63ed1bb..761c9173 100644 --- a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/common/System.h @@ -31,24 +31,14 @@ std::string basename(const std::string& path); bool dirExists(const std::string& path); -void dirCreate(const std::string& path); - -void dirCreateRecursive(const std::string& path); - bool fileExists(const std::string& path); std::string getEnvVar(const std::string& key, const std::string& dflt = ""); -std::string getCurrentDate(); - -std::string getCurrentTime(); - std::string getTmpPath(const std::string& filename); std::vector getFileContent(const std::string& file); -std::vector fileGlob(const std::string& pat); - std::ifstream createInputStream(const std::string& filename); std::ofstream createOutputStream( diff --git a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp index df037e1b..2779a5b4 100644 --- a/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp +++ b/native_client/ctcdecode/third_party/flashlight/flashlight/lib/text/decoder/Trie.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "flashlight/lib/text/decoder/Trie.h"