Add tool to extract vocabulary from the old LM binary format
This commit is contained in:
parent
a156d28504
commit
708b21a63e
|
@ -27,20 +27,6 @@ genrule(
|
||||||
tools = [":gen_workspace_status.sh"],
|
tools = [":gen_workspace_status.sh"],
|
||||||
)
|
)
|
||||||
|
|
||||||
KENLM_SOURCES = glob(
|
|
||||||
[
|
|
||||||
"kenlm/lm/*.cc",
|
|
||||||
"kenlm/util/*.cc",
|
|
||||||
"kenlm/util/double-conversion/*.cc",
|
|
||||||
"kenlm/lm/*.hh",
|
|
||||||
"kenlm/util/*.hh",
|
|
||||||
"kenlm/util/double-conversion/*.h",
|
|
||||||
],
|
|
||||||
exclude = [
|
|
||||||
"kenlm/*/*test.cc",
|
|
||||||
"kenlm/*/*main.cc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
OPENFST_SOURCES_PLATFORM = select({
|
OPENFST_SOURCES_PLATFORM = select({
|
||||||
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
||||||
|
@ -60,6 +46,27 @@ LINUX_LINKOPTS = [
|
||||||
"-Wl,-export-dynamic",
|
"-Wl,-export-dynamic",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "kenlm",
|
||||||
|
srcs = glob([
|
||||||
|
"kenlm/lm/*.cc",
|
||||||
|
"kenlm/util/*.cc",
|
||||||
|
"kenlm/util/double-conversion/*.cc",
|
||||||
|
"kenlm/util/double-conversion/*.h",
|
||||||
|
],
|
||||||
|
exclude = [
|
||||||
|
"kenlm/*/*test.cc",
|
||||||
|
"kenlm/*/*main.cc",
|
||||||
|
],),
|
||||||
|
hdrs = glob([
|
||||||
|
"kenlm/lm/*.hh",
|
||||||
|
"kenlm/util/*.hh",
|
||||||
|
]),
|
||||||
|
copts = ["-std=c++11"],
|
||||||
|
defines = ["KENLM_MAX_ORDER=6"],
|
||||||
|
includes = ["kenlm"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "decoder",
|
name = "decoder",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -69,17 +76,16 @@ cc_library(
|
||||||
"ctcdecode/scorer.cpp",
|
"ctcdecode/scorer.cpp",
|
||||||
"ctcdecode/path_trie.cpp",
|
"ctcdecode/path_trie.cpp",
|
||||||
"ctcdecode/path_trie.h",
|
"ctcdecode/path_trie.h",
|
||||||
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM,
|
] + OPENFST_SOURCES_PLATFORM,
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ctcdecode/ctc_beam_search_decoder.h",
|
"ctcdecode/ctc_beam_search_decoder.h",
|
||||||
"ctcdecode/scorer.h",
|
"ctcdecode/scorer.h",
|
||||||
],
|
],
|
||||||
defines = ["KENLM_MAX_ORDER=6"],
|
|
||||||
includes = [
|
includes = [
|
||||||
".",
|
".",
|
||||||
"ctcdecode/third_party/ThreadPool",
|
"ctcdecode/third_party/ThreadPool",
|
||||||
"kenlm",
|
|
||||||
] + OPENFST_INCLUDES_PLATFORM,
|
] + OPENFST_INCLUDES_PLATFORM,
|
||||||
|
deps = [":kenlm"]
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_shared_object(
|
tf_cc_shared_object(
|
||||||
|
@ -181,6 +187,15 @@ genrule(
|
||||||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "enumerate_kenlm_vocabulary",
|
||||||
|
srcs = [
|
||||||
|
"enumerate_kenlm_vocabulary.cpp",
|
||||||
|
],
|
||||||
|
deps = [":kenlm"],
|
||||||
|
copts = ["-std=c++11"],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "trie_load",
|
name = "trie_load",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <iostream>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "lm/enumerate_vocab.hh"
|
||||||
|
#include "lm/virtual_interface.hh"
|
||||||
|
#include "lm/word_index.hh"
|
||||||
|
#include "lm/model.hh"
|
||||||
|
|
||||||
|
const std::string START_TOKEN = "<s>";
|
||||||
|
const std::string UNK_TOKEN = "<unk>";
|
||||||
|
const std::string END_TOKEN = "</s>";
|
||||||
|
|
||||||
|
// Implement a callback to retrieve the dictionary of language model.
|
||||||
|
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
RetrieveStrEnumerateVocab() {}
|
||||||
|
|
||||||
|
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||||
|
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> vocabulary;
|
||||||
|
};
|
||||||
|
|
||||||
|
int main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
if (argc != 3) {
|
||||||
|
std::cerr << "Usage: " << argv[0] << " <kenlm_model> <output_path>" << std::endl;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* kenlm_model = argv[1];
|
||||||
|
const char* output_path = argv[2];
|
||||||
|
|
||||||
|
std::unique_ptr<lm::base::Model> language_model_;
|
||||||
|
lm::ngram::Config config;
|
||||||
|
RetrieveStrEnumerateVocab enumerate;
|
||||||
|
config.enumerate_vocab = &enumerate;
|
||||||
|
language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config));
|
||||||
|
|
||||||
|
std::ofstream fout(output_path);
|
||||||
|
for (const std::string& word : enumerate.vocabulary) {
|
||||||
|
fout << word << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue