Add tool to extract vocabulary from the old LM binary format
This commit is contained in:
parent
a156d28504
commit
708b21a63e
|
@ -27,20 +27,6 @@ genrule(
|
|||
tools = [":gen_workspace_status.sh"],
|
||||
)
|
||||
|
||||
KENLM_SOURCES = glob(
|
||||
[
|
||||
"kenlm/lm/*.cc",
|
||||
"kenlm/util/*.cc",
|
||||
"kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/util/*.hh",
|
||||
"kenlm/util/double-conversion/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"kenlm/*/*test.cc",
|
||||
"kenlm/*/*main.cc",
|
||||
],
|
||||
)
|
||||
|
||||
OPENFST_SOURCES_PLATFORM = select({
|
||||
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
||||
|
@ -60,6 +46,27 @@ LINUX_LINKOPTS = [
|
|||
"-Wl,-export-dynamic",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "kenlm",
|
||||
srcs = glob([
|
||||
"kenlm/lm/*.cc",
|
||||
"kenlm/util/*.cc",
|
||||
"kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/util/double-conversion/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"kenlm/*/*test.cc",
|
||||
"kenlm/*/*main.cc",
|
||||
],),
|
||||
hdrs = glob([
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/util/*.hh",
|
||||
]),
|
||||
copts = ["-std=c++11"],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = ["kenlm"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "decoder",
|
||||
srcs = [
|
||||
|
@ -69,17 +76,16 @@ cc_library(
|
|||
"ctcdecode/scorer.cpp",
|
||||
"ctcdecode/path_trie.cpp",
|
||||
"ctcdecode/path_trie.h",
|
||||
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM,
|
||||
] + OPENFST_SOURCES_PLATFORM,
|
||||
hdrs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
"ctcdecode/scorer.h",
|
||||
],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = [
|
||||
".",
|
||||
"ctcdecode/third_party/ThreadPool",
|
||||
"kenlm",
|
||||
] + OPENFST_INCLUDES_PLATFORM,
|
||||
deps = [":kenlm"]
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
|
@ -181,6 +187,15 @@ genrule(
|
|||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "enumerate_kenlm_vocabulary",
|
||||
srcs = [
|
||||
"enumerate_kenlm_vocabulary.cpp",
|
||||
],
|
||||
deps = [":kenlm"],
|
||||
copts = ["-std=c++11"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "trie_load",
|
||||
srcs = [
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
#include "lm/model.hh"
|
||||
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrieve the dictionary of language model.
|
||||
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab
|
||||
{
|
||||
public:
|
||||
RetrieveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if (argc != 3) {
|
||||
std::cerr << "Usage: " << argv[0] << " <kenlm_model> <output_path>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const char* kenlm_model = argv[1];
|
||||
const char* output_path = argv[2];
|
||||
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
lm::ngram::Config config;
|
||||
RetrieveStrEnumerateVocab enumerate;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config));
|
||||
|
||||
std::ofstream fout(output_path);
|
||||
for (const std::string& word : enumerate.vocabulary) {
|
||||
fout << word << "\n";
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue