Add tool to extract vocabulary from the old LM binary format

This commit is contained in:
Reuben Morais 2020-01-20 17:25:20 +01:00
parent a156d28504
commit 708b21a63e
2 changed files with 82 additions and 17 deletions

View File

@ -27,20 +27,6 @@ genrule(
tools = [":gen_workspace_status.sh"],
)
KENLM_SOURCES = glob(
[
"kenlm/lm/*.cc",
"kenlm/util/*.cc",
"kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh",
"kenlm/util/*.hh",
"kenlm/util/double-conversion/*.h",
],
exclude = [
"kenlm/*/*test.cc",
"kenlm/*/*main.cc",
],
)
OPENFST_SOURCES_PLATFORM = select({
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
@ -60,6 +46,27 @@ LINUX_LINKOPTS = [
"-Wl,-export-dynamic",
]
cc_library(
name = "kenlm",
srcs = glob([
"kenlm/lm/*.cc",
"kenlm/util/*.cc",
"kenlm/util/double-conversion/*.cc",
"kenlm/util/double-conversion/*.h",
],
exclude = [
"kenlm/*/*test.cc",
"kenlm/*/*main.cc",
],),
hdrs = glob([
"kenlm/lm/*.hh",
"kenlm/util/*.hh",
]),
copts = ["-std=c++11"],
defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"],
)
cc_library(
name = "decoder",
srcs = [
@ -69,17 +76,16 @@ cc_library(
"ctcdecode/scorer.cpp",
"ctcdecode/path_trie.cpp",
"ctcdecode/path_trie.h",
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM,
] + OPENFST_SOURCES_PLATFORM,
hdrs = [
"ctcdecode/ctc_beam_search_decoder.h",
"ctcdecode/scorer.h",
],
defines = ["KENLM_MAX_ORDER=6"],
includes = [
".",
"ctcdecode/third_party/ThreadPool",
"kenlm",
] + OPENFST_INCLUDES_PLATFORM,
deps = [":kenlm"]
)
tf_cc_shared_object(
@ -181,6 +187,15 @@ genrule(
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
)
cc_binary(
name = "enumerate_kenlm_vocabulary",
srcs = [
"enumerate_kenlm_vocabulary.cpp",
],
deps = [":kenlm"],
copts = ["-std=c++11"],
)
cc_binary(
name = "trie_load",
srcs = [

View File

@ -0,0 +1,50 @@
#include <string>
#include <vector>
#include <iostream>
#include <fstream>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "lm/model.hh"
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrieve the dictionary of language model.
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab
{
public:
RetrieveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
int main(int argc, char** argv)
{
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <kenlm_model> <output_path>" << std::endl;
return -1;
}
const char* kenlm_model = argv[1];
const char* output_path = argv[2];
std::unique_ptr<lm::base::Model> language_model_;
lm::ngram::Config config;
RetrieveStrEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config));
std::ofstream fout(output_path);
for (const std::string& word : enumerate.vocabulary) {
fout << word << "\n";
}
return 0;
}