Merge branch 'decoder-api-changes' (PR #2681)
This commit is contained in:
commit
5366f90375
4
.gitattributes
vendored
4
.gitattributes
vendored
@ -1,3 +1 @@
|
||||
*.binary filter=lfs diff=lfs merge=lfs -crlf
|
||||
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf
|
||||
data/lm/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
||||
data/lm/kenlm.scorer filter=lfs diff=lfs merge=lfs -text
|
||||
|
@ -7,7 +7,7 @@ extension-pkg-whitelist=
|
||||
|
||||
# Add files or directories to the blacklist. They should be base names, not
|
||||
# paths.
|
||||
ignore=examples
|
||||
ignore=native_client/kenlm
|
||||
|
||||
# Add files or directories matching the regex patterns to the blacklist. The
|
||||
# regex matches against base names, not paths.
|
||||
|
@ -882,8 +882,7 @@ def package_zip():
|
||||
}
|
||||
}, f)
|
||||
|
||||
shutil.copy(FLAGS.lm_binary_path, export_dir)
|
||||
shutil.copy(FLAGS.lm_trie_path, export_dir)
|
||||
shutil.copy(FLAGS.scorer_path, export_dir)
|
||||
|
||||
archive = shutil.make_archive(zip_filename, 'zip', export_dir)
|
||||
log_info('Exported packaged model {}'.format(archive))
|
||||
@ -926,10 +925,9 @@ def do_single_file_inference(input_file_path):
|
||||
|
||||
logits = np.squeeze(logits)
|
||||
|
||||
if FLAGS.lm_binary_path:
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
Config.alphabet)
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,
|
||||
|
@ -172,7 +172,7 @@ RUN ./configure
|
||||
|
||||
|
||||
# Build DeepSpeech
|
||||
RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic --config=cuda -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-mtune=generic --copt=-march=x86-64 --copt=-msse --copt=-msse2 --copt=-msse3 --copt=-msse4.1 --copt=-msse4.2 --copt=-mavx --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_trie --verbose_failures --action_env=LD_LIBRARY_PATH=${LD_LIBRARY_PATH}
|
||||
RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic --config=cuda -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-mtune=generic --copt=-march=x86-64 --copt=-msse --copt=-msse2 --copt=-msse3 --copt=-msse4.1 --copt=-msse4.2 --copt=-mavx --copt=-fvisibility=hidden //native_client:libdeepspeech.so --verbose_failures --action_env=LD_LIBRARY_PATH=${LD_LIBRARY_PATH}
|
||||
|
||||
###
|
||||
### Using TensorFlow upstream should work
|
||||
@ -187,8 +187,7 @@ RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_s
|
||||
# RUN pip3 install /tmp/tensorflow_pkg/*.whl
|
||||
|
||||
# Copy built libs to /DeepSpeech/native_client
|
||||
RUN cp /tensorflow/bazel-bin/native_client/generate_trie /DeepSpeech/native_client/ \
|
||||
&& cp /tensorflow/bazel-bin/native_client/libdeepspeech.so /DeepSpeech/native_client/
|
||||
RUN cp /tensorflow/bazel-bin/native_client/libdeepspeech.so /DeepSpeech/native_client/
|
||||
|
||||
# Install TensorFlow
|
||||
WORKDIR /DeepSpeech/
|
||||
|
@ -36,7 +36,7 @@ To install and use deepspeech all you have to do is:
|
||||
tar xvf audio-0.6.1.tar.gz
|
||||
|
||||
# Transcribe an audio file
|
||||
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --lm deepspeech-0.6.1-models/lm.binary --trie deepspeech-0.6.1-models/trie --audio audio/2830-3980-0043.wav
|
||||
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --scorer deepspeech-0.6.1-models/kenlm.scorer --audio audio/2830-3980-0043.wav
|
||||
|
||||
A pre-trained English model is available for use and can be downloaded using `the instructions below <doc/USING.rst#using-a-pre-trained-model>`_. A package with some example audio files is available for download in our `release notes <https://github.com/mozilla/DeepSpeech/releases/latest>`_.
|
||||
|
||||
@ -52,7 +52,7 @@ Quicker inference can be performed using a supported NVIDIA GPU on Linux. See th
|
||||
pip3 install deepspeech-gpu
|
||||
|
||||
# Transcribe an audio file.
|
||||
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --lm deepspeech-0.6.1-models/lm.binary --trie deepspeech-0.6.1-models/trie --audio audio/2830-3980-0043.wav
|
||||
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --scorer deepspeech-0.6.1-models/kenlm.scorer --audio audio/2830-3980-0043.wav
|
||||
|
||||
Please ensure you have the required `CUDA dependencies <doc/USING.rst#cuda-dependency>`_.
|
||||
|
||||
|
@ -21,8 +21,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--n_hidden 100 --epochs 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
|
||||
|
||||
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
|
||||
echo "Did not resume training from checkpoint"
|
||||
|
@ -25,6 +25,5 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--n_hidden 100 --epochs $epoch_count \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate}
|
||||
|
@ -21,12 +21,10 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
|
||||
--n_hidden 100 --epochs 1 \
|
||||
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
|
||||
--learning_rate 0.001 --dropout_rate 0.05 \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie'
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer'
|
||||
|
||||
python -u DeepSpeech.py \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--one_shot_infer 'data/smoke_test/LDC93S1.wav'
|
||||
|
@ -20,8 +20,7 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_tflite
|
||||
|
||||
@ -31,8 +30,7 @@ python -u DeepSpeech.py --noshow_progressbar \
|
||||
--n_hidden 100 \
|
||||
--checkpoint_dir '/tmp/ckpt' \
|
||||
--export_dir '/tmp/train_tflite/en-us' \
|
||||
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \
|
||||
--lm_trie_path 'data/smoke_test/vocab.trie' \
|
||||
--scorer_path 'data/smoke_test/pruned_lm.scorer' \
|
||||
--audio_sample_rate ${audio_sample_rate} \
|
||||
--export_language 'Fake English (fk-FK)' \
|
||||
--export_zip
|
||||
|
@ -5,9 +5,7 @@ This directory contains language-specific data files. Most importantly, you will
|
||||
|
||||
1. A list of unique characters for the target language (e.g. English) in `data/alphabet.txt`
|
||||
|
||||
2. A binary n-gram language model compiled by `kenlm` in `data/lm/lm.binary`
|
||||
|
||||
3. A trie model compiled by `generate_trie <https://github.com/mozilla/DeepSpeech#using-the-command-line-client>`_ in `data/lm/trie`
|
||||
2. A scorer package (`data/lm/kenlm.scorer`) generated with `data/lm/generate_package.py`. The scorer package includes a binary n-gram language model generated with `data/lm/generate_lm.py`.
|
||||
|
||||
For more information on how to build these resources from scratch, see `data/lm/README.md`
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
|
||||
lm.binary was generated from the LibriSpeech normalized LM training text, available `here <http://www.openslr.org/11>`_\ , using the `generate_lm.py` script (will generate lm.binary in the folder it is run from). KenLM's built binaries must be in your PATH (lmplz, build_binary, filter).
|
||||
The LM binary was generated from the LibriSpeech normalized LM training text, available `here <http://www.openslr.org/11>`_\ , using the `generate_lm.py` script (will generate `lm.binary` and `librispeech-vocab-500k.txt` in the folder it is run from). `KenLM <https://github.com/kpu/kenlm>`_'s built binaries must be in your PATH (lmplz, build_binary, filter).
|
||||
|
||||
The trie was then generated from the vocabulary of the language model:
|
||||
The scorer package was then built using the `generate_package.py` script:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./generate_trie ../data/alphabet.txt lm.binary trie
|
||||
python generate_lm.py # this will create lm.binary and librispeech-vocab-500k.txt
|
||||
python generate_package.py --alphabet ../alphabet.txt --lm lm.binary --vocab librispeech-vocab-500k.txt --default_alpha 0.75 --default_beta 1.85 --package kenlm.scorer
|
||||
|
@ -39,10 +39,13 @@ def main():
|
||||
'--prune', '0', '0', '1'
|
||||
])
|
||||
|
||||
# Filter LM using vocabulary of top 500k words
|
||||
filtered_path = os.path.join(tmp, 'lm_filtered.arpa')
|
||||
vocab_str = '\n'.join(word for word, count in counter.most_common(500000))
|
||||
with open('librispeech-vocab-500k.txt', 'w') as fout:
|
||||
fout.write(vocab_str)
|
||||
|
||||
# Filter LM using vocabulary of top 500k words
|
||||
print('Filtering ARPA file...')
|
||||
filtered_path = os.path.join(tmp, 'lm_filtered.arpa')
|
||||
subprocess.run(['filter', 'single', 'model:{}'.format(lm_path), filtered_path], input=vocab_str.encode('utf-8'), check=True)
|
||||
|
||||
# Quantize and produce trie binary.
|
||||
@ -50,6 +53,7 @@ def main():
|
||||
subprocess.check_call([
|
||||
'build_binary', '-a', '255',
|
||||
'-q', '8',
|
||||
'-v',
|
||||
'trie',
|
||||
filtered_path,
|
||||
'lm.binary'
|
||||
|
154
data/lm/generate_package.py
Normal file
154
data/lm/generate_package.py
Normal file
@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# Make sure we can import stuff from util/
|
||||
# This script needs to be run from the root of the DeepSpeech repository
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
from util.text import Alphabet, UTF8Alphabet
|
||||
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
|
||||
|
||||
|
||||
def create_bundle(
|
||||
alphabet_path,
|
||||
lm_path,
|
||||
vocab_path,
|
||||
package_path,
|
||||
force_utf8,
|
||||
default_alpha,
|
||||
default_beta,
|
||||
):
|
||||
words = set()
|
||||
vocab_looks_char_based = True
|
||||
with open(vocab_path) as fin:
|
||||
for line in fin:
|
||||
for word in line.split():
|
||||
words.add(word.encode("utf-8"))
|
||||
if len(word) > 1:
|
||||
vocab_looks_char_based = False
|
||||
print("{} unique words read from vocabulary file.".format(len(words)))
|
||||
print(
|
||||
"{} like a character based model.".format(
|
||||
"Looks" if vocab_looks_char_based else "Doesn't look"
|
||||
)
|
||||
)
|
||||
|
||||
if force_utf8 != None: # pylint: disable=singleton-comparison
|
||||
use_utf8 = force_utf8.value
|
||||
print("Forcing UTF-8 mode = {}".format(use_utf8))
|
||||
else:
|
||||
use_utf8 = vocab_looks_char_based
|
||||
|
||||
if use_utf8:
|
||||
serialized_alphabet = UTF8Alphabet().serialize()
|
||||
else:
|
||||
serialized_alphabet = Alphabet(alphabet_path).serialize()
|
||||
|
||||
alphabet = NativeAlphabet()
|
||||
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
|
||||
if err != 0:
|
||||
print("Error loading alphabet: {}".format(err))
|
||||
sys.exit(1)
|
||||
|
||||
scorer = Scorer()
|
||||
scorer.set_alphabet(alphabet)
|
||||
scorer.set_utf8_mode(use_utf8)
|
||||
scorer.reset_params(default_alpha, default_beta)
|
||||
scorer.load_lm(lm_path)
|
||||
scorer.fill_dictionary(list(words))
|
||||
shutil.copy(lm_path, package_path)
|
||||
scorer.save_dictionary(package_path, True) # append, not overwrite
|
||||
print("Package created in {}".format(package_path))
|
||||
|
||||
|
||||
class Tristate(object):
|
||||
def __init__(self, value=None):
|
||||
if any(value is v for v in (True, False, None)):
|
||||
self.value = value
|
||||
else:
|
||||
raise ValueError("Tristate value must be True, False, or None")
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.value is other.value
|
||||
if isinstance(other, Tristate)
|
||||
else self.value is other
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __bool__(self):
|
||||
raise TypeError("Tristate object may not be used as a Boolean")
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return "Tristate(%s)" % self.value
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate an external scorer package for DeepSpeech."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alphabet",
|
||||
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
required=True,
|
||||
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab",
|
||||
required=True,
|
||||
help="Path of vocabulary file. Must contain words separated by whitespace.",
|
||||
)
|
||||
parser.add_argument("--package", required=True, help="Path to save scorer package.")
|
||||
parser.add_argument(
|
||||
"--default_alpha",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Default value of alpha hyperparameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default_beta",
|
||||
type=float,
|
||||
required=True,
|
||||
help="Default value of beta hyperparameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_utf8",
|
||||
default="",
|
||||
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
|
||||
force_utf8 = Tristate(True)
|
||||
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
|
||||
force_utf8 = Tristate(False)
|
||||
else:
|
||||
force_utf8 = Tristate(None)
|
||||
|
||||
create_bundle(
|
||||
args.alphabet,
|
||||
args.lm,
|
||||
args.vocab,
|
||||
args.package,
|
||||
force_utf8,
|
||||
args.default_alpha,
|
||||
args.default_beta,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
3
data/lm/kenlm.scorer
Normal file
3
data/lm/kenlm.scorer
Normal file
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3ba04978fca285c34c99bf115ee61549937e422ac91def80122a767e114c035e
|
||||
size 953436352
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a24953ce3f013bbf5f4a1c9f5a0e5482bc56eaa81638276de522f39e62ff3a56
|
||||
size 945699324
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0281e5e784ffccb4aeae5e7d64099058a0c22e42dbb7aa2d3ef2fbbff53db3ab
|
||||
size 12200736
|
Binary file not shown.
3540
data/smoke_test/vocab.pruned.txt
Normal file
3540
data/smoke_test/vocab.pruned.txt
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@ -7,7 +7,13 @@ C
|
||||
.. doxygenfunction:: DS_FreeModel
|
||||
:project: deepspeech-c
|
||||
|
||||
.. doxygenfunction:: DS_EnableDecoderWithLM
|
||||
.. doxygenfunction:: DS_EnableExternalScorer
|
||||
:project: deepspeech-c
|
||||
|
||||
.. doxygenfunction:: DS_DisableExternalScorer
|
||||
:project: deepspeech-c
|
||||
|
||||
.. doxygenfunction:: DS_SetScorerAlphaBeta
|
||||
:project: deepspeech-c
|
||||
|
||||
.. doxygenfunction:: DS_GetModelSampleRate
|
||||
|
@ -7,7 +7,7 @@ Creating a model instance and loading model
|
||||
.. literalinclude:: ../native_client/client.cc
|
||||
:language: c
|
||||
:linenos:
|
||||
:lines: 370-388
|
||||
:lines: 370-390
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
|
@ -7,6 +7,12 @@ Model
|
||||
.. js:autoclass:: Model
|
||||
:members:
|
||||
|
||||
Stream
|
||||
------
|
||||
|
||||
.. js:autoclass:: Stream
|
||||
:members:
|
||||
|
||||
Module exported methods
|
||||
-----------------------
|
||||
|
||||
|
@ -7,7 +7,7 @@ Creating a model instance and loading model
|
||||
.. literalinclude:: ../native_client/javascript/client.js
|
||||
:language: javascript
|
||||
:linenos:
|
||||
:lines: 57-66
|
||||
:lines: 54-72
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
@ -15,7 +15,7 @@ Performing inference
|
||||
.. literalinclude:: ../native_client/javascript/client.js
|
||||
:language: javascript
|
||||
:linenos:
|
||||
:lines: 115-117
|
||||
:lines: 117-121
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
@ -9,6 +9,12 @@ Model
|
||||
.. autoclass:: Model
|
||||
:members:
|
||||
|
||||
Stream
|
||||
------
|
||||
|
||||
.. autoclass:: Stream
|
||||
:members:
|
||||
|
||||
Metadata
|
||||
--------
|
||||
|
||||
|
@ -7,7 +7,7 @@ Creating a model instance and loading model
|
||||
.. literalinclude:: ../native_client/python/client.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 69, 78
|
||||
:lines: 111, 120
|
||||
|
||||
Performing inference
|
||||
--------------------
|
||||
@ -15,7 +15,7 @@ Performing inference
|
||||
.. literalinclude:: ../native_client/python/client.py
|
||||
:language: python
|
||||
:linenos:
|
||||
:lines: 95-98
|
||||
:lines: 140-145
|
||||
|
||||
Full source code
|
||||
----------------
|
||||
|
@ -106,9 +106,9 @@ Note: the following command assumes you `downloaded the pre-trained model <#gett
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
deepspeech --model models/output_graph.pbmm --lm models/lm.binary --trie models/trie --audio my_audio_file.wav
|
||||
deepspeech --model models/output_graph.pbmm --scorer models/kenlm.scorer --audio my_audio_file.wav
|
||||
|
||||
The arguments ``--lm`` and ``--trie`` are optional, and represent a language model.
|
||||
The ``--scorer`` argument is optional, and represents an external language model to be used when transcribing the audio.
|
||||
|
||||
See :github:`client.py <native_client/python/client.py>` for an example of how to use the package programatically.
|
||||
|
||||
@ -162,7 +162,7 @@ Note: the following command assumes you `downloaded the pre-trained model <#gett
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
./deepspeech --model models/output_graph.pbmm --lm models/lm.binary --trie models/trie --audio audio_input.wav
|
||||
./deepspeech --model models/output_graph.pbmm --scorer models/kenlm.scorer --audio audio_input.wav
|
||||
|
||||
See the help output with ``./deepspeech -h`` and the :github:`native client README <native_client/README.rst>` for more details.
|
||||
|
||||
|
@ -42,10 +42,9 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model, try_loading):
|
||||
if FLAGS.lm_binary_path:
|
||||
if FLAGS.scorer_path:
|
||||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
|
||||
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
|
||||
Config.alphabet)
|
||||
FLAGS.scorer_path, Config.alphabet)
|
||||
else:
|
||||
scorer = None
|
||||
|
||||
|
@ -27,17 +27,18 @@ This module should be self-contained:
|
||||
- pip install native_client/python/dist/deepspeech*.whl
|
||||
- pip install -r requirements_eval_tflite.txt
|
||||
|
||||
Then run with a TF Lite model, LM/trie and a CSV test file
|
||||
Then run with a TF Lite model, a scorer and a CSV test file
|
||||
'''
|
||||
|
||||
BEAM_WIDTH = 500
|
||||
LM_ALPHA = 0.75
|
||||
LM_BETA = 1.85
|
||||
|
||||
def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask):
|
||||
def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
|
||||
ds = Model(model, BEAM_WIDTH)
|
||||
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA)
|
||||
ds.enableExternalScorer(scorer)
|
||||
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -64,7 +65,7 @@ def main(args, _):
|
||||
|
||||
processes = []
|
||||
for i in range(args.proc):
|
||||
worker_process = Process(target=tflite_worker, args=(args.model, args.lm, args.trie, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i))
|
||||
worker_process = Process(target=tflite_worker, args=(args.model, args.scorer, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i))
|
||||
worker_process.start() # Launch reader() as a separate python process
|
||||
processes.append(worker_process)
|
||||
|
||||
@ -113,10 +114,8 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', required=True,
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', required=True,
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--scorer', required=True,
|
||||
help='Path to the external scorer file')
|
||||
parser.add_argument('--csv', required=True,
|
||||
help='Path to the CSV source file')
|
||||
parser.add_argument('--proc', required=False, default=cpu_count(), type=int,
|
||||
|
@ -27,20 +27,6 @@ genrule(
|
||||
tools = [":gen_workspace_status.sh"],
|
||||
)
|
||||
|
||||
KENLM_SOURCES = glob(
|
||||
[
|
||||
"kenlm/lm/*.cc",
|
||||
"kenlm/util/*.cc",
|
||||
"kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/util/*.hh",
|
||||
"kenlm/util/double-conversion/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"kenlm/*/*test.cc",
|
||||
"kenlm/*/*main.cc",
|
||||
],
|
||||
)
|
||||
|
||||
OPENFST_SOURCES_PLATFORM = select({
|
||||
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
||||
@ -60,6 +46,27 @@ LINUX_LINKOPTS = [
|
||||
"-Wl,-export-dynamic",
|
||||
]
|
||||
|
||||
cc_library(
|
||||
name = "kenlm",
|
||||
srcs = glob([
|
||||
"kenlm/lm/*.cc",
|
||||
"kenlm/util/*.cc",
|
||||
"kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/util/double-conversion/*.h",
|
||||
],
|
||||
exclude = [
|
||||
"kenlm/*/*test.cc",
|
||||
"kenlm/*/*main.cc",
|
||||
],),
|
||||
hdrs = glob([
|
||||
"kenlm/lm/*.hh",
|
||||
"kenlm/util/*.hh",
|
||||
]),
|
||||
copts = ["-std=c++11"],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = ["kenlm"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "decoder",
|
||||
srcs = [
|
||||
@ -69,17 +76,16 @@ cc_library(
|
||||
"ctcdecode/scorer.cpp",
|
||||
"ctcdecode/path_trie.cpp",
|
||||
"ctcdecode/path_trie.h",
|
||||
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM,
|
||||
] + OPENFST_SOURCES_PLATFORM,
|
||||
hdrs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
"ctcdecode/scorer.h",
|
||||
],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
includes = [
|
||||
".",
|
||||
"ctcdecode/third_party/ThreadPool",
|
||||
"kenlm",
|
||||
] + OPENFST_INCLUDES_PLATFORM,
|
||||
deps = [":kenlm"]
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
@ -182,18 +188,12 @@ genrule(
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "generate_trie",
|
||||
name = "enumerate_kenlm_vocabulary",
|
||||
srcs = [
|
||||
"alphabet.h",
|
||||
"generate_trie.cpp",
|
||||
"enumerate_kenlm_vocabulary.cpp",
|
||||
],
|
||||
deps = [":kenlm"],
|
||||
copts = ["-std=c++11"],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
],
|
||||
deps = [":decoder"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
|
@ -12,19 +12,17 @@
|
||||
|
||||
char* model = NULL;
|
||||
|
||||
char* lm = NULL;
|
||||
|
||||
char* trie = NULL;
|
||||
char* scorer = NULL;
|
||||
|
||||
char* audio = NULL;
|
||||
|
||||
int beam_width = 500;
|
||||
|
||||
float lm_alpha = 0.75f;
|
||||
bool set_alphabeta = false;
|
||||
|
||||
float lm_beta = 1.85f;
|
||||
float lm_alpha = 0.f;
|
||||
|
||||
bool load_without_trie = false;
|
||||
float lm_beta = 0.f;
|
||||
|
||||
bool show_times = false;
|
||||
|
||||
@ -39,45 +37,42 @@ int stream_size = 0;
|
||||
void PrintHelp(const char* bin)
|
||||
{
|
||||
std::cout <<
|
||||
"Usage: " << bin << " --model MODEL [--lm LM --trie TRIE] --audio AUDIO [-t] [-e]\n"
|
||||
"Usage: " << bin << " --model MODEL [--scorer SCORER] --audio AUDIO [-t] [-e]\n"
|
||||
"\n"
|
||||
"Running DeepSpeech inference.\n"
|
||||
"\n"
|
||||
" --model MODEL Path to the model (protocol buffer binary file)\n"
|
||||
" --lm LM Path to the language model binary file\n"
|
||||
" --trie TRIE Path to the language model trie file created with native_client/generate_trie\n"
|
||||
" --audio AUDIO Path to the audio file to run (WAV format)\n"
|
||||
" --beam_width BEAM_WIDTH Value for decoder beam width (int)\n"
|
||||
" --lm_alpha LM_ALPHA Value for language model alpha param (float)\n"
|
||||
" --lm_beta LM_BETA Value for language model beta param (float)\n"
|
||||
" -t Run in benchmark mode, output mfcc & inference time\n"
|
||||
" --extended Output string from extended metadata\n"
|
||||
" --json Extended output, shows word timings as JSON\n"
|
||||
" --stream size Run in stream mode, output intermediate results\n"
|
||||
" --help Show help\n"
|
||||
" --version Print version and exits\n";
|
||||
"\t--model MODEL\t\tPath to the model (protocol buffer binary file)\n"
|
||||
"\t--scorer SCORER\t\tPath to the external scorer file\n"
|
||||
"\t--audio AUDIO\t\tPath to the audio file to run (WAV format)\n"
|
||||
"\t--beam_width BEAM_WIDTH\tValue for decoder beam width (int)\n"
|
||||
"\t--lm_alpha LM_ALPHA\tValue for language model alpha param (float)\n"
|
||||
"\t--lm_beta LM_BETA\tValue for language model beta param (float)\n"
|
||||
"\t-t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
|
||||
"\t--extended\t\tOutput string from extended metadata\n"
|
||||
"\t--json\t\t\tExtended output, shows word timings as JSON\n"
|
||||
"\t--stream size\t\tRun in stream mode, output intermediate results\n"
|
||||
"\t--help\t\t\tShow help\n"
|
||||
"\t--version\t\tPrint version and exits\n";
|
||||
DS_PrintVersions();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
bool ProcessArgs(int argc, char** argv)
|
||||
{
|
||||
const char* const short_opts = "m:a:l:r:w:c:d:b:tehv";
|
||||
const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
|
||||
const option long_opts[] = {
|
||||
{"model", required_argument, nullptr, 'm'},
|
||||
{"lm", required_argument, nullptr, 'l'},
|
||||
{"trie", required_argument, nullptr, 'r'},
|
||||
{"audio", required_argument, nullptr, 'w'},
|
||||
{"scorer", required_argument, nullptr, 'l'},
|
||||
{"audio", required_argument, nullptr, 'a'},
|
||||
{"beam_width", required_argument, nullptr, 'b'},
|
||||
{"lm_alpha", required_argument, nullptr, 'c'},
|
||||
{"lm_beta", required_argument, nullptr, 'd'},
|
||||
{"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999},
|
||||
{"t", no_argument, nullptr, 't'},
|
||||
{"extended", no_argument, nullptr, 'e'},
|
||||
{"json", no_argument, nullptr, 'j'},
|
||||
{"stream", required_argument, nullptr, 's'},
|
||||
{"help", no_argument, nullptr, 'h'},
|
||||
{"version", no_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, 'h'},
|
||||
{nullptr, no_argument, nullptr, 0}
|
||||
};
|
||||
|
||||
@ -95,41 +90,31 @@ bool ProcessArgs(int argc, char** argv)
|
||||
break;
|
||||
|
||||
case 'l':
|
||||
lm = optarg;
|
||||
scorer = optarg;
|
||||
break;
|
||||
|
||||
case 'r':
|
||||
trie = optarg;
|
||||
break;
|
||||
|
||||
case 'w':
|
||||
case 'a':
|
||||
audio = optarg;
|
||||
break;
|
||||
|
||||
case 'b':
|
||||
beam_width = atoi(optarg);
|
||||
break;
|
||||
|
||||
case 'c':
|
||||
lm_alpha = atof(optarg);
|
||||
break;
|
||||
|
||||
case 'd':
|
||||
lm_beta = atof(optarg);
|
||||
break;
|
||||
case 'b':
|
||||
beam_width = atoi(optarg);
|
||||
break;
|
||||
|
||||
case 999:
|
||||
load_without_trie = true;
|
||||
case 'c':
|
||||
set_alphabeta = true;
|
||||
lm_alpha = atof(optarg);
|
||||
break;
|
||||
|
||||
case 'd':
|
||||
set_alphabeta = true;
|
||||
lm_beta = atof(optarg);
|
||||
break;
|
||||
|
||||
case 't':
|
||||
show_times = true;
|
||||
break;
|
||||
|
||||
case 'v':
|
||||
has_versions = true;
|
||||
break;
|
||||
|
||||
case 'e':
|
||||
extended_metadata = true;
|
||||
break;
|
||||
@ -142,6 +127,10 @@ bool ProcessArgs(int argc, char** argv)
|
||||
stream_size = atoi(optarg);
|
||||
break;
|
||||
|
||||
case 'v':
|
||||
has_versions = true;
|
||||
break;
|
||||
|
||||
case 'h': // -h or --help
|
||||
case '?': // Unrecognized option
|
||||
default:
|
||||
|
@ -374,16 +374,19 @@ main(int argc, char **argv)
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (lm && (trie || load_without_trie)) {
|
||||
int status = DS_EnableDecoderWithLM(ctx,
|
||||
lm,
|
||||
trie,
|
||||
lm_alpha,
|
||||
lm_beta);
|
||||
if (scorer) {
|
||||
int status = DS_EnableExternalScorer(ctx, scorer);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
|
||||
fprintf(stderr, "Could not enable external scorer.\n");
|
||||
return 1;
|
||||
}
|
||||
if (set_alphabeta) {
|
||||
status = DS_SetScorerAlphaBeta(ctx, lm_alpha, lm_beta);
|
||||
if (status != 0) {
|
||||
fprintf(stderr, "Error setting scorer alpha and beta.\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef NO_SOX
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from . import swigwrapper # pylint: disable=import-self
|
||||
from .swigwrapper import Alphabet
|
||||
|
||||
__version__ = swigwrapper.__version__
|
||||
|
||||
@ -11,26 +12,36 @@ class Scorer(swigwrapper.Scorer):
|
||||
:type alpha: float
|
||||
:param beta: Word insertion bonus.
|
||||
:type beta: float
|
||||
:model_path: Path to load language model.
|
||||
:trie_path: Path to trie file.
|
||||
:scorer_path: Path to load scorer from.
|
||||
:alphabet: Alphabet
|
||||
:type model_path: basestring
|
||||
:type scorer_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
||||
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
|
||||
super(Scorer, self).__init__()
|
||||
serialized = alphabet.serialize()
|
||||
native_alphabet = swigwrapper.Alphabet()
|
||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||
if err != 0:
|
||||
raise ValueError("Error when deserializing alphabet.")
|
||||
# Allow bare initialization
|
||||
if alphabet:
|
||||
assert alpha is not None, 'alpha parameter is required'
|
||||
assert beta is not None, 'beta parameter is required'
|
||||
assert scorer_path, 'scorer_path parameter is required'
|
||||
|
||||
err = self.init(alpha, beta,
|
||||
model_path.encode('utf-8'),
|
||||
trie_path.encode('utf-8'),
|
||||
native_alphabet)
|
||||
if err != 0:
|
||||
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||
serialized = alphabet.serialize()
|
||||
native_alphabet = swigwrapper.Alphabet()
|
||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||
if err != 0:
|
||||
raise ValueError('Error when deserializing alphabet.')
|
||||
|
||||
err = self.init(scorer_path.encode('utf-8'),
|
||||
native_alphabet)
|
||||
if err != 0:
|
||||
raise ValueError('Scorer initialization failed with error code {}'.format(err))
|
||||
|
||||
self.reset_params(alpha, beta)
|
||||
|
||||
def load_lm(self, lm_path):
|
||||
super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
||||
|
||||
def save_dictionary(self, save_path, *args, **kwargs):
|
||||
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
|
@ -18,7 +18,7 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
{
|
||||
// assign special ids
|
||||
abs_time_step_ = 0;
|
||||
@ -36,7 +36,7 @@ DecoderState::init(const Alphabet& alphabet,
|
||||
prefix_root_.reset(root);
|
||||
prefixes_.push_back(root);
|
||||
|
||||
if (ext_scorer != nullptr) {
|
||||
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
|
||||
// no need for std::make_shared<>() since Copy() does 'new' behind the doors
|
||||
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
|
||||
root->set_dictionary(dict_ptr);
|
||||
@ -58,7 +58,7 @@ DecoderState::next(const double *probs,
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer_ != nullptr) {
|
||||
if (ext_scorer_) {
|
||||
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
||||
std::partial_sort(prefixes_.begin(),
|
||||
prefixes_.begin() + num_prefixes,
|
||||
@ -109,7 +109,7 @@ DecoderState::next(const double *probs,
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
if (ext_scorer_ != nullptr) {
|
||||
if (ext_scorer_) {
|
||||
// skip scoring the space in word based LMs
|
||||
PathTrie* prefix_to_score;
|
||||
if (ext_scorer_->is_utf8_mode()) {
|
||||
@ -166,7 +166,7 @@ DecoderState::decode() const
|
||||
}
|
||||
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer_ != nullptr) {
|
||||
if (ext_scorer_) {
|
||||
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
||||
auto prefix = prefixes_copy[i];
|
||||
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
|
||||
@ -200,7 +200,7 @@ DecoderState::decode() const
|
||||
Output output;
|
||||
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
|
||||
double approx_ctc = scores[prefixes_copy[i]];
|
||||
if (ext_scorer_ != nullptr) {
|
||||
if (ext_scorer_) {
|
||||
auto words = ext_scorer_->split_labels_into_scored_units(output.tokens);
|
||||
// remove term insertion weight
|
||||
approx_ctc -= words.size() * ext_scorer_->beta;
|
||||
@ -222,7 +222,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
{
|
||||
DecoderState state;
|
||||
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||
@ -243,7 +243,7 @@ ctc_beam_search_decoder_batch(
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer)
|
||||
std::shared_ptr<Scorer> ext_scorer)
|
||||
{
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -16,7 +17,7 @@ class DecoderState {
|
||||
double cutoff_prob_;
|
||||
size_t cutoff_top_n_;
|
||||
|
||||
Scorer* ext_scorer_; // weak
|
||||
std::shared_ptr<Scorer> ext_scorer_;
|
||||
std::vector<PathTrie*> prefixes_;
|
||||
std::unique_ptr<PathTrie> prefix_root_;
|
||||
|
||||
@ -45,7 +46,7 @@ public:
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
|
||||
/* Send data to the decoder
|
||||
*
|
||||
@ -95,7 +96,7 @@ std::vector<Output> ctc_beam_search_decoder(
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
* Parameters:
|
||||
@ -126,6 +127,6 @@ ctc_beam_search_decoder_batch(
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer);
|
||||
std::shared_ptr<Scorer> ext_scorer);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
@ -24,41 +24,37 @@
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
static const int32_t MAGIC = 'TRIE';
|
||||
static const int32_t FILE_VERSION = 5;
|
||||
static const int32_t FILE_VERSION = 6;
|
||||
|
||||
int
|
||||
Scorer::init(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
Scorer::init(const std::string& lm_path,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
alphabet_ = alphabet;
|
||||
setup(lm_path, trie_path);
|
||||
return 0;
|
||||
set_alphabet(alphabet);
|
||||
return load_lm(lm_path);
|
||||
}
|
||||
|
||||
int
|
||||
Scorer::init(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
Scorer::init(const std::string& lm_path,
|
||||
const std::string& alphabet_config_path)
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
int err = alphabet_.init(alphabet_config_path.c_str());
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
setup(lm_path, trie_path);
|
||||
return 0;
|
||||
setup_char_map();
|
||||
return load_lm(lm_path);
|
||||
}
|
||||
|
||||
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||
void
|
||||
Scorer::set_alphabet(const Alphabet& alphabet)
|
||||
{
|
||||
alphabet_ = alphabet;
|
||||
setup_char_map();
|
||||
}
|
||||
|
||||
void Scorer::setup_char_map()
|
||||
{
|
||||
// (Re-)Initialize character map
|
||||
char_map_.clear();
|
||||
@ -71,78 +67,99 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||
// state, otherwise wrong decoding results would be given.
|
||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
||||
}
|
||||
|
||||
// load language model
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||
|
||||
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
|
||||
|
||||
lm::ngram::Config config;
|
||||
|
||||
if (!has_trie) { // no trie was specified, build it now
|
||||
RetrieveStrEnumerateVocab enumerate;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
auto vocab = enumerate.vocabulary;
|
||||
for (size_t i = 0; i < vocab.size(); ++i) {
|
||||
if (vocab[i] != UNK_TOKEN &&
|
||||
vocab[i] != START_TOKEN &&
|
||||
vocab[i] != END_TOKEN &&
|
||||
get_utf8_str_len(vocab[i]) > 1) {
|
||||
is_utf8_mode_ = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (alphabet_.GetSize() != 255) {
|
||||
is_utf8_mode_ = false;
|
||||
}
|
||||
|
||||
// Add spaces only in word-based scoring
|
||||
fill_dictionary(vocab);
|
||||
} else {
|
||||
config.load_method = util::LoadMethod::LAZY;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(trie_path, std::ios::binary);
|
||||
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
|
||||
"your trie file." << std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Trie file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). Update your trie file."
|
||||
<< std::endl;
|
||||
throw 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = trie_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
}
|
||||
|
||||
max_order_ = language_model_->Order();
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path)
|
||||
int Scorer::load_lm(const std::string& lm_path)
|
||||
{
|
||||
std::ofstream fout(path, std::ios::binary);
|
||||
// Check if file is readable to avoid KenLM throwing an exception
|
||||
const char* filename = lm_path.c_str();
|
||||
if (access(filename, R_OK) != 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Check if the file format is valid to avoid KenLM throwing an exception
|
||||
lm::ngram::ModelType model_type;
|
||||
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Load the LM
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::LoadMethod::LAZY;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
|
||||
max_order_ = language_model_->Order();
|
||||
|
||||
uint64_t package_size;
|
||||
{
|
||||
util::scoped_fd fd(util::OpenReadOrThrow(filename));
|
||||
package_size = util::SizeFile(fd.get());
|
||||
}
|
||||
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
|
||||
if (package_size <= trie_offset) {
|
||||
// File ends without a trie structure
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Read metadata and trie from file
|
||||
std::ifstream fin(lm_path, std::ios::binary);
|
||||
fin.seekg(trie_offset);
|
||||
return load_trie(fin, lm_path);
|
||||
}
|
||||
|
||||
int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
|
||||
{
|
||||
int magic;
|
||||
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != MAGIC) {
|
||||
std::cerr << "Error: Can't parse scorer file, invalid header. Try updating "
|
||||
"your scorer file." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int version;
|
||||
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
if (version != FILE_VERSION) {
|
||||
std::cerr << "Error: Scorer file version mismatch (" << version
|
||||
<< " instead of expected " << FILE_VERSION
|
||||
<< "). ";
|
||||
if (version < FILE_VERSION) {
|
||||
std::cerr << "Update your scorer file.";
|
||||
} else {
|
||||
std::cerr << "Downgrade your scorer file or update your version of DeepSpeech.";
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
|
||||
// Read hyperparameters from header
|
||||
double alpha, beta;
|
||||
fin.read(reinterpret_cast<char*>(&alpha), sizeof(alpha));
|
||||
fin.read(reinterpret_cast<char*>(&beta), sizeof(beta));
|
||||
reset_params(alpha, beta);
|
||||
|
||||
fst::FstReadOptions opt;
|
||||
opt.mode = fst::FstReadOptions::MAP;
|
||||
opt.source = file_path;
|
||||
dictionary.reset(FstType::Read(fin, opt));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
|
||||
{
|
||||
std::ios::openmode om;
|
||||
if (append_instead_of_overwrite) {
|
||||
om = std::ios::in|std::ios::out|std::ios::binary|std::ios::ate;
|
||||
} else {
|
||||
om = std::ios::out|std::ios::binary;
|
||||
}
|
||||
std::fstream fout(path, om);
|
||||
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
|
||||
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
|
||||
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
|
||||
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
|
||||
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
|
||||
fst::FstWriteOptions opt;
|
||||
opt.align = true;
|
||||
opt.source = path;
|
||||
|
@ -6,31 +6,19 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
#include "path_trie.h"
|
||||
#include "alphabet.h"
|
||||
#include "deepspeech.h"
|
||||
|
||||
const double OOV_SCORE = -1000.0;
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrieve the dictionary of language model.
|
||||
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab {
|
||||
public:
|
||||
RetrieveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
/* External scorer to query score for n-gram or sentence, including language
|
||||
* model scoring and word insertion.
|
||||
*
|
||||
@ -40,9 +28,9 @@ public:
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
public:
|
||||
using FstType = PathTrie::FstType;
|
||||
|
||||
public:
|
||||
Scorer() = default;
|
||||
~Scorer() = default;
|
||||
|
||||
@ -50,16 +38,10 @@ public:
|
||||
Scorer(const Scorer&) = delete;
|
||||
Scorer& operator=(const Scorer&) = delete;
|
||||
|
||||
int init(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::string &trie_path,
|
||||
int init(const std::string &lm_path,
|
||||
const Alphabet &alphabet);
|
||||
|
||||
int init(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::string &trie_path,
|
||||
int init(const std::string &lm_path,
|
||||
const std::string &alphabet_config_path);
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words,
|
||||
@ -76,12 +58,15 @@ public:
|
||||
// return the max order
|
||||
size_t get_max_order() const { return max_order_; }
|
||||
|
||||
// retrun true if the language model is character based
|
||||
// return true if the language model is character based
|
||||
bool is_utf8_mode() const { return is_utf8_mode_; }
|
||||
|
||||
// reset params alpha & beta
|
||||
void reset_params(float alpha, float beta);
|
||||
|
||||
// force set UTF-8 mode, ignore value read from file
|
||||
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
|
||||
|
||||
// make ngram for a given prefix
|
||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||
|
||||
@ -89,12 +74,20 @@ public:
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||
|
||||
void set_alphabet(const Alphabet& alphabet);
|
||||
|
||||
// save dictionary in file
|
||||
void save_dictionary(const std::string &path);
|
||||
void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
|
||||
|
||||
// return weather this step represents a boundary where beam scoring should happen
|
||||
bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
|
||||
|
||||
// fill dictionary FST from a vocabulary
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
|
||||
// load language model from given path
|
||||
int load_lm(const std::string &lm_path);
|
||||
|
||||
// language model weight
|
||||
double alpha = 0.;
|
||||
// word insertion weight
|
||||
@ -104,14 +97,10 @@ public:
|
||||
std::unique_ptr<FstType> dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, fill FST's dictionary
|
||||
void setup(const std::string &lm_path, const std::string &trie_path);
|
||||
// necessary setup after setting alphabet
|
||||
void setup_char_map();
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// fill dictionary for FST
|
||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
int load_trie(std::ifstream& fin, const std::string& file_path);
|
||||
|
||||
private:
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
|
@ -7,15 +7,22 @@
|
||||
#include "workspace_status.h"
|
||||
%}
|
||||
|
||||
%include "pyabc.i"
|
||||
%include "std_string.i"
|
||||
%include "std_vector.i"
|
||||
%include <pyabc.i>
|
||||
%include <std_string.i>
|
||||
%include <std_vector.i>
|
||||
%include <std_shared_ptr.i>
|
||||
%include "numpy.i"
|
||||
|
||||
%init %{
|
||||
import_array();
|
||||
%}
|
||||
|
||||
namespace std {
|
||||
%template(StringVector) vector<string>;
|
||||
}
|
||||
|
||||
%shared_ptr(Scorer);
|
||||
|
||||
// Convert NumPy arrays to pointer+lengths
|
||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||
|
@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
|
||||
}
|
||||
|
||||
int
|
||||
DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
const char* aLMPath,
|
||||
const char* aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta)
|
||||
DS_EnableExternalScorer(ModelState* aCtx,
|
||||
const char* aScorerPath)
|
||||
{
|
||||
aCtx->scorer_.reset(new Scorer());
|
||||
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
aCtx->alphabet_);
|
||||
int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
|
||||
if (err != 0) {
|
||||
return DS_ERR_INVALID_LM;
|
||||
return DS_ERR_INVALID_SCORER;
|
||||
}
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
DS_DisableExternalScorer(ModelState* aCtx)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
aCtx->scorer_.reset();
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||
float aAlpha,
|
||||
float aBeta)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
aCtx->scorer_->reset_params(aAlpha, aBeta);
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int
|
||||
DS_CreateStream(ModelState* aCtx,
|
||||
StreamingState** retval)
|
||||
@ -348,7 +363,7 @@ DS_CreateStream(ModelState* aCtx,
|
||||
aCtx->beam_width_,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
aCtx->scorer_.get());
|
||||
aCtx->scorer_);
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
|
@ -59,8 +59,9 @@ enum DeepSpeech_Error_Codes
|
||||
// Invalid parameters
|
||||
DS_ERR_INVALID_ALPHABET = 0x2000,
|
||||
DS_ERR_INVALID_SHAPE = 0x2001,
|
||||
DS_ERR_INVALID_LM = 0x2002,
|
||||
DS_ERR_INVALID_SCORER = 0x2002,
|
||||
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
|
||||
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
|
||||
|
||||
// Runtime failures
|
||||
DS_ERR_FAIL_INIT_MMAP = 0x3000,
|
||||
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
|
||||
void DS_FreeModel(ModelState* ctx);
|
||||
|
||||
/**
|
||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
||||
* @brief Enable decoding using an external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param aLMPath The path to the language model binary file.
|
||||
* @param aTriePath The path to the trie file build from the same vocabu-
|
||||
* lary as the language model binary.
|
||||
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
|
||||
weight.
|
||||
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
|
||||
weight.
|
||||
* @param aScorerPath The path to the external scorer file.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
const char* aLMPath,
|
||||
const char* aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta);
|
||||
int DS_EnableExternalScorer(ModelState* aCtx,
|
||||
const char* aScorerPath);
|
||||
|
||||
/**
|
||||
* @brief Disable decoding using an external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_DisableExternalScorer(ModelState* aCtx);
|
||||
|
||||
/**
|
||||
* @brief Set hyperparameters alpha and beta of the external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
|
||||
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||
float aAlpha,
|
||||
float aBeta);
|
||||
|
||||
/**
|
||||
* @brief Use the DeepSpeech model to perform Speech-To-Text.
|
||||
|
@ -1,141 +0,0 @@
|
||||
#ifndef DEEPSPEECH_COMPAT_H
|
||||
#define DEEPSPEECH_COMPAT_H
|
||||
|
||||
#include "deepspeech.h"
|
||||
|
||||
#warning This header is a convenience wrapper for compatibility with \
|
||||
the previous API, it has deprecated function names and arguments. \
|
||||
If possible, update your code instead of using this header.
|
||||
|
||||
/**
|
||||
* @brief An object providing an interface to a trained DeepSpeech model.
|
||||
*
|
||||
* @param aModelPath The path to the frozen model graph.
|
||||
* @param aNCep UNUSED, DEPRECATED.
|
||||
* @param aNContext UNUSED, DEPRECATED.
|
||||
* @param aAlphabetConfigPath UNUSED, DEPRECATED.
|
||||
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
||||
* width generates better results at the cost of decoding
|
||||
* time.
|
||||
* @param[out] retval a ModelState pointer
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
int DS_CreateModel(const char* aModelPath,
|
||||
unsigned int /*aNCep*/,
|
||||
unsigned int /*aNContext*/,
|
||||
const char* /*aAlphabetConfigPath*/,
|
||||
unsigned int aBeamWidth,
|
||||
ModelState** retval)
|
||||
{
|
||||
return DS_CreateModel(aModelPath, aBeamWidth, retval);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Frees associated resources and destroys model object.
|
||||
*/
|
||||
void DS_DestroyModel(ModelState* ctx)
|
||||
{
|
||||
return DS_FreeModel(ctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param aAlphabetConfigPath UNUSED, DEPRECATED.
|
||||
* @param aLMPath The path to the language model binary file.
|
||||
* @param aTriePath The path to the trie file build from the same vocabu-
|
||||
* lary as the language model binary.
|
||||
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
|
||||
weight.
|
||||
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
|
||||
weight.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
int DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
const char* /*aAlphabetConfigPath*/,
|
||||
const char* aLMPath,
|
||||
const char* aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta)
|
||||
{
|
||||
return DS_EnableDecoderWithLM(aCtx, aLMPath, aTriePath, aLMAlpha, aLMBeta);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a new streaming inference state. The streaming state returned
|
||||
* by this function can then be passed to {@link DS_FeedAudioContent()}
|
||||
* and {@link DS_FinishStream()}.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model to use.
|
||||
* @param aSampleRate UNUSED, DEPRECATED.
|
||||
* @param[out] retval an opaque pointer that represents the streaming state. Can
|
||||
* be NULL if an error occurs.
|
||||
*
|
||||
* @return Zero for success, non-zero on failure.
|
||||
*/
|
||||
int DS_SetupStream(ModelState* aCtx,
|
||||
unsigned int /*aSampleRate*/,
|
||||
StreamingState** retval)
|
||||
{
|
||||
return DS_CreateStream(aCtx, retval);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Destroy a streaming state without decoding the computed logits. This
|
||||
* can be used if you no longer need the result of an ongoing streaming
|
||||
* inference and don't want to perform a costly decode operation.
|
||||
*
|
||||
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
|
||||
*
|
||||
* @note This method will free the state pointer (@p aSctx).
|
||||
*/
|
||||
void DS_DiscardStream(StreamingState* aSctx)
|
||||
{
|
||||
return DS_FreeStream(aSctx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Use the DeepSpeech model to perform Speech-To-Text.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model to use.
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
|
||||
* sample rate (matching what the model was trained on).
|
||||
* @param aBufferSize The number of samples in the audio signal.
|
||||
* @param aSampleRate UNUSED, DEPRECATED.
|
||||
*
|
||||
* @return The STT result. The user is responsible for freeing the string using
|
||||
* {@link DS_FreeString()}. Returns NULL on error.
|
||||
*/
|
||||
char* DS_SpeechToText(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int /*aSampleRate*/)
|
||||
{
|
||||
return DS_SpeechToText(aCtx, aBuffer, aBufferSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
|
||||
* about the results.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model to use.
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
|
||||
* sample rate (matching what the model was trained on).
|
||||
* @param aBufferSize The number of samples in the audio signal.
|
||||
* @param aSampleRate UNUSED, DEPRECATED.
|
||||
*
|
||||
* @return Outputs a struct of individual letters along with their timing information.
|
||||
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
|
||||
*/
|
||||
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int /*aSampleRate*/)
|
||||
{
|
||||
return DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize);
|
||||
}
|
||||
|
||||
#endif /* DEEPSPEECH_COMPAT_H */
|
@ -82,8 +82,8 @@ namespace DeepSpeechClient
|
||||
throw new ArgumentException("Invalid alphabet embedded in model. (Data corruption?)");
|
||||
case ErrorCodes.DS_ERR_INVALID_SHAPE:
|
||||
throw new ArgumentException("Invalid model shape.");
|
||||
case ErrorCodes.DS_ERR_INVALID_LM:
|
||||
throw new ArgumentException("Invalid language model file.");
|
||||
case ErrorCodes.DS_ERR_INVALID_SCORER:
|
||||
throw new ArgumentException("Invalid scorer file.");
|
||||
case ErrorCodes.DS_ERR_FAIL_INIT_MMAP:
|
||||
throw new ArgumentException("Failed to initialize memory mapped model.");
|
||||
case ErrorCodes.DS_ERR_FAIL_INIT_SESS:
|
||||
@ -100,6 +100,8 @@ namespace DeepSpeechClient
|
||||
throw new ArgumentException("Error failed to create session.");
|
||||
case ErrorCodes.DS_ERR_MODEL_INCOMPATIBLE:
|
||||
throw new ArgumentException("Error incompatible model.");
|
||||
case ErrorCodes.DS_ERR_SCORER_NOT_ENABLED:
|
||||
throw new ArgumentException("External scorer is not enabled.");
|
||||
default:
|
||||
throw new ArgumentException("Unknown error, please make sure you are using the correct native binary.");
|
||||
}
|
||||
@ -114,45 +116,48 @@ namespace DeepSpeechClient
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enable decoding using beam scoring with a KenLM language model.
|
||||
/// Enable decoding using an external scorer.
|
||||
/// </summary>
|
||||
/// <param name="aLMPath">The path to the language model binary file.</param>
|
||||
/// <param name="aTriePath">The path to the trie file build from the same vocabulary as the language model binary.</param>
|
||||
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
|
||||
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
|
||||
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
|
||||
public unsafe void EnableDecoderWithLM(string aLMPath, string aTriePath,
|
||||
float aLMAlpha, float aLMBeta)
|
||||
/// <param name="aScorerPath">The path to the external scorer file.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with an external scorer.</exception>
|
||||
/// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
|
||||
public unsafe void EnableExternalScorer(string aScorerPath)
|
||||
{
|
||||
string exceptionMessage = null;
|
||||
if (string.IsNullOrWhiteSpace(aLMPath))
|
||||
if (string.IsNullOrWhiteSpace(aScorerPath))
|
||||
{
|
||||
exceptionMessage = "Path to the language model file cannot be empty.";
|
||||
throw new FileNotFoundException("Path to the scorer file cannot be empty.");
|
||||
}
|
||||
if (!File.Exists(aLMPath))
|
||||
if (!File.Exists(aScorerPath))
|
||||
{
|
||||
exceptionMessage = $"Cannot find the language model file: {aLMPath}";
|
||||
}
|
||||
if (string.IsNullOrWhiteSpace(aTriePath))
|
||||
{
|
||||
exceptionMessage = "Path to the trie file cannot be empty.";
|
||||
}
|
||||
if (!File.Exists(aTriePath))
|
||||
{
|
||||
exceptionMessage = $"Cannot find the trie file: {aTriePath}";
|
||||
throw new FileNotFoundException($"Cannot find the scorer file: {aScorerPath}");
|
||||
}
|
||||
|
||||
if (exceptionMessage != null)
|
||||
{
|
||||
throw new FileNotFoundException(exceptionMessage);
|
||||
}
|
||||
var resultCode = NativeImp.DS_EnableExternalScorer(_modelStatePP, aScorerPath);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
var resultCode = NativeImp.DS_EnableDecoderWithLM(_modelStatePP,
|
||||
aLMPath,
|
||||
aTriePath,
|
||||
aLMAlpha,
|
||||
aLMBeta);
|
||||
/// <summary>
|
||||
/// Disable decoding using an external scorer.
|
||||
/// </summary>
|
||||
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
|
||||
public unsafe void DisableExternalScorer()
|
||||
{
|
||||
var resultCode = NativeImp.DS_DisableExternalScorer(_modelStatePP);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Set hyperparameters alpha and beta of the external scorer.
|
||||
/// </summary>
|
||||
/// <param name="aAlpha">The alpha hyperparameter of the decoder. Language model weight.</param>
|
||||
/// <param name="aBeta">The beta hyperparameter of the decoder. Word insertion weight.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
|
||||
public unsafe void SetScorerAlphaBeta(float aAlpha, float aBeta)
|
||||
{
|
||||
var resultCode = NativeImp.DS_SetScorerAlphaBeta(_modelStatePP,
|
||||
aAlpha,
|
||||
aBeta);
|
||||
EvaluateResultCode(resultCode);
|
||||
}
|
||||
|
||||
|
@ -14,8 +14,9 @@
|
||||
// Invalid parameters
|
||||
DS_ERR_INVALID_ALPHABET = 0x2000,
|
||||
DS_ERR_INVALID_SHAPE = 0x2001,
|
||||
DS_ERR_INVALID_LM = 0x2002,
|
||||
DS_ERR_INVALID_SCORER = 0x2002,
|
||||
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
|
||||
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
|
||||
|
||||
// Runtime failures
|
||||
DS_ERR_FAIL_INIT_MMAP = 0x3000,
|
||||
|
@ -21,18 +21,26 @@ namespace DeepSpeechClient.Interfaces
|
||||
unsafe int GetModelSampleRate();
|
||||
|
||||
/// <summary>
|
||||
/// Enable decoding using beam scoring with a KenLM language model.
|
||||
/// Enable decoding using an external scorer.
|
||||
/// </summary>
|
||||
/// <param name="aLMPath">The path to the language model binary file.</param>
|
||||
/// <param name="aTriePath">The path to the trie file build from the same vocabulary as the language model binary.</param>
|
||||
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param>
|
||||
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
|
||||
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
|
||||
unsafe void EnableDecoderWithLM(string aLMPath,
|
||||
string aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta);
|
||||
/// <param name="aScorerPath">The path to the external scorer file.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with an external scorer.</exception>
|
||||
/// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
|
||||
unsafe void EnableExternalScorer(string aScorerPath);
|
||||
|
||||
/// <summary>
|
||||
/// Disable decoding using an external scorer.
|
||||
/// </summary>
|
||||
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
|
||||
unsafe void DisableExternalScorer();
|
||||
|
||||
/// <summary>
|
||||
/// Set hyperparameters alpha and beta of the external scorer.
|
||||
/// </summary>
|
||||
/// <param name="aAlpha">The alpha hyperparameter of the decoder. Language model weight.</param>
|
||||
/// <param name="aBeta">The beta hyperparameter of the decoder. Word insertion weight.</param>
|
||||
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
|
||||
unsafe void SetScorerAlphaBeta(float aAlpha, float aBeta);
|
||||
|
||||
/// <summary>
|
||||
/// Use the DeepSpeech model to perform Speech-To-Text.
|
||||
|
@ -23,11 +23,16 @@ namespace DeepSpeechClient
|
||||
internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_EnableDecoderWithLM(IntPtr** aCtx,
|
||||
string aLMPath,
|
||||
string aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta);
|
||||
internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx,
|
||||
string aScorerPath);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
|
||||
internal static unsafe extern ErrorCodes DS_SetScorerAlphaBeta(IntPtr** aCtx,
|
||||
float aAlpha,
|
||||
float aBeta);
|
||||
|
||||
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl,
|
||||
CharSet = CharSet.Ansi, SetLastError = true)]
|
||||
|
@ -35,22 +35,18 @@ namespace CSharpExamples
|
||||
static void Main(string[] args)
|
||||
{
|
||||
string model = null;
|
||||
string lm = null;
|
||||
string trie = null;
|
||||
string scorer = null;
|
||||
string audio = null;
|
||||
bool extended = false;
|
||||
if (args.Length > 0)
|
||||
{
|
||||
model = GetArgument(args, "--model");
|
||||
lm = GetArgument(args, "--lm");
|
||||
trie = GetArgument(args, "--trie");
|
||||
scorer = GetArgument(args, "--scorer");
|
||||
audio = GetArgument(args, "--audio");
|
||||
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
|
||||
}
|
||||
|
||||
const uint BEAM_WIDTH = 500;
|
||||
const float LM_ALPHA = 0.75f;
|
||||
const float LM_BETA = 1.85f;
|
||||
|
||||
Stopwatch stopwatch = new Stopwatch();
|
||||
try
|
||||
@ -64,14 +60,10 @@ namespace CSharpExamples
|
||||
|
||||
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
|
||||
stopwatch.Reset();
|
||||
if (lm != null)
|
||||
if (scorer != null)
|
||||
{
|
||||
Console.WriteLine("Loadin LM...");
|
||||
sttClient.EnableDecoderWithLM(
|
||||
lm ?? "lm.binary",
|
||||
trie ?? "trie",
|
||||
LM_ALPHA, LM_BETA);
|
||||
|
||||
Console.WriteLine("Loading scorer...");
|
||||
sttClient.EnableExternalScorer(scorer ?? "kenlm.scorer");
|
||||
}
|
||||
|
||||
string audioFile = audio ?? "arctic_a0024.wav";
|
||||
|
50
native_client/enumerate_kenlm_vocabulary.cpp
Normal file
50
native_client/enumerate_kenlm_vocabulary.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
#include "lm/model.hh"
|
||||
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrieve the dictionary of language model.
|
||||
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab
|
||||
{
|
||||
public:
|
||||
RetrieveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if (argc != 3) {
|
||||
std::cerr << "Usage: " << argv[0] << " <kenlm_model> <output_path>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const char* kenlm_model = argv[1];
|
||||
const char* output_path = argv[2];
|
||||
|
||||
std::unique_ptr<lm::base::Model> language_model_;
|
||||
lm::ngram::Config config;
|
||||
RetrieveStrEnumerateVocab enumerate;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config));
|
||||
|
||||
std::ofstream fout(output_path);
|
||||
for (const std::string& word : enumerate.vocabulary) {
|
||||
fout << word << "\n";
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
|
||||
Alphabet alphabet;
|
||||
int err = alphabet.init(alphabet_path);
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
Scorer scorer;
|
||||
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
scorer.save_dictionary(trie_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 4) {
|
||||
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return generate_trie(argv[1], argv[2], argv[3]);
|
||||
}
|
@ -51,12 +51,11 @@ Please push DeepSpeech data to ``/sdcard/deepspeech/``\ , including:
|
||||
|
||||
|
||||
* ``output_graph.tflite`` which is the TF Lite model
|
||||
* ``lm.binary`` and ``trie`` files, if you want to use the language model ; please
|
||||
be aware that too big language model will make the device run out of memory
|
||||
* ``kenlm.scorer``, if you want to use the scorer; please be aware that too big
|
||||
scorer will make the device run out of memory
|
||||
|
||||
Then, push binaries from ``native_client.tar.xz`` to ``/data/local/tmp/ds``\ :
|
||||
|
||||
|
||||
* ``deepspeech``
|
||||
* ``libdeepspeech.so``
|
||||
* ``libc++_shared.so``
|
||||
|
@ -31,8 +31,6 @@ public class DeepSpeechActivity extends AppCompatActivity {
|
||||
Button _startInference;
|
||||
|
||||
final int BEAM_WIDTH = 50;
|
||||
final float LM_ALPHA = 0.75f;
|
||||
final float LM_BETA = 1.85f;
|
||||
|
||||
private char readLEChar(RandomAccessFile f) throws IOException {
|
||||
byte b1 = f.readByte();
|
||||
|
@ -30,15 +30,11 @@ import java.nio.ByteBuffer;
|
||||
public class BasicTest {
|
||||
|
||||
public static final String modelFile = "/data/local/tmp/test/output_graph.tflite";
|
||||
public static final String lmFile = "/data/local/tmp/test/lm.binary";
|
||||
public static final String trieFile = "/data/local/tmp/test/trie";
|
||||
public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer";
|
||||
public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav";
|
||||
|
||||
public static final int BEAM_WIDTH = 50;
|
||||
|
||||
public static final float LM_ALPHA = 0.75f;
|
||||
public static final float LM_BETA = 1.85f;
|
||||
|
||||
private char readLEChar(RandomAccessFile f) throws IOException {
|
||||
byte b1 = f.readByte();
|
||||
byte b2 = f.readByte();
|
||||
@ -130,7 +126,7 @@ public class BasicTest {
|
||||
@Test
|
||||
public void loadDeepSpeech_stt_withLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
m.enableDecoderWithLM(lmFile, trieFile, LM_ALPHA, LM_BETA);
|
||||
m.enableExternalScorer(scorerFile);
|
||||
|
||||
String decoded = doSTT(m, false);
|
||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||
@ -149,7 +145,7 @@ public class BasicTest {
|
||||
@Test
|
||||
public void loadDeepSpeech_sttWithMetadata_withLM() {
|
||||
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
|
||||
m.enableDecoderWithLM(lmFile, trieFile, LM_ALPHA, LM_BETA);
|
||||
m.enableExternalScorer(scorerFile);
|
||||
|
||||
String decoded = doSTT(m, true);
|
||||
assertEquals("she had your dark suit in greasy wash water all year", decoded);
|
||||
|
@ -47,17 +47,35 @@ public class DeepSpeechModel {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
||||
* @brief Enable decoding using an external scorer.
|
||||
*
|
||||
* @param lm The path to the language model binary file.
|
||||
* @param trie The path to the trie file build from the same vocabulary as the language model binary.
|
||||
* @param lm_alpha The alpha hyperparameter of the CTC decoder. Language Model weight.
|
||||
* @param lm_beta The beta hyperparameter of the CTC decoder. Word insertion weight.
|
||||
* @param scorer The path to the external scorer file.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
public void enableDecoderWithLM(String lm, String trie, float lm_alpha, float lm_beta) {
|
||||
impl.EnableDecoderWithLM(this._msp, lm, trie, lm_alpha, lm_beta);
|
||||
public void enableExternalScorer(String scorer) {
|
||||
impl.EnableExternalScorer(this._msp, scorer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Disable decoding using an external scorer.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
public void disableExternalScorer() {
|
||||
impl.DisableExternalScorer(this._msp);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
||||
*
|
||||
* @param alpha The alpha hyperparameter of the decoder. Language model weight.
|
||||
* @param beta The beta hyperparameter of the decoder. Word insertion weight.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
public void setScorerAlphaBeta(float alpha, float beta) {
|
||||
impl.SetScorerAlphaBeta(this._msp, alpha, beta);
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -29,12 +29,11 @@ VersionAction.prototype.call = function(parser) {
|
||||
|
||||
var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
|
||||
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
|
||||
parser.addArgument(['--lm'], {help: 'Path to the language model binary file', nargs: '?'});
|
||||
parser.addArgument(['--trie'], {help: 'Path to the language model trie file created with native_client/generate_trie', nargs: '?'});
|
||||
parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
|
||||
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
|
||||
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
|
||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha)', defaultValue: 0.75, type: 'float'});
|
||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta)', defaultValue: 1.85, type: 'float'});
|
||||
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
|
||||
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
|
||||
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
|
||||
var args = parser.parseArgs();
|
||||
@ -60,12 +59,16 @@ console.error('Loaded model in %ds.', totalTime(model_load_end));
|
||||
|
||||
var desired_sample_rate = model.sampleRate();
|
||||
|
||||
if (args['lm'] && args['trie']) {
|
||||
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
|
||||
const lm_load_start = process.hrtime();
|
||||
model.enableDecoderWithLM(args['lm'], args['trie'], args['lm_alpha'], args['lm_beta']);
|
||||
const lm_load_end = process.hrtime(lm_load_start);
|
||||
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
|
||||
if (args['scorer']) {
|
||||
console.error('Loading scorer from file %s', args['scorer']);
|
||||
const scorer_load_start = process.hrtime();
|
||||
model.enableExternalScorer(args['scorer']);
|
||||
const scorer_load_end = process.hrtime(scorer_load_start);
|
||||
console.error('Loaded scorer in %ds.', totalTime(scorer_load_end));
|
||||
|
||||
if (args['lm_alpha'] && args['lm_beta']) {
|
||||
model.setScorerAlphaBeta(args['lm_alpha'], args['lm_beta']);
|
||||
}
|
||||
}
|
||||
|
||||
const buffer = Fs.readFileSync(args['audio']);
|
||||
|
@ -52,31 +52,46 @@ Model.prototype.sampleRate = function() {
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable decoding using beam scoring with a KenLM language model.
|
||||
* Enable decoding using an external scorer.
|
||||
*
|
||||
* @param {string} aScorerPath The path to the external scorer file.
|
||||
*
|
||||
* @return {number} Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
Model.prototype.enableExternalScorer = function(aScorerPath) {
|
||||
return binding.EnableExternalScorer(this._impl, aScorerPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Disable decoding using an external scorer.
|
||||
*
|
||||
* @return {number} Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
Model.prototype.disableExternalScorer = function() {
|
||||
return binding.EnableExternalScorer(this._impl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set hyperparameters alpha and beta of the external scorer.
|
||||
*
|
||||
* @param {string} aLMPath The path to the language model binary file.
|
||||
* @param {string} aTriePath The path to the trie file build from the same vocabulary as the language model binary.
|
||||
* @param {float} aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model weight.
|
||||
* @param {float} aLMBeta The beta hyperparameter of the CTC decoder. Word insertion weight.
|
||||
*
|
||||
* @return {number} Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
Model.prototype.enableDecoderWithLM = function() {
|
||||
const args = [this._impl].concat(Array.prototype.slice.call(arguments));
|
||||
return binding.EnableDecoderWithLM.apply(null, args);
|
||||
Model.prototype.setScorerAlphaBeta = function(aLMAlpha, aLMBeta) {
|
||||
return binding.SetScorerAlphaBeta(this._impl, aLMAlpha, aLMBeta);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use the DeepSpeech model to perform Speech-To-Text.
|
||||
*
|
||||
* @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
* @param {number} aBufferSize The number of samples in the audio signal.
|
||||
*
|
||||
* @return {string} The STT result. Returns undefined on error.
|
||||
*/
|
||||
Model.prototype.stt = function() {
|
||||
const args = [this._impl].concat(Array.prototype.slice.call(arguments));
|
||||
return binding.SpeechToText.apply(null, args);
|
||||
Model.prototype.stt = function(aBuffer) {
|
||||
return binding.SpeechToText(this._impl, aBuffer);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -84,25 +99,22 @@ Model.prototype.stt = function() {
|
||||
* about the results.
|
||||
*
|
||||
* @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
* @param {number} aBufferSize The number of samples in the audio signal.
|
||||
*
|
||||
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
|
||||
*/
|
||||
Model.prototype.sttWithMetadata = function() {
|
||||
const args = [this._impl].concat(Array.prototype.slice.call(arguments));
|
||||
return binding.SpeechToTextWithMetadata.apply(null, args);
|
||||
Model.prototype.sttWithMetadata = function(aBuffer) {
|
||||
return binding.SpeechToTextWithMetadata(this._impl, aBuffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new streaming inference state. The streaming state returned by this function can then be passed to :js:func:`Model.feedAudioContent` and :js:func:`Model.finishStream`.
|
||||
* Create a new streaming inference state. One can then call :js:func:`Stream.feedAudioContent` and :js:func:`Stream.finishStream` on the returned stream object.
|
||||
*
|
||||
* @return {object} an opaque object that represents the streaming state.
|
||||
* @return {object} a :js:func:`Stream` object that represents the streaming state.
|
||||
*
|
||||
* @throws on error
|
||||
*/
|
||||
Model.prototype.createStream = function() {
|
||||
const args = [this._impl].concat(Array.prototype.slice.call(arguments));
|
||||
const rets = binding.CreateStream.apply(null, args);
|
||||
const rets = binding.CreateStream(this._impl);
|
||||
const status = rets[0];
|
||||
const ctx = rets[1];
|
||||
if (status !== 0) {
|
||||
@ -111,55 +123,61 @@ Model.prototype.createStream = function() {
|
||||
return ctx;
|
||||
}
|
||||
|
||||
/**
|
||||
* @class
|
||||
* Provides an interface to a DeepSpeech stream. The constructor cannot be called
|
||||
* directly, use :js:func:`Model.createStream`.
|
||||
*/
|
||||
function Stream(nativeStream) {
|
||||
this._impl = nativeStream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Feed audio samples to an ongoing streaming inference.
|
||||
*
|
||||
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
|
||||
* @param {buffer} aBuffer An array of 16-bit, mono raw audio samples at the
|
||||
* appropriate sample rate (matching what the model was trained on).
|
||||
* @param {number} aBufferSize The number of samples in @param aBuffer.
|
||||
*/
|
||||
Model.prototype.feedAudioContent = function() {
|
||||
binding.FeedAudioContent.apply(null, arguments);
|
||||
Stream.prototype.feedAudioContent = function(aBuffer) {
|
||||
binding.FeedAudioContent(this._impl, aBuffer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the intermediate decoding of an ongoing streaming inference.
|
||||
*
|
||||
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
|
||||
*
|
||||
* @return {string} The STT intermediate result.
|
||||
*/
|
||||
Model.prototype.intermediateDecode = function() {
|
||||
return binding.IntermediateDecode.apply(null, arguments);
|
||||
Stream.prototype.intermediateDecode = function() {
|
||||
return binding.IntermediateDecode(this._impl);
|
||||
}
|
||||
|
||||
/**
|
||||
* Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal.
|
||||
*
|
||||
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
|
||||
*
|
||||
* @return {string} The STT result.
|
||||
*
|
||||
* This method will free the state (@param aSctx).
|
||||
* This method will free the stream, it must not be used after this method is called.
|
||||
*/
|
||||
Model.prototype.finishStream = function() {
|
||||
return binding.FinishStream.apply(null, arguments);
|
||||
Stream.prototype.finishStream = function() {
|
||||
result = binding.FinishStream(this._impl);
|
||||
this._impl = null;
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata.
|
||||
*
|
||||
* @param {object} aSctx A streaming state pointer returned by :js:func:`Model.setupStream`.
|
||||
*
|
||||
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`.
|
||||
*
|
||||
* This method will free the state pointer (@param aSctx).
|
||||
* This method will free the stream, it must not be used after this method is called.
|
||||
*/
|
||||
Model.prototype.finishStreamWithMetadata = function() {
|
||||
return binding.FinishStreamWithMetadata.apply(null, arguments);
|
||||
Stream.prototype.finishStreamWithMetadata = function() {
|
||||
result = binding.FinishStreamWithMetadata(this._impl);
|
||||
this._impl = null;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Frees associated resources and destroys model object.
|
||||
*
|
||||
@ -184,10 +202,10 @@ function FreeMetadata(metadata) {
|
||||
* can be used if you no longer need the result of an ongoing streaming
|
||||
* inference and don't want to perform a costly decode operation.
|
||||
*
|
||||
* @param {Object} stream A streaming state pointer returned by :js:func:`Model.createStream`.
|
||||
* @param {Object} stream A stream object returned by :js:func:`Model.createStream`.
|
||||
*/
|
||||
function FreeStream(stream) {
|
||||
return binding.FreeStream(stream);
|
||||
return binding.FreeStream(stream._impl);
|
||||
}
|
||||
|
||||
/**
|
||||
|
3
native_client/kenlm/.gitignore
vendored
3
native_client/kenlm/.gitignore
vendored
@ -3,6 +3,9 @@ util/file_piece.cc.gz
|
||||
*.o
|
||||
doc/
|
||||
build/
|
||||
/bin
|
||||
/lib
|
||||
/tests
|
||||
._*
|
||||
windows/Win32
|
||||
windows/x64
|
||||
|
@ -12,3 +12,7 @@ If you only want the query code and do not care about compression (.gz, .bz2, an
|
||||
Windows:
|
||||
The windows directory has visual studio files. Note that you need to compile
|
||||
the kenlm project before build_binary and ngram_query projects.
|
||||
|
||||
OSX:
|
||||
Missing dependencies can be remedied with brew.
|
||||
brew install cmake boost eigen
|
||||
|
@ -1 +1 @@
|
||||
cdd794598ea15dc23a7daaf7a8cf89423c97f7e6
|
||||
b9f35777d112ce2fc10bd3986302517a16dc3883
|
||||
|
@ -2,9 +2,9 @@
|
||||
|
||||
Language model inference code by Kenneth Heafield (kenlm at kheafield.com)
|
||||
|
||||
I do development in master on https://github.com/kpu/kenlm/. Normally, it works, but I do not guarantee it will compile, give correct answers, or generate non-broken binary files. For a more stable release, get http://kheafield.com/code/kenlm.tar.gz .
|
||||
I do development in master on https://github.com/kpu/kenlm/. Normally, it works, but I do not guarantee it will compile, give correct answers, or generate non-broken binary files. For a more stable release, get https://kheafield.com/code/kenlm.tar.gz .
|
||||
|
||||
The website http://kheafield.com/code/kenlm/ has more documentation. If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
|
||||
The website https://kheafield.com/code/kenlm/ has more documentation. If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
|
||||
|
||||
## Compiling
|
||||
Use cmake, see [BUILDING](BUILDING) for more detail.
|
||||
@ -33,7 +33,7 @@ lmplz estimates unpruned language models with modified Kneser-Ney smoothing. Af
|
||||
```bash
|
||||
bin/lmplz -o 5 <text >text.arpa
|
||||
```
|
||||
The algorithm is on-disk, using an amount of memory that you specify. See http://kheafield.com/code/kenlm/estimation/ for more.
|
||||
The algorithm is on-disk, using an amount of memory that you specify. See https://kheafield.com/code/kenlm/estimation/ for more.
|
||||
|
||||
MT Marathon 2012 team members Ivan Pouzyrevsky and Mohammed Mediani contributed to the computation design and early implementation. Jon Clark contributed to the design, clarified points about smoothing, and added logging.
|
||||
|
||||
@ -43,15 +43,15 @@ filter takes an ARPA or count file and removes entries that will never be querie
|
||||
```bash
|
||||
bin/filter
|
||||
```
|
||||
and see http://kheafield.com/code/kenlm/filter/ for more documentation.
|
||||
and see https://kheafield.com/code/kenlm/filter/ for more documentation.
|
||||
|
||||
## Querying
|
||||
|
||||
Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and a bit slower.
|
||||
Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and is a bit slower.
|
||||
|
||||
As is the custom in language modeling, all probabilities are log base 10.
|
||||
|
||||
With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See http://kheafield.com/code/kenlm/benchmark/.
|
||||
With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See https://kheafield.com/code/kenlm/benchmark/.
|
||||
|
||||
Binary format via mmap is supported. Run `./build_binary` to make one then pass the binary file name to the appropriate Model constructor.
|
||||
|
||||
@ -71,7 +71,7 @@ Hideo Okuma and Tomoyuki Yoshimura from NICT contributed ports to ARM and MinGW.
|
||||
|
||||
- Select the macros you want, listed in the previous section.
|
||||
|
||||
- There are two build systems: compile.sh and Jamroot+Jamfile. They're pretty simple and are intended to be reimplemented in your build system.
|
||||
- There are two build systems: compile.sh and cmake. They're pretty simple and are intended to be reimplemented in your build system.
|
||||
|
||||
- Use either the interface in `lm/model.hh` or `lm/virtual_interface.hh`. Interface documentation is in comments of `lm/virtual_interface.hh` and `lm/model.hh`.
|
||||
|
||||
@ -101,4 +101,4 @@ See [python/example.py](python/example.py) and [python/kenlm.pyx](python/kenlm.p
|
||||
|
||||
---
|
||||
|
||||
The name was Hieu Hoang's idea, not mine.
|
||||
The name was Hieu Hoang's idea, not mine.
|
||||
|
@ -1,7 +1,7 @@
|
||||
KenLM source downloaded from http://kheafield.com/code/kenlm.tar.gz on 2017/08/05
|
||||
sha256 c4c9f587048470c9a6a592914f0609a71fbb959f0a4cad371e8c355ce81f7c6b
|
||||
KenLM source downloaded from https://github.com/kpu/kenlm on 2020/01/15
|
||||
commit b9f35777d112ce2fc10bd3986302517a16dc3883
|
||||
|
||||
This corresponds to https://github.com/kpu/kenlm/commit/cdd794598ea15dc23a7daaf7a8cf89423c97f7e6
|
||||
This corresponds to https://github.com/kpu/kenlm/commit/b9f35777d112ce2fc10bd3986302517a16dc3883
|
||||
|
||||
The following procedure was run to remove unneeded files:
|
||||
|
||||
@ -10,19 +10,3 @@ rm -rf windows include lm/filter lm/builder util/stream util/getopt.* python
|
||||
|
||||
This was done in order to ensure uniqueness of double_conversion:
|
||||
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
|
||||
|
||||
Please apply this patch to be able to build on Android:
|
||||
diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc
|
||||
index d53dc0a..b5e36b2 100644
|
||||
--- a/native_client/kenlm/util/file.cc
|
||||
+++ b/native_client/kenlm/util/file.cc
|
||||
@@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
|
||||
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
||||
for (int i=0; vars[i]; ++i) {
|
||||
char *val =
|
||||
-#if defined(_GNU_SOURCE)
|
||||
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
|
||||
#if __GLIBC_PREREQ(2,17)
|
||||
secure_getenv
|
||||
#else // __GLIBC_PREREQ
|
||||
|
||||
|
@ -10,7 +10,6 @@
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
|
||||
#ifdef WIN32
|
||||
#include "util/getopt.hh"
|
||||
@ -23,11 +22,12 @@ namespace ngram {
|
||||
namespace {
|
||||
|
||||
void Usage(const char *name, const char *default_mem) {
|
||||
std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
|
||||
std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-v] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
|
||||
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
|
||||
" Default is -100. The ARPA file will always take precedence.\n"
|
||||
"-s allows models to be built even if they do not have <s> and </s>.\n"
|
||||
"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
|
||||
"-v disables inclusion of the vocabulary in the binary file.\n"
|
||||
"-w mmap|after determines how writing is done.\n"
|
||||
" mmap maps the binary file and writes to it. Default for trie.\n"
|
||||
" after allocates anonymous memory, builds, and writes. Default for probing.\n"
|
||||
@ -112,7 +112,7 @@ int main(int argc, char *argv[]) {
|
||||
lm::ngram::Config config;
|
||||
config.building_memory = util::ParseSize(default_mem);
|
||||
int opt;
|
||||
while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) {
|
||||
while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:vh")) != -1) {
|
||||
switch(opt) {
|
||||
case 'q':
|
||||
config.prob_bits = ParseBitCount(optarg);
|
||||
@ -165,6 +165,9 @@ int main(int argc, char *argv[]) {
|
||||
ParseFileList(optarg, config.rest_lower_files);
|
||||
config.rest_function = Config::REST_LOWER;
|
||||
break;
|
||||
case 'v':
|
||||
config.include_vocab = false;
|
||||
break;
|
||||
case 'h': // help
|
||||
default:
|
||||
Usage(argv[0], default_mem);
|
||||
|
@ -7,7 +7,7 @@
|
||||
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
|
||||
*/
|
||||
#ifndef KENLM_ORDER_MESSAGE
|
||||
#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh."
|
||||
#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. With cmake:\n cmake -DKENLM_MAX_ORDER=10 ..\nWith Moses:\n bjam --max-kenlm-order=10 -a\nOtherwise, edit lm/max_order.hh."
|
||||
#endif
|
||||
|
||||
#endif // LM_MAX_ORDER_H
|
||||
|
@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
|
||||
return backing_.VocabStringReadingOffset();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Do a paraonoid copy of history, assuming new_word has already been copied
|
||||
// (hence the -1). out_state.length could be zero so I avoided using
|
||||
|
@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
|
||||
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
|
||||
}
|
||||
|
||||
uint64_t GetEndOfSearchOffset() const;
|
||||
|
||||
private:
|
||||
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;
|
||||
|
||||
|
@ -19,8 +19,8 @@ void Usage(const char *name) {
|
||||
"Each word in the output is formatted as:\n"
|
||||
" word=vocab_id ngram_length log10(p(word|context))\n"
|
||||
"where ngram_length is the length of n-gram matched. A vocab_id of 0 indicates\n"
|
||||
"indicates the unknown word. Sentence-level output includes log10 probability of\n"
|
||||
"the sentence and OOV count.\n";
|
||||
"the unknown word. Sentence-level output includes log10 probability of the\n"
|
||||
"sentence and OOV count.\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
@ -19,8 +19,8 @@
|
||||
|
||||
namespace lm {
|
||||
|
||||
// 1 for '\t', '\n', and ' '. This is stricter than isspace.
|
||||
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
// 1 for '\t', '\n', '\r', and ' '. This is stricter than isspace. Apparently ARPA allows vertical tab inside a word.
|
||||
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
|
||||
|
||||
namespace {
|
||||
|
||||
@ -85,6 +85,11 @@ void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
|
||||
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
|
||||
}
|
||||
|
||||
void ConsumeNewline(util::FilePiece &in) {
|
||||
char follow = in.get();
|
||||
UTIL_THROW_IF('\n' != follow, FormatLoadException, "Expected newline got '" << follow << "'");
|
||||
}
|
||||
|
||||
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
|
||||
switch (in.get()) {
|
||||
case '\t':
|
||||
@ -94,6 +99,9 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
|
||||
UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");
|
||||
}
|
||||
break;
|
||||
case '\r':
|
||||
ConsumeNewline(in);
|
||||
// Intentionally no break.
|
||||
case '\n':
|
||||
break;
|
||||
default:
|
||||
@ -120,8 +128,18 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {
|
||||
UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
|
||||
#endif
|
||||
}
|
||||
UTIL_THROW_IF(in.get() != '\n', FormatLoadException, "Expected newline after backoff");
|
||||
switch (char got = in.get()) {
|
||||
case '\r':
|
||||
ConsumeNewline(in);
|
||||
case '\n':
|
||||
break;
|
||||
default:
|
||||
UTIL_THROW(FormatLoadException, "Expected newline after backoffs, got " << got);
|
||||
}
|
||||
break;
|
||||
case '\r':
|
||||
ConsumeNewline(in);
|
||||
// Intentionally no break.
|
||||
case '\n':
|
||||
backoff = ngram::kNoExtensionBackoff;
|
||||
break;
|
||||
|
@ -137,6 +137,8 @@ class Model {
|
||||
|
||||
const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
|
||||
|
||||
virtual uint64_t GetEndOfSearchOffset() const = 0;
|
||||
|
||||
private:
|
||||
template <class T, class U, class V> friend class ModelFacade;
|
||||
explicit Model(size_t state_size) : state_size_(state_size) {}
|
||||
|
@ -282,7 +282,7 @@ void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to
|
||||
if (have_words) ReadWords(fd, to, bound_, offset);
|
||||
}
|
||||
|
||||
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
|
||||
void MissingUnknown(const Config &config) {
|
||||
switch(config.unknown_missing) {
|
||||
case SILENT:
|
||||
return;
|
||||
@ -294,7 +294,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
|
||||
}
|
||||
}
|
||||
|
||||
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) {
|
||||
void MissingSentenceMarker(const Config &config, const char *str) {
|
||||
switch (config.sentence_marker_missing) {
|
||||
case SILENT:
|
||||
return;
|
||||
|
@ -207,10 +207,10 @@ class ProbingVocabulary : public base::Vocabulary {
|
||||
detail::ProbingVocabularyHeader *header_;
|
||||
};
|
||||
|
||||
void MissingUnknown(const Config &config) throw(SpecialWordMissingException);
|
||||
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException);
|
||||
void MissingUnknown(const Config &config);
|
||||
void MissingSentenceMarker(const Config &config, const char *str);
|
||||
|
||||
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) {
|
||||
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) {
|
||||
if (!vocab.SawUnk()) MissingUnknown(config);
|
||||
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
|
||||
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");
|
||||
|
@ -2,6 +2,8 @@ from setuptools import setup, Extension
|
||||
import glob
|
||||
import platform
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
|
||||
#Does gcc compile with this header and library?
|
||||
def compile_test(header, library):
|
||||
@ -9,16 +11,28 @@ def compile_test(header, library):
|
||||
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
|
||||
return os.system(command) == 0
|
||||
|
||||
max_order = "6"
|
||||
is_max_order = [s for s in sys.argv if "--max_order" in s]
|
||||
for element in is_max_order:
|
||||
max_order = re.split('[= ]',element)[1]
|
||||
sys.argv.remove(element)
|
||||
|
||||
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob('util/double-conversion/*.cc')
|
||||
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob('util/double-conversion/*.cc') + glob.glob('python/*.cc')
|
||||
FILES = [fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))]
|
||||
|
||||
LIBS = ['stdc++']
|
||||
if platform.system() != 'Darwin':
|
||||
LIBS.append('rt')
|
||||
if platform.system() == 'Linux':
|
||||
LIBS = ['stdc++', 'rt']
|
||||
elif platform.system() == 'Darwin':
|
||||
LIBS = ['c++']
|
||||
else:
|
||||
LIBS = []
|
||||
|
||||
#We don't need -std=c++11 but python seems to be compiled with it now. https://github.com/kpu/kenlm/issues/86
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER='+max_order, '-std=c++11']
|
||||
|
||||
#Attempted fix to https://github.com/kpu/kenlm/issues/186 and https://github.com/kpu/kenlm/issues/197
|
||||
if platform.system() == 'Darwin':
|
||||
ARGS += ["-stdlib=libc++", "-mmacosx-version-min=10.7"]
|
||||
|
||||
if compile_test('zlib.h', 'z'):
|
||||
ARGS.append('-DHAVE_ZLIB')
|
||||
|
@ -108,7 +108,7 @@ typedef union { float f; uint32_t i; } FloatEnc;
|
||||
|
||||
inline float ReadFloat32(const void *base, uint64_t bit_off) {
|
||||
FloatEnc encoded;
|
||||
encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32);
|
||||
encoded.i = static_cast<uint32_t>(ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32));
|
||||
return encoded.f;
|
||||
}
|
||||
inline void WriteFloat32(void *base, uint64_t bit_off, float value) {
|
||||
@ -135,7 +135,7 @@ inline void UnsetSign(float &to) {
|
||||
|
||||
inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) {
|
||||
FloatEnc encoded;
|
||||
encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31);
|
||||
encoded.i = static_cast<uint32_t>(ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31));
|
||||
// Sign bit set means negative.
|
||||
encoded.i |= kSignBit;
|
||||
return encoded.f;
|
||||
|
@ -25,7 +25,7 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cmath>
|
||||
#include <math.h>
|
||||
|
||||
#include "bignum-dtoa.h"
|
||||
|
||||
@ -192,13 +192,13 @@ static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
|
||||
delta_plus = delta_minus;
|
||||
}
|
||||
*length = 0;
|
||||
while (true) {
|
||||
for (;;) {
|
||||
uint16_t digit;
|
||||
digit = numerator->DivideModuloIntBignum(*denominator);
|
||||
ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
|
||||
// digit = numerator / denominator (integer division).
|
||||
// numerator = numerator % denominator.
|
||||
buffer[(*length)++] = digit + '0';
|
||||
buffer[(*length)++] = static_cast<char>(digit + '0');
|
||||
|
||||
// Can we stop already?
|
||||
// If the remainder of the division is less than the distance to the lower
|
||||
@ -282,7 +282,7 @@ static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
|
||||
// exponent (decimal_point), when rounding upwards.
|
||||
static void GenerateCountedDigits(int count, int* decimal_point,
|
||||
Bignum* numerator, Bignum* denominator,
|
||||
Vector<char>(buffer), int* length) {
|
||||
Vector<char> buffer, int* length) {
|
||||
ASSERT(count >= 0);
|
||||
for (int i = 0; i < count - 1; ++i) {
|
||||
uint16_t digit;
|
||||
@ -290,7 +290,7 @@ static void GenerateCountedDigits(int count, int* decimal_point,
|
||||
ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
|
||||
// digit = numerator / denominator (integer division).
|
||||
// numerator = numerator % denominator.
|
||||
buffer[i] = digit + '0';
|
||||
buffer[i] = static_cast<char>(digit + '0');
|
||||
// Prepare for next iteration.
|
||||
numerator->Times10();
|
||||
}
|
||||
@ -300,7 +300,8 @@ static void GenerateCountedDigits(int count, int* decimal_point,
|
||||
if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) {
|
||||
digit++;
|
||||
}
|
||||
buffer[count - 1] = digit + '0';
|
||||
ASSERT(digit <= 10);
|
||||
buffer[count - 1] = static_cast<char>(digit + '0');
|
||||
// Correct bad digits (in case we had a sequence of '9's). Propagate the
|
||||
// carry until we hat a non-'9' or til we reach the first digit.
|
||||
for (int i = count - 1; i > 0; --i) {
|
||||
|
@ -40,6 +40,7 @@ Bignum::Bignum()
|
||||
|
||||
template<typename S>
|
||||
static int BitSize(S value) {
|
||||
(void) value; // Mark variable as used.
|
||||
return 8 * sizeof(value);
|
||||
}
|
||||
|
||||
@ -103,7 +104,7 @@ void Bignum::AssignDecimalString(Vector<const char> value) {
|
||||
const int kMaxUint64DecimalDigits = 19;
|
||||
Zero();
|
||||
int length = value.length();
|
||||
int pos = 0;
|
||||
unsigned int pos = 0;
|
||||
// Let's just say that each digit needs 4 bits.
|
||||
while (length >= kMaxUint64DecimalDigits) {
|
||||
uint64_t digits = ReadUInt64(value, pos, kMaxUint64DecimalDigits);
|
||||
@ -122,9 +123,8 @@ void Bignum::AssignDecimalString(Vector<const char> value) {
|
||||
static int HexCharValue(char c) {
|
||||
if ('0' <= c && c <= '9') return c - '0';
|
||||
if ('a' <= c && c <= 'f') return 10 + c - 'a';
|
||||
if ('A' <= c && c <= 'F') return 10 + c - 'A';
|
||||
UNREACHABLE();
|
||||
return 0; // To make compiler happy.
|
||||
ASSERT('A' <= c && c <= 'F');
|
||||
return 10 + c - 'A';
|
||||
}
|
||||
|
||||
|
||||
@ -501,13 +501,14 @@ uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) {
|
||||
// Start by removing multiples of 'other' until both numbers have the same
|
||||
// number of digits.
|
||||
while (BigitLength() > other.BigitLength()) {
|
||||
// This naive approach is extremely inefficient if the this divided other
|
||||
// might be big. This function is implemented for doubleToString where
|
||||
// This naive approach is extremely inefficient if `this` divided by other
|
||||
// is big. This function is implemented for doubleToString where
|
||||
// the result should be small (less than 10).
|
||||
ASSERT(other.bigits_[other.used_digits_ - 1] >= ((1 << kBigitSize) / 16));
|
||||
ASSERT(bigits_[used_digits_ - 1] < 0x10000);
|
||||
// Remove the multiples of the first digit.
|
||||
// Example this = 23 and other equals 9. -> Remove 2 multiples.
|
||||
result += bigits_[used_digits_ - 1];
|
||||
result += static_cast<uint16_t>(bigits_[used_digits_ - 1]);
|
||||
SubtractTimes(other, bigits_[used_digits_ - 1]);
|
||||
}
|
||||
|
||||
@ -523,13 +524,15 @@ uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) {
|
||||
// Shortcut for easy (and common) case.
|
||||
int quotient = this_bigit / other_bigit;
|
||||
bigits_[used_digits_ - 1] = this_bigit - other_bigit * quotient;
|
||||
result += quotient;
|
||||
ASSERT(quotient < 0x10000);
|
||||
result += static_cast<uint16_t>(quotient);
|
||||
Clamp();
|
||||
return result;
|
||||
}
|
||||
|
||||
int division_estimate = this_bigit / (other_bigit + 1);
|
||||
result += division_estimate;
|
||||
ASSERT(division_estimate < 0x10000);
|
||||
result += static_cast<uint16_t>(division_estimate);
|
||||
SubtractTimes(other, division_estimate);
|
||||
|
||||
if (other_bigit * (division_estimate + 1) > this_bigit) {
|
||||
@ -560,8 +563,8 @@ static int SizeInHexChars(S number) {
|
||||
|
||||
static char HexCharOfValue(int value) {
|
||||
ASSERT(0 <= value && value <= 16);
|
||||
if (value < 10) return value + '0';
|
||||
return value - 10 + 'A';
|
||||
if (value < 10) return static_cast<char>(value + '0');
|
||||
return static_cast<char>(value - 10 + 'A');
|
||||
}
|
||||
|
||||
|
||||
@ -755,7 +758,6 @@ void Bignum::SubtractTimes(const Bignum& other, int factor) {
|
||||
Chunk difference = bigits_[i] - borrow;
|
||||
bigits_[i] = difference & kBigitMask;
|
||||
borrow = difference >> (kChunkSize - 1);
|
||||
++i;
|
||||
}
|
||||
Clamp();
|
||||
}
|
||||
|
@ -49,7 +49,6 @@ class Bignum {
|
||||
|
||||
void AssignPowerUInt16(uint16_t base, int exponent);
|
||||
|
||||
void AddUInt16(uint16_t operand);
|
||||
void AddUInt64(uint64_t operand);
|
||||
void AddBignum(const Bignum& other);
|
||||
// Precondition: this >= other.
|
||||
|
@ -25,9 +25,9 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cstdarg>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <stdarg.h>
|
||||
#include <limits.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
@ -131,7 +131,6 @@ static const CachedPower kCachedPowers[] = {
|
||||
{UINT64_2PART_C(0xaf87023b, 9bf0ee6b), 1066, 340},
|
||||
};
|
||||
|
||||
static const int kCachedPowersLength = ARRAY_SIZE(kCachedPowers);
|
||||
static const int kCachedPowersOffset = 348; // -1 * the first decimal_exponent.
|
||||
static const double kD_1_LOG2_10 = 0.30102999566398114; // 1 / lg(10)
|
||||
// Difference between the decimal exponents in the table above.
|
||||
@ -149,9 +148,10 @@ void PowersOfTenCache::GetCachedPowerForBinaryExponentRange(
|
||||
int foo = kCachedPowersOffset;
|
||||
int index =
|
||||
(foo + static_cast<int>(k) - 1) / kDecimalExponentDistance + 1;
|
||||
ASSERT(0 <= index && index < kCachedPowersLength);
|
||||
ASSERT(0 <= index && index < static_cast<int>(ARRAY_SIZE(kCachedPowers)));
|
||||
CachedPower cached_power = kCachedPowers[index];
|
||||
ASSERT(min_exponent <= cached_power.binary_exponent);
|
||||
(void) max_exponent; // Mark variable as used.
|
||||
ASSERT(cached_power.binary_exponent <= max_exponent);
|
||||
*decimal_exponent = cached_power.decimal_exponent;
|
||||
*power = DiyFp(cached_power.significand, cached_power.binary_exponent);
|
||||
|
@ -42,7 +42,7 @@ class DiyFp {
|
||||
static const int kSignificandSize = 64;
|
||||
|
||||
DiyFp() : f_(0), e_(0) {}
|
||||
DiyFp(uint64_t f, int e) : f_(f), e_(e) {}
|
||||
DiyFp(uint64_t significand, int exponent) : f_(significand), e_(exponent) {}
|
||||
|
||||
// this = this - other.
|
||||
// The exponents of both numbers must be the same and the significand of this
|
||||
@ -76,22 +76,22 @@ class DiyFp {
|
||||
|
||||
void Normalize() {
|
||||
ASSERT(f_ != 0);
|
||||
uint64_t f = f_;
|
||||
int e = e_;
|
||||
uint64_t significand = f_;
|
||||
int exponent = e_;
|
||||
|
||||
// This method is mainly called for normalizing boundaries. In general
|
||||
// boundaries need to be shifted by 10 bits. We thus optimize for this case.
|
||||
const uint64_t k10MSBits = UINT64_2PART_C(0xFFC00000, 00000000);
|
||||
while ((f & k10MSBits) == 0) {
|
||||
f <<= 10;
|
||||
e -= 10;
|
||||
while ((significand & k10MSBits) == 0) {
|
||||
significand <<= 10;
|
||||
exponent -= 10;
|
||||
}
|
||||
while ((f & kUint64MSB) == 0) {
|
||||
f <<= 1;
|
||||
e--;
|
||||
while ((significand & kUint64MSB) == 0) {
|
||||
significand <<= 1;
|
||||
exponent--;
|
||||
}
|
||||
f_ = f;
|
||||
e_ = e;
|
||||
f_ = significand;
|
||||
e_ = exponent;
|
||||
}
|
||||
|
||||
static DiyFp Normalize(const DiyFp& a) {
|
||||
|
@ -25,8 +25,8 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
#include <limits.h>
|
||||
#include <math.h>
|
||||
|
||||
#include "double-conversion.h"
|
||||
|
||||
@ -118,7 +118,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
|
||||
StringBuilder* result_builder) const {
|
||||
// Create a representation that is padded with zeros if needed.
|
||||
if (decimal_point <= 0) {
|
||||
// "0.00000decimal_rep".
|
||||
// "0.00000decimal_rep" or "0.000decimal_rep00".
|
||||
result_builder->AddCharacter('0');
|
||||
if (digits_after_point > 0) {
|
||||
result_builder->AddCharacter('.');
|
||||
@ -129,7 +129,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
|
||||
result_builder->AddPadding('0', remaining_digits);
|
||||
}
|
||||
} else if (decimal_point >= length) {
|
||||
// "decimal_rep0000.00000" or "decimal_rep.0000"
|
||||
// "decimal_rep0000.00000" or "decimal_rep.0000".
|
||||
result_builder->AddSubstring(decimal_digits, length);
|
||||
result_builder->AddPadding('0', decimal_point - length);
|
||||
if (digits_after_point > 0) {
|
||||
@ -137,7 +137,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
|
||||
result_builder->AddPadding('0', digits_after_point);
|
||||
}
|
||||
} else {
|
||||
// "decima.l_rep000"
|
||||
// "decima.l_rep000".
|
||||
ASSERT(digits_after_point > 0);
|
||||
result_builder->AddSubstring(decimal_digits, decimal_point);
|
||||
result_builder->AddCharacter('.');
|
||||
@ -348,7 +348,6 @@ static BignumDtoaMode DtoaToBignumDtoaMode(
|
||||
case DoubleToStringConverter::PRECISION: return BIGNUM_DTOA_PRECISION;
|
||||
default:
|
||||
UNREACHABLE();
|
||||
return BIGNUM_DTOA_SHORTEST; // To silence compiler.
|
||||
}
|
||||
}
|
||||
|
||||
@ -403,8 +402,8 @@ void DoubleToStringConverter::DoubleToAscii(double v,
|
||||
vector, length, point);
|
||||
break;
|
||||
default:
|
||||
UNREACHABLE();
|
||||
fast_worked = false;
|
||||
UNREACHABLE();
|
||||
}
|
||||
if (fast_worked) return;
|
||||
|
||||
@ -417,8 +416,9 @@ void DoubleToStringConverter::DoubleToAscii(double v,
|
||||
|
||||
// Consumes the given substring from the iterator.
|
||||
// Returns false, if the substring does not match.
|
||||
static bool ConsumeSubString(const char** current,
|
||||
const char* end,
|
||||
template <class Iterator>
|
||||
static bool ConsumeSubString(Iterator* current,
|
||||
Iterator end,
|
||||
const char* substring) {
|
||||
ASSERT(**current == *substring);
|
||||
for (substring++; *substring != '\0'; substring++) {
|
||||
@ -440,10 +440,36 @@ static bool ConsumeSubString(const char** current,
|
||||
const int kMaxSignificantDigits = 772;
|
||||
|
||||
|
||||
static const char kWhitespaceTable7[] = { 32, 13, 10, 9, 11, 12 };
|
||||
static const int kWhitespaceTable7Length = ARRAY_SIZE(kWhitespaceTable7);
|
||||
|
||||
|
||||
static const uc16 kWhitespaceTable16[] = {
|
||||
160, 8232, 8233, 5760, 6158, 8192, 8193, 8194, 8195,
|
||||
8196, 8197, 8198, 8199, 8200, 8201, 8202, 8239, 8287, 12288, 65279
|
||||
};
|
||||
static const int kWhitespaceTable16Length = ARRAY_SIZE(kWhitespaceTable16);
|
||||
|
||||
|
||||
static bool isWhitespace(int x) {
|
||||
if (x < 128) {
|
||||
for (int i = 0; i < kWhitespaceTable7Length; i++) {
|
||||
if (kWhitespaceTable7[i] == x) return true;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < kWhitespaceTable16Length; i++) {
|
||||
if (kWhitespaceTable16[i] == x) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Returns true if a nonspace found and false if the end has reached.
|
||||
static inline bool AdvanceToNonspace(const char** current, const char* end) {
|
||||
template <class Iterator>
|
||||
static inline bool AdvanceToNonspace(Iterator* current, Iterator end) {
|
||||
while (*current != end) {
|
||||
if (**current != ' ') return true;
|
||||
if (!isWhitespace(**current)) return true;
|
||||
++*current;
|
||||
}
|
||||
return false;
|
||||
@ -462,26 +488,57 @@ static double SignedZero(bool sign) {
|
||||
}
|
||||
|
||||
|
||||
// Returns true if 'c' is a decimal digit that is valid for the given radix.
|
||||
//
|
||||
// The function is small and could be inlined, but VS2012 emitted a warning
|
||||
// because it constant-propagated the radix and concluded that the last
|
||||
// condition was always true. By moving it into a separate function the
|
||||
// compiler wouldn't warn anymore.
|
||||
#if _MSC_VER
|
||||
#pragma optimize("",off)
|
||||
static bool IsDecimalDigitForRadix(int c, int radix) {
|
||||
return '0' <= c && c <= '9' && (c - '0') < radix;
|
||||
}
|
||||
#pragma optimize("",on)
|
||||
#else
|
||||
static bool inline IsDecimalDigitForRadix(int c, int radix) {
|
||||
return '0' <= c && c <= '9' && (c - '0') < radix;
|
||||
}
|
||||
#endif
|
||||
// Returns true if 'c' is a character digit that is valid for the given radix.
|
||||
// The 'a_character' should be 'a' or 'A'.
|
||||
//
|
||||
// The function is small and could be inlined, but VS2012 emitted a warning
|
||||
// because it constant-propagated the radix and concluded that the first
|
||||
// condition was always false. By moving it into a separate function the
|
||||
// compiler wouldn't warn anymore.
|
||||
static bool IsCharacterDigitForRadix(int c, int radix, char a_character) {
|
||||
return radix > 10 && c >= a_character && c < a_character + radix - 10;
|
||||
}
|
||||
|
||||
|
||||
// Parsing integers with radix 2, 4, 8, 16, 32. Assumes current != end.
|
||||
template <int radix_log_2>
|
||||
static double RadixStringToIeee(const char* current,
|
||||
const char* end,
|
||||
template <int radix_log_2, class Iterator>
|
||||
static double RadixStringToIeee(Iterator* current,
|
||||
Iterator end,
|
||||
bool sign,
|
||||
bool allow_trailing_junk,
|
||||
double junk_string_value,
|
||||
bool read_as_double,
|
||||
const char** trailing_pointer) {
|
||||
ASSERT(current != end);
|
||||
bool* result_is_junk) {
|
||||
ASSERT(*current != end);
|
||||
|
||||
const int kDoubleSize = Double::kSignificandSize;
|
||||
const int kSingleSize = Single::kSignificandSize;
|
||||
const int kSignificandSize = read_as_double? kDoubleSize: kSingleSize;
|
||||
|
||||
*result_is_junk = true;
|
||||
|
||||
// Skip leading 0s.
|
||||
while (*current == '0') {
|
||||
++current;
|
||||
if (current == end) {
|
||||
*trailing_pointer = end;
|
||||
while (**current == '0') {
|
||||
++(*current);
|
||||
if (*current == end) {
|
||||
*result_is_junk = false;
|
||||
return SignedZero(sign);
|
||||
}
|
||||
}
|
||||
@ -492,14 +549,14 @@ static double RadixStringToIeee(const char* current,
|
||||
|
||||
do {
|
||||
int digit;
|
||||
if (*current >= '0' && *current <= '9' && *current < '0' + radix) {
|
||||
digit = static_cast<char>(*current) - '0';
|
||||
} else if (radix > 10 && *current >= 'a' && *current < 'a' + radix - 10) {
|
||||
digit = static_cast<char>(*current) - 'a' + 10;
|
||||
} else if (radix > 10 && *current >= 'A' && *current < 'A' + radix - 10) {
|
||||
digit = static_cast<char>(*current) - 'A' + 10;
|
||||
if (IsDecimalDigitForRadix(**current, radix)) {
|
||||
digit = static_cast<char>(**current) - '0';
|
||||
} else if (IsCharacterDigitForRadix(**current, radix, 'a')) {
|
||||
digit = static_cast<char>(**current) - 'a' + 10;
|
||||
} else if (IsCharacterDigitForRadix(**current, radix, 'A')) {
|
||||
digit = static_cast<char>(**current) - 'A' + 10;
|
||||
} else {
|
||||
if (allow_trailing_junk || !AdvanceToNonspace(¤t, end)) {
|
||||
if (allow_trailing_junk || !AdvanceToNonspace(current, end)) {
|
||||
break;
|
||||
} else {
|
||||
return junk_string_value;
|
||||
@ -523,14 +580,14 @@ static double RadixStringToIeee(const char* current,
|
||||
exponent = overflow_bits_count;
|
||||
|
||||
bool zero_tail = true;
|
||||
while (true) {
|
||||
++current;
|
||||
if (current == end || !isDigit(*current, radix)) break;
|
||||
zero_tail = zero_tail && *current == '0';
|
||||
for (;;) {
|
||||
++(*current);
|
||||
if (*current == end || !isDigit(**current, radix)) break;
|
||||
zero_tail = zero_tail && **current == '0';
|
||||
exponent += radix_log_2;
|
||||
}
|
||||
|
||||
if (!allow_trailing_junk && AdvanceToNonspace(¤t, end)) {
|
||||
if (!allow_trailing_junk && AdvanceToNonspace(current, end)) {
|
||||
return junk_string_value;
|
||||
}
|
||||
|
||||
@ -552,13 +609,13 @@ static double RadixStringToIeee(const char* current,
|
||||
}
|
||||
break;
|
||||
}
|
||||
++current;
|
||||
} while (current != end);
|
||||
++(*current);
|
||||
} while (*current != end);
|
||||
|
||||
ASSERT(number < ((int64_t)1 << kSignificandSize));
|
||||
ASSERT(static_cast<int64_t>(static_cast<double>(number)) == number);
|
||||
|
||||
*trailing_pointer = current;
|
||||
*result_is_junk = false;
|
||||
|
||||
if (exponent == 0) {
|
||||
if (sign) {
|
||||
@ -573,13 +630,14 @@ static double RadixStringToIeee(const char* current,
|
||||
}
|
||||
|
||||
|
||||
template <class Iterator>
|
||||
double StringToDoubleConverter::StringToIeee(
|
||||
const char* input,
|
||||
Iterator input,
|
||||
int length,
|
||||
int* processed_characters_count,
|
||||
bool read_as_double) const {
|
||||
const char* current = input;
|
||||
const char* end = input + length;
|
||||
bool read_as_double,
|
||||
int* processed_characters_count) const {
|
||||
Iterator current = input;
|
||||
Iterator end = input + length;
|
||||
|
||||
*processed_characters_count = 0;
|
||||
|
||||
@ -600,7 +658,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
|
||||
if (allow_leading_spaces || allow_trailing_spaces) {
|
||||
if (!AdvanceToNonspace(¤t, end)) {
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return empty_string_value_;
|
||||
}
|
||||
if (!allow_leading_spaces && (input != current)) {
|
||||
@ -626,7 +684,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
if (*current == '+' || *current == '-') {
|
||||
sign = (*current == '-');
|
||||
++current;
|
||||
const char* next_non_space = current;
|
||||
Iterator next_non_space = current;
|
||||
// Skip following spaces (if allowed).
|
||||
if (!AdvanceToNonspace(&next_non_space, end)) return junk_string_value_;
|
||||
if (!allow_spaces_after_sign && (current != next_non_space)) {
|
||||
@ -649,7 +707,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
}
|
||||
|
||||
ASSERT(buffer_pos == 0);
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return sign ? -Double::Infinity() : Double::Infinity();
|
||||
}
|
||||
}
|
||||
@ -668,7 +726,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
}
|
||||
|
||||
ASSERT(buffer_pos == 0);
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return sign ? -Double::NaN() : Double::NaN();
|
||||
}
|
||||
}
|
||||
@ -677,7 +735,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
if (*current == '0') {
|
||||
++current;
|
||||
if (current == end) {
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return SignedZero(sign);
|
||||
}
|
||||
|
||||
@ -690,17 +748,17 @@ double StringToDoubleConverter::StringToIeee(
|
||||
return junk_string_value_; // "0x".
|
||||
}
|
||||
|
||||
const char* tail_pointer = NULL;
|
||||
double result = RadixStringToIeee<4>(current,
|
||||
bool result_is_junk;
|
||||
double result = RadixStringToIeee<4>(¤t,
|
||||
end,
|
||||
sign,
|
||||
allow_trailing_junk,
|
||||
junk_string_value_,
|
||||
read_as_double,
|
||||
&tail_pointer);
|
||||
if (tail_pointer != NULL) {
|
||||
if (allow_trailing_spaces) AdvanceToNonspace(&tail_pointer, end);
|
||||
*processed_characters_count = tail_pointer - input;
|
||||
&result_is_junk);
|
||||
if (!result_is_junk) {
|
||||
if (allow_trailing_spaces) AdvanceToNonspace(¤t, end);
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -709,7 +767,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
while (*current == '0') {
|
||||
++current;
|
||||
if (current == end) {
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return SignedZero(sign);
|
||||
}
|
||||
}
|
||||
@ -757,7 +815,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
while (*current == '0') {
|
||||
++current;
|
||||
if (current == end) {
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return SignedZero(sign);
|
||||
}
|
||||
exponent--; // Move this 0 into the exponent.
|
||||
@ -801,9 +859,9 @@ double StringToDoubleConverter::StringToIeee(
|
||||
return junk_string_value_;
|
||||
}
|
||||
}
|
||||
char sign = '+';
|
||||
char exponen_sign = '+';
|
||||
if (*current == '+' || *current == '-') {
|
||||
sign = static_cast<char>(*current);
|
||||
exponen_sign = static_cast<char>(*current);
|
||||
++current;
|
||||
if (current == end) {
|
||||
if (allow_trailing_junk) {
|
||||
@ -837,7 +895,7 @@ double StringToDoubleConverter::StringToIeee(
|
||||
++current;
|
||||
} while (current != end && *current >= '0' && *current <= '9');
|
||||
|
||||
exponent += (sign == '-' ? -num : num);
|
||||
exponent += (exponen_sign == '-' ? -num : num);
|
||||
}
|
||||
|
||||
if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) {
|
||||
@ -855,16 +913,17 @@ double StringToDoubleConverter::StringToIeee(
|
||||
|
||||
if (octal) {
|
||||
double result;
|
||||
const char* tail_pointer = NULL;
|
||||
result = RadixStringToIeee<3>(buffer,
|
||||
bool result_is_junk;
|
||||
char* start = buffer;
|
||||
result = RadixStringToIeee<3>(&start,
|
||||
buffer + buffer_pos,
|
||||
sign,
|
||||
allow_trailing_junk,
|
||||
junk_string_value_,
|
||||
read_as_double,
|
||||
&tail_pointer);
|
||||
ASSERT(tail_pointer != NULL);
|
||||
*processed_characters_count = current - input;
|
||||
&result_is_junk);
|
||||
ASSERT(!result_is_junk);
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -882,8 +941,42 @@ double StringToDoubleConverter::StringToIeee(
|
||||
} else {
|
||||
converted = Strtof(Vector<const char>(buffer, buffer_pos), exponent);
|
||||
}
|
||||
*processed_characters_count = current - input;
|
||||
*processed_characters_count = static_cast<int>(current - input);
|
||||
return sign? -converted: converted;
|
||||
}
|
||||
|
||||
|
||||
double StringToDoubleConverter::StringToDouble(
|
||||
const char* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return StringToIeee(buffer, length, true, processed_characters_count);
|
||||
}
|
||||
|
||||
|
||||
double StringToDoubleConverter::StringToDouble(
|
||||
const uc16* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return StringToIeee(buffer, length, true, processed_characters_count);
|
||||
}
|
||||
|
||||
|
||||
float StringToDoubleConverter::StringToFloat(
|
||||
const char* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return static_cast<float>(StringToIeee(buffer, length, false,
|
||||
processed_characters_count));
|
||||
}
|
||||
|
||||
|
||||
float StringToDoubleConverter::StringToFloat(
|
||||
const uc16* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return static_cast<float>(StringToIeee(buffer, length, false,
|
||||
processed_characters_count));
|
||||
}
|
||||
|
||||
} // namespace kenlm_double_conversion
|
||||
|
@ -415,9 +415,10 @@ class StringToDoubleConverter {
|
||||
// junk, too.
|
||||
// - ALLOW_TRAILING_JUNK: ignore trailing characters that are not part of
|
||||
// a double literal.
|
||||
// - ALLOW_LEADING_SPACES: skip over leading spaces.
|
||||
// - ALLOW_TRAILING_SPACES: ignore trailing spaces.
|
||||
// - ALLOW_SPACES_AFTER_SIGN: ignore spaces after the sign.
|
||||
// - ALLOW_LEADING_SPACES: skip over leading whitespace, including spaces,
|
||||
// new-lines, and tabs.
|
||||
// - ALLOW_TRAILING_SPACES: ignore trailing whitespace.
|
||||
// - ALLOW_SPACES_AFTER_SIGN: ignore whitespace after the sign.
|
||||
// Ex: StringToDouble("- 123.2") -> -123.2.
|
||||
// StringToDouble("+ 123.2") -> 123.2
|
||||
//
|
||||
@ -502,19 +503,24 @@ class StringToDoubleConverter {
|
||||
// in the 'processed_characters_count'. Trailing junk is never included.
|
||||
double StringToDouble(const char* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return StringToIeee(buffer, length, processed_characters_count, true);
|
||||
}
|
||||
int* processed_characters_count) const;
|
||||
|
||||
// Same as StringToDouble above but for 16 bit characters.
|
||||
double StringToDouble(const uc16* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const;
|
||||
|
||||
// Same as StringToDouble but reads a float.
|
||||
// Note that this is not equivalent to static_cast<float>(StringToDouble(...))
|
||||
// due to potential double-rounding.
|
||||
float StringToFloat(const char* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const {
|
||||
return static_cast<float>(StringToIeee(buffer, length,
|
||||
processed_characters_count, false));
|
||||
}
|
||||
int* processed_characters_count) const;
|
||||
|
||||
// Same as StringToFloat above but for 16 bit characters.
|
||||
float StringToFloat(const uc16* buffer,
|
||||
int length,
|
||||
int* processed_characters_count) const;
|
||||
|
||||
private:
|
||||
const int flags_;
|
||||
@ -523,10 +529,11 @@ class StringToDoubleConverter {
|
||||
const char* const infinity_symbol_;
|
||||
const char* const nan_symbol_;
|
||||
|
||||
double StringToIeee(const char* buffer,
|
||||
template <class Iterator>
|
||||
double StringToIeee(Iterator start_pointer,
|
||||
int length,
|
||||
int* processed_characters_count,
|
||||
bool read_as_double) const;
|
||||
bool read_as_double,
|
||||
int* processed_characters_count) const;
|
||||
|
||||
DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter);
|
||||
};
|
||||
|
@ -248,10 +248,7 @@ static void BiggestPowerTen(uint32_t number,
|
||||
// Note: kPowersOf10[i] == 10^(i-1).
|
||||
exponent_plus_one_guess++;
|
||||
// We don't have any guarantees that 2^number_bits <= number.
|
||||
// TODO(floitsch): can we change the 'while' into an 'if'? We definitely see
|
||||
// number < (2^number_bits - 1), but I haven't encountered
|
||||
// number < (2^number_bits - 2) yet.
|
||||
while (number < kSmallPowersOfTen[exponent_plus_one_guess]) {
|
||||
if (number < kSmallPowersOfTen[exponent_plus_one_guess]) {
|
||||
exponent_plus_one_guess--;
|
||||
}
|
||||
*power = kSmallPowersOfTen[exponent_plus_one_guess];
|
||||
@ -350,7 +347,8 @@ static bool DigitGen(DiyFp low,
|
||||
// that is smaller than integrals.
|
||||
while (*kappa > 0) {
|
||||
int digit = integrals / divisor;
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
integrals %= divisor;
|
||||
(*kappa)--;
|
||||
@ -379,13 +377,14 @@ static bool DigitGen(DiyFp low,
|
||||
ASSERT(one.e() >= -60);
|
||||
ASSERT(fractionals < one.f());
|
||||
ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f());
|
||||
while (true) {
|
||||
for (;;) {
|
||||
fractionals *= 10;
|
||||
unit *= 10;
|
||||
unsafe_interval.set_f(unsafe_interval.f() * 10);
|
||||
// Integer division by one.
|
||||
int digit = static_cast<int>(fractionals >> -one.e());
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
fractionals &= one.f() - 1; // Modulo by one.
|
||||
(*kappa)--;
|
||||
@ -459,7 +458,8 @@ static bool DigitGenCounted(DiyFp w,
|
||||
// that is smaller than 'integrals'.
|
||||
while (*kappa > 0) {
|
||||
int digit = integrals / divisor;
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
requested_digits--;
|
||||
integrals %= divisor;
|
||||
@ -492,7 +492,8 @@ static bool DigitGenCounted(DiyFp w,
|
||||
w_error *= 10;
|
||||
// Integer division by one.
|
||||
int digit = static_cast<int>(fractionals >> -one.e());
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
requested_digits--;
|
||||
fractionals &= one.f() - 1; // Modulo by one.
|
||||
|
@ -25,7 +25,7 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cmath>
|
||||
#include <math.h>
|
||||
|
||||
#include "fixed-dtoa.h"
|
||||
#include "ieee.h"
|
||||
@ -98,7 +98,7 @@ class UInt128 {
|
||||
return high_bits_ == 0 && low_bits_ == 0;
|
||||
}
|
||||
|
||||
int BitAt(int position) {
|
||||
int BitAt(int position) const {
|
||||
if (position >= 64) {
|
||||
return static_cast<int>(high_bits_ >> (position - 64)) & 1;
|
||||
} else {
|
||||
@ -133,7 +133,7 @@ static void FillDigits32(uint32_t number, Vector<char> buffer, int* length) {
|
||||
while (number != 0) {
|
||||
int digit = number % 10;
|
||||
number /= 10;
|
||||
buffer[(*length) + number_length] = '0' + digit;
|
||||
buffer[(*length) + number_length] = static_cast<char>('0' + digit);
|
||||
number_length++;
|
||||
}
|
||||
// Exchange the digits.
|
||||
@ -150,7 +150,7 @@ static void FillDigits32(uint32_t number, Vector<char> buffer, int* length) {
|
||||
}
|
||||
|
||||
|
||||
static void FillDigits64FixedLength(uint64_t number, int requested_length,
|
||||
static void FillDigits64FixedLength(uint64_t number,
|
||||
Vector<char> buffer, int* length) {
|
||||
const uint32_t kTen7 = 10000000;
|
||||
// For efficiency cut the number into 3 uint32_t parts, and print those.
|
||||
@ -253,12 +253,14 @@ static void FillFractionals(uint64_t fractionals, int exponent,
|
||||
fractionals *= 5;
|
||||
point--;
|
||||
int digit = static_cast<int>(fractionals >> point);
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
fractionals -= static_cast<uint64_t>(digit) << point;
|
||||
}
|
||||
// If the first bit after the point is set we have to round up.
|
||||
if (((fractionals >> (point - 1)) & 1) == 1) {
|
||||
ASSERT(fractionals == 0 || point - 1 >= 0);
|
||||
if ((fractionals != 0) && ((fractionals >> (point - 1)) & 1) == 1) {
|
||||
RoundUp(buffer, length, decimal_point);
|
||||
}
|
||||
} else { // We need 128 bits.
|
||||
@ -274,7 +276,8 @@ static void FillFractionals(uint64_t fractionals, int exponent,
|
||||
fractionals128.Multiply(5);
|
||||
point--;
|
||||
int digit = fractionals128.DivModPowerOf2(point);
|
||||
buffer[*length] = '0' + digit;
|
||||
ASSERT(digit <= 9);
|
||||
buffer[*length] = static_cast<char>('0' + digit);
|
||||
(*length)++;
|
||||
}
|
||||
if (fractionals128.BitAt(point - 1) == 1) {
|
||||
@ -358,7 +361,7 @@ bool FastFixedDtoa(double v,
|
||||
remainder = (dividend % divisor) << exponent;
|
||||
}
|
||||
FillDigits32(quotient, buffer, length);
|
||||
FillDigits64FixedLength(remainder, divisor_power, buffer, length);
|
||||
FillDigits64FixedLength(remainder, buffer, length);
|
||||
*decimal_point = *length;
|
||||
} else if (exponent >= 0) {
|
||||
// 0 <= exponent <= 11
|
||||
|
@ -99,7 +99,7 @@ class Double {
|
||||
}
|
||||
|
||||
double PreviousDouble() const {
|
||||
if (d64_ == (kInfinity | kSignMask)) return -Double::Infinity();
|
||||
if (d64_ == (kInfinity | kSignMask)) return -Infinity();
|
||||
if (Sign() < 0) {
|
||||
return Double(d64_ + 1).value();
|
||||
} else {
|
||||
@ -256,6 +256,8 @@ class Double {
|
||||
return (significand & kSignificandMask) |
|
||||
(biased_exponent << kPhysicalSignificandSize);
|
||||
}
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(Double);
|
||||
};
|
||||
|
||||
class Single {
|
||||
@ -391,6 +393,8 @@ class Single {
|
||||
static const uint32_t kNaN = 0x7FC00000;
|
||||
|
||||
const uint32_t d32_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(Single);
|
||||
};
|
||||
|
||||
} // namespace kenlm_double_conversion
|
||||
|
@ -25,8 +25,8 @@
|
||||
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
#include <cstdarg>
|
||||
#include <climits>
|
||||
#include <stdarg.h>
|
||||
#include <limits.h>
|
||||
|
||||
#include "strtod.h"
|
||||
#include "bignum.h"
|
||||
@ -137,6 +137,7 @@ static void TrimAndCut(Vector<const char> buffer, int exponent,
|
||||
Vector<const char> right_trimmed = TrimTrailingZeros(left_trimmed);
|
||||
exponent += left_trimmed.length() - right_trimmed.length();
|
||||
if (right_trimmed.length() > kMaxSignificantDecimalDigits) {
|
||||
(void) space_size; // Mark variable as used.
|
||||
ASSERT(space_size >= kMaxSignificantDecimalDigits);
|
||||
CutToMaxSignificantDigits(right_trimmed, exponent,
|
||||
buffer_copy_space, updated_exponent);
|
||||
@ -263,7 +264,6 @@ static DiyFp AdjustmentPowerOfTen(int exponent) {
|
||||
case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40);
|
||||
default:
|
||||
UNREACHABLE();
|
||||
return DiyFp(0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -286,7 +286,7 @@ static bool DiyFpStrtod(Vector<const char> buffer,
|
||||
const int kDenominator = 1 << kDenominatorLog;
|
||||
// Move the remaining decimals into the exponent.
|
||||
exponent += remaining_decimals;
|
||||
int error = (remaining_decimals == 0 ? 0 : kDenominator / 2);
|
||||
uint64_t error = (remaining_decimals == 0 ? 0 : kDenominator / 2);
|
||||
|
||||
int old_e = input.e();
|
||||
input.Normalize();
|
||||
@ -506,9 +506,7 @@ float Strtof(Vector<const char> buffer, int exponent) {
|
||||
double double_previous = Double(double_guess).PreviousDouble();
|
||||
|
||||
float f1 = static_cast<float>(double_previous);
|
||||
#ifndef NDEBUG
|
||||
float f2 = float_guess;
|
||||
#endif
|
||||
float f3 = static_cast<float>(double_next);
|
||||
float f4;
|
||||
if (is_correct) {
|
||||
@ -517,9 +515,8 @@ float Strtof(Vector<const char> buffer, int exponent) {
|
||||
double double_next2 = Double(double_next).NextDouble();
|
||||
f4 = static_cast<float>(double_next2);
|
||||
}
|
||||
#ifndef NDEBUG
|
||||
(void) f2; // Mark variable as used.
|
||||
ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4);
|
||||
#endif
|
||||
|
||||
// If the guess doesn't lie near a single-precision boundary we can simply
|
||||
// return its float-value.
|
||||
|
@ -33,14 +33,29 @@
|
||||
|
||||
#include <assert.h>
|
||||
#ifndef ASSERT
|
||||
#define ASSERT(condition) (assert(condition))
|
||||
#define ASSERT(condition) \
|
||||
assert(condition);
|
||||
#endif
|
||||
#ifndef UNIMPLEMENTED
|
||||
#define UNIMPLEMENTED() (abort())
|
||||
#endif
|
||||
#ifndef DOUBLE_CONVERSION_NO_RETURN
|
||||
#ifdef _MSC_VER
|
||||
#define DOUBLE_CONVERSION_NO_RETURN __declspec(noreturn)
|
||||
#else
|
||||
#define DOUBLE_CONVERSION_NO_RETURN __attribute__((noreturn))
|
||||
#endif
|
||||
#endif
|
||||
#ifndef UNREACHABLE
|
||||
#ifdef _MSC_VER
|
||||
void DOUBLE_CONVERSION_NO_RETURN abort_noreturn();
|
||||
inline void abort_noreturn() { abort(); }
|
||||
#define UNREACHABLE() (abort_noreturn())
|
||||
#else
|
||||
#define UNREACHABLE() (abort())
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
// Double operations detection based on target architecture.
|
||||
// Linux uses a 80bit wide floating point stack on x86. This induces double
|
||||
@ -55,11 +70,17 @@
|
||||
#if defined(_M_X64) || defined(__x86_64__) || \
|
||||
defined(__ARMEL__) || defined(__avr32__) || \
|
||||
defined(__hppa__) || defined(__ia64__) || \
|
||||
defined(__mips__) || defined(__powerpc__) || \
|
||||
defined(__mips__) || \
|
||||
defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || \
|
||||
defined(_POWER) || defined(_ARCH_PPC) || defined(_ARCH_PPC64) || \
|
||||
defined(__sparc__) || defined(__sparc) || defined(__s390__) || \
|
||||
defined(__SH4__) || defined(__alpha__) || \
|
||||
defined(_MIPS_ARCH_MIPS32R2) || defined(__aarch64__)
|
||||
defined(_MIPS_ARCH_MIPS32R2) || \
|
||||
defined(__AARCH64EL__) || defined(__aarch64__) || \
|
||||
defined(__riscv)
|
||||
#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1
|
||||
#elif defined(__mc68000__)
|
||||
#undef DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS
|
||||
#elif defined(_M_IX86) || defined(__i386__) || defined(__i386)
|
||||
#if defined(_WIN32)
|
||||
// Windows uses a 64bit wide floating point stack.
|
||||
@ -71,6 +92,11 @@
|
||||
#error Target architecture was not detected as supported by Double-Conversion.
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define DOUBLE_CONVERSION_UNUSED __attribute__((unused))
|
||||
#else
|
||||
#define DOUBLE_CONVERSION_UNUSED
|
||||
#endif
|
||||
|
||||
#if defined(_WIN32) && !defined(__MINGW32__)
|
||||
|
||||
@ -90,6 +116,8 @@ typedef unsigned __int64 uint64_t;
|
||||
|
||||
#endif
|
||||
|
||||
typedef uint16_t uc16;
|
||||
|
||||
// The following macro works on both 32 and 64-bit platforms.
|
||||
// Usage: instead of writing 0x1234567890123456
|
||||
// write UINT64_2PART_C(0x12345678,90123456);
|
||||
@ -155,8 +183,8 @@ template <typename T>
|
||||
class Vector {
|
||||
public:
|
||||
Vector() : start_(NULL), length_(0) {}
|
||||
Vector(T* data, int length) : start_(data), length_(length) {
|
||||
ASSERT(length == 0 || (length > 0 && data != NULL));
|
||||
Vector(T* data, int len) : start_(data), length_(len) {
|
||||
ASSERT(len == 0 || (len > 0 && data != NULL));
|
||||
}
|
||||
|
||||
// Returns a vector using the same backing storage as this one,
|
||||
@ -198,8 +226,8 @@ class Vector {
|
||||
// buffer bounds on all operations in debug mode.
|
||||
class StringBuilder {
|
||||
public:
|
||||
StringBuilder(char* buffer, int size)
|
||||
: buffer_(buffer, size), position_(0) { }
|
||||
StringBuilder(char* buffer, int buffer_size)
|
||||
: buffer_(buffer, buffer_size), position_(0) { }
|
||||
|
||||
~StringBuilder() { if (!is_finalized()) Finalize(); }
|
||||
|
||||
@ -218,8 +246,7 @@ class StringBuilder {
|
||||
// 0-characters; use the Finalize() method to terminate the string
|
||||
// instead.
|
||||
void AddCharacter(char c) {
|
||||
// I just extract raw data not a cstr so null is fine.
|
||||
//ASSERT(c != '\0');
|
||||
ASSERT(c != '\0');
|
||||
ASSERT(!is_finalized() && position_ < buffer_.length());
|
||||
buffer_[position_++] = c;
|
||||
}
|
||||
@ -234,8 +261,7 @@ class StringBuilder {
|
||||
// builder. The input string must have enough characters.
|
||||
void AddSubstring(const char* s, int n) {
|
||||
ASSERT(!is_finalized() && position_ + n < buffer_.length());
|
||||
// I just extract raw data not a cstr so null is fine.
|
||||
//ASSERT(static_cast<size_t>(n) <= strlen(s));
|
||||
ASSERT(static_cast<size_t>(n) <= strlen(s));
|
||||
memmove(&buffer_[position_], s, n * kCharSize);
|
||||
position_ += n;
|
||||
}
|
||||
@ -255,8 +281,7 @@ class StringBuilder {
|
||||
buffer_[position_] = '\0';
|
||||
// Make sure nobody managed to add a 0-character to the
|
||||
// buffer while building the string.
|
||||
// I just extract raw data not a cstr so null is fine.
|
||||
//ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));
|
||||
ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));
|
||||
position_ = -1;
|
||||
ASSERT(is_finalized());
|
||||
return buffer_.start();
|
||||
@ -299,11 +324,8 @@ template <class Dest, class Source>
|
||||
inline Dest BitCast(const Source& source) {
|
||||
// Compile time assertion: sizeof(Dest) == sizeof(Source)
|
||||
// A compile error here means your Dest and Source have different sizes.
|
||||
typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1]
|
||||
#if __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 8
|
||||
__attribute__((unused))
|
||||
#endif
|
||||
;
|
||||
DOUBLE_CONVERSION_UNUSED
|
||||
typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1];
|
||||
|
||||
Dest dest;
|
||||
memmove(&dest, &source, sizeof(dest));
|
||||
|
@ -134,7 +134,7 @@ class OverflowException : public Exception {
|
||||
|
||||
template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
|
||||
UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
|
||||
return value;
|
||||
return static_cast<std::size_t>(value);
|
||||
}
|
||||
|
||||
template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {
|
||||
|
@ -490,7 +490,7 @@ int
|
||||
mkstemp_and_unlink(char *tmpl) {
|
||||
int ret = mkstemp(tmpl);
|
||||
if (ret != -1) {
|
||||
UTIL_THROW_IF(unlink(tmpl), ErrnoException, "while deleting delete " << tmpl);
|
||||
UTIL_THROW_IF(unlink(tmpl), ErrnoException, "while deleting " << tmpl);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ class FilePiece {
|
||||
if (position_ == position_end_) {
|
||||
try {
|
||||
Shift();
|
||||
} catch (const util::EndOfFileException &e) { return false; }
|
||||
} catch (const util::EndOfFileException &) { return false; }
|
||||
// And break out at end of file.
|
||||
if (position_ == position_end_) return false;
|
||||
}
|
||||
|
@ -142,7 +142,7 @@ void UnmapOrThrow(void *start, size_t length) {
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
UTIL_THROW_IF(!::UnmapViewOfFile(start), ErrnoException, "Failed to unmap a file");
|
||||
#else
|
||||
UTIL_THROW_IF(munmap(start, length), ErrnoException, "munmap failed");
|
||||
UTIL_THROW_IF(munmap(start, length), ErrnoException, "munmap failed with " << start << " for length " << length);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ class DivMod {
|
||||
public:
|
||||
explicit DivMod(std::size_t buckets) : buckets_(buckets) {}
|
||||
|
||||
static std::size_t RoundBuckets(std::size_t from) {
|
||||
static uint64_t RoundBuckets(uint64_t from) {
|
||||
return from;
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ class Power2Mod {
|
||||
}
|
||||
|
||||
// Round up to next power of 2.
|
||||
static std::size_t RoundBuckets(std::size_t from) {
|
||||
static uint64_t RoundBuckets(uint64_t from) {
|
||||
--from;
|
||||
from |= from >> 1;
|
||||
from |= from >> 2;
|
||||
|
@ -5,10 +5,9 @@
|
||||
#include "util/spaces.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
#include <boost/iterator/iterator_facade.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <iterator>
|
||||
|
||||
namespace util {
|
||||
|
||||
@ -97,12 +96,12 @@ class AnyCharacterLast {
|
||||
StringPiece chars_;
|
||||
};
|
||||
|
||||
template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> {
|
||||
template <class Find, bool SkipEmpty = false> class TokenIter : public std::iterator<std::forward_iterator_tag, const StringPiece, std::ptrdiff_t, const StringPiece *, const StringPiece &> {
|
||||
public:
|
||||
TokenIter() {}
|
||||
|
||||
template <class Construct> TokenIter(const StringPiece &str, const Construct &construct) : after_(str), finder_(construct) {
|
||||
increment();
|
||||
++*this;
|
||||
}
|
||||
|
||||
bool operator!() const {
|
||||
@ -116,10 +115,15 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
|
||||
return TokenIter<Find, SkipEmpty>();
|
||||
}
|
||||
|
||||
private:
|
||||
friend class boost::iterator_core_access;
|
||||
bool operator==(const TokenIter<Find, SkipEmpty> &other) const {
|
||||
return current_.data() == other.current_.data();
|
||||
}
|
||||
|
||||
void increment() {
|
||||
bool operator!=(const TokenIter<Find, SkipEmpty> &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
TokenIter<Find, SkipEmpty> &operator++() {
|
||||
do {
|
||||
StringPiece found(finder_.Find(after_));
|
||||
current_ = StringPiece(after_.data(), found.data() - after_.data());
|
||||
@ -129,17 +133,25 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
|
||||
after_ = StringPiece(found.data() + found.size(), after_.data() - found.data() + after_.size() - found.size());
|
||||
}
|
||||
} while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false.
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool equal(const TokenIter<Find, SkipEmpty> &other) const {
|
||||
return current_.data() == other.current_.data();
|
||||
TokenIter<Find, SkipEmpty> &operator++(int) {
|
||||
TokenIter<Find, SkipEmpty> ret(*this);
|
||||
++*this;
|
||||
return ret;
|
||||
}
|
||||
|
||||
const StringPiece &dereference() const {
|
||||
const StringPiece &operator*() const {
|
||||
UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens");
|
||||
return current_;
|
||||
}
|
||||
const StringPiece *operator->() const {
|
||||
UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens");
|
||||
return ¤t_;
|
||||
}
|
||||
|
||||
private:
|
||||
StringPiece current_;
|
||||
StringPiece after_;
|
||||
|
||||
|
@ -16,7 +16,7 @@ struct ModelState {
|
||||
static constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
Alphabet alphabet_;
|
||||
std::unique_ptr<Scorer> scorer_;
|
||||
std::shared_ptr<Scorer> scorer_;
|
||||
unsigned int beam_width_;
|
||||
unsigned int n_steps_;
|
||||
unsigned int n_context_;
|
||||
|
@ -21,7 +21,6 @@ import deepspeech
|
||||
|
||||
# rename for backwards compatibility
|
||||
from deepspeech.impl import PrintVersions as printVersions
|
||||
from deepspeech.impl import FreeStream as freeStream
|
||||
|
||||
class Model(object):
|
||||
"""
|
||||
@ -56,127 +55,163 @@ class Model(object):
|
||||
"""
|
||||
return deepspeech.impl.GetModelSampleRate(self._impl)
|
||||
|
||||
def enableDecoderWithLM(self, *args, **kwargs):
|
||||
def enableExternalScorer(self, scorer_path):
|
||||
"""
|
||||
Enable decoding using beam scoring with a KenLM language model.
|
||||
Enable decoding using an external scorer.
|
||||
|
||||
:param aLMPath: The path to the language model binary file.
|
||||
:type aLMPath: str
|
||||
:param scorer_path: The path to the external scorer file.
|
||||
:type scorer_path: str
|
||||
|
||||
:param aTriePath: The path to the trie file build from the same vocabulary as the language model binary.
|
||||
:type aTriePath: str
|
||||
|
||||
:param aLMAlpha: The alpha hyperparameter of the CTC decoder. Language Model weight.
|
||||
:type aLMAlpha: float
|
||||
|
||||
:param aLMBeta: The beta hyperparameter of the CTC decoder. Word insertion weight.
|
||||
:type aLMBeta: float
|
||||
|
||||
:return: Zero on success, non-zero on failure (invalid arguments).
|
||||
:return: Zero on success, non-zero on failure.
|
||||
:type: int
|
||||
"""
|
||||
return deepspeech.impl.EnableDecoderWithLM(self._impl, *args, **kwargs)
|
||||
return deepspeech.impl.EnableExternalScorer(self._impl, scorer_path)
|
||||
|
||||
def stt(self, *args, **kwargs):
|
||||
def disableExternalScorer(self):
|
||||
"""
|
||||
Disable decoding using an external scorer.
|
||||
|
||||
:return: Zero on success, non-zero on failure.
|
||||
"""
|
||||
return deepspeech.impl.DisableExternalScorer(self._impl)
|
||||
|
||||
def setScorerAlphaBeta(self, alpha, beta):
|
||||
"""
|
||||
Set hyperparameters alpha and beta of the external scorer.
|
||||
|
||||
:param alpha: The alpha hyperparameter of the decoder. Language model weight.
|
||||
:type alpha: float
|
||||
|
||||
:param beta: The beta hyperparameter of the decoder. Word insertion weight.
|
||||
:type beta: float
|
||||
|
||||
:return: Zero on success, non-zero on failure.
|
||||
:type: int
|
||||
"""
|
||||
return deepspeech.impl.SetScorerAlphaBeta(self._impl, alpha, beta)
|
||||
|
||||
def stt(self, audio_buffer):
|
||||
"""
|
||||
Use the DeepSpeech model to perform Speech-To-Text.
|
||||
|
||||
:param aBuffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type aBuffer: int array
|
||||
|
||||
:param aBufferSize: The number of samples in the audio signal.
|
||||
:type aBufferSize: int
|
||||
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type audio_buffer: numpy.int16 array
|
||||
|
||||
:return: The STT result.
|
||||
:type: str
|
||||
"""
|
||||
return deepspeech.impl.SpeechToText(self._impl, *args, **kwargs)
|
||||
return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
|
||||
|
||||
def sttWithMetadata(self, *args, **kwargs):
|
||||
def sttWithMetadata(self, audio_buffer):
|
||||
"""
|
||||
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
|
||||
|
||||
:param aBuffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type aBuffer: int array
|
||||
|
||||
:param aBufferSize: The number of samples in the audio signal.
|
||||
:type aBufferSize: int
|
||||
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type audio_buffer: numpy.int16 array
|
||||
|
||||
:return: Outputs a struct of individual letters along with their timing information.
|
||||
:type: :func:`Metadata`
|
||||
"""
|
||||
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, *args, **kwargs)
|
||||
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer)
|
||||
|
||||
def createStream(self):
|
||||
"""
|
||||
Create a new streaming inference state. The streaming state returned
|
||||
by this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`.
|
||||
Create a new streaming inference state. The streaming state returned by
|
||||
this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`.
|
||||
|
||||
:return: Object holding the stream
|
||||
:return: Stream object representing the newly created stream
|
||||
:type: :func:`Stream`
|
||||
|
||||
:throws: RuntimeError on error
|
||||
"""
|
||||
status, ctx = deepspeech.impl.CreateStream(self._impl)
|
||||
if status != 0:
|
||||
raise RuntimeError("CreateStream failed with error code {}".format(status))
|
||||
return ctx
|
||||
return Stream(ctx)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def feedAudioContent(self, *args, **kwargs):
|
||||
|
||||
class Stream(object):
|
||||
"""
|
||||
Class wrapping a DeepSpeech stream. The constructor cannot be called directly.
|
||||
Use :func:`Model.createStream()`
|
||||
"""
|
||||
def __init__(self, native_stream):
|
||||
self._impl = native_stream
|
||||
|
||||
def __del__(self):
|
||||
if self._impl:
|
||||
self.freeStream()
|
||||
|
||||
def feedAudioContent(self, audio_buffer):
|
||||
"""
|
||||
Feed audio samples to an ongoing streaming inference.
|
||||
|
||||
:param aSctx: A streaming state pointer returned by :func:`createStream()`.
|
||||
:type aSctx: object
|
||||
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
|
||||
:type audio_buffer: numpy.int16 array
|
||||
|
||||
:param aBuffer: An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on).
|
||||
:type aBuffer: int array
|
||||
|
||||
:param aBufferSize: The number of samples in @p aBuffer.
|
||||
:type aBufferSize: int
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
deepspeech.impl.FeedAudioContent(*args, **kwargs)
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to feed an already finished stream?")
|
||||
deepspeech.impl.FeedAudioContent(self._impl, audio_buffer)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def intermediateDecode(self, *args, **kwargs):
|
||||
def intermediateDecode(self):
|
||||
"""
|
||||
Compute the intermediate decoding of an ongoing streaming inference.
|
||||
|
||||
:param aSctx: A streaming state pointer returned by :func:`createStream()`.
|
||||
:type aSctx: object
|
||||
|
||||
:return: The STT intermediate result.
|
||||
:type: str
|
||||
"""
|
||||
return deepspeech.impl.IntermediateDecode(*args, **kwargs)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def finishStream(self, *args, **kwargs):
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming
|
||||
inference, returns the STT result over the whole audio signal.
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
|
||||
return deepspeech.impl.IntermediateDecode(self._impl)
|
||||
|
||||
:param aSctx: A streaming state pointer returned by :func:`createStream()`.
|
||||
:type aSctx: object
|
||||
def finishStream(self):
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming inference,
|
||||
returns the STT result over the whole audio signal.
|
||||
|
||||
:return: The STT result.
|
||||
:type: str
|
||||
"""
|
||||
return deepspeech.impl.FinishStream(*args, **kwargs)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def finishStreamWithMetadata(self, *args, **kwargs):
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming
|
||||
inference, returns per-letter metadata.
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
|
||||
result = deepspeech.impl.FinishStream(self._impl)
|
||||
self._impl = None
|
||||
return result
|
||||
|
||||
:param aSctx: A streaming state pointer returned by :func:`createStream()`.
|
||||
:type aSctx: object
|
||||
def finishStreamWithMetadata(self):
|
||||
"""
|
||||
Signal the end of an audio signal to an ongoing streaming inference,
|
||||
returns per-letter metadata.
|
||||
|
||||
:return: Outputs a struct of individual letters along with their timing information.
|
||||
:type: :func:`Metadata`
|
||||
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
return deepspeech.impl.FinishStreamWithMetadata(*args, **kwargs)
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
|
||||
result = deepspeech.impl.FinishStreamWithMetadata(self._impl)
|
||||
self._impl = None
|
||||
return result
|
||||
|
||||
def freeStream(self):
|
||||
"""
|
||||
Destroy a streaming state without decoding the computed logits. This can
|
||||
be used if you no longer need the result of an ongoing streaming inference.
|
||||
|
||||
:throws: RuntimeError if the stream object is not valid
|
||||
"""
|
||||
if not self._impl:
|
||||
raise RuntimeError("Stream object is not valid. Trying to free an already finished stream?")
|
||||
deepspeech.impl.FreeStream(self._impl)
|
||||
self._impl = None
|
||||
|
||||
|
||||
# This is only for documentation purpose
|
||||
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h
|
||||
@ -189,22 +224,18 @@ class MetadataItem(object):
|
||||
"""
|
||||
The character generated for transcription
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
||||
def timestep(self):
|
||||
"""
|
||||
Position of the character in units of 20ms
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
||||
def start_time(self):
|
||||
"""
|
||||
Position of the character in seconds
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
@ -218,8 +249,7 @@ class Metadata(object):
|
||||
:return: A list of :func:`MetadataItem` elements
|
||||
:type: list
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
||||
def num_items(self):
|
||||
"""
|
||||
@ -228,8 +258,7 @@ class Metadata(object):
|
||||
:return: Size of the list of items
|
||||
:type: int
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
||||
def confidence(self):
|
||||
"""
|
||||
@ -237,5 +266,4 @@ class Metadata(object):
|
||||
sum of the acoustic model logit values for each timestep/character that
|
||||
contributed to the creation of this transcription.
|
||||
"""
|
||||
# pylint: disable=unnecessary-pass
|
||||
pass
|
||||
|
||||
|
@ -72,7 +72,7 @@ def metadata_json_output(metadata):
|
||||
json_result["words"] = words_from_metadata(metadata)
|
||||
json_result["confidence"] = metadata.confidence
|
||||
return json.dumps(json_result)
|
||||
|
||||
|
||||
|
||||
|
||||
class VersionAction(argparse.Action):
|
||||
@ -88,18 +88,16 @@ def main():
|
||||
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', nargs='?',
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', nargs='?',
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--scorer', required=False,
|
||||
help='Path to the external scorer file')
|
||||
parser.add_argument('--audio', required=True,
|
||||
help='Path to the audio file to run (WAV format)')
|
||||
parser.add_argument('--beam_width', type=int, default=500,
|
||||
help='Beam width for the CTC decoder')
|
||||
parser.add_argument('--lm_alpha', type=float, default=0.75,
|
||||
help='Language model weight (lm_alpha)')
|
||||
parser.add_argument('--lm_beta', type=float, default=1.85,
|
||||
help='Word insertion bonus (lm_beta)')
|
||||
parser.add_argument('--lm_alpha', type=float,
|
||||
help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
|
||||
parser.add_argument('--lm_beta', type=float,
|
||||
help='Word insertion bonus (lm_beta). If not specified, use default from the scorer package.')
|
||||
parser.add_argument('--version', action=VersionAction,
|
||||
help='Print version and exits')
|
||||
parser.add_argument('--extended', required=False, action='store_true',
|
||||
@ -116,12 +114,15 @@ def main():
|
||||
|
||||
desired_sample_rate = ds.sampleRate()
|
||||
|
||||
if args.lm and args.trie:
|
||||
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
|
||||
lm_load_start = timer()
|
||||
ds.enableDecoderWithLM(args.lm, args.trie, args.lm_alpha, args.lm_beta)
|
||||
lm_load_end = timer() - lm_load_start
|
||||
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)
|
||||
if args.scorer:
|
||||
print('Loading scorer from files {}'.format(args.scorer), file=sys.stderr)
|
||||
scorer_load_start = timer()
|
||||
ds.enableExternalScorer(args.scorer)
|
||||
scorer_load_end = timer() - scorer_load_start
|
||||
print('Loaded scorer in {:.3}s.'.format(scorer_load_end), file=sys.stderr)
|
||||
|
||||
if args.lm_alpha and args.lm_beta:
|
||||
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
|
||||
|
||||
fin = wave.open(args.audio, 'rb')
|
||||
fs = fin.getframerate()
|
||||
|
@ -14,21 +14,13 @@ from deepspeech import Model
|
||||
# Beam width used in the CTC decoder when building candidate transcriptions
|
||||
BEAM_WIDTH = 500
|
||||
|
||||
# The alpha hyperparameter of the CTC decoder. Language Model weight
|
||||
LM_ALPHA = 0.75
|
||||
|
||||
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
|
||||
LM_BETA = 1.85
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
|
||||
parser.add_argument('--model', required=True,
|
||||
help='Path to the model (protocol buffer binary file)')
|
||||
parser.add_argument('--lm', nargs='?',
|
||||
help='Path to the language model binary file')
|
||||
parser.add_argument('--trie', nargs='?',
|
||||
help='Path to the language model trie file created with native_client/generate_trie')
|
||||
parser.add_argument('--scorer', nargs='?',
|
||||
help='Path to the external scorer file')
|
||||
parser.add_argument('--audio1', required=True,
|
||||
help='First audio file to use in interleaved streams')
|
||||
parser.add_argument('--audio2', required=True,
|
||||
@ -37,8 +29,8 @@ def main():
|
||||
|
||||
ds = Model(args.model, BEAM_WIDTH)
|
||||
|
||||
if args.lm and args.trie:
|
||||
ds.enableDecoderWithLM(args.lm, args.trie, LM_ALPHA, LM_BETA)
|
||||
if args.scorer:
|
||||
ds.enableExternalScorer(args.scorer)
|
||||
|
||||
fin = wave.open(args.audio1, 'rb')
|
||||
fs1 = fin.getframerate()
|
||||
@ -57,11 +49,11 @@ def main():
|
||||
splits2 = np.array_split(audio2, 10)
|
||||
|
||||
for part1, part2 in zip(splits1, splits2):
|
||||
ds.feedAudioContent(stream1, part1)
|
||||
ds.feedAudioContent(stream2, part2)
|
||||
stream1.feedAudioContent(part1)
|
||||
stream2.feedAudioContent(part2)
|
||||
|
||||
print(ds.finishStream(stream1))
|
||||
print(ds.finishStream(stream2))
|
||||
print(stream1.finishStream())
|
||||
print(stream2.finishStream())
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -27,9 +27,9 @@ int main(int argc, char** argv)
|
||||
return err;
|
||||
}
|
||||
Scorer scorer;
|
||||
|
||||
err = scorer.init(kenlm_path, alphabet);
|
||||
#ifndef DEBUG
|
||||
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet);
|
||||
return err;
|
||||
#else
|
||||
// Print some info about the FST
|
||||
using FstType = fst::ConstFst<fst::StdArc>;
|
||||
@ -60,7 +60,6 @@ int main(int argc, char** argv)
|
||||
// for (int i = 1; i < 10; ++i) {
|
||||
// print_states_from(i);
|
||||
// }
|
||||
#endif // DEBUG
|
||||
|
||||
return 0;
|
||||
#endif // DEBUG
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
|
||||
|
||||
BAZEL_TARGETS="
|
||||
//native_client:libdeepspeech.so
|
||||
//native_client:generate_trie
|
||||
"
|
||||
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_ARM64_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
|
@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
|
||||
|
||||
BAZEL_TARGETS="
|
||||
//native_client:libdeepspeech.so
|
||||
//native_client:generate_trie
|
||||
"
|
||||
|
||||
BAZEL_ENV_FLAGS="TF_NEED_CUDA=1 ${TF_CUDA_FLAGS}"
|
||||
|
@ -30,11 +30,11 @@ then:
|
||||
image: ${build.docker_image}
|
||||
|
||||
env:
|
||||
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.15/models.tar.gz"
|
||||
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models.tar.gz"
|
||||
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
|
||||
PIP_DEFAULT_TIMEOUT: "60"
|
||||
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
|
||||
EXAMPLES_CHECKOUT_TARGET: "master"
|
||||
EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a"
|
||||
|
||||
command:
|
||||
- "/bin/bash"
|
||||
|
@ -10,7 +10,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
|
||||
|
||||
BAZEL_TARGETS="
|
||||
//native_client:libdeepspeech.so
|
||||
//native_client:generate_trie
|
||||
"
|
||||
|
||||
if [ "${runtime}" = "tflite" ]; then
|
||||
|
@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
|
||||
|
||||
BAZEL_TARGETS="
|
||||
//native_client:libdeepspeech.so
|
||||
//native_client:generate_trie
|
||||
"
|
||||
|
||||
BAZEL_BUILD_FLAGS="${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS}"
|
||||
|
@ -49,7 +49,7 @@ deepspeech --version
|
||||
|
||||
pushd ${HOME}/DeepSpeech/ds/
|
||||
python bin/import_ldc93s1.py data/smoke_test
|
||||
python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --lm data/smoke_test/vocab.pruned.lm --trie data/smoke_test/vocab.trie --csv data/smoke_test/ldc93s1.csv
|
||||
python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --scorer data/smoke_test/pruned_lm.scorer --csv data/smoke_test/ldc93s1.csv
|
||||
popd
|
||||
|
||||
virtualenv_deactivate "${pyalias}" "${PYENV_NAME}"
|
||||
|
@ -378,7 +378,7 @@ run_netframework_inference_tests()
|
||||
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
set -e
|
||||
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?"
|
||||
}
|
||||
@ -401,7 +401,7 @@ run_electronjs_inference_tests()
|
||||
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
set -e
|
||||
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?"
|
||||
}
|
||||
@ -427,7 +427,7 @@ run_basic_inference_tests()
|
||||
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$status"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm}" "$status"
|
||||
@ -444,7 +444,7 @@ run_all_inference_tests()
|
||||
assert_correct_ldc93s1 "${phrase_pbmodel_nolm_stereo_44k}" "$status"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"
|
||||
@ -457,7 +457,7 @@ run_all_inference_tests()
|
||||
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
set -e
|
||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||
fi;
|
||||
@ -470,8 +470,7 @@ run_prod_concurrent_stream_tests()
|
||||
set +e
|
||||
output=$(python ${TASKCLUSTER_TMP_DIR}/test_sources/concurrent_streams.py \
|
||||
--model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} \
|
||||
--lm ${TASKCLUSTER_TMP_DIR}/lm.binary \
|
||||
--trie ${TASKCLUSTER_TMP_DIR}/trie \
|
||||
--scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer \
|
||||
--audio1 ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_16000.wav \
|
||||
--audio2 ${TASKCLUSTER_TMP_DIR}/new-home-in-the-stars-16k.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
@ -489,19 +488,19 @@ run_prod_inference_tests()
|
||||
local _bitrate=$1
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}"
|
||||
@ -509,7 +508,7 @@ run_prod_inference_tests()
|
||||
# Run down-sampling warning test only when we actually perform downsampling
|
||||
if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then
|
||||
set +e
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
set -e
|
||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||
fi;
|
||||
@ -520,19 +519,19 @@ run_prodtflite_inference_tests()
|
||||
local _bitrate=$1
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
|
||||
|
||||
set +e
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_prodtflitemodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}"
|
||||
@ -540,7 +539,7 @@ run_prodtflite_inference_tests()
|
||||
# Run down-sampling warning test only when we actually perform downsampling
|
||||
if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then
|
||||
set +e
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||
set -e
|
||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||
fi;
|
||||
@ -555,7 +554,7 @@ run_multi_inference_tests()
|
||||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}" "$status"
|
||||
|
||||
set +e -o pipefail
|
||||
multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/ 2>${TASKCLUSTER_TMP_DIR}/stderr | tr '\n' '%')
|
||||
multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/ 2>${TASKCLUSTER_TMP_DIR}/stderr | tr '\n' '%')
|
||||
status=$?
|
||||
set -e +o pipefail
|
||||
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
|
||||
@ -564,7 +563,7 @@ run_multi_inference_tests()
|
||||
run_cpp_only_inference_tests()
|
||||
{
|
||||
set +e
|
||||
phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1)
|
||||
phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1)
|
||||
status=$?
|
||||
set -e
|
||||
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status"
|
||||
@ -669,8 +668,7 @@ download_data()
|
||||
${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source}"
|
||||
${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source_mmap}"
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/*.wav ${TASKCLUSTER_TMP_DIR}/
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.pruned.lm ${TASKCLUSTER_TMP_DIR}/lm.binary
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.trie ${TASKCLUSTER_TMP_DIR}/trie
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/pruned_lm.scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer
|
||||
cp -R ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/test ${TASKCLUSTER_TMP_DIR}/test_sources
|
||||
}
|
||||
|
||||
@ -1562,7 +1560,6 @@ package_native_client()
|
||||
fi;
|
||||
|
||||
${TAR} -cf - \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_trie${PLATFORM_EXE_SUFFIX} \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so \
|
||||
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so.if.lib \
|
||||
-C ${deepspeech_dir}/ LICENSE \
|
||||
@ -1767,8 +1764,7 @@ android_setup_apk_data()
|
||||
adb push \
|
||||
${TASKCLUSTER_TMP_DIR}/${model_name} \
|
||||
${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \
|
||||
${TASKCLUSTER_TMP_DIR}/lm.binary \
|
||||
${TASKCLUSTER_TMP_DIR}/trie \
|
||||
${TASKCLUSTER_TMP_DIR}/kenlm.scorer \
|
||||
${ANDROID_TMP_DIR}/test/
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
|
||||
|
||||
BAZEL_TARGETS="
|
||||
//native_client:libdeepspeech.so
|
||||
//native_client:generate_trie
|
||||
"
|
||||
|
||||
if [ "${package_option}" = "--cuda" ]; then
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user