Rewrite data/lm/generate_package.py into native_client/generate_scorer_package.cpp
This commit is contained in:
parent
03ca94887c
commit
f82c77392d
@ -1,157 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import ds_ctcdecoder
|
|
||||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
|
||||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
|
||||||
|
|
||||||
|
|
||||||
def create_bundle(
|
|
||||||
alphabet_path,
|
|
||||||
lm_path,
|
|
||||||
vocab_path,
|
|
||||||
package_path,
|
|
||||||
force_utf8,
|
|
||||||
default_alpha,
|
|
||||||
default_beta,
|
|
||||||
):
|
|
||||||
words = set()
|
|
||||||
vocab_looks_char_based = True
|
|
||||||
with open(vocab_path) as fin:
|
|
||||||
for line in fin:
|
|
||||||
for word in line.split():
|
|
||||||
words.add(word.encode("utf-8"))
|
|
||||||
if len(word) > 1:
|
|
||||||
vocab_looks_char_based = False
|
|
||||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
|
||||||
|
|
||||||
cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
|
|
||||||
print("{} like a character based model.".format(cbm))
|
|
||||||
|
|
||||||
if force_utf8 != None: # pylint: disable=singleton-comparison
|
|
||||||
use_utf8 = force_utf8.value
|
|
||||||
else:
|
|
||||||
use_utf8 = vocab_looks_char_based
|
|
||||||
print("Using detected UTF-8 mode: {}".format(use_utf8))
|
|
||||||
|
|
||||||
if use_utf8:
|
|
||||||
serialized_alphabet = UTF8Alphabet().serialize()
|
|
||||||
else:
|
|
||||||
if not alphabet_path:
|
|
||||||
raise RuntimeError("No --alphabet path specified, can't continue.")
|
|
||||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
|
||||||
|
|
||||||
alphabet = NativeAlphabet()
|
|
||||||
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
|
||||||
if err != 0:
|
|
||||||
raise RuntimeError("Error loading alphabet: {}".format(err))
|
|
||||||
|
|
||||||
scorer = Scorer()
|
|
||||||
scorer.set_alphabet(alphabet)
|
|
||||||
scorer.set_utf8_mode(use_utf8)
|
|
||||||
scorer.reset_params(default_alpha, default_beta)
|
|
||||||
err = scorer.load_lm(lm_path)
|
|
||||||
if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
|
|
||||||
print('Error loading language model file: 0x{:X}.'.format(err))
|
|
||||||
print('See the error codes section in https://deepspeech.readthedocs.io for a description.')
|
|
||||||
sys.exit(1)
|
|
||||||
scorer.fill_dictionary(list(words))
|
|
||||||
shutil.copy(lm_path, package_path)
|
|
||||||
# append, not overwrite
|
|
||||||
if scorer.save_dictionary(package_path, True):
|
|
||||||
print("Package created in {}".format(package_path))
|
|
||||||
else:
|
|
||||||
print("Error when creating {}".format(package_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
class Tristate(object):
|
|
||||||
def __init__(self, value=None):
|
|
||||||
if any(value is v for v in (True, False, None)):
|
|
||||||
self.value = value
|
|
||||||
else:
|
|
||||||
raise ValueError("Tristate value must be True, False, or None")
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (
|
|
||||||
self.value is other.value
|
|
||||||
if isinstance(other, Tristate)
|
|
||||||
else self.value is other
|
|
||||||
)
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self == other
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
raise TypeError("Tristate object may not be used as a Boolean")
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.value)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "Tristate(%s)" % self.value
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate an external scorer package for DeepSpeech."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--alphabet",
|
|
||||||
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lm",
|
|
||||||
required=True,
|
|
||||||
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vocab",
|
|
||||||
required=True,
|
|
||||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--default_alpha",
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help="Default value of alpha hyperparameter.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--default_beta",
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help="Default value of beta hyperparameter.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force_utf8",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://github.com/mozilla/DeepSpeech/blob/master/doc/Decoder.rst#utf-8-mode> for further explanation",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
|
|
||||||
force_utf8 = Tristate(True)
|
|
||||||
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
|
|
||||||
force_utf8 = Tristate(False)
|
|
||||||
else:
|
|
||||||
force_utf8 = Tristate(None)
|
|
||||||
|
|
||||||
create_bundle(
|
|
||||||
args.alphabet,
|
|
||||||
args.lm,
|
|
||||||
args.vocab,
|
|
||||||
args.package,
|
|
||||||
force_utf8,
|
|
||||||
args.default_alpha,
|
|
||||||
args.default_beta,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||||
@ -78,6 +79,8 @@ cc_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"ctcdecode/ctc_beam_search_decoder.h",
|
"ctcdecode/ctc_beam_search_decoder.h",
|
||||||
"ctcdecode/scorer.h",
|
"ctcdecode/scorer.h",
|
||||||
|
"ctcdecode/decoder_utils.h",
|
||||||
|
"alphabet.h",
|
||||||
],
|
],
|
||||||
includes = [
|
includes = [
|
||||||
".",
|
".",
|
||||||
@ -186,6 +189,22 @@ genrule(
|
|||||||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "generate_scorer_package",
|
||||||
|
srcs = [
|
||||||
|
"generate_scorer_package.cpp",
|
||||||
|
"deepspeech_errors.cc",
|
||||||
|
],
|
||||||
|
copts = ["-std=c++11"],
|
||||||
|
deps = [
|
||||||
|
":decoder",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/flags:parse",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@boost//:program_options",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "enumerate_kenlm_vocabulary",
|
name = "enumerate_kenlm_vocabulary",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -19,7 +19,7 @@ public:
|
|||||||
Alphabet(const Alphabet&) = default;
|
Alphabet(const Alphabet&) = default;
|
||||||
Alphabet& operator=(const Alphabet&) = default;
|
Alphabet& operator=(const Alphabet&) = default;
|
||||||
|
|
||||||
int init(const char *config_file) {
|
virtual int init(const char *config_file) {
|
||||||
std::ifstream in(config_file, std::ios::in);
|
std::ifstream in(config_file, std::ios::in);
|
||||||
if (!in) {
|
if (!in) {
|
||||||
return 1;
|
return 1;
|
||||||
@ -45,6 +45,30 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string serialize() {
|
||||||
|
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||||
|
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||||
|
// encoded bytes with the label.
|
||||||
|
std::stringstream out;
|
||||||
|
|
||||||
|
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||||
|
uint16_t size = size_;
|
||||||
|
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||||
|
|
||||||
|
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||||
|
uint16_t key = it->first;
|
||||||
|
string str = it->second;
|
||||||
|
uint16_t len = str.length();
|
||||||
|
// Then we write the key as uint16_t, followed by the length of the value
|
||||||
|
// as uint16_t, followed by `length` bytes (the value itself).
|
||||||
|
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||||
|
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||||
|
out.write(str.data(), len);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
int deserialize(const char* buffer, const int buffer_size) {
|
int deserialize(const char* buffer, const int buffer_size) {
|
||||||
// See util/text.py for an explanation of the serialization format.
|
// See util/text.py for an explanation of the serialization format.
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
@ -126,11 +150,28 @@ public:
|
|||||||
return word;
|
return word;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
unsigned int space_label_;
|
unsigned int space_label_;
|
||||||
std::unordered_map<unsigned int, std::string> label_to_str_;
|
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||||
std::unordered_map<std::string, unsigned int> str_to_label_;
|
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UTF8Alphabet : public Alphabet
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
UTF8Alphabet() {
|
||||||
|
size_ = 255;
|
||||||
|
space_label_ = ' ' - 1;
|
||||||
|
for (int i = 0; i < size_; ++i) {
|
||||||
|
std::string val(1, i+1);
|
||||||
|
label_to_str_[i] = val;
|
||||||
|
str_to_label_[val] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int init(const char*) override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif //ALPHABET_H
|
#endif //ALPHABET_H
|
||||||
|
@ -357,7 +357,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
return ngram;
|
return ngram;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary)
|
void Scorer::fill_dictionary(const std::unordered_set<std::string>& vocabulary)
|
||||||
{
|
{
|
||||||
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
||||||
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "lm/virtual_interface.hh"
|
#include "lm/virtual_interface.hh"
|
||||||
@ -83,7 +84,7 @@ public:
|
|||||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||||
|
|
||||||
// fill dictionary FST from a vocabulary
|
// fill dictionary FST from a vocabulary
|
||||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);
|
||||||
|
|
||||||
// load language model from given path
|
// load language model from given path
|
||||||
int load_lm(const std::string &lm_path);
|
int load_lm(const std::string &lm_path);
|
||||||
|
146
native_client/generate_scorer_package.cpp
Normal file
146
native_client/generate_scorer_package.cpp
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <iostream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "boost/program_options.hpp"
|
||||||
|
|
||||||
|
#include "ctcdecode/decoder_utils.h"
|
||||||
|
#include "ctcdecode/scorer.h"
|
||||||
|
#include "alphabet.h"
|
||||||
|
#include "deepspeech.h"
|
||||||
|
|
||||||
|
namespace po = boost::program_options;
|
||||||
|
|
||||||
|
int
|
||||||
|
create_package(absl::optional<string> alphabet_path,
|
||||||
|
string lm_path,
|
||||||
|
string vocab_path,
|
||||||
|
string package_path,
|
||||||
|
absl::optional<bool> force_utf8,
|
||||||
|
float default_alpha,
|
||||||
|
float default_beta)
|
||||||
|
{
|
||||||
|
// Read vocabulary
|
||||||
|
unordered_set<string> words;
|
||||||
|
bool vocab_looks_char_based = true;
|
||||||
|
ifstream fin(vocab_path);
|
||||||
|
if (!fin) {
|
||||||
|
cerr << "Invalid vocabulary file " << vocab_path << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
string word;
|
||||||
|
while (fin >> word) {
|
||||||
|
words.insert(word);
|
||||||
|
if (get_utf8_str_len(word) > 1) {
|
||||||
|
vocab_looks_char_based = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cerr << words.size() << " unique words read from vocabulary file.\n"
|
||||||
|
<< (vocab_looks_char_based ? "Looks" : "Doesn't look")
|
||||||
|
<< " like a character based (Bytes Are All You Need) model.\n";
|
||||||
|
|
||||||
|
if (!force_utf8.has_value()) {
|
||||||
|
force_utf8 = vocab_looks_char_based;
|
||||||
|
cerr << "--force_utf8 was not specified, using value "
|
||||||
|
<< "infered from vocabulary contents: "
|
||||||
|
<< (vocab_looks_char_based ? "true" : "false") << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (force_utf8.value() && !alphabet_path.has_value()) {
|
||||||
|
cerr << "No --alphabet file specified, not using bytes output mode, can't continue.\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer scorer;
|
||||||
|
if (force_utf8.value()) {
|
||||||
|
scorer.set_alphabet(UTF8Alphabet());
|
||||||
|
} else {
|
||||||
|
Alphabet alphabet;
|
||||||
|
alphabet.init(alphabet_path->c_str());
|
||||||
|
scorer.set_alphabet(alphabet);
|
||||||
|
}
|
||||||
|
scorer.set_utf8_mode(force_utf8.value());
|
||||||
|
scorer.reset_params(default_alpha, default_beta);
|
||||||
|
int err = scorer.load_lm(lm_path);
|
||||||
|
if (err != DS_ERR_SCORER_NO_TRIE) {
|
||||||
|
cerr << "Error loading language model file: "
|
||||||
|
<< DS_ErrorCodeToErrorMessage(err) << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
scorer.fill_dictionary(words);
|
||||||
|
|
||||||
|
// Copy LM file to final package file destination
|
||||||
|
{
|
||||||
|
ifstream lm_src(lm_path, std::ios::binary);
|
||||||
|
ofstream package_dest(package_path, std::ios::binary);
|
||||||
|
package_dest << lm_src.rdbuf();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save dictionary to package file, appending instead of overwriting
|
||||||
|
if (!scorer.save_dictionary(package_path, true)) {
|
||||||
|
cerr << "Error when saving package in " << package_path << ".\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
cerr << "Package created in " << package_path << ".\n";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
po::options_description desc("Options");
|
||||||
|
desc.add_options()
|
||||||
|
("help", "show help message")
|
||||||
|
("alphabet", po::value<string>(), "Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.")
|
||||||
|
("lm", po::value<string>(), "Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.")
|
||||||
|
("vocab", po::value<string>(), "Path of vocabulary file. Must contain words separated by whitespace.")
|
||||||
|
("package", po::value<string>(), "Path to save scorer package.")
|
||||||
|
("default_alpha", po::value<float>(), "Default value of alpha hyperparameter (float).")
|
||||||
|
("default_beta", po::value<float>(), "Default value of beta hyperparameter (float).")
|
||||||
|
("force_utf8", po::value<bool>(), "Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://deepspeech.readthedocs.io/en/master/Decoder.html#utf-8-mode> for further explanation.")
|
||||||
|
;
|
||||||
|
|
||||||
|
po::variables_map vm;
|
||||||
|
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||||
|
po::notify(vm);
|
||||||
|
|
||||||
|
if (vm.count("help")) {
|
||||||
|
cout << desc << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required flags.
|
||||||
|
for (const string& flag : {"lm", "vocab", "package", "default_alpha", "default_beta"}) {
|
||||||
|
if (!vm.count(flag)) {
|
||||||
|
cerr << "--" << flag << " is a required flag. Pass --help for help.\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional --force_utf8
|
||||||
|
absl::optional<bool> force_utf8 = absl::nullopt;
|
||||||
|
if (vm.count("force_utf8")) {
|
||||||
|
force_utf8 = vm["force_utf8"].as<bool>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional --alphabet
|
||||||
|
absl::optional<string> alphabet = absl::nullopt;
|
||||||
|
if (vm.count("alphabet")) {
|
||||||
|
alphabet = vm["alphabet"].as<string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
create_package(alphabet,
|
||||||
|
vm["lm"].as<string>(),
|
||||||
|
vm["vocab"].as<string>(),
|
||||||
|
vm["package"].as<string>(),
|
||||||
|
force_utf8,
|
||||||
|
vm["default_alpha"].as<float>(),
|
||||||
|
vm["default_beta"].as<float>());
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user