Rewrite data/lm/generate_package.py into native_client/generate_scorer_package.cpp

This commit is contained in:
Reuben Morais 2020-06-26 10:27:35 +02:00
parent 03ca94887c
commit f82c77392d
6 changed files with 211 additions and 161 deletions

View File

@ -1,157 +0,0 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
import argparse
import shutil
import sys
import ds_ctcdecoder
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
def create_bundle(
alphabet_path,
lm_path,
vocab_path,
package_path,
force_utf8,
default_alpha,
default_beta,
):
words = set()
vocab_looks_char_based = True
with open(vocab_path) as fin:
for line in fin:
for word in line.split():
words.add(word.encode("utf-8"))
if len(word) > 1:
vocab_looks_char_based = False
print("{} unique words read from vocabulary file.".format(len(words)))
cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
print("{} like a character based model.".format(cbm))
if force_utf8 != None: # pylint: disable=singleton-comparison
use_utf8 = force_utf8.value
else:
use_utf8 = vocab_looks_char_based
print("Using detected UTF-8 mode: {}".format(use_utf8))
if use_utf8:
serialized_alphabet = UTF8Alphabet().serialize()
else:
if not alphabet_path:
raise RuntimeError("No --alphabet path specified, can't continue.")
serialized_alphabet = Alphabet(alphabet_path).serialize()
alphabet = NativeAlphabet()
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
if err != 0:
raise RuntimeError("Error loading alphabet: {}".format(err))
scorer = Scorer()
scorer.set_alphabet(alphabet)
scorer.set_utf8_mode(use_utf8)
scorer.reset_params(default_alpha, default_beta)
err = scorer.load_lm(lm_path)
if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
print('Error loading language model file: 0x{:X}.'.format(err))
print('See the error codes section in https://deepspeech.readthedocs.io for a description.')
sys.exit(1)
scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path)
# append, not overwrite
if scorer.save_dictionary(package_path, True):
print("Package created in {}".format(package_path))
else:
print("Error when creating {}".format(package_path))
sys.exit(1)
class Tristate(object):
def __init__(self, value=None):
if any(value is v for v in (True, False, None)):
self.value = value
else:
raise ValueError("Tristate value must be True, False, or None")
def __eq__(self, other):
return (
self.value is other.value
if isinstance(other, Tristate)
else self.value is other
)
def __ne__(self, other):
return not self == other
def __bool__(self):
raise TypeError("Tristate object may not be used as a Boolean")
def __str__(self):
return str(self.value)
def __repr__(self):
return "Tristate(%s)" % self.value
def main():
parser = argparse.ArgumentParser(
description="Generate an external scorer package for DeepSpeech."
)
parser.add_argument(
"--alphabet",
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
)
parser.add_argument(
"--lm",
required=True,
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
)
parser.add_argument(
"--vocab",
required=True,
help="Path of vocabulary file. Must contain words separated by whitespace.",
)
parser.add_argument("--package", required=True, help="Path to save scorer package.")
parser.add_argument(
"--default_alpha",
type=float,
required=True,
help="Default value of alpha hyperparameter.",
)
parser.add_argument(
"--default_beta",
type=float,
required=True,
help="Default value of beta hyperparameter.",
)
parser.add_argument(
"--force_utf8",
type=str,
default="",
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://github.com/mozilla/DeepSpeech/blob/master/doc/Decoder.rst#utf-8-mode> for further explanation",
)
args = parser.parse_args()
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
force_utf8 = Tristate(True)
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
force_utf8 = Tristate(False)
else:
force_utf8 = Tristate(None)
create_bundle(
args.alphabet,
args.lm,
args.vocab,
args.package,
force_utf8,
args.default_alpha,
args.default_beta,
)
if __name__ == "__main__":
main()

View File

