Rewrite data/lm/generate_package.py into native_client/generate_scorer_package.cpp
This commit is contained in:
parent
03ca94887c
commit
f82c77392d
@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import ds_ctcdecoder
|
||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
def create_bundle(
|
||||
alphabet_path,
|
||||
lm_path,
|
||||
vocab_path,
|
||||
package_path,
|
||||
force_utf8,
|
||||
default_alpha,
|
||||
default_beta,
|
||||
):
|
||||
words = set()
|
||||
vocab_looks_char_based = True
|
||||
with open(vocab_path) as fin:
|
||||
for line in fin:
|
||||
for word in line.split():
|
||||
words.add(word.encode("utf-8"))
|
||||
if len(word) > 1:
|
||||
vocab_looks_char_based = False
|
||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
||||
|
||||
cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
|
||||
print("{} like a character based model.".format(cbm))
|
||||
|
||||
if force_utf8 != None: # pylint: disable=singleton-comparison
|
||||
use_utf8 = force_utf8.value
|
||||
else:
|
||||
use_utf8 = vocab_looks_char_based
|
||||
print("Using detected UTF-8 mode: {}".format(use_utf8))
|
||||
|
||||
if use_utf8:
|
||||
serialized_alphabet = UTF8Alphabet().serialize()
|
||||
else:
|
||||
if not alphabet_path:
|
||||
raise RuntimeError("No --alphabet path specified, can't continue.")
|
||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||
|
||||
alphabet = NativeAlphabet()
|
||||
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
||||
if err != 0:
|
||||
raise RuntimeError("Error loading alphabet: {}".format(err))
|
||||
|
||||
scorer = Scorer()
|
||||
scorer.set_alphabet(alphabet)
|
||||
scorer.set_utf8_mode(use_utf8)
|
||||
scorer.reset_params(default_alpha, default_beta)
|
||||
err = scorer.load_lm(lm_path)
|
||||
if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
|
||||
print('Error loading language model file: 0x{:X}.'.format(err))
|
||||
print('See the error codes section in https://deepspeech.readthedocs.io for a description.')
|
||||
sys.exit(1)
|
||||
scorer.fill_dictionary(list(words))
|
||||
shutil.copy(lm_path, package_path)
|
||||
# append, not overwrite
|
||||
if scorer.save_dictionary(package_path, True):
|
||||
print("Package created in {}".format(package_path))
|
||||
else:
|
||||
print("Error when creating {}".format(package_path))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class Tristate(object):
|
||||
def __init__(self, value=None):
|
||||
if any(value is v for v in (True, False, None)):
|
||||
self.value = value
|
||||
else:
|
||||
raise ValueError("Tristate value must be True, False, or None")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.value is other.value
|
||||
if isinstance(other, Tristate)
|
||||
else self.value is other
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __bool__(self):
|
||||
raise TypeError("Tristate object may not be used as a Boolean")
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return "Tristate(%s)" % self.value
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate an external scorer package for DeepSpeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alphabet",
|
||||
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
required=True,
|
||||
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab",
|
||||
required=True,
|
||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||
)
|
||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||
parser.add_argument(
|
||||
"--default_alpha",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Default value of alpha hyperparameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_beta",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Default value of beta hyperparameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_utf8",
|
||||
type=str,
|
||||
default="",
|
||||
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://github.com/mozilla/DeepSpeech/blob/master/doc/Decoder.rst#utf-8-mode> for further explanation",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
|
||||
force_utf8 = Tristate(True)
|
||||
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
|
||||
force_utf8 = Tristate(False)
|
||||
else:
|
||||
force_utf8 = Tristate(None)
|
||||
|
||||
create_bundle(
|
||||
args.alphabet,
|
||||
args.lm,
|
||||
args.vocab,
|
||||
args.package,
|
||||
force_utf8,
|
||||
args.default_alpha,
|
||||
args.default_beta,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -2,6 +2,7 @@
|
||||
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||
|
||||
load(
|
||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||
@ -78,6 +79,8 @@ cc_library(
|
||||
hdrs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
"ctcdecode/scorer.h",
|
||||
"ctcdecode/decoder_utils.h",
|
||||
"alphabet.h",
|
||||
],
|
||||
includes = [
|
||||
".",
|
||||
@ -186,6 +189,22 @@ genrule(
|
||||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "generate_scorer_package",
|
||||
srcs = [
|
||||
"generate_scorer_package.cpp",
|
||||
"deepspeech_errors.cc",
|
||||
],
|
||||
copts = ["-std=c++11"],
|
||||
deps = [
|
||||
":decoder",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/flags:parse",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@boost//:program_options",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "enumerate_kenlm_vocabulary",
|
||||
srcs = [
|
||||
|
@ -19,7 +19,7 @@ public:
|
||||
Alphabet(const Alphabet&) = default;
|
||||
Alphabet& operator=(const Alphabet&) = default;
|
||||
|
||||
int init(const char *config_file) {
|
||||
virtual int init(const char *config_file) {
|
||||
std::ifstream in(config_file, std::ios::in);
|
||||
if (!in) {
|
||||
return 1;
|
||||
@ -45,6 +45,30 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string serialize() {
|
||||
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||
// encoded bytes with the label.
|
||||
std::stringstream out;
|
||||
|
||||
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||
uint16_t size = size_;
|
||||
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||
|
||||
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||
uint16_t key = it->first;
|
||||
string str = it->second;
|
||||
uint16_t len = str.length();
|
||||
// Then we write the key as uint16_t, followed by the length of the value
|
||||
// as uint16_t, followed by `length` bytes (the value itself).
|
||||
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||
out.write(str.data(), len);
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
int deserialize(const char* buffer, const int buffer_size) {
|
||||
// See util/text.py for an explanation of the serialization format.
|
||||
int offset = 0;
|
||||
@ -126,11 +150,28 @@ public:
|
||||
return word;
|
||||
}
|
||||
|
||||
private:
|
||||
protected:
|
||||
size_t size_;
|
||||
unsigned int space_label_;
|
||||
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||
};
|
||||
|
||||
class UTF8Alphabet : public Alphabet
|
||||
{
|
||||
public:
|
||||
UTF8Alphabet() {
|
||||
size_ = 255;
|
||||
space_label_ = ' ' - 1;
|
||||
for (int i = 0; i < size_; ++i) {
|
||||
std::string val(1, i+1);
|
||||
label_to_str_[i] = val;
|
||||
str_to_label_[val] = i;
|
||||
}
|
||||
}
|
||||
|
||||
int init(const char*) override {}
|
||||
};
|
||||
|
||||
|
||||
#endif //ALPHABET_H
|
||||
|
@ -357,7 +357,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
return ngram;
|
||||
}
|
||||
|
||||
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary)
|
||||
void Scorer::fill_dictionary(const std::unordered_set<std::string>& vocabulary)
|
||||
{
|
||||
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
||||
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
||||
|
@ -4,6 +4,7 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "lm/virtual_interface.hh"
|
||||
@ -83,7 +84,7 @@ public:
|
||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||
|
||||
// fill dictionary FST from a vocabulary
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);
|
||||
|
||||
// load language model from given path
|
||||
int load_lm(const std::string &lm_path);
|
||||
|
146
native_client/generate_scorer_package.cpp
Normal file
146
native_client/generate_scorer_package.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "boost/program_options.hpp"
|
||||
|
||||
#include "ctcdecode/decoder_utils.h"
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "alphabet.h"
|
||||
#include "deepspeech.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
int
|
||||
create_package(absl::optional<string> alphabet_path,
|
||||
string lm_path,
|
||||
string vocab_path,
|
||||
string package_path,
|
||||
absl::optional<bool> force_utf8,
|
||||
float default_alpha,
|
||||
float default_beta)
|
||||
{
|
||||
// Read vocabulary
|
||||
unordered_set<string> words;
|
||||
bool vocab_looks_char_based = true;
|
||||
ifstream fin(vocab_path);
|
||||
if (!fin) {
|
||||
cerr << "Invalid vocabulary file " << vocab_path << "\n";
|
||||
return 1;
|
||||
}
|
||||
string word;
|
||||
while (fin >> word) {
|
||||
words.insert(word);
|
||||
if (get_utf8_str_len(word) > 1) {
|
||||
vocab_looks_char_based = false;
|
||||
}
|
||||
}
|
||||
cerr << words.size() << " unique words read from vocabulary file.\n"
|
||||
<< (vocab_looks_char_based ? "Looks" : "Doesn't look")
|
||||
<< " like a character based (Bytes Are All You Need) model.\n";
|
||||
|
||||
if (!force_utf8.has_value()) {
|
||||
force_utf8 = vocab_looks_char_based;
|
||||
cerr << "--force_utf8 was not specified, using value "
|
||||
<< "infered from vocabulary contents: "
|
||||
<< (vocab_looks_char_based ? "true" : "false") << "\n";
|
||||
}
|
||||
|
||||
if (force_utf8.value() && !alphabet_path.has_value()) {
|
||||
cerr << "No --alphabet file specified, not using bytes output mode, can't continue.\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
Scorer scorer;
|
||||
if (force_utf8.value()) {
|
||||
scorer.set_alphabet(UTF8Alphabet());
|
||||
} else {
|
||||
Alphabet alphabet;
|
||||
alphabet.init(alphabet_path->c_str());
|
||||
scorer.set_alphabet(alphabet);
|
||||
}
|
||||
scorer.set_utf8_mode(force_utf8.value());
|
||||
scorer.reset_params(default_alpha, default_beta);
|
||||
int err = scorer.load_lm(lm_path);
|
||||
if (err != DS_ERR_SCORER_NO_TRIE) {
|
||||
cerr << "Error loading language model file: "
|
||||
<< DS_ErrorCodeToErrorMessage(err) << "\n";
|
||||
return 1;
|
||||
}
|
||||
scorer.fill_dictionary(words);
|
||||
|
||||
// Copy LM file to final package file destination
|
||||
{
|
||||
ifstream lm_src(lm_path, std::ios::binary);
|
||||
ofstream package_dest(package_path, std::ios::binary);
|
||||
package_dest << lm_src.rdbuf();
|
||||
}
|
||||
|
||||
// Save dictionary to package file, appending instead of overwriting
|
||||
if (!scorer.save_dictionary(package_path, true)) {
|
||||
cerr << "Error when saving package in " << package_path << ".\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
cerr << "Package created in " << package_path << ".\n";
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
main(int argc, char** argv)
|
||||
{
|
||||
po::options_description desc("Options");
|
||||
desc.add_options()
|
||||
("help", "show help message")
|
||||
("alphabet", po::value<string>(), "Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.")
|
||||
("lm", po::value<string>(), "Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.")
|
||||
("vocab", po::value<string>(), "Path of vocabulary file. Must contain words separated by whitespace.")
|
||||
("package", po::value<string>(), "Path to save scorer package.")
|
||||
("default_alpha", po::value<float>(), "Default value of alpha hyperparameter (float).")
|
||||
("default_beta", po::value<float>(), "Default value of beta hyperparameter (float).")
|
||||
("force_utf8", po::value<bool>(), "Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://deepspeech.readthedocs.io/en/master/Decoder.html#utf-8-mode> for further explanation.")
|
||||
;
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
po::notify(vm);
|
||||
|
||||
if (vm.count("help")) {
|
||||
cout << desc << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check required flags.
|
||||
for (const string& flag : {"lm", "vocab", "package", "default_alpha", "default_beta"}) {
|
||||
if (!vm.count(flag)) {
|
||||
cerr << "--" << flag << " is a required flag. Pass --help for help.\n";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Parse optional --force_utf8
|
||||
absl::optional<bool> force_utf8 = absl::nullopt;
|
||||
if (vm.count("force_utf8")) {
|
||||
force_utf8 = vm["force_utf8"].as<bool>();
|
||||
}
|
||||
|
||||
// Parse optional --alphabet
|
||||
absl::optional<string> alphabet = absl::nullopt;
|
||||
if (vm.count("alphabet")) {
|
||||
alphabet = vm["alphabet"].as<string>();
|
||||
}
|
||||
|
||||
create_package(alphabet,
|
||||
vm["lm"].as<string>(),
|
||||
vm["vocab"].as<string>(),
|
||||
vm["package"].as<string>(),
|
||||
force_utf8,
|
||||
vm["default_alpha"].as<float>(),
|
||||
vm["default_beta"].as<float>());
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user