diff --git a/native_client/BUILD b/native_client/BUILD index 250bc450..1e4a66eb 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -27,20 +27,6 @@ genrule( tools = [":gen_workspace_status.sh"], ) -KENLM_SOURCES = glob( - [ - "kenlm/lm/*.cc", - "kenlm/util/*.cc", - "kenlm/util/double-conversion/*.cc", - "kenlm/lm/*.hh", - "kenlm/util/*.hh", - "kenlm/util/double-conversion/*.h", - ], - exclude = [ - "kenlm/*/*test.cc", - "kenlm/*/*main.cc", - ], -) OPENFST_SOURCES_PLATFORM = select({ "//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]), @@ -60,6 +46,27 @@ LINUX_LINKOPTS = [ "-Wl,-export-dynamic", ] +cc_library( + name = "kenlm", + srcs = glob([ + "kenlm/lm/*.cc", + "kenlm/util/*.cc", + "kenlm/util/double-conversion/*.cc", + "kenlm/util/double-conversion/*.h", + ], + exclude = [ + "kenlm/*/*test.cc", + "kenlm/*/*main.cc", + ],), + hdrs = glob([ + "kenlm/lm/*.hh", + "kenlm/util/*.hh", + ]), + copts = ["-std=c++11"], + defines = ["KENLM_MAX_ORDER=6"], + includes = ["kenlm"], +) + cc_library( name = "decoder", srcs = [ @@ -69,17 +76,16 @@ cc_library( "ctcdecode/scorer.cpp", "ctcdecode/path_trie.cpp", "ctcdecode/path_trie.h", - ] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM, + ] + OPENFST_SOURCES_PLATFORM, hdrs = [ "ctcdecode/ctc_beam_search_decoder.h", "ctcdecode/scorer.h", ], - defines = ["KENLM_MAX_ORDER=6"], includes = [ ".", "ctcdecode/third_party/ThreadPool", - "kenlm", ] + OPENFST_INCLUDES_PLATFORM, + deps = [":kenlm"] ) tf_cc_shared_object( @@ -181,6 +187,15 @@ genrule( cmd = "dsymutil $(location :libdeepspeech.so) -o $@" ) +cc_binary( + name = "enumerate_kenlm_vocabulary", + srcs = [ + "enumerate_kenlm_vocabulary.cpp", + ], + deps = [":kenlm"], + copts = ["-std=c++11"], +) + cc_binary( name = "trie_load", srcs = [ diff --git a/native_client/enumerate_kenlm_vocabulary.cpp b/native_client/enumerate_kenlm_vocabulary.cpp new file mode 100644 index 00000000..79a8cab6 --- /dev/null +++ b/native_client/enumerate_kenlm_vocabulary.cpp @@ -0,0 +1,50 @@ +#include +#include +#include +#include + +#include "lm/enumerate_vocab.hh" +#include "lm/virtual_interface.hh" +#include "lm/word_index.hh" +#include "lm/model.hh" + +const std::string START_TOKEN = ""; +const std::string UNK_TOKEN = ""; +const std::string END_TOKEN = ""; + +// Implement a callback to retrieve the dictionary of language model. +class RetrieveStrEnumerateVocab : public lm::EnumerateVocab +{ +public: + RetrieveStrEnumerateVocab() {} + + void Add(lm::WordIndex index, const StringPiece &str) { + vocabulary.push_back(std::string(str.data(), str.length())); + } + + std::vector vocabulary; +}; + +int main(int argc, char** argv) +{ + if (argc != 3) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + const char* kenlm_model = argv[1]; + const char* output_path = argv[2]; + + std::unique_ptr language_model_; + lm::ngram::Config config; + RetrieveStrEnumerateVocab enumerate; + config.enumerate_vocab = &enumerate; + language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config)); + + std::ofstream fout(output_path); + for (const std::string& word : enumerate.vocabulary) { + fout << word << "\n"; + } + + return 0; +}