@ -2,6 +2,7 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
load(
"@org_tensorflow//tensorflow/lite:build_def.bzl",
@ -78,6 +79,8 @@ cc_library(
hdrs = [
"ctcdecode/ctc_beam_search_decoder.h",
"ctcdecode/scorer.h",
"ctcdecode/decoder_utils.h",
"alphabet.h",
],
includes = [
".",
@ -186,6 +189,22 @@ genrule(
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
)
cc_binary(
name = "generate_scorer_package",
srcs = [
"generate_scorer_package.cpp",
"deepspeech_errors.cc",
],
copts = ["-std=c++11"],
deps = [
":decoder",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/types:optional",
"@boost//:program_options",
],
)
cc_binary(
name = "enumerate_kenlm_vocabulary",
srcs = [

View File

@ -19,7 +19,7 @@ public:
Alphabet(const Alphabet&) = default;
Alphabet& operator=(const Alphabet&) = default;
int init(const char *config_file) {
virtual int init(const char *config_file) {
std::ifstream in(config_file, std::ios::in);
if (!in) {
return 1;
@ -45,6 +45,30 @@ public:
return 0;
}
std::string serialize() {
// Serialization format is a sequence of (key, value) pairs, where key is
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
// encoded bytes with the label.
std::stringstream out;
// We start by writing the number of pairs in the buffer as uint16_t.
uint16_t size = size_;
out.write(reinterpret_cast<char*>(&size), sizeof(size));
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
uint16_t key = it->first;
string str = it->second;
uint16_t len = str.length();
// Then we write the key as uint16_t, followed by the length of the value
// as uint16_t, followed by `length` bytes (the value itself).
out.write(reinterpret_cast<char*>(&key), sizeof(key));
out.write(reinterpret_cast<char*>(&len), sizeof(len));
out.write(str.data(), len);
}
return out.str();
}
int deserialize(const char* buffer, const int buffer_size) {
// See util/text.py for an explanation of the serialization format.
int offset = 0;
@ -126,11 +150,28 @@ public:
return word;
}
private:
protected:
size_t size_;
unsigned int space_label_;
std::unordered_map<unsigned int, std::string> label_to_str_;
std::unordered_map<std::string, unsigned int> str_to_label_;
};
class UTF8Alphabet : public Alphabet
{
public:
UTF8Alphabet() {
size_ = 255;
space_label_ = ' ' - 1;
for (int i = 0; i < size_; ++i) {
std::string val(1, i+1);
label_to_str_[i] = val;
str_to_label_[val] = i;
}
}
int init(const char*) override {}
};
#endif //ALPHABET_H

View File

@ -357,7 +357,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
return ngram;
}
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary)
void Scorer::fill_dictionary(const std::unordered_set<std::string>& vocabulary)
{
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
// and then we convert to a ConstFst for the decoder and for storing on disk.

View File

@ -4,6 +4,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lm/virtual_interface.hh"
@ -83,7 +84,7 @@ public:
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
// fill dictionary FST from a vocabulary
void fill_dictionary(const std::vector<std::string> &vocabulary);
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);
// load language model from given path
int load_lm(const std::string &lm_path);

View File

@ -0,0 +1,146 @@
#include <string>
#include <vector>
#include <fstream>
#include <unordered_set>
#include <iostream>
using namespace std;
#include "absl/types/optional.h"
#include "boost/program_options.hpp"
#include "ctcdecode/decoder_utils.h"
#include "ctcdecode/scorer.h"
#include "alphabet.h"
#include "deepspeech.h"
namespace po = boost::program_options;
int
create_package(absl::optional<string> alphabet_path,
string lm_path,
string vocab_path,
string package_path,
absl::optional<bool> force_utf8,
float default_alpha,
float default_beta)
{
// Read vocabulary
unordered_set<string> words;
bool vocab_looks_char_based = true;
ifstream fin(vocab_path);
if (!fin) {
cerr << "Invalid vocabulary file " << vocab_path << "\n";
return 1;
}
string word;
while (fin >> word) {
words.insert(word);
if (get_utf8_str_len(word) > 1) {
vocab_looks_char_based = false;
}
}
cerr << words.size() << " unique words read from vocabulary file.\n"
<< (vocab_looks_char_based ? "Looks" : "Doesn't look")
<< " like a character based (Bytes Are All You Need) model.\n";
if (!force_utf8.has_value()) {
force_utf8 = vocab_looks_char_based;
cerr << "--force_utf8 was not specified, using value "
<< "infered from vocabulary contents: "
<< (vocab_looks_char_based ? "true" : "false") << "\n";
}
if (force_utf8.value() && !alphabet_path.has_value()) {
cerr << "No --alphabet file specified, not using bytes output mode, can't continue.\n";
return 1;
}
Scorer scorer;
if (force_utf8.value()) {
scorer.set_alphabet(UTF8Alphabet());
} else {
Alphabet alphabet;
alphabet.init(alphabet_path->c_str());
scorer.set_alphabet(alphabet);
}
scorer.set_utf8_mode(force_utf8.value());
scorer.reset_params(default_alpha, default_beta);
int err = scorer.load_lm(lm_path);
if (err != DS_ERR_SCORER_NO_TRIE) {
cerr << "Error loading language model file: "
<< DS_ErrorCodeToErrorMessage(err) << "\n";
return 1;
}
scorer.fill_dictionary(words);
// Copy LM file to final package file destination
{
ifstream lm_src(lm_path, std::ios::binary);
ofstream package_dest(package_path, std::ios::binary);
package_dest << lm_src.rdbuf();
}
// Save dictionary to package file, appending instead of overwriting
if (!scorer.save_dictionary(package_path, true)) {
cerr << "Error when saving package in " << package_path << ".\n";
return 1;
}
cerr << "Package created in " << package_path << ".\n";
return 0;
}
int
main(int argc, char** argv)
{
po::options_description desc("Options");
desc.add_options()
("help", "show help message")
("alphabet", po::value<string>(), "Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.")
("lm", po::value<string>(), "Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.")
("vocab", po::value<string>(), "Path of vocabulary file. Must contain words separated by whitespace.")
("package", po::value<string>(), "Path to save scorer package.")
("default_alpha", po::value<float>(), "Default value of alpha hyperparameter (float).")
("default_beta", po::value<float>(), "Default value of beta hyperparameter (float).")
("force_utf8", po::value<bool>(), "Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://deepspeech.readthedocs.io/en/master/Decoder.html#utf-8-mode> for further explanation.")
;
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
po::notify(vm);
if (vm.count("help")) {
cout << desc << "\n";
return 1;
}
// Check required flags.
for (const string& flag : {"lm", "vocab", "package", "default_alpha", "default_beta"}) {
if (!vm.count(flag)) {
cerr << "--" << flag << " is a required flag. Pass --help for help.\n";
return 1;
}
}
// Parse optional --force_utf8
absl::optional<bool> force_utf8 = absl::nullopt;
if (vm.count("force_utf8")) {
force_utf8 = vm["force_utf8"].as<bool>();
}
// Parse optional --alphabet
absl::optional<string> alphabet = absl::nullopt;
if (vm.count("alphabet")) {
alphabet = vm["alphabet"].as<string>();
}
create_package(alphabet,
vm["lm"].as<string>(),
vm["vocab"].as<string>(),
vm["package"].as<string>(),
force_utf8,
vm["default_alpha"].as<float>(),
vm["default_beta"].as<float>());
return 0;
}