diff --git a/data/lm/generate_package.py b/data/lm/generate_package.py new file mode 100644 index 00000000..ee3c106b --- /dev/null +++ b/data/lm/generate_package.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +from __future__ import absolute_import, division, print_function + +# Make sure we can import stuff from util/ +# This script needs to be run from the root of the DeepSpeech repository +import os +import sys +sys.path.insert(1, os.path.join(sys.path[0], '..', '..')) + +import argparse +import shutil + +from util.text import Alphabet, UTF8Alphabet +from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet + + +def create_bundle(alphabet_path, lm_path, vocab_path, package_path, force_utf8): + words = set() + vocab_looks_char_based = True + with open(vocab_path) as fin: + for line in fin: + for word in line.split(): + words.add(word.encode('utf-8')) + if len(word) > 1: + vocab_looks_char_based = False + print("{} unique words read from vocabulary file.".format(len(words))) + print( + "{} like a character based model.".format( + "Looks" if vocab_looks_char_based else "Doesn't look" + ) + ) + + if force_utf8 != None: + use_utf8 = force_utf8.value + else: + use_utf8 = vocab_looks_char_based + + if use_utf8: + serialized_alphabet = UTF8Alphabet().serialize() + else: + serialized_alphabet = Alphabet(alphabet_path).serialize() + + alphabet = NativeAlphabet() + err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet)) + if err != 0: + print("Error loading alphabet: {}".format(err)) + sys.exit(1) + + scorer = Scorer() + scorer.set_alphabet(alphabet) + scorer.set_utf8_mode(use_utf8) + scorer.load_lm(lm_path, "") + scorer.fill_dictionary(list(words)) + shutil.copy(lm_path, package_path) + scorer.save_dictionary(package_path, True) # append, not overwrite + print('Package created in {}'.format(package_path)) + + +class Tristate(object): + def __init__(self, value=None): + if any(value is v for v in (True, False, None)): + self.value = value + else: + raise ValueError("Tristate value must be True, False, or None") + + def __eq__(self, other): + return (self.value is other.value if isinstance(other, Tristate) + else self.value is other) + + def __ne__(self, other): + return not self == other + + def __bool__(self): + raise TypeError("Tristate object may not be used as a Boolean") + + def __str__(self): + return str(self.value) + + def __repr__(self): + return "Tristate(%s)" % self.value + + +def main(): + parser = argparse.ArgumentParser( + description="Generate an external scorer package for DeepSpeech." + ) + parser.add_argument( + "--alphabet", + help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.", + ) + parser.add_argument( + "--lm", + required=True, + help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.", + ) + parser.add_argument( + "--vocab", + required=True, + help="Path of vocabulary file. Must contain words separated by whitespace.", + ) + parser.add_argument("--package", required=True, help="Path to save scorer package.") + parser.add_argument( + "--force_utf8", + default="", + help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary.", + ) + args = parser.parse_args() + + if args.force_utf8 in ("True", "1", "true", "yes", "y"): + force_utf8 = Tristate(True) + elif args.force_utf8 in ("False", "0", "false", "no", "n"): + force_utf8 = Tristate(False) + else: + force_utf8 = Tristate(None) + + create_bundle(args.alphabet, args.lm, args.vocab, args.package, force_utf8) + + +if __name__ == "__main__": + main() diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 71432a7c..3fab4eb7 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -17,29 +17,28 @@ class Scorer(swigwrapper.Scorer): :alphabet: Alphabet :type model_path: basestring """ - def __init__(self, alpha, beta, model_path, trie_path, alphabet): + def __init__(self, alpha=None, beta=None, model_path=None, trie_path=None, alphabet=None): super(Scorer, self).__init__() - serialized = alphabet.serialize() - native_alphabet = swigwrapper.Alphabet() - err = native_alphabet.deserialize(serialized, len(serialized)) - if err != 0: - raise ValueError("Error when deserializing alphabet.") + # Allow bare initialization + if alphabet: + serialized = alphabet.serialize() + native_alphabet = swigwrapper.Alphabet() + err = native_alphabet.deserialize(serialized, len(serialized)) + if err != 0: + raise ValueError("Error when deserializing alphabet.") - err = self.init(alpha, beta, - model_path.encode('utf-8'), - trie_path.encode('utf-8'), - native_alphabet) - if err != 0: - raise ValueError("Scorer initialization failed with error code {}".format(err), err) - - def __init__(self): - super(Scorer, self).__init__() + err = self.init(alpha, beta, + model_path.encode('utf-8'), + trie_path.encode('utf-8'), + native_alphabet) + if err != 0: + raise ValueError("Scorer initialization failed with error code {}".format(err), err) def load_lm(self, lm_path, trie_path): super(Scorer, self).load_lm(lm_path.encode('utf-8'), trie_path.encode('utf-8')) - def save_dictionary(self, save_path): - super(Scorer, self).save_dictionary(save_path.encode('utf-8')) + def save_dictionary(self, save_path, *args, **kwargs): + super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs) def ctc_beam_search_decoder(probs_seq, diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 3180724f..dfe2824a 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -128,9 +128,15 @@ void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path) max_order_ = language_model_->Order(); } -void Scorer::save_dictionary(const std::string& path) +void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite) { - std::fstream fout(path, std::ios::in|std::ios::out|std::ios::binary|std::ios::ate); + std::ios::openmode om; + if (append_instead_of_overwrite) { + om = std::ios::in|std::ios::out|std::ios::binary|std::ios::ate; + } else { + om = std::ios::out|std::ios::binary; + } + std::fstream fout(path, om); fout.write(reinterpret_cast(&MAGIC), sizeof(MAGIC)); fout.write(reinterpret_cast(&FILE_VERSION), sizeof(FILE_VERSION)); fout.write(reinterpret_cast(&is_utf8_mode_), sizeof(is_utf8_mode_)); diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index f6c7d7bb..17bd1028 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -95,7 +95,7 @@ public: void set_alphabet(const Alphabet& alphabet); // save dictionary in file - void save_dictionary(const std::string &path); + void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false); // return weather this step represents a boundary where beam scoring should happen bool is_scoring_boundary(PathTrie* prefix, size_t new_label);