diff --git a/data/lm/generate_package.py b/data/lm/generate_package.py deleted file mode 100644 index 30a33fcc..00000000 --- a/data/lm/generate_package.py +++ /dev/null @@ -1,157 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import, division, print_function - -import argparse -import shutil -import sys - -import ds_ctcdecoder -from deepspeech_training.util.text import Alphabet, UTF8Alphabet -from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet - - -def create_bundle( - alphabet_path, - lm_path, - vocab_path, - package_path, - force_utf8, - default_alpha, - default_beta, -): - words = set() - vocab_looks_char_based = True - with open(vocab_path) as fin: - for line in fin: - for word in line.split(): - words.add(word.encode("utf-8")) - if len(word) > 1: - vocab_looks_char_based = False - print("{} unique words read from vocabulary file.".format(len(words))) - - cbm = "Looks" if vocab_looks_char_based else "Doesn't look" - print("{} like a character based model.".format(cbm)) - - if force_utf8 != None: # pylint: disable=singleton-comparison - use_utf8 = force_utf8.value - else: - use_utf8 = vocab_looks_char_based - print("Using detected UTF-8 mode: {}".format(use_utf8)) - - if use_utf8: - serialized_alphabet = UTF8Alphabet().serialize() - else: - if not alphabet_path: - raise RuntimeError("No --alphabet path specified, can't continue.") - serialized_alphabet = Alphabet(alphabet_path).serialize() - - alphabet = NativeAlphabet() - err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet)) - if err != 0: - raise RuntimeError("Error loading alphabet: {}".format(err)) - - scorer = Scorer() - scorer.set_alphabet(alphabet) - scorer.set_utf8_mode(use_utf8) - scorer.reset_params(default_alpha, default_beta) - err = scorer.load_lm(lm_path) - if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE: - print('Error loading language model file: 0x{:X}.'.format(err)) - print('See the error codes section in https://deepspeech.readthedocs.io for a description.') - sys.exit(1) - scorer.fill_dictionary(list(words)) - shutil.copy(lm_path, package_path) - # append, not overwrite - if scorer.save_dictionary(package_path, True): - print("Package created in {}".format(package_path)) - else: - print("Error when creating {}".format(package_path)) - sys.exit(1) - - -class Tristate(object): - def __init__(self, value=None): - if any(value is v for v in (True, False, None)): - self.value = value - else: - raise ValueError("Tristate value must be True, False, or None") - - def __eq__(self, other): - return ( - self.value is other.value - if isinstance(other, Tristate) - else self.value is other - ) - - def __ne__(self, other): - return not self == other - - def __bool__(self): - raise TypeError("Tristate object may not be used as a Boolean") - - def __str__(self): - return str(self.value) - - def __repr__(self): - return "Tristate(%s)" % self.value - - -def main(): - parser = argparse.ArgumentParser( - description="Generate an external scorer package for DeepSpeech." - ) - parser.add_argument( - "--alphabet", - help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.", - ) - parser.add_argument( - "--lm", - required=True, - help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.", - ) - parser.add_argument( - "--vocab", - required=True, - help="Path of vocabulary file. Must contain words separated by whitespace.", - ) - parser.add_argument("--package", required=True, help="Path to save scorer package.") - parser.add_argument( - "--default_alpha", - type=float, - required=True, - help="Default value of alpha hyperparameter.", - ) - parser.add_argument( - "--default_beta", - type=float, - required=True, - help="Default value of beta hyperparameter.", - ) - parser.add_argument( - "--force_utf8", - type=str, - default="", - help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See for further explanation", - ) - args = parser.parse_args() - - if args.force_utf8 in ("True", "1", "true", "yes", "y"): - force_utf8 = Tristate(True) - elif args.force_utf8 in ("False", "0", "false", "no", "n"): - force_utf8 = Tristate(False) - else: - force_utf8 = Tristate(None) - - create_bundle( - args.alphabet, - args.lm, - args.vocab, - args.package, - force_utf8, - args.default_alpha, - args.default_beta, - ) - - -if __name__ == "__main__": - main() diff --git a/native_client/BUILD b/native_client/BUILD index 965a766c..36702088 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -2,6 +2,7 @@ load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load( "@org_tensorflow//tensorflow/lite:build_def.bzl", @@ -78,6 +79,8 @@ cc_library( hdrs = [ "ctcdecode/ctc_beam_search_decoder.h", "ctcdecode/scorer.h", + "ctcdecode/decoder_utils.h", + "alphabet.h", ], includes = [ ".", @@ -186,6 +189,22 @@ genrule( cmd = "dsymutil $(location :libdeepspeech.so) -o $@" ) +cc_binary( + name = "generate_scorer_package", + srcs = [ + "generate_scorer_package.cpp", + "deepspeech_errors.cc", + ], + copts = ["-std=c++11"], + deps = [ + ":decoder", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/types:optional", + "@boost//:program_options", + ], +) + cc_binary( name = "enumerate_kenlm_vocabulary", srcs = [ diff --git a/native_client/alphabet.h b/native_client/alphabet.h index ace905cc..e57ef914 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -19,7 +19,7 @@ public: Alphabet(const Alphabet&) = default; Alphabet& operator=(const Alphabet&) = default; - int init(const char *config_file) { + virtual int init(const char *config_file) { std::ifstream in(config_file, std::ios::in); if (!in) { return 1; @@ -45,6 +45,30 @@ public: return 0; } + std::string serialize() { + // Serialization format is a sequence of (key, value) pairs, where key is + // a uint16_t and value is a uint16_t length followed by `length` UTF-8 + // encoded bytes with the label. + std::stringstream out; + + // We start by writing the number of pairs in the buffer as uint16_t. + uint16_t size = size_; + out.write(reinterpret_cast(&size), sizeof(size)); + + for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) { + uint16_t key = it->first; + string str = it->second; + uint16_t len = str.length(); + // Then we write the key as uint16_t, followed by the length of the value + // as uint16_t, followed by `length` bytes (the value itself). + out.write(reinterpret_cast(&key), sizeof(key)); + out.write(reinterpret_cast(&len), sizeof(len)); + out.write(str.data(), len); + } + + return out.str(); + } + int deserialize(const char* buffer, const int buffer_size) { // See util/text.py for an explanation of the serialization format. int offset = 0; @@ -126,11 +150,28 @@ public: return word; } -private: +protected: size_t size_; unsigned int space_label_; std::unordered_map label_to_str_; std::unordered_map str_to_label_; }; +class UTF8Alphabet : public Alphabet +{ +public: + UTF8Alphabet() { + size_ = 255; + space_label_ = ' ' - 1; + for (int i = 0; i < size_; ++i) { + std::string val(1, i+1); + label_to_str_[i] = val; + str_to_label_[val] = i; + } + } + + int init(const char*) override {} +}; + + #endif //ALPHABET_H diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index ebf55227..401613d1 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -357,7 +357,7 @@ std::vector Scorer::make_ngram(PathTrie* prefix) return ngram; } -void Scorer::fill_dictionary(const std::vector& vocabulary) +void Scorer::fill_dictionary(const std::unordered_set& vocabulary) { // ConstFst is immutable, so we need to use a MutableFst to create the trie, // and then we convert to a ConstFst for the decoder and for storing on disk. diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index d2a1c8b3..3e7c0761 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "lm/virtual_interface.hh" @@ -83,7 +84,7 @@ public: bool is_scoring_boundary(PathTrie* prefix, size_t new_label); // fill dictionary FST from a vocabulary - void fill_dictionary(const std::vector &vocabulary); + void fill_dictionary(const std::unordered_set &vocabulary); // load language model from given path int load_lm(const std::string &lm_path); diff --git a/native_client/generate_scorer_package.cpp b/native_client/generate_scorer_package.cpp new file mode 100644 index 00000000..910bf9c2 --- /dev/null +++ b/native_client/generate_scorer_package.cpp @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include +using namespace std; + +#include "absl/types/optional.h" +#include "boost/program_options.hpp" + +#include "ctcdecode/decoder_utils.h" +#include "ctcdecode/scorer.h" +#include "alphabet.h" +#include "deepspeech.h" + +namespace po = boost::program_options; + +int +create_package(absl::optional alphabet_path, + string lm_path, + string vocab_path, + string package_path, + absl::optional force_utf8, + float default_alpha, + float default_beta) +{ + // Read vocabulary + unordered_set words; + bool vocab_looks_char_based = true; + ifstream fin(vocab_path); + if (!fin) { + cerr << "Invalid vocabulary file " << vocab_path << "\n"; + return 1; + } + string word; + while (fin >> word) { + words.insert(word); + if (get_utf8_str_len(word) > 1) { + vocab_looks_char_based = false; + } + } + cerr << words.size() << " unique words read from vocabulary file.\n" + << (vocab_looks_char_based ? "Looks" : "Doesn't look") + << " like a character based (Bytes Are All You Need) model.\n"; + + if (!force_utf8.has_value()) { + force_utf8 = vocab_looks_char_based; + cerr << "--force_utf8 was not specified, using value " + << "infered from vocabulary contents: " + << (vocab_looks_char_based ? "true" : "false") << "\n"; + } + + if (force_utf8.value() && !alphabet_path.has_value()) { + cerr << "No --alphabet file specified, not using bytes output mode, can't continue.\n"; + return 1; + } + + Scorer scorer; + if (force_utf8.value()) { + scorer.set_alphabet(UTF8Alphabet()); + } else { + Alphabet alphabet; + alphabet.init(alphabet_path->c_str()); + scorer.set_alphabet(alphabet); + } + scorer.set_utf8_mode(force_utf8.value()); + scorer.reset_params(default_alpha, default_beta); + int err = scorer.load_lm(lm_path); + if (err != DS_ERR_SCORER_NO_TRIE) { + cerr << "Error loading language model file: " + << DS_ErrorCodeToErrorMessage(err) << "\n"; + return 1; + } + scorer.fill_dictionary(words); + + // Copy LM file to final package file destination + { + ifstream lm_src(lm_path, std::ios::binary); + ofstream package_dest(package_path, std::ios::binary); + package_dest << lm_src.rdbuf(); + } + + // Save dictionary to package file, appending instead of overwriting + if (!scorer.save_dictionary(package_path, true)) { + cerr << "Error when saving package in " << package_path << ".\n"; + return 1; + } + + cerr << "Package created in " << package_path << ".\n"; + return 0; +} + +int +main(int argc, char** argv) +{ + po::options_description desc("Options"); + desc.add_options() + ("help", "show help message") + ("alphabet", po::value(), "Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.") + ("lm", po::value(), "Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.") + ("vocab", po::value(), "Path of vocabulary file. Must contain words separated by whitespace.") + ("package", po::value(), "Path to save scorer package.") + ("default_alpha", po::value(), "Default value of alpha hyperparameter (float).") + ("default_beta", po::value(), "Default value of beta hyperparameter (float).") + ("force_utf8", po::value(), "Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See for further explanation.") + ; + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + po::notify(vm); + + if (vm.count("help")) { + cout << desc << "\n"; + return 1; + } + + // Check required flags. + for (const string& flag : {"lm", "vocab", "package", "default_alpha", "default_beta"}) { + if (!vm.count(flag)) { + cerr << "--" << flag << " is a required flag. Pass --help for help.\n"; + return 1; + } + } + + // Parse optional --force_utf8 + absl::optional force_utf8 = absl::nullopt; + if (vm.count("force_utf8")) { + force_utf8 = vm["force_utf8"].as(); + } + + // Parse optional --alphabet + absl::optional alphabet = absl::nullopt; + if (vm.count("alphabet")) { + alphabet = vm["alphabet"].as(); + } + + create_package(alphabet, + vm["lm"].as(), + vm["vocab"].as(), + vm["package"].as(), + force_utf8, + vm["default_alpha"].as(), + vm["default_beta"].as()); + + return 0; +}