Merge pull request #3113 from mozilla/generate-package-cpp
Rewrite generate_package.py in C++ to avoid training dependencies
This commit is contained in:
commit
d2d46c3aee
@ -3,9 +3,9 @@ Language-Specific Data
|
|||||||
|
|
||||||
This directory contains language-specific data files. Most importantly, you will find here:
|
This directory contains language-specific data files. Most importantly, you will find here:
|
||||||
|
|
||||||
1. A list of unique characters for the target language (e.g. English) in `data/alphabet.txt`
|
1. A list of unique characters for the target language (e.g. English) in ``data/alphabet.txt``. After installing the training code, you can check ``python -m deepspeech_training.util.check_characters --help`` for a tool that creates an alphabet file from a list of training CSV files.
|
||||||
|
|
||||||
2. A scorer package (`data/lm/kenlm.scorer`) generated with `data/lm/generate_package.py`. The scorer package includes a binary n-gram language model generated with `data/lm/generate_lm.py`.
|
2. A scorer package (``data/lm/kenlm.scorer``) generated with ``generate_scorer_package`` (``native_client/generate_scorer_package.cpp``). The scorer package includes a binary n-gram language model generated with ``data/lm/generate_lm.py``.
|
||||||
|
|
||||||
For more information on how to build these resources from scratch, see the ``External scorer scripts`` section on `deepspeech.readthedocs.io <https://deepspeech.readthedocs.io/>`_.
|
For more information on how to build these resources from scratch, see the ``External scorer scripts`` section on `deepspeech.readthedocs.io <https://deepspeech.readthedocs.io/>`_.
|
||||||
|
|
||||||
|
@ -1,157 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
from __future__ import absolute_import, division, print_function
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import ds_ctcdecoder
|
|
||||||
from deepspeech_training.util.text import Alphabet, UTF8Alphabet
|
|
||||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
|
||||||
|
|
||||||
|
|
||||||
def create_bundle(
|
|
||||||
alphabet_path,
|
|
||||||
lm_path,
|
|
||||||
vocab_path,
|
|
||||||
package_path,
|
|
||||||
force_utf8,
|
|
||||||
default_alpha,
|
|
||||||
default_beta,
|
|
||||||
):
|
|
||||||
words = set()
|
|
||||||
vocab_looks_char_based = True
|
|
||||||
with open(vocab_path) as fin:
|
|
||||||
for line in fin:
|
|
||||||
for word in line.split():
|
|
||||||
words.add(word.encode("utf-8"))
|
|
||||||
if len(word) > 1:
|
|
||||||
vocab_looks_char_based = False
|
|
||||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
|
||||||
|
|
||||||
cbm = "Looks" if vocab_looks_char_based else "Doesn't look"
|
|
||||||
print("{} like a character based model.".format(cbm))
|
|
||||||
|
|
||||||
if force_utf8 != None: # pylint: disable=singleton-comparison
|
|
||||||
use_utf8 = force_utf8.value
|
|
||||||
else:
|
|
||||||
use_utf8 = vocab_looks_char_based
|
|
||||||
print("Using detected UTF-8 mode: {}".format(use_utf8))
|
|
||||||
|
|
||||||
if use_utf8:
|
|
||||||
serialized_alphabet = UTF8Alphabet().serialize()
|
|
||||||
else:
|
|
||||||
if not alphabet_path:
|
|
||||||
raise RuntimeError("No --alphabet path specified, can't continue.")
|
|
||||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
|
||||||
|
|
||||||
alphabet = NativeAlphabet()
|
|
||||||
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
|
||||||
if err != 0:
|
|
||||||
raise RuntimeError("Error loading alphabet: {}".format(err))
|
|
||||||
|
|
||||||
scorer = Scorer()
|
|
||||||
scorer.set_alphabet(alphabet)
|
|
||||||
scorer.set_utf8_mode(use_utf8)
|
|
||||||
scorer.reset_params(default_alpha, default_beta)
|
|
||||||
err = scorer.load_lm(lm_path)
|
|
||||||
if err != ds_ctcdecoder.DS_ERR_SCORER_NO_TRIE:
|
|
||||||
print('Error loading language model file: 0x{:X}.'.format(err))
|
|
||||||
print('See the error codes section in https://deepspeech.readthedocs.io for a description.')
|
|
||||||
sys.exit(1)
|
|
||||||
scorer.fill_dictionary(list(words))
|
|
||||||
shutil.copy(lm_path, package_path)
|
|
||||||
# append, not overwrite
|
|
||||||
if scorer.save_dictionary(package_path, True):
|
|
||||||
print("Package created in {}".format(package_path))
|
|
||||||
else:
|
|
||||||
print("Error when creating {}".format(package_path))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
class Tristate(object):
|
|
||||||
def __init__(self, value=None):
|
|
||||||
if any(value is v for v in (True, False, None)):
|
|
||||||
self.value = value
|
|
||||||
else:
|
|
||||||
raise ValueError("Tristate value must be True, False, or None")
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (
|
|
||||||
self.value is other.value
|
|
||||||
if isinstance(other, Tristate)
|
|
||||||
else self.value is other
|
|
||||||
)
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self == other
|
|
||||||
|
|
||||||
def __bool__(self):
|
|
||||||
raise TypeError("Tristate object may not be used as a Boolean")
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return str(self.value)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "Tristate(%s)" % self.value
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate an external scorer package for DeepSpeech."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--alphabet",
|
|
||||||
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lm",
|
|
||||||
required=True,
|
|
||||||
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--vocab",
|
|
||||||
required=True,
|
|
||||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--default_alpha",
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help="Default value of alpha hyperparameter.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--default_beta",
|
|
||||||
type=float,
|
|
||||||
required=True,
|
|
||||||
help="Default value of beta hyperparameter.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--force_utf8",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://github.com/mozilla/DeepSpeech/blob/master/doc/Decoder.rst#utf-8-mode> for further explanation",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
|
|
||||||
force_utf8 = Tristate(True)
|
|
||||||
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
|
|
||||||
force_utf8 = Tristate(False)
|
|
||||||
else:
|
|
||||||
force_utf8 = Tristate(None)
|
|
||||||
|
|
||||||
create_bundle(
|
|
||||||
args.alphabet,
|
|
||||||
args.lm,
|
|
||||||
args.vocab,
|
|
||||||
args.package,
|
|
||||||
force_utf8,
|
|
||||||
args.default_alpha,
|
|
||||||
args.default_beta,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -56,9 +56,11 @@ At decoding time, the scorer is queried every time a Unicode codepoint is predic
|
|||||||
|
|
||||||
**Acoustic models trained with ``--utf8`` MUST NOT be used with an alphabet based scorer. Conversely, acoustic models trained with an alphabet file MUST NOT be used with a UTF-8 scorer.**
|
**Acoustic models trained with ``--utf8`` MUST NOT be used with an alphabet based scorer. Conversely, acoustic models trained with an alphabet file MUST NOT be used with a UTF-8 scorer.**
|
||||||
|
|
||||||
UTF-8 scorers can be built by using an input corpus with space separated codepoints. If your corpus only contains single codepoints separated by spaces, ``data/lm/generate_package.py`` should automatically enable UTF-8 mode, and it should print the message "Looks like a character based model."
|
UTF-8 scorers can be built by using an input corpus with space separated codepoints. If your corpus only contains single codepoints separated by spaces, ``generate_scorer_package`` should automatically enable UTF-8 mode, and it should print the message "Looks like a character based model."
|
||||||
|
|
||||||
If the message "Doesn't look like a character based model." is printed, you should double check your inputs to make sure it only contains single codepoints separated by spaces. UTF-8 mode can be forced by specifying the ``--force_utf8`` flag when running ``data/lm/generate_package.py``, but it is NOT RECOMMENDED.
|
If the message "Doesn't look like a character based model." is printed, you should double check your inputs to make sure it only contains single codepoints separated by spaces. UTF-8 mode can be forced by specifying the ``--force_utf8`` flag when running ``generate_scorer_package``, but it is NOT RECOMMENDED.
|
||||||
|
|
||||||
|
See :ref:`scorer-scripts` for more details on using ``generate_scorer_package``.
|
||||||
|
|
||||||
Because KenLM uses spaces as a word separator, the resulting language model will not include space characters in it. If you wish to use UTF-8 mode but still model spaces, you need to replace spaces in the input corpus with a different character **before** converting it to space separated codepoints. For example:
|
Because KenLM uses spaces as a word separator, the resulting language model will not include space characters in it. If you wish to use UTF-8 mode but still model spaces, you need to replace spaces in the input corpus with a different character **before** converting it to space separated codepoints. For example:
|
||||||
|
|
||||||
|
@ -5,7 +5,9 @@ External scorer scripts
|
|||||||
|
|
||||||
DeepSpeech pre-trained models include an external scorer. This document explains how to reproduce our external scorer, as well as adapt the scripts to create your own.
|
DeepSpeech pre-trained models include an external scorer. This document explains how to reproduce our external scorer, as well as adapt the scripts to create your own.
|
||||||
|
|
||||||
The scorer is composed of two sub-components, a KenLM language model and a trie data structure containing all words in the vocabulary. In order to create the scorer package, first we must create a KenLM language model (using ``data/lm/generate_lm.py``, and then use ``data/lm/generate_package.py`` to create the final package file including the trie data structure.
|
The scorer is composed of two sub-components, a KenLM language model and a trie data structure containing all words in the vocabulary. In order to create the scorer package, first we must create a KenLM language model (using ``data/lm/generate_lm.py``, and then use ``generate_scorer_package`` to create the final package file including the trie data structure.
|
||||||
|
|
||||||
|
The ``generate_scorer_package`` binary is part of the native client package that is included with official releases. You can find the appropriate archive for your platform in the `GitHub release downloads <https://github.com/mozilla/DeepSpeech/releases/latest>`_. The native client package is named ``native_client.{arch}.{config}.{plat}.tar.xz``, where ``{arch}`` is the architecture the binary was built for, for example ``amd64`` or ``arm64``, ``config`` is the build configuration, which for building decoder packages does not matter, and ``{plat}`` is the platform the binary was built-for, for example ``linux`` or ``osx``. If you wanted to run the ``generate_scorer_package`` binary on a Linux desktop, you would download ``native_client.amd64.cpu.linux.tar.xz``.
|
||||||
|
|
||||||
Reproducing our external scorer
|
Reproducing our external scorer
|
||||||
-------------------------------
|
-------------------------------
|
||||||
@ -36,12 +38,15 @@ Else you have to build `KenLM <https://github.com/kpu/kenlm>`_ first and then pa
|
|||||||
--binary_a_bits 255 --binary_q_bits 8 --binary_type trie
|
--binary_a_bits 255 --binary_q_bits 8 --binary_type trie
|
||||||
|
|
||||||
|
|
||||||
Afterwards you can use ``generate_package.py`` to generate the scorer package using the ``lm.binary`` and ``vocab-500000.txt`` files:
|
Afterwards you can use ``generate_scorer_package`` to generate the scorer package using the ``lm.binary`` and ``vocab-500000.txt`` files:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
cd data/lm
|
cd data/lm
|
||||||
python3 generate_package.py --alphabet ../alphabet.txt --lm lm.binary --vocab vocab-500000.txt \
|
# Download and extract appropriate native_client package:
|
||||||
|
curl -LO http://github.com/mozilla/DeepSpeech/releases/...
|
||||||
|
tar xvf native_client.*.tar.xz
|
||||||
|
./generate_scorer_package --alphabet ../alphabet.txt --lm lm.binary --vocab vocab-500000.txt \
|
||||||
--package kenlm.scorer --default_alpha 0.931289039105002 --default_beta 1.1834137581510284
|
--package kenlm.scorer --default_alpha 0.931289039105002 --default_beta 1.1834137581510284
|
||||||
|
|
||||||
Building your own scorer
|
Building your own scorer
|
||||||
@ -51,7 +56,6 @@ Building your own scorer can be useful if you're using models in a narrow usage
|
|||||||
|
|
||||||
The LibriSpeech LM training text used by our scorer is around 4GB uncompressed, which should give an idea of the size of a corpus needed for a reasonable language model for general speech recognition. For more constrained use cases with smaller vocabularies, you don't need as much data, but you should still try to gather as much as you can.
|
The LibriSpeech LM training text used by our scorer is around 4GB uncompressed, which should give an idea of the size of a corpus needed for a reasonable language model for general speech recognition. For more constrained use cases with smaller vocabularies, you don't need as much data, but you should still try to gather as much as you can.
|
||||||
|
|
||||||
With a text corpus in hand, you can then re-use the ``generate_lm.py`` and ``generate_package.py`` scripts to create your own scorer that is compatible with DeepSpeech clients and language bindings. Before building the language model, you must first familiarize yourself with the `KenLM toolkit <https://kheafield.com/code/kenlm/>`_. Most of the options exposed by the ``generate_lm.py`` script are simply forwarded to KenLM options of the same name, so you must read the KenLM documentation in order to fully understand their behavior.
|
With a text corpus in hand, you can then re-use ``generate_lm.py`` and ``generate_scorer_package`` to create your own scorer that is compatible with DeepSpeech clients and language bindings. Before building the language model, you must first familiarize yourself with the `KenLM toolkit <https://kheafield.com/code/kenlm/>`_. Most of the options exposed by the ``generate_lm.py`` script are simply forwarded to KenLM options of the same name, so you must read the KenLM documentation in order to fully understand their behavior.
|
||||||
|
|
||||||
After using ``generate_lm.py`` to create a KenLM language model binary file, you can use ``generate_package.py`` to create a scorer package as described in the previous section. Note that we have a :github:`lm_optimizer.py script <lm_optimizer.py>` which can be used to find good default values for alpha and beta. To use it, you must first
|
After using ``generate_lm.py`` to create a KenLM language model binary file, you can use ``generate_scorer_package`` to create a scorer package as described in the previous section. Note that we have a :github:`lm_optimizer.py script <lm_optimizer.py>` which can be used to find good default values for alpha and beta. To use it, you must first generate a package with any value set for default alpha and beta flags. For this step, it doesn't matter what values you use, as they'll be overridden by ``lm_optimizer.py`` later. Then, use ``lm_optimizer.py`` with this scorer file to find good alpha and beta values. Finally, use ``generate_scorer_package`` again, this time with the new values.
|
||||||
generate a package with any value set for default alpha and beta flags. For this step, it doesn't matter what values you use, as they'll be overridden by ``lm_optimizer.py``. Then, use ``lm_optimizer.py`` with this scorer file to find good alpha and beta values. Finally, use ``generate_package.py`` again, this time with the new values.
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cc_shared_object")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
|
load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||||
|
|
||||||
load(
|
load(
|
||||||
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||||
@ -74,16 +75,24 @@ cc_library(
|
|||||||
"ctcdecode/scorer.cpp",
|
"ctcdecode/scorer.cpp",
|
||||||
"ctcdecode/path_trie.cpp",
|
"ctcdecode/path_trie.cpp",
|
||||||
"ctcdecode/path_trie.h",
|
"ctcdecode/path_trie.h",
|
||||||
|
"alphabet.cc",
|
||||||
] + OPENFST_SOURCES_PLATFORM,
|
] + OPENFST_SOURCES_PLATFORM,
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ctcdecode/ctc_beam_search_decoder.h",
|
"ctcdecode/ctc_beam_search_decoder.h",
|
||||||
"ctcdecode/scorer.h",
|
"ctcdecode/scorer.h",
|
||||||
|
"ctcdecode/decoder_utils.h",
|
||||||
|
"alphabet.h",
|
||||||
],
|
],
|
||||||
includes = [
|
includes = [
|
||||||
".",
|
".",
|
||||||
"ctcdecode/third_party/ThreadPool",
|
"ctcdecode/third_party/ThreadPool",
|
||||||
] + OPENFST_INCLUDES_PLATFORM,
|
] + OPENFST_INCLUDES_PLATFORM,
|
||||||
deps = [":kenlm"]
|
deps = [":kenlm"],
|
||||||
|
linkopts = [
|
||||||
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_shared_object(
|
tf_cc_shared_object(
|
||||||
@ -91,11 +100,11 @@ tf_cc_shared_object(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"deepspeech.cc",
|
"deepspeech.cc",
|
||||||
"deepspeech.h",
|
"deepspeech.h",
|
||||||
"alphabet.h",
|
"deepspeech_errors.cc",
|
||||||
"modelstate.h",
|
|
||||||
"modelstate.cc",
|
"modelstate.cc",
|
||||||
"workspace_status.h",
|
"modelstate.h",
|
||||||
"workspace_status.cc",
|
"workspace_status.cc",
|
||||||
|
"workspace_status.h",
|
||||||
] + select({
|
] + select({
|
||||||
"//native_client:tflite": [
|
"//native_client:tflite": [
|
||||||
"tflitemodelstate.h",
|
"tflitemodelstate.h",
|
||||||
@ -185,6 +194,27 @@ genrule(
|
|||||||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_binary(
|
||||||
|
name = "generate_scorer_package",
|
||||||
|
srcs = [
|
||||||
|
"generate_scorer_package.cpp",
|
||||||
|
"deepspeech_errors.cc",
|
||||||
|
],
|
||||||
|
copts = ["-std=c++11"],
|
||||||
|
deps = [
|
||||||
|
":decoder",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/flags:parse",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@boost//:program_options",
|
||||||
|
],
|
||||||
|
linkopts = [
|
||||||
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "enumerate_kenlm_vocabulary",
|
name = "enumerate_kenlm_vocabulary",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -201,10 +231,5 @@ cc_binary(
|
|||||||
"trie_load.cc",
|
"trie_load.cc",
|
||||||
],
|
],
|
||||||
copts = ["-std=c++11"],
|
copts = ["-std=c++11"],
|
||||||
linkopts = [
|
|
||||||
"-lm",
|
|
||||||
"-ldl",
|
|
||||||
"-pthread",
|
|
||||||
],
|
|
||||||
deps = [":decoder"],
|
deps = [":decoder"],
|
||||||
)
|
)
|
||||||
|
189
native_client/alphabet.cc
Normal file
189
native_client/alphabet.cc
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
#include "alphabet.h"
|
||||||
|
#include "ctcdecode/decoder_utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
// std::getline, but handle newline conventions from multiple platforms instead
|
||||||
|
// of just the platform this code was built for
|
||||||
|
std::istream&
|
||||||
|
getline_crossplatform(std::istream& is, std::string& t)
|
||||||
|
{
|
||||||
|
t.clear();
|
||||||
|
|
||||||
|
// The characters in the stream are read one-by-one using a std::streambuf.
|
||||||
|
// That is faster than reading them one-by-one using the std::istream.
|
||||||
|
// Code that uses streambuf this way must be guarded by a sentry object.
|
||||||
|
// The sentry object performs various tasks,
|
||||||
|
// such as thread synchronization and updating the stream state.
|
||||||
|
std::istream::sentry se(is, true);
|
||||||
|
std::streambuf* sb = is.rdbuf();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
int c = sb->sbumpc();
|
||||||
|
switch (c) {
|
||||||
|
case '\n':
|
||||||
|
return is;
|
||||||
|
case '\r':
|
||||||
|
if(sb->sgetc() == '\n')
|
||||||
|
sb->sbumpc();
|
||||||
|
return is;
|
||||||
|
case std::streambuf::traits_type::eof():
|
||||||
|
// Also handle the case when the last line has no line ending
|
||||||
|
if(t.empty())
|
||||||
|
is.setstate(std::ios::eofbit);
|
||||||
|
return is;
|
||||||
|
default:
|
||||||
|
t += (char)c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
Alphabet::init(const char *config_file)
|
||||||
|
{
|
||||||
|
std::ifstream in(config_file, std::ios::in);
|
||||||
|
if (!in) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
unsigned int label = 0;
|
||||||
|
space_label_ = -2;
|
||||||
|
for (std::string line; getline_crossplatform(in, line);) {
|
||||||
|
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||||
|
line = '#';
|
||||||
|
} else if (line[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
//TODO: we should probably do something more i18n-aware here
|
||||||
|
if (line == " ") {
|
||||||
|
space_label_ = label;
|
||||||
|
}
|
||||||
|
label_to_str_[label] = line;
|
||||||
|
str_to_label_[line] = label;
|
||||||
|
++label;
|
||||||
|
}
|
||||||
|
size_ = label;
|
||||||
|
in.close();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Serialize()
|
||||||
|
{
|
||||||
|
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||||
|
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||||
|
// encoded bytes with the label.
|
||||||
|
std::stringstream out;
|
||||||
|
|
||||||
|
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||||
|
uint16_t size = size_;
|
||||||
|
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||||
|
|
||||||
|
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||||
|
uint16_t key = it->first;
|
||||||
|
string str = it->second;
|
||||||
|
uint16_t len = str.length();
|
||||||
|
// Then we write the key as uint16_t, followed by the length of the value
|
||||||
|
// as uint16_t, followed by `length` bytes (the value itself).
|
||||||
|
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||||
|
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||||
|
out.write(str.data(), len);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
Alphabet::Deserialize(const char* buffer, const int buffer_size)
|
||||||
|
{
|
||||||
|
// See util/text.py for an explanation of the serialization format.
|
||||||
|
int offset = 0;
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t size = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
size_ = size;
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t label = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < val_len) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
std::string val(buffer+offset, val_len);
|
||||||
|
offset += val_len;
|
||||||
|
|
||||||
|
label_to_str_[label] = val;
|
||||||
|
str_to_label_[val] = label;
|
||||||
|
|
||||||
|
if (val == " ") {
|
||||||
|
space_label_ = label;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::DecodeSingle(unsigned int label) const
|
||||||
|
{
|
||||||
|
auto it = label_to_str_.find(label);
|
||||||
|
if (it != label_to_str_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Invalid label " << label << std::endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned int
|
||||||
|
Alphabet::EncodeSingle(const std::string& string) const
|
||||||
|
{
|
||||||
|
auto it = str_to_label_.find(string);
|
||||||
|
if (it != str_to_label_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Invalid string " << string << std::endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Decode(const std::vector<unsigned int>& input) const
|
||||||
|
{
|
||||||
|
std::string word;
|
||||||
|
for (auto ind : input) {
|
||||||
|
word += DecodeSingle(ind);
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Decode(const unsigned int* input, int length) const
|
||||||
|
{
|
||||||
|
std::string word;
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
word += DecodeSingle(input[i]);
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned int>
|
||||||
|
Alphabet::Encode(const std::string& input) const
|
||||||
|
{
|
||||||
|
std::vector<unsigned int> result;
|
||||||
|
for (auto cp : split_into_codepoints(input)) {
|
||||||
|
result.push_back(EncodeSingle(cp));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
@ -1,9 +1,6 @@
|
|||||||
#ifndef ALPHABET_H
|
#ifndef ALPHABET_H
|
||||||
#define ALPHABET_H
|
#define ALPHABET_H
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -18,92 +15,15 @@ public:
|
|||||||
Alphabet() = default;
|
Alphabet() = default;
|
||||||
Alphabet(const Alphabet&) = default;
|
Alphabet(const Alphabet&) = default;
|
||||||
Alphabet& operator=(const Alphabet&) = default;
|
Alphabet& operator=(const Alphabet&) = default;
|
||||||
|
virtual ~Alphabet() = default;
|
||||||
|
|
||||||
int init(const char *config_file) {
|
virtual int init(const char *config_file);
|
||||||
std::ifstream in(config_file, std::ios::in);
|
|
||||||
if (!in) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
unsigned int label = 0;
|
|
||||||
space_label_ = -2;
|
|
||||||
for (std::string line; std::getline(in, line);) {
|
|
||||||
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
|
||||||
line = '#';
|
|
||||||
} else if (line[0] == '#') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
//TODO: we should probably do something more i18n-aware here
|
|
||||||
if (line == " ") {
|
|
||||||
space_label_ = label;
|
|
||||||
}
|
|
||||||
label_to_str_[label] = line;
|
|
||||||
str_to_label_[line] = label;
|
|
||||||
++label;
|
|
||||||
}
|
|
||||||
size_ = label;
|
|
||||||
in.close();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int deserialize(const char* buffer, const int buffer_size) {
|
// Serialize alphabet into a binary buffer.
|
||||||
// See util/text.py for an explanation of the serialization format.
|
std::string Serialize();
|
||||||
int offset = 0;
|
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t size = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
size_ = size;
|
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
// Deserialize alphabet from a binary buffer.
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
int Deserialize(const char* buffer, const int buffer_size);
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t label = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
|
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
|
|
||||||
if (buffer_size - offset < val_len) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
std::string val(buffer+offset, val_len);
|
|
||||||
offset += val_len;
|
|
||||||
|
|
||||||
label_to_str_[label] = val;
|
|
||||||
str_to_label_[val] = label;
|
|
||||||
|
|
||||||
if (val == " ") {
|
|
||||||
space_label_ = label;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string& StringFromLabel(unsigned int label) const {
|
|
||||||
auto it = label_to_str_.find(label);
|
|
||||||
if (it != label_to_str_.end()) {
|
|
||||||
return it->second;
|
|
||||||
} else {
|
|
||||||
std::cerr << "Invalid label " << label << std::endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned int LabelFromString(const std::string& string) const {
|
|
||||||
auto it = str_to_label_.find(string);
|
|
||||||
if (it != str_to_label_.end()) {
|
|
||||||
return it->second;
|
|
||||||
} else {
|
|
||||||
std::cerr << "Invalid string " << string << std::endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetSize() const {
|
size_t GetSize() const {
|
||||||
return size_;
|
return size_;
|
||||||
@ -117,20 +37,47 @@ public:
|
|||||||
return space_label_;
|
return space_label_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
// Decode a single label into a string.
|
||||||
std::string LabelsToString(const std::vector<T>& input) const {
|
std::string DecodeSingle(unsigned int label) const;
|
||||||
std::string word;
|
|
||||||
for (auto ind : input) {
|
|
||||||
word += StringFromLabel(ind);
|
|
||||||
}
|
|
||||||
return word;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
// Encode a single character/output class into a label.
|
||||||
|
unsigned int EncodeSingle(const std::string& string) const;
|
||||||
|
|
||||||
|
// Decode a sequence of labels into a string.
|
||||||
|
std::string Decode(const std::vector<unsigned int>& input) const;
|
||||||
|
|
||||||
|
// We provide a C-style overload for accepting NumPy arrays as input, since
|
||||||
|
// the NumPy library does not have built-in typemaps for std::vector<T>.
|
||||||
|
std::string Decode(const unsigned int* input, int length) const;
|
||||||
|
|
||||||
|
// Encode a sequence of character/output classes into a sequence of labels.
|
||||||
|
// Characters are assumed to always take a single Unicode codepoint.
|
||||||
|
std::vector<unsigned int> Encode(const std::string& input) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
unsigned int space_label_;
|
unsigned int space_label_;
|
||||||
std::unordered_map<unsigned int, std::string> label_to_str_;
|
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||||
std::unordered_map<std::string, unsigned int> str_to_label_;
|
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class UTF8Alphabet : public Alphabet
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
UTF8Alphabet() {
|
||||||
|
size_ = 255;
|
||||||
|
space_label_ = ' ' - 1;
|
||||||
|
for (size_t i = 0; i < size_; ++i) {
|
||||||
|
std::string val(1, i+1);
|
||||||
|
label_to_str_[i] = val;
|
||||||
|
str_to_label_[val] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int init(const char*) override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#endif //ALPHABET_H
|
#endif //ALPHABET_H
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from . import swigwrapper # pylint: disable=import-self
|
from . import swigwrapper # pylint: disable=import-self
|
||||||
from .swigwrapper import Alphabet
|
from .swigwrapper import UTF8Alphabet
|
||||||
|
|
||||||
__version__ = swigwrapper.__version__
|
__version__ = swigwrapper.__version__
|
||||||
|
|
||||||
@ -30,24 +30,25 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
assert beta is not None, 'beta parameter is required'
|
assert beta is not None, 'beta parameter is required'
|
||||||
assert scorer_path, 'scorer_path parameter is required'
|
assert scorer_path, 'scorer_path parameter is required'
|
||||||
|
|
||||||
serialized = alphabet.serialize()
|
err = self.init(scorer_path, alphabet)
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError('Error when deserializing alphabet.')
|
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
err = self.init(scorer_path.encode('utf-8'),
|
|
||||||
native_alphabet)
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError('Scorer initialization failed with error code {}'.format(err))
|
|
||||||
|
|
||||||
self.reset_params(alpha, beta)
|
self.reset_params(alpha, beta)
|
||||||
|
|
||||||
def load_lm(self, lm_path):
|
|
||||||
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
|
||||||
|
|
||||||
def save_dictionary(self, save_path, *args, **kwargs):
|
class Alphabet(swigwrapper.Alphabet):
|
||||||
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||||
|
def __init__(self, config_path):
|
||||||
|
super(Alphabet, self).__init__()
|
||||||
|
err = self.init(config_path)
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
|
def Encode(self, input):
|
||||||
|
"""Convert SWIG's UnsignedIntVec to a Python list"""
|
||||||
|
res = super(Alphabet, self).Encode(input)
|
||||||
|
return [el for el in res]
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
@ -79,15 +80,10 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
serialized = alphabet.serialize()
|
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError("Error when deserializing alphabet.")
|
|
||||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||||
probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||||
scorer)
|
scorer)
|
||||||
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
return beam_results
|
return beam_results
|
||||||
|
|
||||||
|
|
||||||
@ -126,14 +122,9 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
serialized = alphabet.serialize()
|
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError("Error when deserializing alphabet.")
|
|
||||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
|
||||||
batch_beam_results = [
|
batch_beam_results = [
|
||||||
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
for beam_results in batch_beam_results
|
for beam_results in batch_beam_results
|
||||||
]
|
]
|
||||||
return batch_beam_results
|
return batch_beam_results
|
||||||
|
@ -46,7 +46,8 @@ CTC_DECODER_FILES = [
|
|||||||
'scorer.cpp',
|
'scorer.cpp',
|
||||||
'path_trie.cpp',
|
'path_trie.cpp',
|
||||||
'decoder_utils.cpp',
|
'decoder_utils.cpp',
|
||||||
'workspace_status.cc'
|
'workspace_status.cc',
|
||||||
|
'../alphabet.cc',
|
||||||
]
|
]
|
||||||
|
|
||||||
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
|
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
|
||||||
|
@ -119,7 +119,7 @@ bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::un
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_word_to_fst(const std::vector<int> &word,
|
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||||
fst::StdVectorFst *dictionary) {
|
fst::StdVectorFst *dictionary) {
|
||||||
if (dictionary->NumStates() == 0) {
|
if (dictionary->NumStates() == 0) {
|
||||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||||
@ -144,7 +144,7 @@ bool add_word_to_dictionary(
|
|||||||
fst::StdVectorFst *dictionary) {
|
fst::StdVectorFst *dictionary) {
|
||||||
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
|
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
|
||||||
|
|
||||||
std::vector<int> int_word;
|
std::vector<unsigned int> int_word;
|
||||||
|
|
||||||
for (auto &c : characters) {
|
for (auto &c : characters) {
|
||||||
auto int_c = char_map.find(c);
|
auto int_c = char_map.find(c);
|
||||||
|
@ -86,7 +86,7 @@ std::vector<std::string> split_into_codepoints(const std::string &str);
|
|||||||
std::vector<std::string> split_into_bytes(const std::string &str);
|
std::vector<std::string> split_into_bytes(const std::string &str);
|
||||||
|
|
||||||
// Add a word in index to the dicionary of fst
|
// Add a word in index to the dicionary of fst
|
||||||
void add_word_to_fst(const std::vector<int> &word,
|
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||||
fst::StdVectorFst *dictionary);
|
fst::StdVectorFst *dictionary);
|
||||||
|
|
||||||
// Return whether a byte is a code point boundary (not a continuation byte).
|
// Return whether a byte is a code point boundary (not a continuation byte).
|
||||||
|
@ -8,8 +8,8 @@
|
|||||||
*/
|
*/
|
||||||
struct Output {
|
struct Output {
|
||||||
double confidence;
|
double confidence;
|
||||||
std::vector<int> tokens;
|
std::vector<unsigned int> tokens;
|
||||||
std::vector<int> timesteps;
|
std::vector<unsigned int> timesteps;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // OUTPUT_H_
|
#endif // OUTPUT_H_
|
||||||
|
@ -35,7 +35,7 @@ PathTrie::~PathTrie() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
|
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
|
||||||
auto child = children_.begin();
|
auto child = children_.begin();
|
||||||
for (; child != children_.end(); ++child) {
|
for (; child != children_.end(); ++child) {
|
||||||
if (child->first == new_char) {
|
if (child->first == new_char) {
|
||||||
@ -102,7 +102,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
|
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (parent != nullptr) {
|
if (parent != nullptr) {
|
||||||
@ -114,8 +114,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -124,7 +124,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
|||||||
}
|
}
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||||
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
@ -135,7 +135,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
|||||||
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||||
*first_byte = (unsigned char)character + 1;
|
*first_byte = (unsigned char)character + 1;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -146,8 +146,8 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -225,7 +225,7 @@ void PathTrie::print(const Alphabet& a) {
|
|||||||
for (PathTrie* el : chain) {
|
for (PathTrie* el : chain) {
|
||||||
printf("%X ", (unsigned char)(el->character));
|
printf("%X ", (unsigned char)(el->character));
|
||||||
if (el->character != ROOT_) {
|
if (el->character != ROOT_) {
|
||||||
tr.append(a.StringFromLabel(el->character));
|
tr.append(a.DecodeSingle(el->character));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\ntimesteps:\t ");
|
printf("\ntimesteps:\t ");
|
||||||
|
@ -21,22 +21,22 @@ public:
|
|||||||
~PathTrie();
|
~PathTrie();
|
||||||
|
|
||||||
// get new prefix after appending new char
|
// get new prefix after appending new char
|
||||||
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true);
|
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
|
||||||
|
|
||||||
// get the prefix data in correct time order from root to current node
|
// get the prefix data in correct time order from root to current node
|
||||||
void get_path_vec(std::vector<int>& output, std::vector<int>& timesteps);
|
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||||
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||||
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last word to current node
|
// get the prefix data in correct time order from beginning of last word to current node
|
||||||
PathTrie* get_prev_word(std::vector<int>& output,
|
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
@ -64,8 +64,8 @@ public:
|
|||||||
float log_prob_c;
|
float log_prob_c;
|
||||||
float score;
|
float score;
|
||||||
float approx_ctc;
|
float approx_ctc;
|
||||||
int character;
|
unsigned int character;
|
||||||
int timestep;
|
unsigned int timestep;
|
||||||
PathTrie* parent;
|
PathTrie* parent;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -73,7 +73,7 @@ private:
|
|||||||
bool exists_;
|
bool exists_;
|
||||||
bool has_dictionary_;
|
bool has_dictionary_;
|
||||||
|
|
||||||
std::vector<std::pair<int, PathTrie*>> children_;
|
std::vector<std::pair<unsigned int, PathTrie*>> children_;
|
||||||
|
|
||||||
// pointer to dictionary of FST
|
// pointer to dictionary of FST
|
||||||
std::shared_ptr<FstType> dictionary_;
|
std::shared_ptr<FstType> dictionary_;
|
||||||
|
@ -65,7 +65,7 @@ void Scorer::setup_char_map()
|
|||||||
// The initial state of FST is state 0, hence the index of chars in
|
// The initial state of FST is state 0, hence the index of chars in
|
||||||
// the FST should start from 1 to avoid the conflict with the initial
|
// the FST should start from 1 to avoid the conflict with the initial
|
||||||
// state, otherwise wrong decoding results would be given.
|
// state, otherwise wrong decoding results would be given.
|
||||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
char_map_[alphabet_.DecodeSingle(i)] = i + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,11 +314,11 @@ void Scorer::reset_params(float alpha, float beta)
|
|||||||
this->beta = beta;
|
this->beta = beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
|
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<unsigned int>& labels)
|
||||||
{
|
{
|
||||||
if (labels.empty()) return {};
|
if (labels.empty()) return {};
|
||||||
|
|
||||||
std::string s = alphabet_.LabelsToString(labels);
|
std::string s = alphabet_.Decode(labels);
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
words = split_into_codepoints(s);
|
words = split_into_codepoints(s);
|
||||||
@ -339,8 +339,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> prefix_vec;
|
std::vector<unsigned int> prefix_vec;
|
||||||
std::vector<int> prefix_steps;
|
std::vector<unsigned int> prefix_steps;
|
||||||
|
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
||||||
@ -350,14 +350,14 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
current_node = new_node->parent;
|
current_node = new_node->parent;
|
||||||
|
|
||||||
// reconstruct word
|
// reconstruct word
|
||||||
std::string word = alphabet_.LabelsToString(prefix_vec);
|
std::string word = alphabet_.Decode(prefix_vec);
|
||||||
ngram.push_back(word);
|
ngram.push_back(word);
|
||||||
}
|
}
|
||||||
std::reverse(ngram.begin(), ngram.end());
|
std::reverse(ngram.begin(), ngram.end());
|
||||||
return ngram;
|
return ngram;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::fill_dictionary(const std::vector<std::string>& vocabulary)
|
void Scorer::fill_dictionary(const std::unordered_set<std::string>& vocabulary)
|
||||||
{
|
{
|
||||||
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
// ConstFst is immutable, so we need to use a MutableFst to create the trie,
|
||||||
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
// and then we convert to a ConstFst for the decoder and for storing on disk.
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "lm/virtual_interface.hh"
|
#include "lm/virtual_interface.hh"
|
||||||
@ -72,7 +73,7 @@ public:
|
|||||||
|
|
||||||
// trransform the labels in index to the vector of words (word based lm) or
|
// trransform the labels in index to the vector of words (word based lm) or
|
||||||
// the vector of characters (character based lm)
|
// the vector of characters (character based lm)
|
||||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
std::vector<std::string> split_labels_into_scored_units(const std::vector<unsigned int> &labels);
|
||||||
|
|
||||||
void set_alphabet(const Alphabet& alphabet);
|
void set_alphabet(const Alphabet& alphabet);
|
||||||
|
|
||||||
@ -83,7 +84,7 @@ public:
|
|||||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||||
|
|
||||||
// fill dictionary FST from a vocabulary
|
// fill dictionary FST from a vocabulary
|
||||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);
|
||||||
|
|
||||||
// load language model from given path
|
// load language model from given path
|
||||||
int load_lm(const std::string &lm_path);
|
int load_lm(const std::string &lm_path);
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
%{
|
%{
|
||||||
#include "ctc_beam_search_decoder.h"
|
#include "ctc_beam_search_decoder.h"
|
||||||
#define SWIG_FILE_WITH_INIT
|
#define SWIG_FILE_WITH_INIT
|
||||||
#define SWIG_PYTHON_STRICT_BYTE_CHAR
|
|
||||||
#include "workspace_status.h"
|
#include "workspace_status.h"
|
||||||
%}
|
%}
|
||||||
|
|
||||||
@ -19,6 +18,9 @@ import_array();
|
|||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
%template(StringVector) vector<string>;
|
%template(StringVector) vector<string>;
|
||||||
|
%template(UnsignedIntVector) vector<unsigned int>;
|
||||||
|
%template(OutputVector) vector<Output>;
|
||||||
|
%template(OutputVectorVector) vector<vector<Output>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
%shared_ptr(Scorer);
|
%shared_ptr(Scorer);
|
||||||
@ -27,6 +29,7 @@ namespace std {
|
|||||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
||||||
|
%apply (unsigned int* IN_ARRAY1, int DIM1) {(const unsigned int *input, int length)};
|
||||||
|
|
||||||
%ignore Scorer::dictionary;
|
%ignore Scorer::dictionary;
|
||||||
|
|
||||||
@ -38,10 +41,6 @@ namespace std {
|
|||||||
%constant const char* __version__ = ds_version();
|
%constant const char* __version__ = ds_version();
|
||||||
%constant const char* __git_version__ = ds_git_version();
|
%constant const char* __git_version__ = ds_git_version();
|
||||||
|
|
||||||
%template(IntVector) std::vector<int>;
|
|
||||||
%template(OutputVector) std::vector<Output>;
|
|
||||||
%template(OutputVectorVector) std::vector<std::vector<Output>>;
|
|
||||||
|
|
||||||
// Import only the error code enum definitions from deepspeech.h
|
// Import only the error code enum definitions from deepspeech.h
|
||||||
// We can't just do |%ignore "";| here because it affects this file globally (even
|
// We can't just do |%ignore "";| here because it affects this file globally (even
|
||||||
// files %include'd above). That causes SWIG to lose destructor information and
|
// files %include'd above). That causes SWIG to lose destructor information and
|
||||||
|
@ -501,20 +501,3 @@ DS_Version()
|
|||||||
{
|
{
|
||||||
return strdup(ds_version());
|
return strdup(ds_version());
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
|
||||||
DS_ErrorCodeToErrorMessage(int aErrorCode)
|
|
||||||
{
|
|
||||||
#define RETURN_MESSAGE(NAME, VALUE, DESC) \
|
|
||||||
case NAME: \
|
|
||||||
return strdup(DESC);
|
|
||||||
|
|
||||||
switch(aErrorCode)
|
|
||||||
{
|
|
||||||
DS_FOR_EACH_ERROR(RETURN_MESSAGE)
|
|
||||||
default:
|
|
||||||
return strdup("Unknown error, please make sure you are using the correct native binary.");
|
|
||||||
}
|
|
||||||
|
|
||||||
#undef RETURN_MESSAGE
|
|
||||||
}
|
|
||||||
|
19
native_client/deepspeech_errors.cc
Normal file
19
native_client/deepspeech_errors.cc
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include "deepspeech.h"
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
char*
|
||||||
|
DS_ErrorCodeToErrorMessage(int aErrorCode)
|
||||||
|
{
|
||||||
|
#define RETURN_MESSAGE(NAME, VALUE, DESC) \
|
||||||
|
case NAME: \
|
||||||
|
return strdup(DESC);
|
||||||
|
|
||||||
|
switch(aErrorCode)
|
||||||
|
{
|
||||||
|
DS_FOR_EACH_ERROR(RETURN_MESSAGE)
|
||||||
|
default:
|
||||||
|
return strdup("Unknown error, please make sure you are using the correct native binary.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef RETURN_MESSAGE
|
||||||
|
}
|
146
native_client/generate_scorer_package.cpp
Normal file
146
native_client/generate_scorer_package.cpp
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <iostream>
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "boost/program_options.hpp"
|
||||||
|
|
||||||
|
#include "ctcdecode/decoder_utils.h"
|
||||||
|
#include "ctcdecode/scorer.h"
|
||||||
|
#include "alphabet.h"
|
||||||
|
#include "deepspeech.h"
|
||||||
|
|
||||||
|
namespace po = boost::program_options;
|
||||||
|
|
||||||
|
int
|
||||||
|
create_package(absl::optional<string> alphabet_path,
|
||||||
|
string lm_path,
|
||||||
|
string vocab_path,
|
||||||
|
string package_path,
|
||||||
|
absl::optional<bool> force_utf8,
|
||||||
|
float default_alpha,
|
||||||
|
float default_beta)
|
||||||
|
{
|
||||||
|
// Read vocabulary
|
||||||
|
unordered_set<string> words;
|
||||||
|
bool vocab_looks_char_based = true;
|
||||||
|
ifstream fin(vocab_path);
|
||||||
|
if (!fin) {
|
||||||
|
cerr << "Invalid vocabulary file " << vocab_path << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
string word;
|
||||||
|
while (fin >> word) {
|
||||||
|
words.insert(word);
|
||||||
|
if (get_utf8_str_len(word) > 1) {
|
||||||
|
vocab_looks_char_based = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cerr << words.size() << " unique words read from vocabulary file.\n"
|
||||||
|
<< (vocab_looks_char_based ? "Looks" : "Doesn't look")
|
||||||
|
<< " like a character based (Bytes Are All You Need) model.\n";
|
||||||
|
|
||||||
|
if (!force_utf8.has_value()) {
|
||||||
|
force_utf8 = vocab_looks_char_based;
|
||||||
|
cerr << "--force_utf8 was not specified, using value "
|
||||||
|
<< "infered from vocabulary contents: "
|
||||||
|
<< (vocab_looks_char_based ? "true" : "false") << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (force_utf8.value() && !alphabet_path.has_value()) {
|
||||||
|
cerr << "No --alphabet file specified, not using bytes output mode, can't continue.\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer scorer;
|
||||||
|
if (force_utf8.value()) {
|
||||||
|
scorer.set_alphabet(UTF8Alphabet());
|
||||||
|
} else {
|
||||||
|
Alphabet alphabet;
|
||||||
|
alphabet.init(alphabet_path->c_str());
|
||||||
|
scorer.set_alphabet(alphabet);
|
||||||
|
}
|
||||||
|
scorer.set_utf8_mode(force_utf8.value());
|
||||||
|
scorer.reset_params(default_alpha, default_beta);
|
||||||
|
int err = scorer.load_lm(lm_path);
|
||||||
|
if (err != DS_ERR_SCORER_NO_TRIE) {
|
||||||
|
cerr << "Error loading language model file: "
|
||||||
|
<< DS_ErrorCodeToErrorMessage(err) << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
scorer.fill_dictionary(words);
|
||||||
|
|
||||||
|
// Copy LM file to final package file destination
|
||||||
|
{
|
||||||
|
ifstream lm_src(lm_path, std::ios::binary);
|
||||||
|
ofstream package_dest(package_path, std::ios::binary);
|
||||||
|
package_dest << lm_src.rdbuf();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save dictionary to package file, appending instead of overwriting
|
||||||
|
if (!scorer.save_dictionary(package_path, true)) {
|
||||||
|
cerr << "Error when saving package in " << package_path << ".\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
cerr << "Package created in " << package_path << ".\n";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
main(int argc, char** argv)
|
||||||
|
{
|
||||||
|
po::options_description desc("Options");
|
||||||
|
desc.add_options()
|
||||||
|
("help", "show help message")
|
||||||
|
("alphabet", po::value<string>(), "Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.")
|
||||||
|
("lm", po::value<string>(), "Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.")
|
||||||
|
("vocab", po::value<string>(), "Path of vocabulary file. Must contain words separated by whitespace.")
|
||||||
|
("package", po::value<string>(), "Path to save scorer package.")
|
||||||
|
("default_alpha", po::value<float>(), "Default value of alpha hyperparameter (float).")
|
||||||
|
("default_beta", po::value<float>(), "Default value of beta hyperparameter (float).")
|
||||||
|
("force_utf8", po::value<bool>(), "Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary. See <https://deepspeech.readthedocs.io/en/master/Decoder.html#utf-8-mode> for further explanation.")
|
||||||
|
;
|
||||||
|
|
||||||
|
po::variables_map vm;
|
||||||
|
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||||
|
po::notify(vm);
|
||||||
|
|
||||||
|
if (vm.count("help")) {
|
||||||
|
cout << desc << "\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check required flags.
|
||||||
|
for (const string& flag : {"lm", "vocab", "package", "default_alpha", "default_beta"}) {
|
||||||
|
if (!vm.count(flag)) {
|
||||||
|
cerr << "--" << flag << " is a required flag. Pass --help for help.\n";
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional --force_utf8
|
||||||
|
absl::optional<bool> force_utf8 = absl::nullopt;
|
||||||
|
if (vm.count("force_utf8")) {
|
||||||
|
force_utf8 = vm["force_utf8"].as<bool>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse optional --alphabet
|
||||||
|
absl::optional<string> alphabet = absl::nullopt;
|
||||||
|
if (vm.count("alphabet")) {
|
||||||
|
alphabet = vm["alphabet"].as<string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
create_package(alphabet,
|
||||||
|
vm["lm"].as<string>(),
|
||||||
|
vm["vocab"].as<string>(),
|
||||||
|
vm["package"].as<string>(),
|
||||||
|
force_utf8,
|
||||||
|
vm["default_alpha"].as<float>(),
|
||||||
|
vm["default_beta"].as<float>());
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
@ -33,7 +33,7 @@ char*
|
|||||||
ModelState::decode(const DecoderState& state) const
|
ModelState::decode(const DecoderState& state) const
|
||||||
{
|
{
|
||||||
vector<Output> out = state.decode();
|
vector<Output> out = state.decode();
|
||||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet_.Decode(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
@ -50,7 +50,7 @@ ModelState::decode_metadata(const DecoderState& state,
|
|||||||
|
|
||||||
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
||||||
TokenMetadata token {
|
TokenMetadata token {
|
||||||
strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text
|
strdup(alphabet_.DecodeSingle(out[i].tokens[j]).c_str()), // text
|
||||||
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
|
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
|
||||||
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
|
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
|
||||||
};
|
};
|
||||||
|
@ -206,7 +206,7 @@ TFLiteModelState::init(const char* model_path)
|
|||||||
beam_width_ = (unsigned int)(*beam_width);
|
beam_width_ = (unsigned int)(*beam_width);
|
||||||
|
|
||||||
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
||||||
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
err = alphabet_.Deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ TFModelState::init(const char* model_path)
|
|||||||
beam_width_ = (unsigned int)(beam_width);
|
beam_width_ = (unsigned int)(beam_width);
|
||||||
|
|
||||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
||||||
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ build:
|
|||||||
nc_asset_name: 'native_client.tar.xz'
|
nc_asset_name: 'native_client.tar.xz'
|
||||||
args:
|
args:
|
||||||
tests_cmdline: ''
|
tests_cmdline: ''
|
||||||
tensorflow_git_desc: 'TensorFlow: v2.2.0-14-g7ead558'
|
tensorflow_git_desc: 'TensorFlow: v2.2.0-15-g518c1d0'
|
||||||
test_model_task: ''
|
test_model_task: ''
|
||||||
homebrew:
|
homebrew:
|
||||||
url: ''
|
url: ''
|
||||||
|
@ -142,32 +142,32 @@ system:
|
|||||||
namespace: "project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118"
|
namespace: "project.deepspeech.swig.win.amd64.b5fea54d39832d1d132d7dd921b69c0c2c9d5118"
|
||||||
tensorflow:
|
tensorflow:
|
||||||
linux_amd64_cpu:
|
linux_amd64_cpu:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.cpu/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.cpu/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.cpu"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.cpu"
|
||||||
linux_amd64_cuda:
|
linux_amd64_cuda:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.cuda/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.cuda/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.cuda"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.cuda"
|
||||||
linux_armv7:
|
linux_armv7:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.arm/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.arm/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.arm"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.arm"
|
||||||
linux_arm64:
|
linux_arm64:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.arm64/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.arm64/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.arm64"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.arm64"
|
||||||
darwin_amd64:
|
darwin_amd64:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.osx/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.osx/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.osx"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.osx"
|
||||||
android_arm64:
|
android_arm64:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.android-arm64/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.android-arm64/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.android-arm64"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.android-arm64"
|
||||||
android_armv7:
|
android_armv7:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.android-armv7/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.android-armv7/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.android-armv7"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.android-armv7"
|
||||||
win_amd64_cpu:
|
win_amd64_cpu:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.win/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.win/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.win"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.win"
|
||||||
win_amd64_cuda:
|
win_amd64_cuda:
|
||||||
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.win-cuda/artifacts/public/home.tar.xz"
|
url: "https://community-tc.services.mozilla.com/api/index/v1/task/project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.win-cuda/artifacts/public/home.tar.xz"
|
||||||
namespace: "project.deepspeech.tensorflow.pip.r2.2.7ead55807a2ded84c107720ebca61e6285e2c239.1.win-cuda"
|
namespace: "project.deepspeech.tensorflow.pip.r2.2.518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d.0.win-cuda"
|
||||||
username: 'build-user'
|
username: 'build-user'
|
||||||
homedir:
|
homedir:
|
||||||
linux: '/home/build-user'
|
linux: '/home/build-user'
|
||||||
|
@ -10,6 +10,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
if [ "${arm_flavor}" = "armeabi-v7a" ]; then
|
if [ "${arm_flavor}" = "armeabi-v7a" ]; then
|
||||||
|
@ -8,6 +8,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
BAZEL_BUILD_FLAGS="${BAZEL_ARM64_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
BAZEL_BUILD_FLAGS="${BAZEL_ARM64_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||||
|
@ -8,6 +8,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
BAZEL_ENV_FLAGS="TF_NEED_CUDA=1 ${TF_CUDA_FLAGS}"
|
BAZEL_ENV_FLAGS="TF_NEED_CUDA=1 ${TF_CUDA_FLAGS}"
|
||||||
|
@ -10,6 +10,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
if [ "${runtime}" = "tflite" ]; then
|
if [ "${runtime}" = "tflite" ]; then
|
||||||
|
@ -8,6 +8,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
BAZEL_BUILD_FLAGS="${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
BAZEL_BUILD_FLAGS="${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||||
|
@ -24,6 +24,7 @@ package_native_client()
|
|||||||
${TAR} -cf - \
|
${TAR} -cf - \
|
||||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so \
|
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so \
|
||||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so.if.lib \
|
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so.if.lib \
|
||||||
|
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
||||||
-C ${deepspeech_dir}/ LICENSE \
|
-C ${deepspeech_dir}/ LICENSE \
|
||||||
-C ${deepspeech_dir}/native_client/ deepspeech${PLATFORM_EXE_SUFFIX} \
|
-C ${deepspeech_dir}/native_client/ deepspeech${PLATFORM_EXE_SUFFIX} \
|
||||||
-C ${deepspeech_dir}/native_client/ deepspeech.h \
|
-C ${deepspeech_dir}/native_client/ deepspeech.h \
|
||||||
@ -34,6 +35,7 @@ package_native_client()
|
|||||||
package_native_client_ndk()
|
package_native_client_ndk()
|
||||||
{
|
{
|
||||||
deepspeech_dir=${DS_DSDIR}
|
deepspeech_dir=${DS_DSDIR}
|
||||||
|
tensorflow_dir=${DS_TFDIR}
|
||||||
artifacts_dir=${TASKCLUSTER_ARTIFACTS}
|
artifacts_dir=${TASKCLUSTER_ARTIFACTS}
|
||||||
artifact_name=$1
|
artifact_name=$1
|
||||||
arch_abi=$2
|
arch_abi=$2
|
||||||
@ -56,6 +58,7 @@ package_native_client_ndk()
|
|||||||
tar -cf - \
|
tar -cf - \
|
||||||
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ deepspeech \
|
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ deepspeech \
|
||||||
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ libdeepspeech.so \
|
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ libdeepspeech.so \
|
||||||
|
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_scorer_package \
|
||||||
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ libc++_shared.so \
|
-C ${deepspeech_dir}/native_client/libs/${arch_abi}/ libc++_shared.so \
|
||||||
-C ${deepspeech_dir}/native_client/ deepspeech.h \
|
-C ${deepspeech_dir}/native_client/ deepspeech.h \
|
||||||
-C ${deepspeech_dir}/ LICENSE \
|
-C ${deepspeech_dir}/ LICENSE \
|
||||||
|
@ -10,6 +10,7 @@ source $(dirname "$0")/tf_tc-vars.sh
|
|||||||
|
|
||||||
BAZEL_TARGETS="
|
BAZEL_TARGETS="
|
||||||
//native_client:libdeepspeech.so
|
//native_client:libdeepspeech.so
|
||||||
|
//native_client:generate_scorer_package
|
||||||
"
|
"
|
||||||
|
|
||||||
if [ "${package_option}" = "--cuda" ]; then
|
if [ "${package_option}" = "--cuda" ]; then
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 7ead55807a2ded84c107720ebca61e6285e2c239
|
Subproject commit 518c1d04bf55d362bb11e973b8f5d0aa3e5bf44d
|
@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from deepspeech_training.util.text import Alphabet
|
from ds_ctcdecoder import Alphabet
|
||||||
|
|
||||||
class TestAlphabetParsing(unittest.TestCase):
|
class TestAlphabetParsing(unittest.TestCase):
|
||||||
|
|
||||||
@ -11,12 +11,12 @@ class TestAlphabetParsing(unittest.TestCase):
|
|||||||
label_id = -1
|
label_id = -1
|
||||||
for expected_label, expected_label_id in expected:
|
for expected_label, expected_label_id in expected:
|
||||||
try:
|
try:
|
||||||
label_id = alphabet.encode(expected_label)
|
label_id = alphabet.Encode(expected_label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(label_id, [expected_label_id])
|
self.assertEqual(label_id, [expected_label_id])
|
||||||
try:
|
try:
|
||||||
label = alphabet.decode([expected_label_id])
|
label = alphabet.Decode([expected_label_id])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(label, expected_label)
|
self.assertEqual(label, expected_label)
|
||||||
|
@ -40,7 +40,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|||||||
for i, index in enumerate(indices):
|
for i, index in enumerate(indices):
|
||||||
results[index[0]].append(values[i])
|
results[index[0]].append(values[i])
|
||||||
# List of strings
|
# List of strings
|
||||||
return [alphabet.decode(res) for res in results]
|
return [alphabet.Decode(res) for res in results]
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model):
|
def evaluate(test_csvs, create_model):
|
||||||
|
@ -771,7 +771,7 @@ def export():
|
|||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
@ -6,14 +6,15 @@ import tensorflow.compat.v1 as tfv1
|
|||||||
|
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
|
from ds_ctcdecoder import Alphabet, UTF8Alphabet
|
||||||
|
|
||||||
from .flags import FLAGS
|
from .flags import FLAGS
|
||||||
from .gpu import get_available_gpus
|
from .gpu import get_available_gpus
|
||||||
from .logging import log_error, log_warn
|
from .logging import log_error, log_warn
|
||||||
from .text import Alphabet, UTF8Alphabet
|
|
||||||
from .helpers import parse_file_size
|
from .helpers import parse_file_size
|
||||||
from .augmentations import parse_augmentations
|
from .augmentations import parse_augmentations
|
||||||
|
|
||||||
|
|
||||||
class ConfigSingleton:
|
class ConfigSingleton:
|
||||||
_config = None
|
_config = None
|
||||||
|
|
||||||
@ -115,7 +116,7 @@ def initialize_globals():
|
|||||||
c.n_hidden_3 = c.n_cell_dim
|
c.n_hidden_3 = c.n_cell_dim
|
||||||
|
|
||||||
# Units in the sixth layer = number of characters in the target language plus one
|
# Units in the sixth layer = number of characters in the target language plus one
|
||||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||||
|
|
||||||
# Size of audio window in samples
|
# Size of audio window in samples
|
||||||
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
||||||
|
@ -151,7 +151,7 @@ def create_flags():
|
|||||||
|
|
||||||
f.DEFINE_boolean('utf8', False, 'enable UTF-8 mode. When this is used the model outputs UTF-8 sequences directly rather than using an alphabet mapping.')
|
f.DEFINE_boolean('utf8', False, 'enable UTF-8 mode. When this is used the model outputs UTF-8 sequences directly rather than using an alphabet mapping.')
|
||||||
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
f.DEFINE_string('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||||
f.DEFINE_string('scorer_path', 'data/lm/kenlm.scorer', 'path to the external scorer file created with data/lm/generate_package.py')
|
f.DEFINE_string('scorer_path', 'data/lm/kenlm.scorer', 'path to the external scorer file.')
|
||||||
f.DEFINE_alias('scorer', 'scorer_path')
|
f.DEFINE_alias('scorer', 'scorer_path')
|
||||||
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
f.DEFINE_integer('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
|
||||||
f.DEFINE_float('lm_alpha', 0.931289039105002, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
f.DEFINE_float('lm_alpha', 0.931289039105002, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
|
||||||
|
@ -52,12 +52,10 @@ def check_ctcdecoder_version():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
decoder_version_s = decoder_version.decode()
|
rv = semver.compare(ds_version_s, decoder_version)
|
||||||
|
|
||||||
rv = semver.compare(ds_version_s, decoder_version_s)
|
|
||||||
if rv != 0:
|
if rv != 0:
|
||||||
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
||||||
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
|
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
@ -3,121 +3,6 @@ from __future__ import absolute_import, division, print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from six.moves import range
|
|
||||||
|
|
||||||
class Alphabet(object):
|
|
||||||
def __init__(self, config_file):
|
|
||||||
self._config_file = config_file
|
|
||||||
self._label_to_str = {}
|
|
||||||
self._str_to_label = {}
|
|
||||||
self._size = 0
|
|
||||||
if config_file:
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as fin:
|
|
||||||
for line in fin:
|
|
||||||
if line[0:2] == '\\#':
|
|
||||||
line = '#\n'
|
|
||||||
elif line[0] == '#':
|
|
||||||
continue
|
|
||||||
self._label_to_str[self._size] = line[:-1] # remove the line ending
|
|
||||||
self._str_to_label[line[:-1]] = self._size
|
|
||||||
self._size += 1
|
|
||||||
|
|
||||||
def _string_from_label(self, label):
|
|
||||||
return self._label_to_str[label]
|
|
||||||
|
|
||||||
def _label_from_string(self, string):
|
|
||||||
try:
|
|
||||||
return self._str_to_label[string]
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError(
|
|
||||||
'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in \'{}\'! Use ' \
|
|
||||||
'util/check_characters.py to see what characters are in your [train,dev,test].csv transcripts, and ' \
|
|
||||||
'then add all these to \'{}\'.'.format(string, self._config_file, self._config_file)
|
|
||||||
).with_traceback(e.__traceback__)
|
|
||||||
|
|
||||||
def has_char(self, char):
|
|
||||||
return char in self._str_to_label
|
|
||||||
|
|
||||||
def encode(self, string):
|
|
||||||
res = []
|
|
||||||
for char in string:
|
|
||||||
res.append(self._label_from_string(char))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def decode(self, labels):
|
|
||||||
res = ''
|
|
||||||
for label in labels:
|
|
||||||
res += self._string_from_label(label)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def serialize(self):
|
|
||||||
# Serialization format is a sequence of (key, value) pairs, where key is
|
|
||||||
# a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
|
||||||
# encoded bytes with the label.
|
|
||||||
res = bytearray()
|
|
||||||
|
|
||||||
# We start by writing the number of pairs in the buffer as uint16_t.
|
|
||||||
res += struct.pack('<H', self._size)
|
|
||||||
for key, value in self._label_to_str.items():
|
|
||||||
value = value.encode('utf-8')
|
|
||||||
# struct.pack only takes fixed length strings/buffers, so we have to
|
|
||||||
# construct the correct format string with the length of the encoded
|
|
||||||
# label.
|
|
||||||
res += struct.pack('<HH{}s'.format(len(value)), key, len(value), value)
|
|
||||||
return bytes(res)
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return self._size
|
|
||||||
|
|
||||||
def config_file(self):
|
|
||||||
return self._config_file
|
|
||||||
|
|
||||||
|
|
||||||
class UTF8Alphabet(object):
|
|
||||||
@staticmethod
|
|
||||||
def _string_from_label(_):
|
|
||||||
assert False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _label_from_string(_):
|
|
||||||
assert False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode(string):
|
|
||||||
# 0 never happens in the data, so we can shift values by one, use 255 for
|
|
||||||
# the CTC blank, and keep the alphabet size = 256
|
|
||||||
return np.frombuffer(string.encode('utf-8'), np.uint8).astype(np.int32) - 1
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode(labels):
|
|
||||||
# And here we need to shift back up
|
|
||||||
return bytes(np.asarray(labels, np.uint8) + 1).decode('utf-8', errors='replace')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def size():
|
|
||||||
return 255
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def serialize():
|
|
||||||
res = bytearray()
|
|
||||||
res += struct.pack('<h', 255)
|
|
||||||
for i in range(255):
|
|
||||||
# Note that we also shift back up in the mapping constructed here
|
|
||||||
# so that the native client sees the correct byte values when decoding.
|
|
||||||
res += struct.pack('<hh1s', i, 1, bytes([i+1]))
|
|
||||||
return bytes(res)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def deserialize(buf):
|
|
||||||
size = struct.unpack('<I', buf)[0]
|
|
||||||
assert size == 255
|
|
||||||
return UTF8Alphabet()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def config_file():
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_char_array(transcript, alphabet, context=''):
|
def text_to_char_array(transcript, alphabet, context=''):
|
||||||
r"""
|
r"""
|
||||||
Given a transcript string, map characters to
|
Given a transcript string, map characters to
|
||||||
@ -125,7 +10,7 @@ def text_to_char_array(transcript, alphabet, context=''):
|
|||||||
Use a string in `context` for adding text to raised exceptions.
|
Use a string in `context` for adding text to raised exceptions.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
transcript = alphabet.encode(transcript)
|
transcript = alphabet.Encode(transcript)
|
||||||
if len(transcript) == 0:
|
if len(transcript) == 0:
|
||||||
raise ValueError('While processing {}: Found an empty transcript! '
|
raise ValueError('While processing {}: Found an empty transcript! '
|
||||||
'You must include a transcript for all training data.'
|
'You must include a transcript for all training data.'
|
||||||
|
Loading…
Reference in New Issue
Block a user