Merge branch 'decoder-api-changes' (PR #2681)

This commit is contained in:
Reuben Morais 2020-02-11 21:18:57 +01:00
commit 5366f90375
103 changed files with 4836 additions and 986 deletions

4
.gitattributes vendored
View File

@ -1,3 +1 @@
*.binary filter=lfs diff=lfs merge=lfs -crlf data/lm/kenlm.scorer filter=lfs diff=lfs merge=lfs -text
data/lm/trie filter=lfs diff=lfs merge=lfs -crlf
data/lm/vocab.txt filter=lfs diff=lfs merge=lfs -text

View File

@ -7,7 +7,7 @@ extension-pkg-whitelist=
# Add files or directories to the blacklist. They should be base names, not # Add files or directories to the blacklist. They should be base names, not
# paths. # paths.
ignore=examples ignore=native_client/kenlm
# Add files or directories matching the regex patterns to the blacklist. The # Add files or directories matching the regex patterns to the blacklist. The
# regex matches against base names, not paths. # regex matches against base names, not paths.

View File

@ -882,8 +882,7 @@ def package_zip():
} }
}, f) }, f)
shutil.copy(FLAGS.lm_binary_path, export_dir) shutil.copy(FLAGS.scorer_path, export_dir)
shutil.copy(FLAGS.lm_trie_path, export_dir)
archive = shutil.make_archive(zip_filename, 'zip', export_dir) archive = shutil.make_archive(zip_filename, 'zip', export_dir)
log_info('Exported packaged model {}'.format(archive)) log_info('Exported packaged model {}'.format(archive))
@ -926,10 +925,9 @@ def do_single_file_inference(input_file_path):
logits = np.squeeze(logits) logits = np.squeeze(logits)
if FLAGS.lm_binary_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path, FLAGS.scorer_path, Config.alphabet)
Config.alphabet)
else: else:
scorer = None scorer = None
decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width, decoded = ctc_beam_search_decoder(logits, Config.alphabet, FLAGS.beam_width,

View File

@ -172,7 +172,7 @@ RUN ./configure
# Build DeepSpeech # Build DeepSpeech
RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic --config=cuda -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-mtune=generic --copt=-march=x86-64 --copt=-msse --copt=-msse2 --copt=-msse3 --copt=-msse4.1 --copt=-msse4.2 --copt=-mavx --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_trie --verbose_failures --action_env=LD_LIBRARY_PATH=${LD_LIBRARY_PATH} RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic --config=cuda -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-mtune=generic --copt=-march=x86-64 --copt=-msse --copt=-msse2 --copt=-msse3 --copt=-msse4.1 --copt=-msse4.2 --copt=-mavx --copt=-fvisibility=hidden //native_client:libdeepspeech.so --verbose_failures --action_env=LD_LIBRARY_PATH=${LD_LIBRARY_PATH}
### ###
### Using TensorFlow upstream should work ### Using TensorFlow upstream should work
@ -187,8 +187,7 @@ RUN bazel build --workspace_status_command="bash native_client/bazel_workspace_s
# RUN pip3 install /tmp/tensorflow_pkg/*.whl # RUN pip3 install /tmp/tensorflow_pkg/*.whl
# Copy built libs to /DeepSpeech/native_client # Copy built libs to /DeepSpeech/native_client
RUN cp /tensorflow/bazel-bin/native_client/generate_trie /DeepSpeech/native_client/ \ RUN cp /tensorflow/bazel-bin/native_client/libdeepspeech.so /DeepSpeech/native_client/
&& cp /tensorflow/bazel-bin/native_client/libdeepspeech.so /DeepSpeech/native_client/
# Install TensorFlow # Install TensorFlow
WORKDIR /DeepSpeech/ WORKDIR /DeepSpeech/

View File

@ -36,7 +36,7 @@ To install and use deepspeech all you have to do is:
tar xvf audio-0.6.1.tar.gz tar xvf audio-0.6.1.tar.gz
# Transcribe an audio file # Transcribe an audio file
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --lm deepspeech-0.6.1-models/lm.binary --trie deepspeech-0.6.1-models/trie --audio audio/2830-3980-0043.wav deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --scorer deepspeech-0.6.1-models/kenlm.scorer --audio audio/2830-3980-0043.wav
A pre-trained English model is available for use and can be downloaded using `the instructions below <doc/USING.rst#using-a-pre-trained-model>`_. A package with some example audio files is available for download in our `release notes <https://github.com/mozilla/DeepSpeech/releases/latest>`_. A pre-trained English model is available for use and can be downloaded using `the instructions below <doc/USING.rst#using-a-pre-trained-model>`_. A package with some example audio files is available for download in our `release notes <https://github.com/mozilla/DeepSpeech/releases/latest>`_.
@ -52,7 +52,7 @@ Quicker inference can be performed using a supported NVIDIA GPU on Linux. See th
pip3 install deepspeech-gpu pip3 install deepspeech-gpu
# Transcribe an audio file. # Transcribe an audio file.
deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --lm deepspeech-0.6.1-models/lm.binary --trie deepspeech-0.6.1-models/trie --audio audio/2830-3980-0043.wav deepspeech --model deepspeech-0.6.1-models/output_graph.pbmm --scorer deepspeech-0.6.1-models/kenlm.scorer --audio audio/2830-3980-0043.wav
Please ensure you have the required `CUDA dependencies <doc/USING.rst#cuda-dependency>`_. Please ensure you have the required `CUDA dependencies <doc/USING.rst#cuda-dependency>`_.

View File

@ -21,8 +21,7 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--n_hidden 100 --epochs 1 \ --n_hidden 100 --epochs 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \ --max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
--learning_rate 0.001 --dropout_rate 0.05 \ --learning_rate 0.001 --dropout_rate 0.05 \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' | tee /tmp/resume.log
--lm_trie_path 'data/smoke_test/vocab.trie' | tee /tmp/resume.log
if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then if ! grep "Restored variables from most recent checkpoint" /tmp/resume.log; then
echo "Did not resume training from checkpoint" echo "Did not resume training from checkpoint"

View File

@ -25,6 +25,5 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--n_hidden 100 --epochs $epoch_count \ --n_hidden 100 --epochs $epoch_count \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \ --max_to_keep 1 --checkpoint_dir '/tmp/ckpt' \
--learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \ --learning_rate 0.001 --dropout_rate 0.05 --export_dir '/tmp/train' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--audio_sample_rate ${audio_sample_rate} --audio_sample_rate ${audio_sample_rate}

View File

@ -21,12 +21,10 @@ python -u DeepSpeech.py --noshow_progressbar --noearly_stop \
--n_hidden 100 --epochs 1 \ --n_hidden 100 --epochs 1 \
--max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \ --max_to_keep 1 --checkpoint_dir '/tmp/ckpt' --checkpoint_secs 0 \
--learning_rate 0.001 --dropout_rate 0.05 \ --learning_rate 0.001 --dropout_rate 0.05 \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer'
--lm_trie_path 'data/smoke_test/vocab.trie'
python -u DeepSpeech.py \ python -u DeepSpeech.py \
--n_hidden 100 \ --n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \ --checkpoint_dir '/tmp/ckpt' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--one_shot_infer 'data/smoke_test/LDC93S1.wav' --one_shot_infer 'data/smoke_test/LDC93S1.wav'

View File

@ -20,8 +20,7 @@ python -u DeepSpeech.py --noshow_progressbar \
--n_hidden 100 \ --n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \ --checkpoint_dir '/tmp/ckpt' \
--export_dir '/tmp/train_tflite' \ --export_dir '/tmp/train_tflite' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--audio_sample_rate ${audio_sample_rate} \ --audio_sample_rate ${audio_sample_rate} \
--export_tflite --export_tflite
@ -31,8 +30,7 @@ python -u DeepSpeech.py --noshow_progressbar \
--n_hidden 100 \ --n_hidden 100 \
--checkpoint_dir '/tmp/ckpt' \ --checkpoint_dir '/tmp/ckpt' \
--export_dir '/tmp/train_tflite/en-us' \ --export_dir '/tmp/train_tflite/en-us' \
--lm_binary_path 'data/smoke_test/vocab.pruned.lm' \ --scorer_path 'data/smoke_test/pruned_lm.scorer' \
--lm_trie_path 'data/smoke_test/vocab.trie' \
--audio_sample_rate ${audio_sample_rate} \ --audio_sample_rate ${audio_sample_rate} \
--export_language 'Fake English (fk-FK)' \ --export_language 'Fake English (fk-FK)' \
--export_zip --export_zip

View File

@ -5,9 +5,7 @@ This directory contains language-specific data files. Most importantly, you will
1. A list of unique characters for the target language (e.g. English) in `data/alphabet.txt` 1. A list of unique characters for the target language (e.g. English) in `data/alphabet.txt`
2. A binary n-gram language model compiled by `kenlm` in `data/lm/lm.binary` 2. A scorer package (`data/lm/kenlm.scorer`) generated with `data/lm/generate_package.py`. The scorer package includes a binary n-gram language model generated with `data/lm/generate_lm.py`.
3. A trie model compiled by `generate_trie <https://github.com/mozilla/DeepSpeech#using-the-command-line-client>`_ in `data/lm/trie`
For more information on how to build these resources from scratch, see `data/lm/README.md` For more information on how to build these resources from scratch, see `data/lm/README.md`

View File

@ -1,8 +1,8 @@
lm.binary was generated from the LibriSpeech normalized LM training text, available `here <http://www.openslr.org/11>`_\ , using the `generate_lm.py` script (will generate lm.binary in the folder it is run from). KenLM's built binaries must be in your PATH (lmplz, build_binary, filter). The LM binary was generated from the LibriSpeech normalized LM training text, available `here <http://www.openslr.org/11>`_\ , using the `generate_lm.py` script (will generate `lm.binary` and `librispeech-vocab-500k.txt` in the folder it is run from). `KenLM <https://github.com/kpu/kenlm>`_'s built binaries must be in your PATH (lmplz, build_binary, filter).
The trie was then generated from the vocabulary of the language model: The scorer package was then built using the `generate_package.py` script:
.. code-block:: bash .. code-block:: bash
python generate_lm.py # this will create lm.binary and librispeech-vocab-500k.txt
./generate_trie ../data/alphabet.txt lm.binary trie python generate_package.py --alphabet ../alphabet.txt --lm lm.binary --vocab librispeech-vocab-500k.txt --default_alpha 0.75 --default_beta 1.85 --package kenlm.scorer

View File

@ -39,10 +39,13 @@ def main():
'--prune', '0', '0', '1' '--prune', '0', '0', '1'
]) ])
# Filter LM using vocabulary of top 500k words
filtered_path = os.path.join(tmp, 'lm_filtered.arpa')
vocab_str = '\n'.join(word for word, count in counter.most_common(500000)) vocab_str = '\n'.join(word for word, count in counter.most_common(500000))
with open('librispeech-vocab-500k.txt', 'w') as fout:
fout.write(vocab_str)
# Filter LM using vocabulary of top 500k words
print('Filtering ARPA file...') print('Filtering ARPA file...')
filtered_path = os.path.join(tmp, 'lm_filtered.arpa')
subprocess.run(['filter', 'single', 'model:{}'.format(lm_path), filtered_path], input=vocab_str.encode('utf-8'), check=True) subprocess.run(['filter', 'single', 'model:{}'.format(lm_path), filtered_path], input=vocab_str.encode('utf-8'), check=True)
# Quantize and produce trie binary. # Quantize and produce trie binary.
@ -50,6 +53,7 @@ def main():
subprocess.check_call([ subprocess.check_call([
'build_binary', '-a', '255', 'build_binary', '-a', '255',
'-q', '8', '-q', '8',
'-v',
'trie', 'trie',
filtered_path, filtered_path,
'lm.binary' 'lm.binary'

154
data/lm/generate_package.py Normal file
View File

@ -0,0 +1,154 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function
# Make sure we can import stuff from util/
# This script needs to be run from the root of the DeepSpeech repository
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
import argparse
import shutil
from util.text import Alphabet, UTF8Alphabet
from ds_ctcdecoder import Scorer, Alphabet as NativeAlphabet
def create_bundle(
alphabet_path,
lm_path,
vocab_path,
package_path,
force_utf8,
default_alpha,
default_beta,
):
words = set()
vocab_looks_char_based = True
with open(vocab_path) as fin:
for line in fin:
for word in line.split():
words.add(word.encode("utf-8"))
if len(word) > 1:
vocab_looks_char_based = False
print("{} unique words read from vocabulary file.".format(len(words)))
print(
"{} like a character based model.".format(
"Looks" if vocab_looks_char_based else "Doesn't look"
)
)
if force_utf8 != None: # pylint: disable=singleton-comparison
use_utf8 = force_utf8.value
print("Forcing UTF-8 mode = {}".format(use_utf8))
else:
use_utf8 = vocab_looks_char_based
if use_utf8:
serialized_alphabet = UTF8Alphabet().serialize()
else:
serialized_alphabet = Alphabet(alphabet_path).serialize()
alphabet = NativeAlphabet()
err = alphabet.deserialize(serialized_alphabet, len(serialized_alphabet))
if err != 0:
print("Error loading alphabet: {}".format(err))
sys.exit(1)
scorer = Scorer()
scorer.set_alphabet(alphabet)
scorer.set_utf8_mode(use_utf8)
scorer.reset_params(default_alpha, default_beta)
scorer.load_lm(lm_path)
scorer.fill_dictionary(list(words))
shutil.copy(lm_path, package_path)
scorer.save_dictionary(package_path, True) # append, not overwrite
print("Package created in {}".format(package_path))
class Tristate(object):
def __init__(self, value=None):
if any(value is v for v in (True, False, None)):
self.value = value
else:
raise ValueError("Tristate value must be True, False, or None")
def __eq__(self, other):
return (
self.value is other.value
if isinstance(other, Tristate)
else self.value is other
)
def __ne__(self, other):
return not self == other
def __bool__(self):
raise TypeError("Tristate object may not be used as a Boolean")
def __str__(self):
return str(self.value)
def __repr__(self):
return "Tristate(%s)" % self.value
def main():
parser = argparse.ArgumentParser(
description="Generate an external scorer package for DeepSpeech."
)
parser.add_argument(
"--alphabet",
help="Path of alphabet file to use for vocabulary construction. Words with characters not in the alphabet will not be included in the vocabulary. Optional if using UTF-8 mode.",
)
parser.add_argument(
"--lm",
required=True,
help="Path of KenLM binary LM file. Must be built without including the vocabulary (use the -v flag). See generate_lm.py for how to create a binary LM.",
)
parser.add_argument(
"--vocab",
required=True,
help="Path of vocabulary file. Must contain words separated by whitespace.",
)
parser.add_argument("--package", required=True, help="Path to save scorer package.")
parser.add_argument(
"--default_alpha",
type=float,
required=True,
help="Default value of alpha hyperparameter.",
)
parser.add_argument(
"--default_beta",
type=float,
required=True,
help="Default value of beta hyperparameter.",
)
parser.add_argument(
"--force_utf8",
default="",
help="Boolean flag, force set or unset UTF-8 mode in the scorer package. If not set, infers from the vocabulary.",
)
args = parser.parse_args()
if args.force_utf8 in ("True", "1", "true", "yes", "y"):
force_utf8 = Tristate(True)
elif args.force_utf8 in ("False", "0", "false", "no", "n"):
force_utf8 = Tristate(False)
else:
force_utf8 = Tristate(None)
create_bundle(
args.alphabet,
args.lm,
args.vocab,
args.package,
force_utf8,
args.default_alpha,
args.default_beta,
)
if __name__ == "__main__":
main()

3
data/lm/kenlm.scorer Normal file
View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3ba04978fca285c34c99bf115ee61549937e422ac91def80122a767e114c035e
size 953436352

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a24953ce3f013bbf5f4a1c9f5a0e5482bc56eaa81638276de522f39e62ff3a56
size 945699324

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0281e5e784ffccb4aeae5e7d64099058a0c22e42dbb7aa2d3ef2fbbff53db3ab
size 12200736

File diff suppressed because it is too large Load Diff

Binary file not shown.

View File

@ -7,7 +7,13 @@ C
.. doxygenfunction:: DS_FreeModel .. doxygenfunction:: DS_FreeModel
:project: deepspeech-c :project: deepspeech-c
.. doxygenfunction:: DS_EnableDecoderWithLM .. doxygenfunction:: DS_EnableExternalScorer
:project: deepspeech-c
.. doxygenfunction:: DS_DisableExternalScorer
:project: deepspeech-c
.. doxygenfunction:: DS_SetScorerAlphaBeta
:project: deepspeech-c :project: deepspeech-c
.. doxygenfunction:: DS_GetModelSampleRate .. doxygenfunction:: DS_GetModelSampleRate

View File

@ -7,7 +7,7 @@ Creating a model instance and loading model
.. literalinclude:: ../native_client/client.cc .. literalinclude:: ../native_client/client.cc
:language: c :language: c
:linenos: :linenos:
:lines: 370-388 :lines: 370-390
Performing inference Performing inference
-------------------- --------------------

View File

@ -7,6 +7,12 @@ Model
.. js:autoclass:: Model .. js:autoclass:: Model
:members: :members:
Stream
------
.. js:autoclass:: Stream
:members:
Module exported methods Module exported methods
----------------------- -----------------------

View File

@ -7,7 +7,7 @@ Creating a model instance and loading model
.. literalinclude:: ../native_client/javascript/client.js .. literalinclude:: ../native_client/javascript/client.js
:language: javascript :language: javascript
:linenos: :linenos:
:lines: 57-66 :lines: 54-72
Performing inference Performing inference
-------------------- --------------------
@ -15,7 +15,7 @@ Performing inference
.. literalinclude:: ../native_client/javascript/client.js .. literalinclude:: ../native_client/javascript/client.js
:language: javascript :language: javascript
:linenos: :linenos:
:lines: 115-117 :lines: 117-121
Full source code Full source code
---------------- ----------------

View File

@ -9,6 +9,12 @@ Model
.. autoclass:: Model .. autoclass:: Model
:members: :members:
Stream
------
.. autoclass:: Stream
:members:
Metadata Metadata
-------- --------

View File

@ -7,7 +7,7 @@ Creating a model instance and loading model
.. literalinclude:: ../native_client/python/client.py .. literalinclude:: ../native_client/python/client.py
:language: python :language: python
:linenos: :linenos:
:lines: 69, 78 :lines: 111, 120
Performing inference Performing inference
-------------------- --------------------
@ -15,7 +15,7 @@ Performing inference
.. literalinclude:: ../native_client/python/client.py .. literalinclude:: ../native_client/python/client.py
:language: python :language: python
:linenos: :linenos:
:lines: 95-98 :lines: 140-145
Full source code Full source code
---------------- ----------------

View File

@ -106,9 +106,9 @@ Note: the following command assumes you `downloaded the pre-trained model <#gett
.. code-block:: bash .. code-block:: bash
deepspeech --model models/output_graph.pbmm --lm models/lm.binary --trie models/trie --audio my_audio_file.wav deepspeech --model models/output_graph.pbmm --scorer models/kenlm.scorer --audio my_audio_file.wav
The arguments ``--lm`` and ``--trie`` are optional, and represent a language model. The ``--scorer`` argument is optional, and represents an external language model to be used when transcribing the audio.
See :github:`client.py <native_client/python/client.py>` for an example of how to use the package programatically. See :github:`client.py <native_client/python/client.py>` for an example of how to use the package programatically.
@ -162,7 +162,7 @@ Note: the following command assumes you `downloaded the pre-trained model <#gett
.. code-block:: bash .. code-block:: bash
./deepspeech --model models/output_graph.pbmm --lm models/lm.binary --trie models/trie --audio audio_input.wav ./deepspeech --model models/output_graph.pbmm --scorer models/kenlm.scorer --audio audio_input.wav
See the help output with ``./deepspeech -h`` and the :github:`native client README <native_client/README.rst>` for more details. See the help output with ``./deepspeech -h`` and the :github:`native client README <native_client/README.rst>` for more details.

View File

@ -42,10 +42,9 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
def evaluate(test_csvs, create_model, try_loading): def evaluate(test_csvs, create_model, try_loading):
if FLAGS.lm_binary_path: if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path, FLAGS.scorer_path, Config.alphabet)
Config.alphabet)
else: else:
scorer = None scorer = None

View File

@ -27,17 +27,18 @@ This module should be self-contained:
- pip install native_client/python/dist/deepspeech*.whl - pip install native_client/python/dist/deepspeech*.whl
- pip install -r requirements_eval_tflite.txt - pip install -r requirements_eval_tflite.txt
Then run with a TF Lite model, LM/trie and a CSV test file Then run with a TF Lite model, a scorer and a CSV test file
''' '''
BEAM_WIDTH = 500 BEAM_WIDTH = 500
LM_ALPHA = 0.75 LM_ALPHA = 0.75
LM_BETA = 1.85 LM_BETA = 1.85
def tflite_worker(model, lm, trie, queue_in, queue_out, gpu_mask): def tflite_worker(model, scorer, queue_in, queue_out, gpu_mask):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask) os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_mask)
ds = Model(model, BEAM_WIDTH) ds = Model(model, BEAM_WIDTH)
ds.enableDecoderWithLM(lm, trie, LM_ALPHA, LM_BETA) ds.enableExternalScorer(scorer)
ds.setScorerAlphaBeta(LM_ALPHA, LM_BETA)
while True: while True:
try: try:
@ -64,7 +65,7 @@ def main(args, _):
processes = [] processes = []
for i in range(args.proc): for i in range(args.proc):
worker_process = Process(target=tflite_worker, args=(args.model, args.lm, args.trie, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i)) worker_process = Process(target=tflite_worker, args=(args.model, args.scorer, work_todo, work_done, i), daemon=True, name='tflite_process_{}'.format(i))
worker_process.start() # Launch reader() as a separate python process worker_process.start() # Launch reader() as a separate python process
processes.append(worker_process) processes.append(worker_process)
@ -113,10 +114,8 @@ def parse_args():
parser = argparse.ArgumentParser(description='Computing TFLite accuracy') parser = argparse.ArgumentParser(description='Computing TFLite accuracy')
parser.add_argument('--model', required=True, parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)') help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', required=True, parser.add_argument('--scorer', required=True,
help='Path to the language model binary file') help='Path to the external scorer file')
parser.add_argument('--trie', required=True,
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--csv', required=True, parser.add_argument('--csv', required=True,
help='Path to the CSV source file') help='Path to the CSV source file')
parser.add_argument('--proc', required=False, default=cpu_count(), type=int, parser.add_argument('--proc', required=False, default=cpu_count(), type=int,

View File

@ -27,20 +27,6 @@ genrule(
tools = [":gen_workspace_status.sh"], tools = [":gen_workspace_status.sh"],
) )
KENLM_SOURCES = glob(
[
"kenlm/lm/*.cc",
"kenlm/util/*.cc",
"kenlm/util/double-conversion/*.cc",
"kenlm/lm/*.hh",
"kenlm/util/*.hh",
"kenlm/util/double-conversion/*.h",
],
exclude = [
"kenlm/*/*test.cc",
"kenlm/*/*main.cc",
],
)
OPENFST_SOURCES_PLATFORM = select({ OPENFST_SOURCES_PLATFORM = select({
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]), "//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
@ -60,6 +46,27 @@ LINUX_LINKOPTS = [
"-Wl,-export-dynamic", "-Wl,-export-dynamic",
] ]
cc_library(
name = "kenlm",
srcs = glob([
"kenlm/lm/*.cc",
"kenlm/util/*.cc",
"kenlm/util/double-conversion/*.cc",
"kenlm/util/double-conversion/*.h",
],
exclude = [
"kenlm/*/*test.cc",
"kenlm/*/*main.cc",
],),
hdrs = glob([
"kenlm/lm/*.hh",
"kenlm/util/*.hh",
]),
copts = ["-std=c++11"],
defines = ["KENLM_MAX_ORDER=6"],
includes = ["kenlm"],
)
cc_library( cc_library(
name = "decoder", name = "decoder",
srcs = [ srcs = [
@ -69,17 +76,16 @@ cc_library(
"ctcdecode/scorer.cpp", "ctcdecode/scorer.cpp",
"ctcdecode/path_trie.cpp", "ctcdecode/path_trie.cpp",
"ctcdecode/path_trie.h", "ctcdecode/path_trie.h",
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM, ] + OPENFST_SOURCES_PLATFORM,
hdrs = [ hdrs = [
"ctcdecode/ctc_beam_search_decoder.h", "ctcdecode/ctc_beam_search_decoder.h",
"ctcdecode/scorer.h", "ctcdecode/scorer.h",
], ],
defines = ["KENLM_MAX_ORDER=6"],
includes = [ includes = [
".", ".",
"ctcdecode/third_party/ThreadPool", "ctcdecode/third_party/ThreadPool",
"kenlm",
] + OPENFST_INCLUDES_PLATFORM, ] + OPENFST_INCLUDES_PLATFORM,
deps = [":kenlm"]
) )
tf_cc_shared_object( tf_cc_shared_object(
@ -182,18 +188,12 @@ genrule(
) )
cc_binary( cc_binary(
name = "generate_trie", name = "enumerate_kenlm_vocabulary",
srcs = [ srcs = [
"alphabet.h", "enumerate_kenlm_vocabulary.cpp",
"generate_trie.cpp",
], ],
deps = [":kenlm"],
copts = ["-std=c++11"], copts = ["-std=c++11"],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
deps = [":decoder"],
) )
cc_binary( cc_binary(

View File

@ -12,19 +12,17 @@
char* model = NULL; char* model = NULL;
char* lm = NULL; char* scorer = NULL;
char* trie = NULL;
char* audio = NULL; char* audio = NULL;
int beam_width = 500; int beam_width = 500;
float lm_alpha = 0.75f; bool set_alphabeta = false;
float lm_beta = 1.85f; float lm_alpha = 0.f;
bool load_without_trie = false; float lm_beta = 0.f;
bool show_times = false; bool show_times = false;
@ -39,45 +37,42 @@ int stream_size = 0;
void PrintHelp(const char* bin) void PrintHelp(const char* bin)
{ {
std::cout << std::cout <<
"Usage: " << bin << " --model MODEL [--lm LM --trie TRIE] --audio AUDIO [-t] [-e]\n" "Usage: " << bin << " --model MODEL [--scorer SCORER] --audio AUDIO [-t] [-e]\n"
"\n" "\n"
"Running DeepSpeech inference.\n" "Running DeepSpeech inference.\n"
"\n" "\n"
" --model MODEL Path to the model (protocol buffer binary file)\n" "\t--model MODEL\t\tPath to the model (protocol buffer binary file)\n"
" --lm LM Path to the language model binary file\n" "\t--scorer SCORER\t\tPath to the external scorer file\n"
" --trie TRIE Path to the language model trie file created with native_client/generate_trie\n" "\t--audio AUDIO\t\tPath to the audio file to run (WAV format)\n"
" --audio AUDIO Path to the audio file to run (WAV format)\n" "\t--beam_width BEAM_WIDTH\tValue for decoder beam width (int)\n"
" --beam_width BEAM_WIDTH Value for decoder beam width (int)\n" "\t--lm_alpha LM_ALPHA\tValue for language model alpha param (float)\n"
" --lm_alpha LM_ALPHA Value for language model alpha param (float)\n" "\t--lm_beta LM_BETA\tValue for language model beta param (float)\n"
" --lm_beta LM_BETA Value for language model beta param (float)\n" "\t-t\t\t\tRun in benchmark mode, output mfcc & inference time\n"
" -t Run in benchmark mode, output mfcc & inference time\n" "\t--extended\t\tOutput string from extended metadata\n"
" --extended Output string from extended metadata\n" "\t--json\t\t\tExtended output, shows word timings as JSON\n"
" --json Extended output, shows word timings as JSON\n" "\t--stream size\t\tRun in stream mode, output intermediate results\n"
" --stream size Run in stream mode, output intermediate results\n" "\t--help\t\t\tShow help\n"
" --help Show help\n" "\t--version\t\tPrint version and exits\n";
" --version Print version and exits\n";
DS_PrintVersions(); DS_PrintVersions();
exit(1); exit(1);
} }
bool ProcessArgs(int argc, char** argv) bool ProcessArgs(int argc, char** argv)
{ {
const char* const short_opts = "m:a:l:r:w:c:d:b:tehv"; const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
const option long_opts[] = { const option long_opts[] = {
{"model", required_argument, nullptr, 'm'}, {"model", required_argument, nullptr, 'm'},
{"lm", required_argument, nullptr, 'l'}, {"scorer", required_argument, nullptr, 'l'},
{"trie", required_argument, nullptr, 'r'}, {"audio", required_argument, nullptr, 'a'},
{"audio", required_argument, nullptr, 'w'},
{"beam_width", required_argument, nullptr, 'b'}, {"beam_width", required_argument, nullptr, 'b'},
{"lm_alpha", required_argument, nullptr, 'c'}, {"lm_alpha", required_argument, nullptr, 'c'},
{"lm_beta", required_argument, nullptr, 'd'}, {"lm_beta", required_argument, nullptr, 'd'},
{"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999},
{"t", no_argument, nullptr, 't'}, {"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'}, {"extended", no_argument, nullptr, 'e'},
{"json", no_argument, nullptr, 'j'}, {"json", no_argument, nullptr, 'j'},
{"stream", required_argument, nullptr, 's'}, {"stream", required_argument, nullptr, 's'},
{"help", no_argument, nullptr, 'h'},
{"version", no_argument, nullptr, 'v'}, {"version", no_argument, nullptr, 'v'},
{"help", no_argument, nullptr, 'h'},
{nullptr, no_argument, nullptr, 0} {nullptr, no_argument, nullptr, 0}
}; };
@ -95,41 +90,31 @@ bool ProcessArgs(int argc, char** argv)
break; break;
case 'l': case 'l':
lm = optarg; scorer = optarg;
break; break;
case 'r': case 'a':
trie = optarg;
break;
case 'w':
audio = optarg; audio = optarg;
break; break;
case 'b': case 'b':
beam_width = atoi(optarg); beam_width = atoi(optarg);
break; break;
case 'c':
lm_alpha = atof(optarg);
break;
case 'd':
lm_beta = atof(optarg);
break;
case 999: case 'c':
load_without_trie = true; set_alphabeta = true;
lm_alpha = atof(optarg);
break;
case 'd':
set_alphabeta = true;
lm_beta = atof(optarg);
break; break;
case 't': case 't':
show_times = true; show_times = true;
break; break;
case 'v':
has_versions = true;
break;
case 'e': case 'e':
extended_metadata = true; extended_metadata = true;
break; break;
@ -142,6 +127,10 @@ bool ProcessArgs(int argc, char** argv)
stream_size = atoi(optarg); stream_size = atoi(optarg);
break; break;
case 'v':
has_versions = true;
break;
case 'h': // -h or --help case 'h': // -h or --help
case '?': // Unrecognized option case '?': // Unrecognized option
default: default:

View File

@ -374,16 +374,19 @@ main(int argc, char **argv)
return 1; return 1;
} }
if (lm && (trie || load_without_trie)) { if (scorer) {
int status = DS_EnableDecoderWithLM(ctx, int status = DS_EnableExternalScorer(ctx, scorer);
lm,
trie,
lm_alpha,
lm_beta);
if (status != 0) { if (status != 0) {
fprintf(stderr, "Could not enable CTC decoder with LM.\n"); fprintf(stderr, "Could not enable external scorer.\n");
return 1; return 1;
} }
if (set_alphabeta) {
status = DS_SetScorerAlphaBeta(ctx, lm_alpha, lm_beta);
if (status != 0) {
fprintf(stderr, "Error setting scorer alpha and beta.\n");
return 1;
}
}
} }
#ifndef NO_SOX #ifndef NO_SOX

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import Alphabet
__version__ = swigwrapper.__version__ __version__ = swigwrapper.__version__
@ -11,26 +12,36 @@ class Scorer(swigwrapper.Scorer):
:type alpha: float :type alpha: float
:param beta: Word insertion bonus. :param beta: Word insertion bonus.
:type beta: float :type beta: float
:model_path: Path to load language model. :scorer_path: Path to load scorer from.
:trie_path: Path to trie file.
:alphabet: Alphabet :alphabet: Alphabet
:type model_path: basestring :type scorer_path: basestring
""" """
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
super(Scorer, self).__init__() super(Scorer, self).__init__()
serialized = alphabet.serialize() # Allow bare initialization
native_alphabet = swigwrapper.Alphabet() if alphabet:
err = native_alphabet.deserialize(serialized, len(serialized)) assert alpha is not None, 'alpha parameter is required'
if err != 0: assert beta is not None, 'beta parameter is required'
raise ValueError("Error when deserializing alphabet.") assert scorer_path, 'scorer_path parameter is required'
err = self.init(alpha, beta, serialized = alphabet.serialize()
model_path.encode('utf-8'), native_alphabet = swigwrapper.Alphabet()
trie_path.encode('utf-8'), err = native_alphabet.deserialize(serialized, len(serialized))
native_alphabet) if err != 0:
if err != 0: raise ValueError('Error when deserializing alphabet.')
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
err = self.init(scorer_path.encode('utf-8'),
native_alphabet)
if err != 0:
raise ValueError('Scorer initialization failed with error code {}'.format(err))
self.reset_params(alpha, beta)
def load_lm(self, lm_path):
super(Scorer, self).load_lm(lm_path.encode('utf-8'))
def save_dictionary(self, save_path, *args, **kwargs):
super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
def ctc_beam_search_decoder(probs_seq, def ctc_beam_search_decoder(probs_seq,

View File

@ -18,7 +18,7 @@ DecoderState::init(const Alphabet& alphabet,
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) std::shared_ptr<Scorer> ext_scorer)
{ {
// assign special ids // assign special ids
abs_time_step_ = 0; abs_time_step_ = 0;
@ -36,7 +36,7 @@ DecoderState::init(const Alphabet& alphabet,
prefix_root_.reset(root); prefix_root_.reset(root);
prefixes_.push_back(root); prefixes_.push_back(root);
if (ext_scorer != nullptr) { if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
// no need for std::make_shared<>() since Copy() does 'new' behind the doors // no need for std::make_shared<>() since Copy() does 'new' behind the doors
auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true)); auto dict_ptr = std::shared_ptr<PathTrie::FstType>(ext_scorer->dictionary->Copy(true));
root->set_dictionary(dict_ptr); root->set_dictionary(dict_ptr);
@ -58,7 +58,7 @@ DecoderState::next(const double *probs,
float min_cutoff = -NUM_FLT_INF; float min_cutoff = -NUM_FLT_INF;
bool full_beam = false; bool full_beam = false;
if (ext_scorer_ != nullptr) { if (ext_scorer_) {
size_t num_prefixes = std::min(prefixes_.size(), beam_size_); size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
std::partial_sort(prefixes_.begin(), std::partial_sort(prefixes_.begin(),
prefixes_.begin() + num_prefixes, prefixes_.begin() + num_prefixes,
@ -109,7 +109,7 @@ DecoderState::next(const double *probs,
log_p = log_prob_c + prefix->score; log_p = log_prob_c + prefix->score;
} }
if (ext_scorer_ != nullptr) { if (ext_scorer_) {
// skip scoring the space in word based LMs // skip scoring the space in word based LMs
PathTrie* prefix_to_score; PathTrie* prefix_to_score;
if (ext_scorer_->is_utf8_mode()) { if (ext_scorer_->is_utf8_mode()) {
@ -166,7 +166,7 @@ DecoderState::decode() const
} }
// score the last word of each prefix that doesn't end with space // score the last word of each prefix that doesn't end with space
if (ext_scorer_ != nullptr) { if (ext_scorer_) {
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) { for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i]; auto prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) { if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
@ -200,7 +200,7 @@ DecoderState::decode() const
Output output; Output output;
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps); prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
double approx_ctc = scores[prefixes_copy[i]]; double approx_ctc = scores[prefixes_copy[i]];
if (ext_scorer_ != nullptr) { if (ext_scorer_) {
auto words = ext_scorer_->split_labels_into_scored_units(output.tokens); auto words = ext_scorer_->split_labels_into_scored_units(output.tokens);
// remove term insertion weight // remove term insertion weight
approx_ctc -= words.size() * ext_scorer_->beta; approx_ctc -= words.size() * ext_scorer_->beta;
@ -222,7 +222,7 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) std::shared_ptr<Scorer> ext_scorer)
{ {
DecoderState state; DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
@ -243,7 +243,7 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) std::shared_ptr<Scorer> ext_scorer)
{ {
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element"); VALID_CHECK_EQ(batch_size, seq_lengths_size, "must have one sequence length per batch element");

View File

@ -1,6 +1,7 @@
#ifndef CTC_BEAM_SEARCH_DECODER_H_ #ifndef CTC_BEAM_SEARCH_DECODER_H_
#define CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
@ -16,7 +17,7 @@ class DecoderState {
double cutoff_prob_; double cutoff_prob_;
size_t cutoff_top_n_; size_t cutoff_top_n_;
Scorer* ext_scorer_; // weak std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_; std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_; std::unique_ptr<PathTrie> prefix_root_;
@ -45,7 +46,7 @@ public:
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer); std::shared_ptr<Scorer> ext_scorer);
/* Send data to the decoder /* Send data to the decoder
* *
@ -95,7 +96,7 @@ std::vector<Output> ctc_beam_search_decoder(
size_t beam_size, size_t beam_size,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer); std::shared_ptr<Scorer> ext_scorer);
/* CTC Beam Search Decoder for batch data /* CTC Beam Search Decoder for batch data
* Parameters: * Parameters:
@ -126,6 +127,6 @@ ctc_beam_search_decoder_batch(
size_t num_processes, size_t num_processes,
double cutoff_prob, double cutoff_prob,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer); std::shared_ptr<Scorer> ext_scorer);
#endif // CTC_BEAM_SEARCH_DECODER_H_ #endif // CTC_BEAM_SEARCH_DECODER_H_

View File

@ -24,41 +24,37 @@
#include "decoder_utils.h" #include "decoder_utils.h"
using namespace lm::ngram;
static const int32_t MAGIC = 'TRIE'; static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 5; static const int32_t FILE_VERSION = 6;
int int
Scorer::init(double alpha, Scorer::init(const std::string& lm_path,
double beta,
const std::string& lm_path,
const std::string& trie_path,
const Alphabet& alphabet) const Alphabet& alphabet)
{ {
reset_params(alpha, beta); set_alphabet(alphabet);
alphabet_ = alphabet; return load_lm(lm_path);
setup(lm_path, trie_path);
return 0;
} }
int int
Scorer::init(double alpha, Scorer::init(const std::string& lm_path,
double beta,
const std::string& lm_path,
const std::string& trie_path,
const std::string& alphabet_config_path) const std::string& alphabet_config_path)
{ {
reset_params(alpha, beta);
int err = alphabet_.init(alphabet_config_path.c_str()); int err = alphabet_.init(alphabet_config_path.c_str());
if (err != 0) { if (err != 0) {
return err; return err;
} }
setup(lm_path, trie_path); setup_char_map();
return 0; return load_lm(lm_path);
} }
void Scorer::setup(const std::string& lm_path, const std::string& trie_path) void
Scorer::set_alphabet(const Alphabet& alphabet)
{
alphabet_ = alphabet;
setup_char_map();
}
void Scorer::setup_char_map()
{ {
// (Re-)Initialize character map // (Re-)Initialize character map
char_map_.clear(); char_map_.clear();
@ -71,78 +67,99 @@ void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
// state, otherwise wrong decoding results would be given. // state, otherwise wrong decoding results would be given.
char_map_[alphabet_.StringFromLabel(i)] = i + 1; char_map_[alphabet_.StringFromLabel(i)] = i + 1;
} }
// load language model
const char* filename = lm_path.c_str();
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
bool has_trie = trie_path.size() && access(trie_path.c_str(), R_OK) == 0;
lm::ngram::Config config;
if (!has_trie) { // no trie was specified, build it now
RetrieveStrEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
auto vocab = enumerate.vocabulary;
for (size_t i = 0; i < vocab.size(); ++i) {
if (vocab[i] != UNK_TOKEN &&
vocab[i] != START_TOKEN &&
vocab[i] != END_TOKEN &&
get_utf8_str_len(vocab[i]) > 1) {
is_utf8_mode_ = false;
break;
}
}
if (alphabet_.GetSize() != 255) {
is_utf8_mode_ = false;
}
// Add spaces only in word-based scoring
fill_dictionary(vocab);
} else {
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
// Read metadata and trie from file
std::ifstream fin(trie_path, std::ios::binary);
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse trie file, invalid header. Try updating "
"your trie file." << std::endl;
throw 1;
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Trie file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). Update your trie file."
<< std::endl;
throw 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = trie_path;
dictionary.reset(FstType::Read(fin, opt));
}
max_order_ = language_model_->Order();
} }
void Scorer::save_dictionary(const std::string& path) int Scorer::load_lm(const std::string& lm_path)
{ {
std::ofstream fout(path, std::ios::binary); // Check if file is readable to avoid KenLM throwing an exception
const char* filename = lm_path.c_str();
if (access(filename, R_OK) != 0) {
return 1;
}
// Check if the file format is valid to avoid KenLM throwing an exception
lm::ngram::ModelType model_type;
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
return 1;
}
// Load the LM
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
max_order_ = language_model_->Order();
uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(filename));
package_size = util::SizeFile(fd.get());
}
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
if (package_size <= trie_offset) {
// File ends without a trie structure
return 1;
}
// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
return load_trie(fin, lm_path);
}
int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
{
int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
std::cerr << "Error: Can't parse scorer file, invalid header. Try updating "
"your scorer file." << std::endl;
return 1;
}
int version;
fin.read(reinterpret_cast<char*>(&version), sizeof(version));
if (version != FILE_VERSION) {
std::cerr << "Error: Scorer file version mismatch (" << version
<< " instead of expected " << FILE_VERSION
<< "). ";
if (version < FILE_VERSION) {
std::cerr << "Update your scorer file.";
} else {
std::cerr << "Downgrade your scorer file or update your version of DeepSpeech.";
}
std::cerr << std::endl;
return 1;
}
fin.read(reinterpret_cast<char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
// Read hyperparameters from header
double alpha, beta;
fin.read(reinterpret_cast<char*>(&alpha), sizeof(alpha));
fin.read(reinterpret_cast<char*>(&beta), sizeof(beta));
reset_params(alpha, beta);
fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
return 0;
}
void Scorer::save_dictionary(const std::string& path, bool append_instead_of_overwrite)
{
std::ios::openmode om;
if (append_instead_of_overwrite) {
om = std::ios::in|std::ios::out|std::ios::binary|std::ios::ate;
} else {
om = std::ios::out|std::ios::binary;
}
std::fstream fout(path, om);
fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC)); fout.write(reinterpret_cast<const char*>(&MAGIC), sizeof(MAGIC));
fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION)); fout.write(reinterpret_cast<const char*>(&FILE_VERSION), sizeof(FILE_VERSION));
fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_)); fout.write(reinterpret_cast<const char*>(&is_utf8_mode_), sizeof(is_utf8_mode_));
fout.write(reinterpret_cast<const char*>(&alpha), sizeof(alpha));
fout.write(reinterpret_cast<const char*>(&beta), sizeof(beta));
fst::FstWriteOptions opt; fst::FstWriteOptions opt;
opt.align = true; opt.align = true;
opt.source = path; opt.source = path;

View File

@ -6,31 +6,19 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh" #include "lm/virtual_interface.hh"
#include "lm/word_index.hh" #include "lm/word_index.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
#include "path_trie.h" #include "path_trie.h"
#include "alphabet.h" #include "alphabet.h"
#include "deepspeech.h"
const double OOV_SCORE = -1000.0; const double OOV_SCORE = -1000.0;
const std::string START_TOKEN = "<s>"; const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>"; const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>"; const std::string END_TOKEN = "</s>";
// Implement a callback to retrieve the dictionary of language model.
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab {
public:
RetrieveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
/* External scorer to query score for n-gram or sentence, including language /* External scorer to query score for n-gram or sentence, including language
* model scoring and word insertion. * model scoring and word insertion.
* *
@ -40,9 +28,9 @@ public:
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" }); * scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
*/ */
class Scorer { class Scorer {
public:
using FstType = PathTrie::FstType; using FstType = PathTrie::FstType;
public:
Scorer() = default; Scorer() = default;
~Scorer() = default; ~Scorer() = default;
@ -50,16 +38,10 @@ public:
Scorer(const Scorer&) = delete; Scorer(const Scorer&) = delete;
Scorer& operator=(const Scorer&) = delete; Scorer& operator=(const Scorer&) = delete;
int init(double alpha, int init(const std::string &lm_path,
double beta,
const std::string &lm_path,
const std::string &trie_path,
const Alphabet &alphabet); const Alphabet &alphabet);
int init(double alpha, int init(const std::string &lm_path,
double beta,
const std::string &lm_path,
const std::string &trie_path,
const std::string &alphabet_config_path); const std::string &alphabet_config_path);
double get_log_cond_prob(const std::vector<std::string> &words, double get_log_cond_prob(const std::vector<std::string> &words,
@ -76,12 +58,15 @@ public:
// return the max order // return the max order
size_t get_max_order() const { return max_order_; } size_t get_max_order() const { return max_order_; }
// retrun true if the language model is character based // return true if the language model is character based
bool is_utf8_mode() const { return is_utf8_mode_; } bool is_utf8_mode() const { return is_utf8_mode_; }
// reset params alpha & beta // reset params alpha & beta
void reset_params(float alpha, float beta); void reset_params(float alpha, float beta);
// force set UTF-8 mode, ignore value read from file
void set_utf8_mode(bool utf8) { is_utf8_mode_ = utf8; }
// make ngram for a given prefix // make ngram for a given prefix
std::vector<std::string> make_ngram(PathTrie *prefix); std::vector<std::string> make_ngram(PathTrie *prefix);
@ -89,12 +74,20 @@ public:
// the vector of characters (character based lm) // the vector of characters (character based lm)
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels); std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
void set_alphabet(const Alphabet& alphabet);
// save dictionary in file // save dictionary in file
void save_dictionary(const std::string &path); void save_dictionary(const std::string &path, bool append_instead_of_overwrite=false);
// return weather this step represents a boundary where beam scoring should happen // return weather this step represents a boundary where beam scoring should happen
bool is_scoring_boundary(PathTrie* prefix, size_t new_label); bool is_scoring_boundary(PathTrie* prefix, size_t new_label);
// fill dictionary FST from a vocabulary
void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path
int load_lm(const std::string &lm_path);
// language model weight // language model weight
double alpha = 0.; double alpha = 0.;
// word insertion weight // word insertion weight
@ -104,14 +97,10 @@ public:
std::unique_ptr<FstType> dictionary; std::unique_ptr<FstType> dictionary;
protected: protected:
// necessary setup: load language model, fill FST's dictionary // necessary setup after setting alphabet
void setup(const std::string &lm_path, const std::string &trie_path); void setup_char_map();
// load language model from given path int load_trie(std::ifstream& fin, const std::string& file_path);
void load_lm(const std::string &lm_path);
// fill dictionary for FST
void fill_dictionary(const std::vector<std::string> &vocabulary);
private: private:
std::unique_ptr<lm::base::Model> language_model_; std::unique_ptr<lm::base::Model> language_model_;

View File

@ -7,15 +7,22 @@
#include "workspace_status.h" #include "workspace_status.h"
%} %}
%include "pyabc.i" %include <pyabc.i>
%include "std_string.i" %include <std_string.i>
%include "std_vector.i" %include <std_vector.i>
%include <std_shared_ptr.i>
%include "numpy.i" %include "numpy.i"
%init %{ %init %{
import_array(); import_array();
%} %}
namespace std {
%template(StringVector) vector<string>;
}
%shared_ptr(Scorer);
// Convert NumPy arrays to pointer+lengths // Convert NumPy arrays to pointer+lengths
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)}; %apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)}; %apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};

View File

@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
} }
int int
DS_EnableDecoderWithLM(ModelState* aCtx, DS_EnableExternalScorer(ModelState* aCtx,
const char* aLMPath, const char* aScorerPath)
const char* aTriePath,
float aLMAlpha,
float aLMBeta)
{ {
aCtx->scorer_.reset(new Scorer()); aCtx->scorer_.reset(new Scorer());
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta, int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
aLMPath ? aLMPath : "",
aTriePath ? aTriePath : "",
aCtx->alphabet_);
if (err != 0) { if (err != 0) {
return DS_ERR_INVALID_LM; return DS_ERR_INVALID_SCORER;
} }
return DS_ERR_OK; return DS_ERR_OK;
} }
int
DS_DisableExternalScorer(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->scorer_.reset();
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta)
{
if (aCtx->scorer_) {
aCtx->scorer_->reset_params(aAlpha, aBeta);
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int int
DS_CreateStream(ModelState* aCtx, DS_CreateStream(ModelState* aCtx,
StreamingState** retval) StreamingState** retval)
@ -348,7 +363,7 @@ DS_CreateStream(ModelState* aCtx,
aCtx->beam_width_, aCtx->beam_width_,
cutoff_prob, cutoff_prob,
cutoff_top_n, cutoff_top_n,
aCtx->scorer_.get()); aCtx->scorer_);
*retval = ctx.release(); *retval = ctx.release();
return DS_ERR_OK; return DS_ERR_OK;

View File

@ -59,8 +59,9 @@ enum DeepSpeech_Error_Codes
// Invalid parameters // Invalid parameters
DS_ERR_INVALID_ALPHABET = 0x2000, DS_ERR_INVALID_ALPHABET = 0x2000,
DS_ERR_INVALID_SHAPE = 0x2001, DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002, DS_ERR_INVALID_SCORER = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003, DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
// Runtime failures // Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000, DS_ERR_FAIL_INIT_MMAP = 0x3000,
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
void DS_FreeModel(ModelState* ctx); void DS_FreeModel(ModelState* ctx);
/** /**
* @brief Enable decoding using beam scoring with a KenLM language model. * @brief Enable decoding using an external scorer.
* *
* @param aCtx The ModelState pointer for the model being changed. * @param aCtx The ModelState pointer for the model being changed.
* @param aLMPath The path to the language model binary file. * @param aScorerPath The path to the external scorer file.
* @param aTriePath The path to the trie file build from the same vocabu-
* lary as the language model binary.
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
weight.
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
weight.
* *
* @return Zero on success, non-zero on failure (invalid arguments). * @return Zero on success, non-zero on failure (invalid arguments).
*/ */
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
int DS_EnableDecoderWithLM(ModelState* aCtx, int DS_EnableExternalScorer(ModelState* aCtx,
const char* aLMPath, const char* aScorerPath);
const char* aTriePath,
float aLMAlpha, /**
float aLMBeta); * @brief Disable decoding using an external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_DisableExternalScorer(ModelState* aCtx);
/**
* @brief Set hyperparameters alpha and beta of the external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta);
/** /**
* @brief Use the DeepSpeech model to perform Speech-To-Text. * @brief Use the DeepSpeech model to perform Speech-To-Text.

View File

@ -1,141 +0,0 @@
#ifndef DEEPSPEECH_COMPAT_H
#define DEEPSPEECH_COMPAT_H
#include "deepspeech.h"
#warning This header is a convenience wrapper for compatibility with \
the previous API, it has deprecated function names and arguments. \
If possible, update your code instead of using this header.
/**
* @brief An object providing an interface to a trained DeepSpeech model.
*
* @param aModelPath The path to the frozen model graph.
* @param aNCep UNUSED, DEPRECATED.
* @param aNContext UNUSED, DEPRECATED.
* @param aAlphabetConfigPath UNUSED, DEPRECATED.
* @param aBeamWidth The beam width used by the decoder. A larger beam
* width generates better results at the cost of decoding
* time.
* @param[out] retval a ModelState pointer
*
* @return Zero on success, non-zero on failure.
*/
int DS_CreateModel(const char* aModelPath,
unsigned int /*aNCep*/,
unsigned int /*aNContext*/,
const char* /*aAlphabetConfigPath*/,
unsigned int aBeamWidth,
ModelState** retval)
{
return DS_CreateModel(aModelPath, aBeamWidth, retval);
}
/**
* @brief Frees associated resources and destroys model object.
*/
void DS_DestroyModel(ModelState* ctx)
{
return DS_FreeModel(ctx);
}
/**
* @brief Enable decoding using beam scoring with a KenLM language model.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param aAlphabetConfigPath UNUSED, DEPRECATED.
* @param aLMPath The path to the language model binary file.
* @param aTriePath The path to the trie file build from the same vocabu-
* lary as the language model binary.
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
weight.
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
weight.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
int DS_EnableDecoderWithLM(ModelState* aCtx,
const char* /*aAlphabetConfigPath*/,
const char* aLMPath,
const char* aTriePath,
float aLMAlpha,
float aLMBeta)
{
return DS_EnableDecoderWithLM(aCtx, aLMPath, aTriePath, aLMAlpha, aLMBeta);
}
/**
* @brief Create a new streaming inference state. The streaming state returned
* by this function can then be passed to {@link DS_FeedAudioContent()}
* and {@link DS_FinishStream()}.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aSampleRate UNUSED, DEPRECATED.
* @param[out] retval an opaque pointer that represents the streaming state. Can
* be NULL if an error occurs.
*
* @return Zero for success, non-zero on failure.
*/
int DS_SetupStream(ModelState* aCtx,
unsigned int /*aSampleRate*/,
StreamingState** retval)
{
return DS_CreateStream(aCtx, retval);
}
/**
* @brief Destroy a streaming state without decoding the computed logits. This
* can be used if you no longer need the result of an ongoing streaming
* inference and don't want to perform a costly decode operation.
*
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
*
* @note This method will free the state pointer (@p aSctx).
*/
void DS_DiscardStream(StreamingState* aSctx)
{
return DS_FreeStream(aSctx);
}
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on).
* @param aBufferSize The number of samples in the audio signal.
* @param aSampleRate UNUSED, DEPRECATED.
*
* @return The STT result. The user is responsible for freeing the string using
* {@link DS_FreeString()}. Returns NULL on error.
*/
char* DS_SpeechToText(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int /*aSampleRate*/)
{
return DS_SpeechToText(aCtx, aBuffer, aBufferSize);
}
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
* about the results.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on).
* @param aBufferSize The number of samples in the audio signal.
* @param aSampleRate UNUSED, DEPRECATED.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
*/
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int /*aSampleRate*/)
{
return DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize);
}
#endif /* DEEPSPEECH_COMPAT_H */

View File

@ -82,8 +82,8 @@ namespace DeepSpeechClient
throw new ArgumentException("Invalid alphabet embedded in model. (Data corruption?)"); throw new ArgumentException("Invalid alphabet embedded in model. (Data corruption?)");
case ErrorCodes.DS_ERR_INVALID_SHAPE: case ErrorCodes.DS_ERR_INVALID_SHAPE:
throw new ArgumentException("Invalid model shape."); throw new ArgumentException("Invalid model shape.");
case ErrorCodes.DS_ERR_INVALID_LM: case ErrorCodes.DS_ERR_INVALID_SCORER:
throw new ArgumentException("Invalid language model file."); throw new ArgumentException("Invalid scorer file.");
case ErrorCodes.DS_ERR_FAIL_INIT_MMAP: case ErrorCodes.DS_ERR_FAIL_INIT_MMAP:
throw new ArgumentException("Failed to initialize memory mapped model."); throw new ArgumentException("Failed to initialize memory mapped model.");
case ErrorCodes.DS_ERR_FAIL_INIT_SESS: case ErrorCodes.DS_ERR_FAIL_INIT_SESS:
@ -100,6 +100,8 @@ namespace DeepSpeechClient
throw new ArgumentException("Error failed to create session."); throw new ArgumentException("Error failed to create session.");
case ErrorCodes.DS_ERR_MODEL_INCOMPATIBLE: case ErrorCodes.DS_ERR_MODEL_INCOMPATIBLE:
throw new ArgumentException("Error incompatible model."); throw new ArgumentException("Error incompatible model.");
case ErrorCodes.DS_ERR_SCORER_NOT_ENABLED:
throw new ArgumentException("External scorer is not enabled.");
default: default:
throw new ArgumentException("Unknown error, please make sure you are using the correct native binary."); throw new ArgumentException("Unknown error, please make sure you are using the correct native binary.");
} }
@ -114,45 +116,48 @@ namespace DeepSpeechClient
} }
/// <summary> /// <summary>
/// Enable decoding using beam scoring with a KenLM language model. /// Enable decoding using an external scorer.
/// </summary> /// </summary>
/// <param name="aLMPath">The path to the language model binary file.</param> /// <param name="aScorerPath">The path to the external scorer file.</param>
/// <param name="aTriePath">The path to the trie file build from the same vocabulary as the language model binary.</param> /// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with an external scorer.</exception>
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param> /// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param> public unsafe void EnableExternalScorer(string aScorerPath)
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception>
public unsafe void EnableDecoderWithLM(string aLMPath, string aTriePath,
float aLMAlpha, float aLMBeta)
{ {
string exceptionMessage = null; string exceptionMessage = null;
if (string.IsNullOrWhiteSpace(aLMPath)) if (string.IsNullOrWhiteSpace(aScorerPath))
{ {
exceptionMessage = "Path to the language model file cannot be empty."; throw new FileNotFoundException("Path to the scorer file cannot be empty.");
} }
if (!File.Exists(aLMPath)) if (!File.Exists(aScorerPath))
{ {
exceptionMessage = $"Cannot find the language model file: {aLMPath}"; throw new FileNotFoundException($"Cannot find the scorer file: {aScorerPath}");
}
if (string.IsNullOrWhiteSpace(aTriePath))
{
exceptionMessage = "Path to the trie file cannot be empty.";
}
if (!File.Exists(aTriePath))
{
exceptionMessage = $"Cannot find the trie file: {aTriePath}";
} }
if (exceptionMessage != null) var resultCode = NativeImp.DS_EnableExternalScorer(_modelStatePP, aScorerPath);
{ EvaluateResultCode(resultCode);
throw new FileNotFoundException(exceptionMessage); }
}
var resultCode = NativeImp.DS_EnableDecoderWithLM(_modelStatePP, /// <summary>
aLMPath, /// Disable decoding using an external scorer.
aTriePath, /// </summary>
aLMAlpha, /// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
aLMBeta); public unsafe void DisableExternalScorer()
{
var resultCode = NativeImp.DS_DisableExternalScorer(_modelStatePP);
EvaluateResultCode(resultCode);
}
/// <summary>
/// Set hyperparameters alpha and beta of the external scorer.
/// </summary>
/// <param name="aAlpha">The alpha hyperparameter of the decoder. Language model weight.</param>
/// <param name="aBeta">The beta hyperparameter of the decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
public unsafe void SetScorerAlphaBeta(float aAlpha, float aBeta)
{
var resultCode = NativeImp.DS_SetScorerAlphaBeta(_modelStatePP,
aAlpha,
aBeta);
EvaluateResultCode(resultCode); EvaluateResultCode(resultCode);
} }

View File

@ -14,8 +14,9 @@
// Invalid parameters // Invalid parameters
DS_ERR_INVALID_ALPHABET = 0x2000, DS_ERR_INVALID_ALPHABET = 0x2000,
DS_ERR_INVALID_SHAPE = 0x2001, DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002, DS_ERR_INVALID_SCORER = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003, DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
// Runtime failures // Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000, DS_ERR_FAIL_INIT_MMAP = 0x3000,

View File

@ -21,18 +21,26 @@ namespace DeepSpeechClient.Interfaces
unsafe int GetModelSampleRate(); unsafe int GetModelSampleRate();
/// <summary> /// <summary>
/// Enable decoding using beam scoring with a KenLM language model. /// Enable decoding using an external scorer.
/// </summary> /// </summary>
/// <param name="aLMPath">The path to the language model binary file.</param> /// <param name="aScorerPath">The path to the external scorer file.</param>
/// <param name="aTriePath">The path to the trie file build from the same vocabulary as the language model binary.</param> /// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with an external scorer.</exception>
/// <param name="aLMAlpha">The alpha hyperparameter of the CTC decoder. Language Model weight.</param> /// <exception cref="FileNotFoundException">Thrown when cannot find the scorer file.</exception>
/// <param name="aLMBeta">The beta hyperparameter of the CTC decoder. Word insertion weight.</param> unsafe void EnableExternalScorer(string aScorerPath);
/// <exception cref="ArgumentException">Thrown when the native binary failed to enable decoding with a language model.</exception>
/// <exception cref="FileNotFoundException">Thrown when cannot find the language model or trie file.</exception> /// <summary>
unsafe void EnableDecoderWithLM(string aLMPath, /// Disable decoding using an external scorer.
string aTriePath, /// </summary>
float aLMAlpha, /// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
float aLMBeta); unsafe void DisableExternalScorer();
/// <summary>
/// Set hyperparameters alpha and beta of the external scorer.
/// </summary>
/// <param name="aAlpha">The alpha hyperparameter of the decoder. Language model weight.</param>
/// <param name="aBeta">The beta hyperparameter of the decoder. Word insertion weight.</param>
/// <exception cref="ArgumentException">Thrown when an external scorer is not enabled.</exception>
unsafe void SetScorerAlphaBeta(float aAlpha, float aBeta);
/// <summary> /// <summary>
/// Use the DeepSpeech model to perform Speech-To-Text. /// Use the DeepSpeech model to perform Speech-To-Text.

View File

@ -23,11 +23,16 @@ namespace DeepSpeechClient
internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx); internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_EnableDecoderWithLM(IntPtr** aCtx, internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx,
string aLMPath, string aScorerPath);
string aTriePath,
float aLMAlpha, [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
float aLMBeta); internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)]
internal static unsafe extern ErrorCodes DS_SetScorerAlphaBeta(IntPtr** aCtx,
float aAlpha,
float aBeta);
[DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl,
CharSet = CharSet.Ansi, SetLastError = true)] CharSet = CharSet.Ansi, SetLastError = true)]

View File

@ -35,22 +35,18 @@ namespace CSharpExamples
static void Main(string[] args) static void Main(string[] args)
{ {
string model = null; string model = null;
string lm = null; string scorer = null;
string trie = null;
string audio = null; string audio = null;
bool extended = false; bool extended = false;
if (args.Length > 0) if (args.Length > 0)
{ {
model = GetArgument(args, "--model"); model = GetArgument(args, "--model");
lm = GetArgument(args, "--lm"); scorer = GetArgument(args, "--scorer");
trie = GetArgument(args, "--trie");
audio = GetArgument(args, "--audio"); audio = GetArgument(args, "--audio");
extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended")); extended = !string.IsNullOrWhiteSpace(GetArgument(args, "--extended"));
} }
const uint BEAM_WIDTH = 500; const uint BEAM_WIDTH = 500;
const float LM_ALPHA = 0.75f;
const float LM_BETA = 1.85f;
Stopwatch stopwatch = new Stopwatch(); Stopwatch stopwatch = new Stopwatch();
try try
@ -64,14 +60,10 @@ namespace CSharpExamples
Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms"); Console.WriteLine($"Model loaded - {stopwatch.Elapsed.Milliseconds} ms");
stopwatch.Reset(); stopwatch.Reset();
if (lm != null) if (scorer != null)
{ {
Console.WriteLine("Loadin LM..."); Console.WriteLine("Loading scorer...");
sttClient.EnableDecoderWithLM( sttClient.EnableExternalScorer(scorer ?? "kenlm.scorer");
lm ?? "lm.binary",
trie ?? "trie",
LM_ALPHA, LM_BETA);
} }
string audioFile = audio ?? "arctic_a0024.wav"; string audioFile = audio ?? "arctic_a0024.wav";

View File

@ -0,0 +1,50 @@
#include <string>
#include <vector>
#include <iostream>
#include <fstream>
#include "lm/enumerate_vocab.hh"
#include "lm/virtual_interface.hh"
#include "lm/word_index.hh"
#include "lm/model.hh"
const std::string START_TOKEN = "<s>";
const std::string UNK_TOKEN = "<unk>";
const std::string END_TOKEN = "</s>";
// Implement a callback to retrieve the dictionary of language model.
class RetrieveStrEnumerateVocab : public lm::EnumerateVocab
{
public:
RetrieveStrEnumerateVocab() {}
void Add(lm::WordIndex index, const StringPiece &str) {
vocabulary.push_back(std::string(str.data(), str.length()));
}
std::vector<std::string> vocabulary;
};
int main(int argc, char** argv)
{
if (argc != 3) {
std::cerr << "Usage: " << argv[0] << " <kenlm_model> <output_path>" << std::endl;
return -1;
}
const char* kenlm_model = argv[1];
const char* output_path = argv[2];
std::unique_ptr<lm::base::Model> language_model_;
lm::ngram::Config config;
RetrieveStrEnumerateVocab enumerate;
config.enumerate_vocab = &enumerate;
language_model_.reset(lm::ngram::LoadVirtual(kenlm_model, config));
std::ofstream fout(output_path);
for (const std::string& word : enumerate.vocabulary) {
fout << word << "\n";
}
return 0;
}

View File

@ -1,32 +0,0 @@
#include <algorithm>
#include <iostream>
#include <string>
#include "ctcdecode/scorer.h"
#include "alphabet.h"
using namespace std;
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
Alphabet alphabet;
int err = alphabet.init(alphabet_path);
if (err != 0) {
return err;
}
Scorer scorer;
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
if (err != 0) {
return err;
}
scorer.save_dictionary(trie_path);
return 0;
}
int main(int argc, char** argv) {
if (argc != 4) {
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
return -1;
}
return generate_trie(argv[1], argv[2], argv[3]);
}

View File

@ -51,12 +51,11 @@ Please push DeepSpeech data to ``/sdcard/deepspeech/``\ , including:
* ``output_graph.tflite`` which is the TF Lite model * ``output_graph.tflite`` which is the TF Lite model
* ``lm.binary`` and ``trie`` files, if you want to use the language model ; please * ``kenlm.scorer``, if you want to use the scorer; please be aware that too big
be aware that too big language model will make the device run out of memory scorer will make the device run out of memory
Then, push binaries from ``native_client.tar.xz`` to ``/data/local/tmp/ds``\ : Then, push binaries from ``native_client.tar.xz`` to ``/data/local/tmp/ds``\ :
* ``deepspeech`` * ``deepspeech``
* ``libdeepspeech.so`` * ``libdeepspeech.so``
* ``libc++_shared.so`` * ``libc++_shared.so``

View File

@ -31,8 +31,6 @@ public class DeepSpeechActivity extends AppCompatActivity {
Button _startInference; Button _startInference;
final int BEAM_WIDTH = 50; final int BEAM_WIDTH = 50;
final float LM_ALPHA = 0.75f;
final float LM_BETA = 1.85f;
private char readLEChar(RandomAccessFile f) throws IOException { private char readLEChar(RandomAccessFile f) throws IOException {
byte b1 = f.readByte(); byte b1 = f.readByte();

View File

@ -30,15 +30,11 @@ import java.nio.ByteBuffer;
public class BasicTest { public class BasicTest {
public static final String modelFile = "/data/local/tmp/test/output_graph.tflite"; public static final String modelFile = "/data/local/tmp/test/output_graph.tflite";
public static final String lmFile = "/data/local/tmp/test/lm.binary"; public static final String scorerFile = "/data/local/tmp/test/kenlm.scorer";
public static final String trieFile = "/data/local/tmp/test/trie";
public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav"; public static final String wavFile = "/data/local/tmp/test/LDC93S1.wav";
public static final int BEAM_WIDTH = 50; public static final int BEAM_WIDTH = 50;
public static final float LM_ALPHA = 0.75f;
public static final float LM_BETA = 1.85f;
private char readLEChar(RandomAccessFile f) throws IOException { private char readLEChar(RandomAccessFile f) throws IOException {
byte b1 = f.readByte(); byte b1 = f.readByte();
byte b2 = f.readByte(); byte b2 = f.readByte();
@ -130,7 +126,7 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_stt_withLM() { public void loadDeepSpeech_stt_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
m.enableDecoderWithLM(lmFile, trieFile, LM_ALPHA, LM_BETA); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, false); String decoded = doSTT(m, false);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);
@ -149,7 +145,7 @@ public class BasicTest {
@Test @Test
public void loadDeepSpeech_sttWithMetadata_withLM() { public void loadDeepSpeech_sttWithMetadata_withLM() {
DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH); DeepSpeechModel m = new DeepSpeechModel(modelFile, BEAM_WIDTH);
m.enableDecoderWithLM(lmFile, trieFile, LM_ALPHA, LM_BETA); m.enableExternalScorer(scorerFile);
String decoded = doSTT(m, true); String decoded = doSTT(m, true);
assertEquals("she had your dark suit in greasy wash water all year", decoded); assertEquals("she had your dark suit in greasy wash water all year", decoded);

View File

@ -47,17 +47,35 @@ public class DeepSpeechModel {
} }
/** /**
* @brief Enable decoding using beam scoring with a KenLM language model. * @brief Enable decoding using an external scorer.
* *
* @param lm The path to the language model binary file. * @param scorer The path to the external scorer file.
* @param trie The path to the trie file build from the same vocabulary as the language model binary.
* @param lm_alpha The alpha hyperparameter of the CTC decoder. Language Model weight.
* @param lm_beta The beta hyperparameter of the CTC decoder. Word insertion weight.
* *
* @return Zero on success, non-zero on failure (invalid arguments). * @return Zero on success, non-zero on failure (invalid arguments).
*/ */
public void enableDecoderWithLM(String lm, String trie, float lm_alpha, float lm_beta) { public void enableExternalScorer(String scorer) {
impl.EnableDecoderWithLM(this._msp, lm, trie, lm_alpha, lm_beta); impl.EnableExternalScorer(this._msp, scorer);
}
/**
* @brief Disable decoding using an external scorer.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
public void disableExternalScorer() {
impl.DisableExternalScorer(this._msp);
}
/**
* @brief Enable decoding using beam scoring with a KenLM language model.
*
* @param alpha The alpha hyperparameter of the decoder. Language model weight.
* @param beta The beta hyperparameter of the decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
public void setScorerAlphaBeta(float alpha, float beta) {
impl.SetScorerAlphaBeta(this._msp, alpha, beta);
} }
/* /*

View File

@ -29,12 +29,11 @@ VersionAction.prototype.call = function(parser) {
var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'}); var parser = new argparse.ArgumentParser({addHelp: true, description: 'Running DeepSpeech inference.'});
parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'}); parser.addArgument(['--model'], {required: true, help: 'Path to the model (protocol buffer binary file)'});
parser.addArgument(['--lm'], {help: 'Path to the language model binary file', nargs: '?'}); parser.addArgument(['--scorer'], {help: 'Path to the external scorer file'});
parser.addArgument(['--trie'], {help: 'Path to the language model trie file created with native_client/generate_trie', nargs: '?'});
parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'}); parser.addArgument(['--audio'], {required: true, help: 'Path to the audio file to run (WAV format)'});
parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'}); parser.addArgument(['--beam_width'], {help: 'Beam width for the CTC decoder', defaultValue: 500, type: 'int'});
parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha)', defaultValue: 0.75, type: 'float'}); parser.addArgument(['--lm_alpha'], {help: 'Language model weight (lm_alpha). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta)', defaultValue: 1.85, type: 'float'}); parser.addArgument(['--lm_beta'], {help: 'Word insertion bonus (lm_beta). If not specified, use default from the scorer package.', type: 'float'});
parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'}); parser.addArgument(['--version'], {action: VersionAction, help: 'Print version and exits'});
parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'}); parser.addArgument(['--extended'], {action: 'storeTrue', help: 'Output string from extended metadata'});
var args = parser.parseArgs(); var args = parser.parseArgs();
@ -60,12 +59,16 @@ console.error('Loaded model in %ds.', totalTime(model_load_end));
var desired_sample_rate = model.sampleRate(); var desired_sample_rate = model.sampleRate();
if (args['lm'] && args['trie']) { if (args['scorer']) {
console.error('Loading language model from files %s %s', args['lm'], args['trie']); console.error('Loading scorer from file %s', args['scorer']);
const lm_load_start = process.hrtime(); const scorer_load_start = process.hrtime();
model.enableDecoderWithLM(args['lm'], args['trie'], args['lm_alpha'], args['lm_beta']); model.enableExternalScorer(args['scorer']);
const lm_load_end = process.hrtime(lm_load_start); const scorer_load_end = process.hrtime(scorer_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end)); console.error('Loaded scorer in %ds.', totalTime(scorer_load_end));
if (args['lm_alpha'] && args['lm_beta']) {
model.setScorerAlphaBeta(args['lm_alpha'], args['lm_beta']);
}
} }
const buffer = Fs.readFileSync(args['audio']); const buffer = Fs.readFileSync(args['audio']);

View File

@ -52,31 +52,46 @@ Model.prototype.sampleRate = function() {
} }
/** /**
* Enable decoding using beam scoring with a KenLM language model. * Enable decoding using an external scorer.
*
* @param {string} aScorerPath The path to the external scorer file.
*
* @return {number} Zero on success, non-zero on failure (invalid arguments).
*/
Model.prototype.enableExternalScorer = function(aScorerPath) {
return binding.EnableExternalScorer(this._impl, aScorerPath);
}
/**
* Disable decoding using an external scorer.
*
* @return {number} Zero on success, non-zero on failure (invalid arguments).
*/
Model.prototype.disableExternalScorer = function() {
return binding.EnableExternalScorer(this._impl);
}
/**
* Set hyperparameters alpha and beta of the external scorer.
* *
* @param {string} aLMPath The path to the language model binary file.
* @param {string} aTriePath The path to the trie file build from the same vocabulary as the language model binary.
* @param {float} aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model weight. * @param {float} aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model weight.
* @param {float} aLMBeta The beta hyperparameter of the CTC decoder. Word insertion weight. * @param {float} aLMBeta The beta hyperparameter of the CTC decoder. Word insertion weight.
* *
* @return {number} Zero on success, non-zero on failure (invalid arguments). * @return {number} Zero on success, non-zero on failure (invalid arguments).
*/ */
Model.prototype.enableDecoderWithLM = function() { Model.prototype.setScorerAlphaBeta = function(aLMAlpha, aLMBeta) {
const args = [this._impl].concat(Array.prototype.slice.call(arguments)); return binding.SetScorerAlphaBeta(this._impl, aLMAlpha, aLMBeta);
return binding.EnableDecoderWithLM.apply(null, args);
} }
/** /**
* Use the DeepSpeech model to perform Speech-To-Text. * Use the DeepSpeech model to perform Speech-To-Text.
* *
* @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). * @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
* @param {number} aBufferSize The number of samples in the audio signal.
* *
* @return {string} The STT result. Returns undefined on error. * @return {string} The STT result. Returns undefined on error.
*/ */
Model.prototype.stt = function() { Model.prototype.stt = function(aBuffer) {
const args = [this._impl].concat(Array.prototype.slice.call(arguments)); return binding.SpeechToText(this._impl, aBuffer);
return binding.SpeechToText.apply(null, args);
} }
/** /**
@ -84,25 +99,22 @@ Model.prototype.stt = function() {
* about the results. * about the results.
* *
* @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). * @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
* @param {number} aBufferSize The number of samples in the audio signal.
* *
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error.
*/ */
Model.prototype.sttWithMetadata = function() { Model.prototype.sttWithMetadata = function(aBuffer) {
const args = [this._impl].concat(Array.prototype.slice.call(arguments)); return binding.SpeechToTextWithMetadata(this._impl, aBuffer);
return binding.SpeechToTextWithMetadata.apply(null, args);
} }
/** /**
* Create a new streaming inference state. The streaming state returned by this function can then be passed to :js:func:`Model.feedAudioContent` and :js:func:`Model.finishStream`. * Create a new streaming inference state. One can then call :js:func:`Stream.feedAudioContent` and :js:func:`Stream.finishStream` on the returned stream object.
* *
* @return {object} an opaque object that represents the streaming state. * @return {object} a :js:func:`Stream` object that represents the streaming state.
* *
* @throws on error * @throws on error
*/ */
Model.prototype.createStream = function() { Model.prototype.createStream = function() {
const args = [this._impl].concat(Array.prototype.slice.call(arguments)); const rets = binding.CreateStream(this._impl);
const rets = binding.CreateStream.apply(null, args);
const status = rets[0]; const status = rets[0];
const ctx = rets[1]; const ctx = rets[1];
if (status !== 0) { if (status !== 0) {
@ -111,55 +123,61 @@ Model.prototype.createStream = function() {
return ctx; return ctx;
} }
/**
* @class
* Provides an interface to a DeepSpeech stream. The constructor cannot be called
* directly, use :js:func:`Model.createStream`.
*/
function Stream(nativeStream) {
this._impl = nativeStream;
}
/** /**
* Feed audio samples to an ongoing streaming inference. * Feed audio samples to an ongoing streaming inference.
* *
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
* @param {buffer} aBuffer An array of 16-bit, mono raw audio samples at the * @param {buffer} aBuffer An array of 16-bit, mono raw audio samples at the
* appropriate sample rate (matching what the model was trained on). * appropriate sample rate (matching what the model was trained on).
* @param {number} aBufferSize The number of samples in @param aBuffer.
*/ */
Model.prototype.feedAudioContent = function() { Stream.prototype.feedAudioContent = function(aBuffer) {
binding.FeedAudioContent.apply(null, arguments); binding.FeedAudioContent(this._impl, aBuffer);
} }
/** /**
* Compute the intermediate decoding of an ongoing streaming inference. * Compute the intermediate decoding of an ongoing streaming inference.
* *
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
*
* @return {string} The STT intermediate result. * @return {string} The STT intermediate result.
*/ */
Model.prototype.intermediateDecode = function() { Stream.prototype.intermediateDecode = function() {
return binding.IntermediateDecode.apply(null, arguments); return binding.IntermediateDecode(this._impl);
} }
/** /**
* Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal. * Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal.
* *
* @param {object} aSctx A streaming state returned by :js:func:`Model.setupStream`.
*
* @return {string} The STT result. * @return {string} The STT result.
* *
* This method will free the state (@param aSctx). * This method will free the stream, it must not be used after this method is called.
*/ */
Model.prototype.finishStream = function() { Stream.prototype.finishStream = function() {
return binding.FinishStream.apply(null, arguments); result = binding.FinishStream(this._impl);
this._impl = null;
return result;
} }
/** /**
* Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata. * Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata.
* *
* @param {object} aSctx A streaming state pointer returned by :js:func:`Model.setupStream`.
*
* @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`.
* *
* This method will free the state pointer (@param aSctx). * This method will free the stream, it must not be used after this method is called.
*/ */
Model.prototype.finishStreamWithMetadata = function() { Stream.prototype.finishStreamWithMetadata = function() {
return binding.FinishStreamWithMetadata.apply(null, arguments); result = binding.FinishStreamWithMetadata(this._impl);
this._impl = null;
return result;
} }
/** /**
* Frees associated resources and destroys model object. * Frees associated resources and destroys model object.
* *
@ -184,10 +202,10 @@ function FreeMetadata(metadata) {
* can be used if you no longer need the result of an ongoing streaming * can be used if you no longer need the result of an ongoing streaming
* inference and don't want to perform a costly decode operation. * inference and don't want to perform a costly decode operation.
* *
* @param {Object} stream A streaming state pointer returned by :js:func:`Model.createStream`. * @param {Object} stream A stream object returned by :js:func:`Model.createStream`.
*/ */
function FreeStream(stream) { function FreeStream(stream) {
return binding.FreeStream(stream); return binding.FreeStream(stream._impl);
} }
/** /**

View File

@ -3,6 +3,9 @@ util/file_piece.cc.gz
*.o *.o
doc/ doc/
build/ build/
/bin
/lib
/tests
._* ._*
windows/Win32 windows/Win32
windows/x64 windows/x64

View File

@ -12,3 +12,7 @@ If you only want the query code and do not care about compression (.gz, .bz2, an
Windows: Windows:
The windows directory has visual studio files. Note that you need to compile The windows directory has visual studio files. Note that you need to compile
the kenlm project before build_binary and ngram_query projects. the kenlm project before build_binary and ngram_query projects.
OSX:
Missing dependencies can be remedied with brew.
brew install cmake boost eigen

View File

@ -1 +1 @@
cdd794598ea15dc23a7daaf7a8cf89423c97f7e6 b9f35777d112ce2fc10bd3986302517a16dc3883

View File

@ -2,9 +2,9 @@
Language model inference code by Kenneth Heafield (kenlm at kheafield.com) Language model inference code by Kenneth Heafield (kenlm at kheafield.com)
I do development in master on https://github.com/kpu/kenlm/. Normally, it works, but I do not guarantee it will compile, give correct answers, or generate non-broken binary files. For a more stable release, get http://kheafield.com/code/kenlm.tar.gz . I do development in master on https://github.com/kpu/kenlm/. Normally, it works, but I do not guarantee it will compile, give correct answers, or generate non-broken binary files. For a more stable release, get https://kheafield.com/code/kenlm.tar.gz .
The website http://kheafield.com/code/kenlm/ has more documentation. If you're a decoder developer, please download the latest version from there instead of copying from another decoder. The website https://kheafield.com/code/kenlm/ has more documentation. If you're a decoder developer, please download the latest version from there instead of copying from another decoder.
## Compiling ## Compiling
Use cmake, see [BUILDING](BUILDING) for more detail. Use cmake, see [BUILDING](BUILDING) for more detail.
@ -33,7 +33,7 @@ lmplz estimates unpruned language models with modified Kneser-Ney smoothing. Af
```bash ```bash
bin/lmplz -o 5 <text >text.arpa bin/lmplz -o 5 <text >text.arpa
``` ```
The algorithm is on-disk, using an amount of memory that you specify. See http://kheafield.com/code/kenlm/estimation/ for more. The algorithm is on-disk, using an amount of memory that you specify. See https://kheafield.com/code/kenlm/estimation/ for more.
MT Marathon 2012 team members Ivan Pouzyrevsky and Mohammed Mediani contributed to the computation design and early implementation. Jon Clark contributed to the design, clarified points about smoothing, and added logging. MT Marathon 2012 team members Ivan Pouzyrevsky and Mohammed Mediani contributed to the computation design and early implementation. Jon Clark contributed to the design, clarified points about smoothing, and added logging.
@ -43,15 +43,15 @@ filter takes an ARPA or count file and removes entries that will never be querie
```bash ```bash
bin/filter bin/filter
``` ```
and see http://kheafield.com/code/kenlm/filter/ for more documentation. and see https://kheafield.com/code/kenlm/filter/ for more documentation.
## Querying ## Querying
Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and a bit slower. Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and is a bit slower.
As is the custom in language modeling, all probabilities are log base 10. As is the custom in language modeling, all probabilities are log base 10.
With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See http://kheafield.com/code/kenlm/benchmark/. With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See https://kheafield.com/code/kenlm/benchmark/.
Binary format via mmap is supported. Run `./build_binary` to make one then pass the binary file name to the appropriate Model constructor. Binary format via mmap is supported. Run `./build_binary` to make one then pass the binary file name to the appropriate Model constructor.
@ -71,7 +71,7 @@ Hideo Okuma and Tomoyuki Yoshimura from NICT contributed ports to ARM and MinGW.
- Select the macros you want, listed in the previous section. - Select the macros you want, listed in the previous section.
- There are two build systems: compile.sh and Jamroot+Jamfile. They're pretty simple and are intended to be reimplemented in your build system. - There are two build systems: compile.sh and cmake. They're pretty simple and are intended to be reimplemented in your build system.
- Use either the interface in `lm/model.hh` or `lm/virtual_interface.hh`. Interface documentation is in comments of `lm/virtual_interface.hh` and `lm/model.hh`. - Use either the interface in `lm/model.hh` or `lm/virtual_interface.hh`. Interface documentation is in comments of `lm/virtual_interface.hh` and `lm/model.hh`.
@ -101,4 +101,4 @@ See [python/example.py](python/example.py) and [python/kenlm.pyx](python/kenlm.p
--- ---
The name was Hieu Hoang's idea, not mine. The name was Hieu Hoang's idea, not mine.

View File

@ -1,7 +1,7 @@
KenLM source downloaded from http://kheafield.com/code/kenlm.tar.gz on 2017/08/05 KenLM source downloaded from https://github.com/kpu/kenlm on 2020/01/15
sha256 c4c9f587048470c9a6a592914f0609a71fbb959f0a4cad371e8c355ce81f7c6b commit b9f35777d112ce2fc10bd3986302517a16dc3883
This corresponds to https://github.com/kpu/kenlm/commit/cdd794598ea15dc23a7daaf7a8cf89423c97f7e6 This corresponds to https://github.com/kpu/kenlm/commit/b9f35777d112ce2fc10bd3986302517a16dc3883
The following procedure was run to remove unneeded files: The following procedure was run to remove unneeded files:
@ -10,19 +10,3 @@ rm -rf windows include lm/filter lm/builder util/stream util/getopt.* python
This was done in order to ensure uniqueness of double_conversion: This was done in order to ensure uniqueness of double_conversion:
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g' git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
Please apply this patch to be able to build on Android:
diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc
index d53dc0a..b5e36b2 100644
--- a/native_client/kenlm/util/file.cc
+++ b/native_client/kenlm/util/file.cc
@@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
for (int i=0; vars[i]; ++i) {
char *val =
-#if defined(_GNU_SOURCE)
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
#if __GLIBC_PREREQ(2,17)
secure_getenv
#else // __GLIBC_PREREQ

View File

@ -10,7 +10,6 @@
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
#include <cmath> #include <cmath>
#include <cstdlib>
#ifdef WIN32 #ifdef WIN32
#include "util/getopt.hh" #include "util/getopt.hh"
@ -23,11 +22,12 @@ namespace ngram {
namespace { namespace {
void Usage(const char *name, const char *default_mem) { void Usage(const char *name, const char *default_mem) {
std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-v] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n"
"-u sets the log10 probability for <unk> if the ARPA file does not have one.\n" "-u sets the log10 probability for <unk> if the ARPA file does not have one.\n"
" Default is -100. The ARPA file will always take precedence.\n" " Default is -100. The ARPA file will always take precedence.\n"
"-s allows models to be built even if they do not have <s> and </s>.\n" "-s allows models to be built even if they do not have <s> and </s>.\n"
"-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n" "-i allows buggy models from IRSTLM by mapping positive log probability to 0.\n"
"-v disables inclusion of the vocabulary in the binary file.\n"
"-w mmap|after determines how writing is done.\n" "-w mmap|after determines how writing is done.\n"
" mmap maps the binary file and writes to it. Default for trie.\n" " mmap maps the binary file and writes to it. Default for trie.\n"
" after allocates anonymous memory, builds, and writes. Default for probing.\n" " after allocates anonymous memory, builds, and writes. Default for probing.\n"
@ -112,7 +112,7 @@ int main(int argc, char *argv[]) {
lm::ngram::Config config; lm::ngram::Config config;
config.building_memory = util::ParseSize(default_mem); config.building_memory = util::ParseSize(default_mem);
int opt; int opt;
while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:h")) != -1) { while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:vh")) != -1) {
switch(opt) { switch(opt) {
case 'q': case 'q':
config.prob_bits = ParseBitCount(optarg); config.prob_bits = ParseBitCount(optarg);
@ -165,6 +165,9 @@ int main(int argc, char *argv[]) {
ParseFileList(optarg, config.rest_lower_files); ParseFileList(optarg, config.rest_lower_files);
config.rest_function = Config::REST_LOWER; config.rest_function = Config::REST_LOWER;
break; break;
case 'v':
config.include_vocab = false;
break;
case 'h': // help case 'h': // help
default: default:
Usage(argv[0], default_mem); Usage(argv[0], default_mem);

View File

@ -7,7 +7,7 @@
* sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead * sizeof(float*) + (kMaxOrder - 1) * sizeof(float) + malloc overhead
*/ */
#ifndef KENLM_ORDER_MESSAGE #ifndef KENLM_ORDER_MESSAGE
#define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. With cmake:\n cmake -DKENLM_MAX_ORDER=10 ..\nWith Moses:\n bjam --max-kenlm-order=10 -a\nOtherwise, edit lm/max_order.hh."
#endif #endif
#endif // LM_MAX_ORDER_H #endif // LM_MAX_ORDER_H

View File

@ -226,6 +226,10 @@ template <class Search, class VocabularyT> FullScoreReturn GenericModel<Search,
return ret; return ret;
} }
template <class Search, class VocabularyT> uint64_t GenericModel<Search, VocabularyT>::GetEndOfSearchOffset() const {
return backing_.VocabStringReadingOffset();
}
namespace { namespace {
// Do a paraonoid copy of history, assuming new_word has already been copied // Do a paraonoid copy of history, assuming new_word has already been copied
// (hence the -1). out_state.length could be zero so I avoided using // (hence the -1). out_state.length could be zero so I avoided using

View File

@ -102,6 +102,8 @@ template <class Search, class VocabularyT> class GenericModel : public base::Mod
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0;
} }
uint64_t GetEndOfSearchOffset() const;
private: private:
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const;

View File

@ -19,8 +19,8 @@ void Usage(const char *name) {
"Each word in the output is formatted as:\n" "Each word in the output is formatted as:\n"
" word=vocab_id ngram_length log10(p(word|context))\n" " word=vocab_id ngram_length log10(p(word|context))\n"
"where ngram_length is the length of n-gram matched. A vocab_id of 0 indicates\n" "where ngram_length is the length of n-gram matched. A vocab_id of 0 indicates\n"
"indicates the unknown word. Sentence-level output includes log10 probability of\n" "the unknown word. Sentence-level output includes log10 probability of the\n"
"the sentence and OOV count.\n"; "sentence and OOV count.\n";
exit(1); exit(1);
} }

View File

@ -19,8 +19,8 @@
namespace lm { namespace lm {
// 1 for '\t', '\n', and ' '. This is stricter than isspace. // 1 for '\t', '\n', '\r', and ' '. This is stricter than isspace. Apparently ARPA allows vertical tab inside a word.
const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; const bool kARPASpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
namespace { namespace {
@ -85,6 +85,11 @@ void ReadNGramHeader(util::FilePiece &in, unsigned int length) {
if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead"); if (line != expected.str()) UTIL_THROW(FormatLoadException, "Was expecting n-gram header " << expected.str() << " but got " << line << " instead");
} }
void ConsumeNewline(util::FilePiece &in) {
char follow = in.get();
UTIL_THROW_IF('\n' != follow, FormatLoadException, "Expected newline got '" << follow << "'");
}
void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) { void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
switch (in.get()) { switch (in.get()) {
case '\t': case '\t':
@ -94,6 +99,9 @@ void ReadBackoff(util::FilePiece &in, Prob &/*weights*/) {
UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff"); UTIL_THROW(FormatLoadException, "Non-zero backoff " << got << " provided for an n-gram that should have no backoff");
} }
break; break;
case '\r':
ConsumeNewline(in);
// Intentionally no break.
case '\n': case '\n':
break; break;
default: default:
@ -120,8 +128,18 @@ void ReadBackoff(util::FilePiece &in, float &backoff) {
UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff); UTIL_THROW_IF(float_class == FP_NAN || float_class == FP_INFINITE, FormatLoadException, "Bad backoff " << backoff);
#endif #endif
} }
UTIL_THROW_IF(in.get() != '\n', FormatLoadException, "Expected newline after backoff"); switch (char got = in.get()) {
case '\r':
ConsumeNewline(in);
case '\n':
break;
default:
UTIL_THROW(FormatLoadException, "Expected newline after backoffs, got " << got);
}
break; break;
case '\r':
ConsumeNewline(in);
// Intentionally no break.
case '\n': case '\n':
backoff = ngram::kNoExtensionBackoff; backoff = ngram::kNoExtensionBackoff;
break; break;

View File

@ -137,6 +137,8 @@ class Model {
const Vocabulary &BaseVocabulary() const { return *base_vocab_; } const Vocabulary &BaseVocabulary() const { return *base_vocab_; }
virtual uint64_t GetEndOfSearchOffset() const = 0;
private: private:
template <class T, class U, class V> friend class ModelFacade; template <class T, class U, class V> friend class ModelFacade;
explicit Model(size_t state_size) : state_size_(state_size) {} explicit Model(size_t state_size) : state_size_(state_size) {}

View File

@ -282,7 +282,7 @@ void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to
if (have_words) ReadWords(fd, to, bound_, offset); if (have_words) ReadWords(fd, to, bound_, offset);
} }
void MissingUnknown(const Config &config) throw(SpecialWordMissingException) { void MissingUnknown(const Config &config) {
switch(config.unknown_missing) { switch(config.unknown_missing) {
case SILENT: case SILENT:
return; return;
@ -294,7 +294,7 @@ void MissingUnknown(const Config &config) throw(SpecialWordMissingException) {
} }
} }
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException) { void MissingSentenceMarker(const Config &config, const char *str) {
switch (config.sentence_marker_missing) { switch (config.sentence_marker_missing) {
case SILENT: case SILENT:
return; return;

View File

@ -207,10 +207,10 @@ class ProbingVocabulary : public base::Vocabulary {
detail::ProbingVocabularyHeader *header_; detail::ProbingVocabularyHeader *header_;
}; };
void MissingUnknown(const Config &config) throw(SpecialWordMissingException); void MissingUnknown(const Config &config);
void MissingSentenceMarker(const Config &config, const char *str) throw(SpecialWordMissingException); void MissingSentenceMarker(const Config &config, const char *str);
template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) throw(SpecialWordMissingException) { template <class Vocab> void CheckSpecials(const Config &config, const Vocab &vocab) {
if (!vocab.SawUnk()) MissingUnknown(config); if (!vocab.SawUnk()) MissingUnknown(config);
if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>"); if (vocab.BeginSentence() == vocab.NotFound()) MissingSentenceMarker(config, "<s>");
if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>"); if (vocab.EndSentence() == vocab.NotFound()) MissingSentenceMarker(config, "</s>");

View File

@ -2,6 +2,8 @@ from setuptools import setup, Extension
import glob import glob
import platform import platform
import os import os
import sys
import re
#Does gcc compile with this header and library? #Does gcc compile with this header and library?
def compile_test(header, library): def compile_test(header, library):
@ -9,16 +11,28 @@ def compile_test(header, library):
command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\"" command = "bash -c \"g++ -include " + header + " -l" + library + " -x c++ - <<<'int main() {}' -o " + dummy_path + " >/dev/null 2>/dev/null && rm " + dummy_path + " 2>/dev/null\""
return os.system(command) == 0 return os.system(command) == 0
max_order = "6"
is_max_order = [s for s in sys.argv if "--max_order" in s]
for element in is_max_order:
max_order = re.split('[= ]',element)[1]
sys.argv.remove(element)
FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob('util/double-conversion/*.cc') FILES = glob.glob('util/*.cc') + glob.glob('lm/*.cc') + glob.glob('util/double-conversion/*.cc') + glob.glob('python/*.cc')
FILES = [fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))] FILES = [fn for fn in FILES if not (fn.endswith('main.cc') or fn.endswith('test.cc'))]
LIBS = ['stdc++'] if platform.system() == 'Linux':
if platform.system() != 'Darwin': LIBS = ['stdc++', 'rt']
LIBS.append('rt') elif platform.system() == 'Darwin':
LIBS = ['c++']
else:
LIBS = []
#We don't need -std=c++11 but python seems to be compiled with it now. https://github.com/kpu/kenlm/issues/86 #We don't need -std=c++11 but python seems to be compiled with it now. https://github.com/kpu/kenlm/issues/86
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11'] ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER='+max_order, '-std=c++11']
#Attempted fix to https://github.com/kpu/kenlm/issues/186 and https://github.com/kpu/kenlm/issues/197
if platform.system() == 'Darwin':
ARGS += ["-stdlib=libc++", "-mmacosx-version-min=10.7"]
if compile_test('zlib.h', 'z'): if compile_test('zlib.h', 'z'):
ARGS.append('-DHAVE_ZLIB') ARGS.append('-DHAVE_ZLIB')

View File

@ -108,7 +108,7 @@ typedef union { float f; uint32_t i; } FloatEnc;
inline float ReadFloat32(const void *base, uint64_t bit_off) { inline float ReadFloat32(const void *base, uint64_t bit_off) {
FloatEnc encoded; FloatEnc encoded;
encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32); encoded.i = static_cast<uint32_t>(ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 32));
return encoded.f; return encoded.f;
} }
inline void WriteFloat32(void *base, uint64_t bit_off, float value) { inline void WriteFloat32(void *base, uint64_t bit_off, float value) {
@ -135,7 +135,7 @@ inline void UnsetSign(float &to) {
inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) { inline float ReadNonPositiveFloat31(const void *base, uint64_t bit_off) {
FloatEnc encoded; FloatEnc encoded;
encoded.i = ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31); encoded.i = static_cast<uint32_t>(ReadOff(base, bit_off) >> BitPackShift(bit_off & 7, 31));
// Sign bit set means negative. // Sign bit set means negative.
encoded.i |= kSignBit; encoded.i |= kSignBit;
return encoded.f; return encoded.f;

View File

@ -25,7 +25,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cmath> #include <math.h>
#include "bignum-dtoa.h" #include "bignum-dtoa.h"
@ -192,13 +192,13 @@ static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
delta_plus = delta_minus; delta_plus = delta_minus;
} }
*length = 0; *length = 0;
while (true) { for (;;) {
uint16_t digit; uint16_t digit;
digit = numerator->DivideModuloIntBignum(*denominator); digit = numerator->DivideModuloIntBignum(*denominator);
ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive. ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
// digit = numerator / denominator (integer division). // digit = numerator / denominator (integer division).
// numerator = numerator % denominator. // numerator = numerator % denominator.
buffer[(*length)++] = digit + '0'; buffer[(*length)++] = static_cast<char>(digit + '0');
// Can we stop already? // Can we stop already?
// If the remainder of the division is less than the distance to the lower // If the remainder of the division is less than the distance to the lower
@ -282,7 +282,7 @@ static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator,
// exponent (decimal_point), when rounding upwards. // exponent (decimal_point), when rounding upwards.
static void GenerateCountedDigits(int count, int* decimal_point, static void GenerateCountedDigits(int count, int* decimal_point,
Bignum* numerator, Bignum* denominator, Bignum* numerator, Bignum* denominator,
Vector<char>(buffer), int* length) { Vector<char> buffer, int* length) {
ASSERT(count >= 0); ASSERT(count >= 0);
for (int i = 0; i < count - 1; ++i) { for (int i = 0; i < count - 1; ++i) {
uint16_t digit; uint16_t digit;
@ -290,7 +290,7 @@ static void GenerateCountedDigits(int count, int* decimal_point,
ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive. ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive.
// digit = numerator / denominator (integer division). // digit = numerator / denominator (integer division).
// numerator = numerator % denominator. // numerator = numerator % denominator.
buffer[i] = digit + '0'; buffer[i] = static_cast<char>(digit + '0');
// Prepare for next iteration. // Prepare for next iteration.
numerator->Times10(); numerator->Times10();
} }
@ -300,7 +300,8 @@ static void GenerateCountedDigits(int count, int* decimal_point,
if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) { if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) {
digit++; digit++;
} }
buffer[count - 1] = digit + '0'; ASSERT(digit <= 10);
buffer[count - 1] = static_cast<char>(digit + '0');
// Correct bad digits (in case we had a sequence of '9's). Propagate the // Correct bad digits (in case we had a sequence of '9's). Propagate the
// carry until we hat a non-'9' or til we reach the first digit. // carry until we hat a non-'9' or til we reach the first digit.
for (int i = count - 1; i > 0; --i) { for (int i = count - 1; i > 0; --i) {

View File

@ -40,6 +40,7 @@ Bignum::Bignum()
template<typename S> template<typename S>
static int BitSize(S value) { static int BitSize(S value) {
(void) value; // Mark variable as used.
return 8 * sizeof(value); return 8 * sizeof(value);
} }
@ -103,7 +104,7 @@ void Bignum::AssignDecimalString(Vector<const char> value) {
const int kMaxUint64DecimalDigits = 19; const int kMaxUint64DecimalDigits = 19;
Zero(); Zero();
int length = value.length(); int length = value.length();
int pos = 0; unsigned int pos = 0;
// Let's just say that each digit needs 4 bits. // Let's just say that each digit needs 4 bits.
while (length >= kMaxUint64DecimalDigits) { while (length >= kMaxUint64DecimalDigits) {
uint64_t digits = ReadUInt64(value, pos, kMaxUint64DecimalDigits); uint64_t digits = ReadUInt64(value, pos, kMaxUint64DecimalDigits);
@ -122,9 +123,8 @@ void Bignum::AssignDecimalString(Vector<const char> value) {
static int HexCharValue(char c) { static int HexCharValue(char c) {
if ('0' <= c && c <= '9') return c - '0'; if ('0' <= c && c <= '9') return c - '0';
if ('a' <= c && c <= 'f') return 10 + c - 'a'; if ('a' <= c && c <= 'f') return 10 + c - 'a';
if ('A' <= c && c <= 'F') return 10 + c - 'A'; ASSERT('A' <= c && c <= 'F');
UNREACHABLE(); return 10 + c - 'A';
return 0; // To make compiler happy.
} }
@ -501,13 +501,14 @@ uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) {
// Start by removing multiples of 'other' until both numbers have the same // Start by removing multiples of 'other' until both numbers have the same
// number of digits. // number of digits.
while (BigitLength() > other.BigitLength()) { while (BigitLength() > other.BigitLength()) {
// This naive approach is extremely inefficient if the this divided other // This naive approach is extremely inefficient if `this` divided by other
// might be big. This function is implemented for doubleToString where // is big. This function is implemented for doubleToString where
// the result should be small (less than 10). // the result should be small (less than 10).
ASSERT(other.bigits_[other.used_digits_ - 1] >= ((1 << kBigitSize) / 16)); ASSERT(other.bigits_[other.used_digits_ - 1] >= ((1 << kBigitSize) / 16));
ASSERT(bigits_[used_digits_ - 1] < 0x10000);
// Remove the multiples of the first digit. // Remove the multiples of the first digit.
// Example this = 23 and other equals 9. -> Remove 2 multiples. // Example this = 23 and other equals 9. -> Remove 2 multiples.
result += bigits_[used_digits_ - 1]; result += static_cast<uint16_t>(bigits_[used_digits_ - 1]);
SubtractTimes(other, bigits_[used_digits_ - 1]); SubtractTimes(other, bigits_[used_digits_ - 1]);
} }
@ -523,13 +524,15 @@ uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) {
// Shortcut for easy (and common) case. // Shortcut for easy (and common) case.
int quotient = this_bigit / other_bigit; int quotient = this_bigit / other_bigit;
bigits_[used_digits_ - 1] = this_bigit - other_bigit * quotient; bigits_[used_digits_ - 1] = this_bigit - other_bigit * quotient;
result += quotient; ASSERT(quotient < 0x10000);
result += static_cast<uint16_t>(quotient);
Clamp(); Clamp();
return result; return result;
} }
int division_estimate = this_bigit / (other_bigit + 1); int division_estimate = this_bigit / (other_bigit + 1);
result += division_estimate; ASSERT(division_estimate < 0x10000);
result += static_cast<uint16_t>(division_estimate);
SubtractTimes(other, division_estimate); SubtractTimes(other, division_estimate);
if (other_bigit * (division_estimate + 1) > this_bigit) { if (other_bigit * (division_estimate + 1) > this_bigit) {
@ -560,8 +563,8 @@ static int SizeInHexChars(S number) {
static char HexCharOfValue(int value) { static char HexCharOfValue(int value) {
ASSERT(0 <= value && value <= 16); ASSERT(0 <= value && value <= 16);
if (value < 10) return value + '0'; if (value < 10) return static_cast<char>(value + '0');
return value - 10 + 'A'; return static_cast<char>(value - 10 + 'A');
} }
@ -755,7 +758,6 @@ void Bignum::SubtractTimes(const Bignum& other, int factor) {
Chunk difference = bigits_[i] - borrow; Chunk difference = bigits_[i] - borrow;
bigits_[i] = difference & kBigitMask; bigits_[i] = difference & kBigitMask;
borrow = difference >> (kChunkSize - 1); borrow = difference >> (kChunkSize - 1);
++i;
} }
Clamp(); Clamp();
} }

View File

@ -49,7 +49,6 @@ class Bignum {
void AssignPowerUInt16(uint16_t base, int exponent); void AssignPowerUInt16(uint16_t base, int exponent);
void AddUInt16(uint16_t operand);
void AddUInt64(uint64_t operand); void AddUInt64(uint64_t operand);
void AddBignum(const Bignum& other); void AddBignum(const Bignum& other);
// Precondition: this >= other. // Precondition: this >= other.

View File

@ -25,9 +25,9 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cstdarg> #include <stdarg.h>
#include <climits> #include <limits.h>
#include <cmath> #include <math.h>
#include "utils.h" #include "utils.h"
@ -131,7 +131,6 @@ static const CachedPower kCachedPowers[] = {
{UINT64_2PART_C(0xaf87023b, 9bf0ee6b), 1066, 340}, {UINT64_2PART_C(0xaf87023b, 9bf0ee6b), 1066, 340},
}; };
static const int kCachedPowersLength = ARRAY_SIZE(kCachedPowers);
static const int kCachedPowersOffset = 348; // -1 * the first decimal_exponent. static const int kCachedPowersOffset = 348; // -1 * the first decimal_exponent.
static const double kD_1_LOG2_10 = 0.30102999566398114; // 1 / lg(10) static const double kD_1_LOG2_10 = 0.30102999566398114; // 1 / lg(10)
// Difference between the decimal exponents in the table above. // Difference between the decimal exponents in the table above.
@ -149,9 +148,10 @@ void PowersOfTenCache::GetCachedPowerForBinaryExponentRange(
int foo = kCachedPowersOffset; int foo = kCachedPowersOffset;
int index = int index =
(foo + static_cast<int>(k) - 1) / kDecimalExponentDistance + 1; (foo + static_cast<int>(k) - 1) / kDecimalExponentDistance + 1;
ASSERT(0 <= index && index < kCachedPowersLength); ASSERT(0 <= index && index < static_cast<int>(ARRAY_SIZE(kCachedPowers)));
CachedPower cached_power = kCachedPowers[index]; CachedPower cached_power = kCachedPowers[index];
ASSERT(min_exponent <= cached_power.binary_exponent); ASSERT(min_exponent <= cached_power.binary_exponent);
(void) max_exponent; // Mark variable as used.
ASSERT(cached_power.binary_exponent <= max_exponent); ASSERT(cached_power.binary_exponent <= max_exponent);
*decimal_exponent = cached_power.decimal_exponent; *decimal_exponent = cached_power.decimal_exponent;
*power = DiyFp(cached_power.significand, cached_power.binary_exponent); *power = DiyFp(cached_power.significand, cached_power.binary_exponent);

View File

@ -42,7 +42,7 @@ class DiyFp {
static const int kSignificandSize = 64; static const int kSignificandSize = 64;
DiyFp() : f_(0), e_(0) {} DiyFp() : f_(0), e_(0) {}
DiyFp(uint64_t f, int e) : f_(f), e_(e) {} DiyFp(uint64_t significand, int exponent) : f_(significand), e_(exponent) {}
// this = this - other. // this = this - other.
// The exponents of both numbers must be the same and the significand of this // The exponents of both numbers must be the same and the significand of this
@ -76,22 +76,22 @@ class DiyFp {
void Normalize() { void Normalize() {
ASSERT(f_ != 0); ASSERT(f_ != 0);
uint64_t f = f_; uint64_t significand = f_;
int e = e_; int exponent = e_;
// This method is mainly called for normalizing boundaries. In general // This method is mainly called for normalizing boundaries. In general
// boundaries need to be shifted by 10 bits. We thus optimize for this case. // boundaries need to be shifted by 10 bits. We thus optimize for this case.
const uint64_t k10MSBits = UINT64_2PART_C(0xFFC00000, 00000000); const uint64_t k10MSBits = UINT64_2PART_C(0xFFC00000, 00000000);
while ((f & k10MSBits) == 0) { while ((significand & k10MSBits) == 0) {
f <<= 10; significand <<= 10;
e -= 10; exponent -= 10;
} }
while ((f & kUint64MSB) == 0) { while ((significand & kUint64MSB) == 0) {
f <<= 1; significand <<= 1;
e--; exponent--;
} }
f_ = f; f_ = significand;
e_ = e; e_ = exponent;
} }
static DiyFp Normalize(const DiyFp& a) { static DiyFp Normalize(const DiyFp& a) {

View File

@ -25,8 +25,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <climits> #include <limits.h>
#include <cmath> #include <math.h>
#include "double-conversion.h" #include "double-conversion.h"
@ -118,7 +118,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
StringBuilder* result_builder) const { StringBuilder* result_builder) const {
// Create a representation that is padded with zeros if needed. // Create a representation that is padded with zeros if needed.
if (decimal_point <= 0) { if (decimal_point <= 0) {
// "0.00000decimal_rep". // "0.00000decimal_rep" or "0.000decimal_rep00".
result_builder->AddCharacter('0'); result_builder->AddCharacter('0');
if (digits_after_point > 0) { if (digits_after_point > 0) {
result_builder->AddCharacter('.'); result_builder->AddCharacter('.');
@ -129,7 +129,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
result_builder->AddPadding('0', remaining_digits); result_builder->AddPadding('0', remaining_digits);
} }
} else if (decimal_point >= length) { } else if (decimal_point >= length) {
// "decimal_rep0000.00000" or "decimal_rep.0000" // "decimal_rep0000.00000" or "decimal_rep.0000".
result_builder->AddSubstring(decimal_digits, length); result_builder->AddSubstring(decimal_digits, length);
result_builder->AddPadding('0', decimal_point - length); result_builder->AddPadding('0', decimal_point - length);
if (digits_after_point > 0) { if (digits_after_point > 0) {
@ -137,7 +137,7 @@ void DoubleToStringConverter::CreateDecimalRepresentation(
result_builder->AddPadding('0', digits_after_point); result_builder->AddPadding('0', digits_after_point);
} }
} else { } else {
// "decima.l_rep000" // "decima.l_rep000".
ASSERT(digits_after_point > 0); ASSERT(digits_after_point > 0);
result_builder->AddSubstring(decimal_digits, decimal_point); result_builder->AddSubstring(decimal_digits, decimal_point);
result_builder->AddCharacter('.'); result_builder->AddCharacter('.');
@ -348,7 +348,6 @@ static BignumDtoaMode DtoaToBignumDtoaMode(
case DoubleToStringConverter::PRECISION: return BIGNUM_DTOA_PRECISION; case DoubleToStringConverter::PRECISION: return BIGNUM_DTOA_PRECISION;
default: default:
UNREACHABLE(); UNREACHABLE();
return BIGNUM_DTOA_SHORTEST; // To silence compiler.
} }
} }
@ -403,8 +402,8 @@ void DoubleToStringConverter::DoubleToAscii(double v,
vector, length, point); vector, length, point);
break; break;
default: default:
UNREACHABLE();
fast_worked = false; fast_worked = false;
UNREACHABLE();
} }
if (fast_worked) return; if (fast_worked) return;
@ -417,8 +416,9 @@ void DoubleToStringConverter::DoubleToAscii(double v,
// Consumes the given substring from the iterator. // Consumes the given substring from the iterator.
// Returns false, if the substring does not match. // Returns false, if the substring does not match.
static bool ConsumeSubString(const char** current, template <class Iterator>
const char* end, static bool ConsumeSubString(Iterator* current,
Iterator end,
const char* substring) { const char* substring) {
ASSERT(**current == *substring); ASSERT(**current == *substring);
for (substring++; *substring != '\0'; substring++) { for (substring++; *substring != '\0'; substring++) {
@ -440,10 +440,36 @@ static bool ConsumeSubString(const char** current,
const int kMaxSignificantDigits = 772; const int kMaxSignificantDigits = 772;
static const char kWhitespaceTable7[] = { 32, 13, 10, 9, 11, 12 };
static const int kWhitespaceTable7Length = ARRAY_SIZE(kWhitespaceTable7);
static const uc16 kWhitespaceTable16[] = {
160, 8232, 8233, 5760, 6158, 8192, 8193, 8194, 8195,
8196, 8197, 8198, 8199, 8200, 8201, 8202, 8239, 8287, 12288, 65279
};
static const int kWhitespaceTable16Length = ARRAY_SIZE(kWhitespaceTable16);
static bool isWhitespace(int x) {
if (x < 128) {
for (int i = 0; i < kWhitespaceTable7Length; i++) {
if (kWhitespaceTable7[i] == x) return true;
}
} else {
for (int i = 0; i < kWhitespaceTable16Length; i++) {
if (kWhitespaceTable16[i] == x) return true;
}
}
return false;
}
// Returns true if a nonspace found and false if the end has reached. // Returns true if a nonspace found and false if the end has reached.
static inline bool AdvanceToNonspace(const char** current, const char* end) { template <class Iterator>
static inline bool AdvanceToNonspace(Iterator* current, Iterator end) {
while (*current != end) { while (*current != end) {
if (**current != ' ') return true; if (!isWhitespace(**current)) return true;
++*current; ++*current;
} }
return false; return false;
@ -462,26 +488,57 @@ static double SignedZero(bool sign) {
} }
// Returns true if 'c' is a decimal digit that is valid for the given radix.
//
// The function is small and could be inlined, but VS2012 emitted a warning
// because it constant-propagated the radix and concluded that the last
// condition was always true. By moving it into a separate function the
// compiler wouldn't warn anymore.
#if _MSC_VER
#pragma optimize("",off)
static bool IsDecimalDigitForRadix(int c, int radix) {
return '0' <= c && c <= '9' && (c - '0') < radix;
}
#pragma optimize("",on)
#else
static bool inline IsDecimalDigitForRadix(int c, int radix) {
return '0' <= c && c <= '9' && (c - '0') < radix;
}
#endif
// Returns true if 'c' is a character digit that is valid for the given radix.
// The 'a_character' should be 'a' or 'A'.
//
// The function is small and could be inlined, but VS2012 emitted a warning
// because it constant-propagated the radix and concluded that the first
// condition was always false. By moving it into a separate function the
// compiler wouldn't warn anymore.
static bool IsCharacterDigitForRadix(int c, int radix, char a_character) {
return radix > 10 && c >= a_character && c < a_character + radix - 10;
}
// Parsing integers with radix 2, 4, 8, 16, 32. Assumes current != end. // Parsing integers with radix 2, 4, 8, 16, 32. Assumes current != end.
template <int radix_log_2> template <int radix_log_2, class Iterator>
static double RadixStringToIeee(const char* current, static double RadixStringToIeee(Iterator* current,
const char* end, Iterator end,
bool sign, bool sign,
bool allow_trailing_junk, bool allow_trailing_junk,
double junk_string_value, double junk_string_value,
bool read_as_double, bool read_as_double,
const char** trailing_pointer) { bool* result_is_junk) {
ASSERT(current != end); ASSERT(*current != end);
const int kDoubleSize = Double::kSignificandSize; const int kDoubleSize = Double::kSignificandSize;
const int kSingleSize = Single::kSignificandSize; const int kSingleSize = Single::kSignificandSize;
const int kSignificandSize = read_as_double? kDoubleSize: kSingleSize; const int kSignificandSize = read_as_double? kDoubleSize: kSingleSize;
*result_is_junk = true;
// Skip leading 0s. // Skip leading 0s.
while (*current == '0') { while (**current == '0') {
++current; ++(*current);
if (current == end) { if (*current == end) {
*trailing_pointer = end; *result_is_junk = false;
return SignedZero(sign); return SignedZero(sign);
} }
} }
@ -492,14 +549,14 @@ static double RadixStringToIeee(const char* current,
do { do {
int digit; int digit;
if (*current >= '0' && *current <= '9' && *current < '0' + radix) { if (IsDecimalDigitForRadix(**current, radix)) {
digit = static_cast<char>(*current) - '0'; digit = static_cast<char>(**current) - '0';
} else if (radix > 10 && *current >= 'a' && *current < 'a' + radix - 10) { } else if (IsCharacterDigitForRadix(**current, radix, 'a')) {
digit = static_cast<char>(*current) - 'a' + 10; digit = static_cast<char>(**current) - 'a' + 10;
} else if (radix > 10 && *current >= 'A' && *current < 'A' + radix - 10) { } else if (IsCharacterDigitForRadix(**current, radix, 'A')) {
digit = static_cast<char>(*current) - 'A' + 10; digit = static_cast<char>(**current) - 'A' + 10;
} else { } else {
if (allow_trailing_junk || !AdvanceToNonspace(&current, end)) { if (allow_trailing_junk || !AdvanceToNonspace(current, end)) {
break; break;
} else { } else {
return junk_string_value; return junk_string_value;
@ -523,14 +580,14 @@ static double RadixStringToIeee(const char* current,
exponent = overflow_bits_count; exponent = overflow_bits_count;
bool zero_tail = true; bool zero_tail = true;
while (true) { for (;;) {
++current; ++(*current);
if (current == end || !isDigit(*current, radix)) break; if (*current == end || !isDigit(**current, radix)) break;
zero_tail = zero_tail && *current == '0'; zero_tail = zero_tail && **current == '0';
exponent += radix_log_2; exponent += radix_log_2;
} }
if (!allow_trailing_junk && AdvanceToNonspace(&current, end)) { if (!allow_trailing_junk && AdvanceToNonspace(current, end)) {
return junk_string_value; return junk_string_value;
} }
@ -552,13 +609,13 @@ static double RadixStringToIeee(const char* current,
} }
break; break;
} }
++current; ++(*current);
} while (current != end); } while (*current != end);
ASSERT(number < ((int64_t)1 << kSignificandSize)); ASSERT(number < ((int64_t)1 << kSignificandSize));
ASSERT(static_cast<int64_t>(static_cast<double>(number)) == number); ASSERT(static_cast<int64_t>(static_cast<double>(number)) == number);
*trailing_pointer = current; *result_is_junk = false;
if (exponent == 0) { if (exponent == 0) {
if (sign) { if (sign) {
@ -573,13 +630,14 @@ static double RadixStringToIeee(const char* current,
} }
template <class Iterator>
double StringToDoubleConverter::StringToIeee( double StringToDoubleConverter::StringToIeee(
const char* input, Iterator input,
int length, int length,
int* processed_characters_count, bool read_as_double,
bool read_as_double) const { int* processed_characters_count) const {
const char* current = input; Iterator current = input;
const char* end = input + length; Iterator end = input + length;
*processed_characters_count = 0; *processed_characters_count = 0;
@ -600,7 +658,7 @@ double StringToDoubleConverter::StringToIeee(
if (allow_leading_spaces || allow_trailing_spaces) { if (allow_leading_spaces || allow_trailing_spaces) {
if (!AdvanceToNonspace(&current, end)) { if (!AdvanceToNonspace(&current, end)) {
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return empty_string_value_; return empty_string_value_;
} }
if (!allow_leading_spaces && (input != current)) { if (!allow_leading_spaces && (input != current)) {
@ -626,7 +684,7 @@ double StringToDoubleConverter::StringToIeee(
if (*current == '+' || *current == '-') { if (*current == '+' || *current == '-') {
sign = (*current == '-'); sign = (*current == '-');
++current; ++current;
const char* next_non_space = current; Iterator next_non_space = current;
// Skip following spaces (if allowed). // Skip following spaces (if allowed).
if (!AdvanceToNonspace(&next_non_space, end)) return junk_string_value_; if (!AdvanceToNonspace(&next_non_space, end)) return junk_string_value_;
if (!allow_spaces_after_sign && (current != next_non_space)) { if (!allow_spaces_after_sign && (current != next_non_space)) {
@ -649,7 +707,7 @@ double StringToDoubleConverter::StringToIeee(
} }
ASSERT(buffer_pos == 0); ASSERT(buffer_pos == 0);
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return sign ? -Double::Infinity() : Double::Infinity(); return sign ? -Double::Infinity() : Double::Infinity();
} }
} }
@ -668,7 +726,7 @@ double StringToDoubleConverter::StringToIeee(
} }
ASSERT(buffer_pos == 0); ASSERT(buffer_pos == 0);
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return sign ? -Double::NaN() : Double::NaN(); return sign ? -Double::NaN() : Double::NaN();
} }
} }
@ -677,7 +735,7 @@ double StringToDoubleConverter::StringToIeee(
if (*current == '0') { if (*current == '0') {
++current; ++current;
if (current == end) { if (current == end) {
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return SignedZero(sign); return SignedZero(sign);
} }
@ -690,17 +748,17 @@ double StringToDoubleConverter::StringToIeee(
return junk_string_value_; // "0x". return junk_string_value_; // "0x".
} }
const char* tail_pointer = NULL; bool result_is_junk;
double result = RadixStringToIeee<4>(current, double result = RadixStringToIeee<4>(&current,
end, end,
sign, sign,
allow_trailing_junk, allow_trailing_junk,
junk_string_value_, junk_string_value_,
read_as_double, read_as_double,
&tail_pointer); &result_is_junk);
if (tail_pointer != NULL) { if (!result_is_junk) {
if (allow_trailing_spaces) AdvanceToNonspace(&tail_pointer, end); if (allow_trailing_spaces) AdvanceToNonspace(&current, end);
*processed_characters_count = tail_pointer - input; *processed_characters_count = static_cast<int>(current - input);
} }
return result; return result;
} }
@ -709,7 +767,7 @@ double StringToDoubleConverter::StringToIeee(
while (*current == '0') { while (*current == '0') {
++current; ++current;
if (current == end) { if (current == end) {
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return SignedZero(sign); return SignedZero(sign);
} }
} }
@ -757,7 +815,7 @@ double StringToDoubleConverter::StringToIeee(
while (*current == '0') { while (*current == '0') {
++current; ++current;
if (current == end) { if (current == end) {
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return SignedZero(sign); return SignedZero(sign);
} }
exponent--; // Move this 0 into the exponent. exponent--; // Move this 0 into the exponent.
@ -801,9 +859,9 @@ double StringToDoubleConverter::StringToIeee(
return junk_string_value_; return junk_string_value_;
} }
} }
char sign = '+'; char exponen_sign = '+';
if (*current == '+' || *current == '-') { if (*current == '+' || *current == '-') {
sign = static_cast<char>(*current); exponen_sign = static_cast<char>(*current);
++current; ++current;
if (current == end) { if (current == end) {
if (allow_trailing_junk) { if (allow_trailing_junk) {
@ -837,7 +895,7 @@ double StringToDoubleConverter::StringToIeee(
++current; ++current;
} while (current != end && *current >= '0' && *current <= '9'); } while (current != end && *current >= '0' && *current <= '9');
exponent += (sign == '-' ? -num : num); exponent += (exponen_sign == '-' ? -num : num);
} }
if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) { if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) {
@ -855,16 +913,17 @@ double StringToDoubleConverter::StringToIeee(
if (octal) { if (octal) {
double result; double result;
const char* tail_pointer = NULL; bool result_is_junk;
result = RadixStringToIeee<3>(buffer, char* start = buffer;
result = RadixStringToIeee<3>(&start,
buffer + buffer_pos, buffer + buffer_pos,
sign, sign,
allow_trailing_junk, allow_trailing_junk,
junk_string_value_, junk_string_value_,
read_as_double, read_as_double,
&tail_pointer); &result_is_junk);
ASSERT(tail_pointer != NULL); ASSERT(!result_is_junk);
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return result; return result;
} }
@ -882,8 +941,42 @@ double StringToDoubleConverter::StringToIeee(
} else { } else {
converted = Strtof(Vector<const char>(buffer, buffer_pos), exponent); converted = Strtof(Vector<const char>(buffer, buffer_pos), exponent);
} }
*processed_characters_count = current - input; *processed_characters_count = static_cast<int>(current - input);
return sign? -converted: converted; return sign? -converted: converted;
} }
double StringToDoubleConverter::StringToDouble(
const char* buffer,
int length,
int* processed_characters_count) const {
return StringToIeee(buffer, length, true, processed_characters_count);
}
double StringToDoubleConverter::StringToDouble(
const uc16* buffer,
int length,
int* processed_characters_count) const {
return StringToIeee(buffer, length, true, processed_characters_count);
}
float StringToDoubleConverter::StringToFloat(
const char* buffer,
int length,
int* processed_characters_count) const {
return static_cast<float>(StringToIeee(buffer, length, false,
processed_characters_count));
}
float StringToDoubleConverter::StringToFloat(
const uc16* buffer,
int length,
int* processed_characters_count) const {
return static_cast<float>(StringToIeee(buffer, length, false,
processed_characters_count));
}
} // namespace kenlm_double_conversion } // namespace kenlm_double_conversion

View File

@ -415,9 +415,10 @@ class StringToDoubleConverter {
// junk, too. // junk, too.
// - ALLOW_TRAILING_JUNK: ignore trailing characters that are not part of // - ALLOW_TRAILING_JUNK: ignore trailing characters that are not part of
// a double literal. // a double literal.
// - ALLOW_LEADING_SPACES: skip over leading spaces. // - ALLOW_LEADING_SPACES: skip over leading whitespace, including spaces,
// - ALLOW_TRAILING_SPACES: ignore trailing spaces. // new-lines, and tabs.
// - ALLOW_SPACES_AFTER_SIGN: ignore spaces after the sign. // - ALLOW_TRAILING_SPACES: ignore trailing whitespace.
// - ALLOW_SPACES_AFTER_SIGN: ignore whitespace after the sign.
// Ex: StringToDouble("- 123.2") -> -123.2. // Ex: StringToDouble("- 123.2") -> -123.2.
// StringToDouble("+ 123.2") -> 123.2 // StringToDouble("+ 123.2") -> 123.2
// //
@ -502,19 +503,24 @@ class StringToDoubleConverter {
// in the 'processed_characters_count'. Trailing junk is never included. // in the 'processed_characters_count'. Trailing junk is never included.
double StringToDouble(const char* buffer, double StringToDouble(const char* buffer,
int length, int length,
int* processed_characters_count) const { int* processed_characters_count) const;
return StringToIeee(buffer, length, processed_characters_count, true);
} // Same as StringToDouble above but for 16 bit characters.
double StringToDouble(const uc16* buffer,
int length,
int* processed_characters_count) const;
// Same as StringToDouble but reads a float. // Same as StringToDouble but reads a float.
// Note that this is not equivalent to static_cast<float>(StringToDouble(...)) // Note that this is not equivalent to static_cast<float>(StringToDouble(...))
// due to potential double-rounding. // due to potential double-rounding.
float StringToFloat(const char* buffer, float StringToFloat(const char* buffer,
int length, int length,
int* processed_characters_count) const { int* processed_characters_count) const;
return static_cast<float>(StringToIeee(buffer, length,
processed_characters_count, false)); // Same as StringToFloat above but for 16 bit characters.
} float StringToFloat(const uc16* buffer,
int length,
int* processed_characters_count) const;
private: private:
const int flags_; const int flags_;
@ -523,10 +529,11 @@ class StringToDoubleConverter {
const char* const infinity_symbol_; const char* const infinity_symbol_;
const char* const nan_symbol_; const char* const nan_symbol_;
double StringToIeee(const char* buffer, template <class Iterator>
double StringToIeee(Iterator start_pointer,
int length, int length,
int* processed_characters_count, bool read_as_double,
bool read_as_double) const; int* processed_characters_count) const;
DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter); DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter);
}; };

View File

@ -248,10 +248,7 @@ static void BiggestPowerTen(uint32_t number,
// Note: kPowersOf10[i] == 10^(i-1). // Note: kPowersOf10[i] == 10^(i-1).
exponent_plus_one_guess++; exponent_plus_one_guess++;
// We don't have any guarantees that 2^number_bits <= number. // We don't have any guarantees that 2^number_bits <= number.
// TODO(floitsch): can we change the 'while' into an 'if'? We definitely see if (number < kSmallPowersOfTen[exponent_plus_one_guess]) {
// number < (2^number_bits - 1), but I haven't encountered
// number < (2^number_bits - 2) yet.
while (number < kSmallPowersOfTen[exponent_plus_one_guess]) {
exponent_plus_one_guess--; exponent_plus_one_guess--;
} }
*power = kSmallPowersOfTen[exponent_plus_one_guess]; *power = kSmallPowersOfTen[exponent_plus_one_guess];
@ -350,7 +347,8 @@ static bool DigitGen(DiyFp low,
// that is smaller than integrals. // that is smaller than integrals.
while (*kappa > 0) { while (*kappa > 0) {
int digit = integrals / divisor; int digit = integrals / divisor;
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
integrals %= divisor; integrals %= divisor;
(*kappa)--; (*kappa)--;
@ -379,13 +377,14 @@ static bool DigitGen(DiyFp low,
ASSERT(one.e() >= -60); ASSERT(one.e() >= -60);
ASSERT(fractionals < one.f()); ASSERT(fractionals < one.f());
ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f()); ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f());
while (true) { for (;;) {
fractionals *= 10; fractionals *= 10;
unit *= 10; unit *= 10;
unsafe_interval.set_f(unsafe_interval.f() * 10); unsafe_interval.set_f(unsafe_interval.f() * 10);
// Integer division by one. // Integer division by one.
int digit = static_cast<int>(fractionals >> -one.e()); int digit = static_cast<int>(fractionals >> -one.e());
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
fractionals &= one.f() - 1; // Modulo by one. fractionals &= one.f() - 1; // Modulo by one.
(*kappa)--; (*kappa)--;
@ -459,7 +458,8 @@ static bool DigitGenCounted(DiyFp w,
// that is smaller than 'integrals'. // that is smaller than 'integrals'.
while (*kappa > 0) { while (*kappa > 0) {
int digit = integrals / divisor; int digit = integrals / divisor;
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
requested_digits--; requested_digits--;
integrals %= divisor; integrals %= divisor;
@ -492,7 +492,8 @@ static bool DigitGenCounted(DiyFp w,
w_error *= 10; w_error *= 10;
// Integer division by one. // Integer division by one.
int digit = static_cast<int>(fractionals >> -one.e()); int digit = static_cast<int>(fractionals >> -one.e());
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
requested_digits--; requested_digits--;
fractionals &= one.f() - 1; // Modulo by one. fractionals &= one.f() - 1; // Modulo by one.

View File

@ -25,7 +25,7 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cmath> #include <math.h>
#include "fixed-dtoa.h" #include "fixed-dtoa.h"
#include "ieee.h" #include "ieee.h"
@ -98,7 +98,7 @@ class UInt128 {
return high_bits_ == 0 && low_bits_ == 0; return high_bits_ == 0 && low_bits_ == 0;
} }
int BitAt(int position) { int BitAt(int position) const {
if (position >= 64) { if (position >= 64) {
return static_cast<int>(high_bits_ >> (position - 64)) & 1; return static_cast<int>(high_bits_ >> (position - 64)) & 1;
} else { } else {
@ -133,7 +133,7 @@ static void FillDigits32(uint32_t number, Vector<char> buffer, int* length) {
while (number != 0) { while (number != 0) {
int digit = number % 10; int digit = number % 10;
number /= 10; number /= 10;
buffer[(*length) + number_length] = '0' + digit; buffer[(*length) + number_length] = static_cast<char>('0' + digit);
number_length++; number_length++;
} }
// Exchange the digits. // Exchange the digits.
@ -150,7 +150,7 @@ static void FillDigits32(uint32_t number, Vector<char> buffer, int* length) {
} }
static void FillDigits64FixedLength(uint64_t number, int requested_length, static void FillDigits64FixedLength(uint64_t number,
Vector<char> buffer, int* length) { Vector<char> buffer, int* length) {
const uint32_t kTen7 = 10000000; const uint32_t kTen7 = 10000000;
// For efficiency cut the number into 3 uint32_t parts, and print those. // For efficiency cut the number into 3 uint32_t parts, and print those.
@ -253,12 +253,14 @@ static void FillFractionals(uint64_t fractionals, int exponent,
fractionals *= 5; fractionals *= 5;
point--; point--;
int digit = static_cast<int>(fractionals >> point); int digit = static_cast<int>(fractionals >> point);
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
fractionals -= static_cast<uint64_t>(digit) << point; fractionals -= static_cast<uint64_t>(digit) << point;
} }
// If the first bit after the point is set we have to round up. // If the first bit after the point is set we have to round up.
if (((fractionals >> (point - 1)) & 1) == 1) { ASSERT(fractionals == 0 || point - 1 >= 0);
if ((fractionals != 0) && ((fractionals >> (point - 1)) & 1) == 1) {
RoundUp(buffer, length, decimal_point); RoundUp(buffer, length, decimal_point);
} }
} else { // We need 128 bits. } else { // We need 128 bits.
@ -274,7 +276,8 @@ static void FillFractionals(uint64_t fractionals, int exponent,
fractionals128.Multiply(5); fractionals128.Multiply(5);
point--; point--;
int digit = fractionals128.DivModPowerOf2(point); int digit = fractionals128.DivModPowerOf2(point);
buffer[*length] = '0' + digit; ASSERT(digit <= 9);
buffer[*length] = static_cast<char>('0' + digit);
(*length)++; (*length)++;
} }
if (fractionals128.BitAt(point - 1) == 1) { if (fractionals128.BitAt(point - 1) == 1) {
@ -358,7 +361,7 @@ bool FastFixedDtoa(double v,
remainder = (dividend % divisor) << exponent; remainder = (dividend % divisor) << exponent;
} }
FillDigits32(quotient, buffer, length); FillDigits32(quotient, buffer, length);
FillDigits64FixedLength(remainder, divisor_power, buffer, length); FillDigits64FixedLength(remainder, buffer, length);
*decimal_point = *length; *decimal_point = *length;
} else if (exponent >= 0) { } else if (exponent >= 0) {
// 0 <= exponent <= 11 // 0 <= exponent <= 11

View File

@ -99,7 +99,7 @@ class Double {
} }
double PreviousDouble() const { double PreviousDouble() const {
if (d64_ == (kInfinity | kSignMask)) return -Double::Infinity(); if (d64_ == (kInfinity | kSignMask)) return -Infinity();
if (Sign() < 0) { if (Sign() < 0) {
return Double(d64_ + 1).value(); return Double(d64_ + 1).value();
} else { } else {
@ -256,6 +256,8 @@ class Double {
return (significand & kSignificandMask) | return (significand & kSignificandMask) |
(biased_exponent << kPhysicalSignificandSize); (biased_exponent << kPhysicalSignificandSize);
} }
DISALLOW_COPY_AND_ASSIGN(Double);
}; };
class Single { class Single {
@ -391,6 +393,8 @@ class Single {
static const uint32_t kNaN = 0x7FC00000; static const uint32_t kNaN = 0x7FC00000;
const uint32_t d32_; const uint32_t d32_;
DISALLOW_COPY_AND_ASSIGN(Single);
}; };
} // namespace kenlm_double_conversion } // namespace kenlm_double_conversion

View File

@ -25,8 +25,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cstdarg> #include <stdarg.h>
#include <climits> #include <limits.h>
#include "strtod.h" #include "strtod.h"
#include "bignum.h" #include "bignum.h"
@ -137,6 +137,7 @@ static void TrimAndCut(Vector<const char> buffer, int exponent,
Vector<const char> right_trimmed = TrimTrailingZeros(left_trimmed); Vector<const char> right_trimmed = TrimTrailingZeros(left_trimmed);
exponent += left_trimmed.length() - right_trimmed.length(); exponent += left_trimmed.length() - right_trimmed.length();
if (right_trimmed.length() > kMaxSignificantDecimalDigits) { if (right_trimmed.length() > kMaxSignificantDecimalDigits) {
(void) space_size; // Mark variable as used.
ASSERT(space_size >= kMaxSignificantDecimalDigits); ASSERT(space_size >= kMaxSignificantDecimalDigits);
CutToMaxSignificantDigits(right_trimmed, exponent, CutToMaxSignificantDigits(right_trimmed, exponent,
buffer_copy_space, updated_exponent); buffer_copy_space, updated_exponent);
@ -263,7 +264,6 @@ static DiyFp AdjustmentPowerOfTen(int exponent) {
case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40); case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40);
default: default:
UNREACHABLE(); UNREACHABLE();
return DiyFp(0, 0);
} }
} }
@ -286,7 +286,7 @@ static bool DiyFpStrtod(Vector<const char> buffer,
const int kDenominator = 1 << kDenominatorLog; const int kDenominator = 1 << kDenominatorLog;
// Move the remaining decimals into the exponent. // Move the remaining decimals into the exponent.
exponent += remaining_decimals; exponent += remaining_decimals;
int error = (remaining_decimals == 0 ? 0 : kDenominator / 2); uint64_t error = (remaining_decimals == 0 ? 0 : kDenominator / 2);
int old_e = input.e(); int old_e = input.e();
input.Normalize(); input.Normalize();
@ -506,9 +506,7 @@ float Strtof(Vector<const char> buffer, int exponent) {
double double_previous = Double(double_guess).PreviousDouble(); double double_previous = Double(double_guess).PreviousDouble();
float f1 = static_cast<float>(double_previous); float f1 = static_cast<float>(double_previous);
#ifndef NDEBUG
float f2 = float_guess; float f2 = float_guess;
#endif
float f3 = static_cast<float>(double_next); float f3 = static_cast<float>(double_next);
float f4; float f4;
if (is_correct) { if (is_correct) {
@ -517,9 +515,8 @@ float Strtof(Vector<const char> buffer, int exponent) {
double double_next2 = Double(double_next).NextDouble(); double double_next2 = Double(double_next).NextDouble();
f4 = static_cast<float>(double_next2); f4 = static_cast<float>(double_next2);
} }
#ifndef NDEBUG (void) f2; // Mark variable as used.
ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4); ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4);
#endif
// If the guess doesn't lie near a single-precision boundary we can simply // If the guess doesn't lie near a single-precision boundary we can simply
// return its float-value. // return its float-value.

View File

@ -33,14 +33,29 @@
#include <assert.h> #include <assert.h>
#ifndef ASSERT #ifndef ASSERT
#define ASSERT(condition) (assert(condition)) #define ASSERT(condition) \
assert(condition);
#endif #endif
#ifndef UNIMPLEMENTED #ifndef UNIMPLEMENTED
#define UNIMPLEMENTED() (abort()) #define UNIMPLEMENTED() (abort())
#endif #endif
#ifndef DOUBLE_CONVERSION_NO_RETURN
#ifdef _MSC_VER
#define DOUBLE_CONVERSION_NO_RETURN __declspec(noreturn)
#else
#define DOUBLE_CONVERSION_NO_RETURN __attribute__((noreturn))
#endif
#endif
#ifndef UNREACHABLE #ifndef UNREACHABLE
#ifdef _MSC_VER
void DOUBLE_CONVERSION_NO_RETURN abort_noreturn();
inline void abort_noreturn() { abort(); }
#define UNREACHABLE() (abort_noreturn())
#else
#define UNREACHABLE() (abort()) #define UNREACHABLE() (abort())
#endif #endif
#endif
// Double operations detection based on target architecture. // Double operations detection based on target architecture.
// Linux uses a 80bit wide floating point stack on x86. This induces double // Linux uses a 80bit wide floating point stack on x86. This induces double
@ -55,11 +70,17 @@
#if defined(_M_X64) || defined(__x86_64__) || \ #if defined(_M_X64) || defined(__x86_64__) || \
defined(__ARMEL__) || defined(__avr32__) || \ defined(__ARMEL__) || defined(__avr32__) || \
defined(__hppa__) || defined(__ia64__) || \ defined(__hppa__) || defined(__ia64__) || \
defined(__mips__) || defined(__powerpc__) || \ defined(__mips__) || \
defined(__powerpc__) || defined(__ppc__) || defined(__ppc64__) || \
defined(_POWER) || defined(_ARCH_PPC) || defined(_ARCH_PPC64) || \
defined(__sparc__) || defined(__sparc) || defined(__s390__) || \ defined(__sparc__) || defined(__sparc) || defined(__s390__) || \
defined(__SH4__) || defined(__alpha__) || \ defined(__SH4__) || defined(__alpha__) || \
defined(_MIPS_ARCH_MIPS32R2) || defined(__aarch64__) defined(_MIPS_ARCH_MIPS32R2) || \
defined(__AARCH64EL__) || defined(__aarch64__) || \
defined(__riscv)
#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1 #define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1
#elif defined(__mc68000__)
#undef DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS
#elif defined(_M_IX86) || defined(__i386__) || defined(__i386) #elif defined(_M_IX86) || defined(__i386__) || defined(__i386)
#if defined(_WIN32) #if defined(_WIN32)
// Windows uses a 64bit wide floating point stack. // Windows uses a 64bit wide floating point stack.
@ -71,6 +92,11 @@
#error Target architecture was not detected as supported by Double-Conversion. #error Target architecture was not detected as supported by Double-Conversion.
#endif #endif
#if defined(__GNUC__)
#define DOUBLE_CONVERSION_UNUSED __attribute__((unused))
#else
#define DOUBLE_CONVERSION_UNUSED
#endif
#if defined(_WIN32) && !defined(__MINGW32__) #if defined(_WIN32) && !defined(__MINGW32__)
@ -90,6 +116,8 @@ typedef unsigned __int64 uint64_t;
#endif #endif
typedef uint16_t uc16;
// The following macro works on both 32 and 64-bit platforms. // The following macro works on both 32 and 64-bit platforms.
// Usage: instead of writing 0x1234567890123456 // Usage: instead of writing 0x1234567890123456
// write UINT64_2PART_C(0x12345678,90123456); // write UINT64_2PART_C(0x12345678,90123456);
@ -155,8 +183,8 @@ template <typename T>
class Vector { class Vector {
public: public:
Vector() : start_(NULL), length_(0) {} Vector() : start_(NULL), length_(0) {}
Vector(T* data, int length) : start_(data), length_(length) { Vector(T* data, int len) : start_(data), length_(len) {
ASSERT(length == 0 || (length > 0 && data != NULL)); ASSERT(len == 0 || (len > 0 && data != NULL));
} }
// Returns a vector using the same backing storage as this one, // Returns a vector using the same backing storage as this one,
@ -198,8 +226,8 @@ class Vector {
// buffer bounds on all operations in debug mode. // buffer bounds on all operations in debug mode.
class StringBuilder { class StringBuilder {
public: public:
StringBuilder(char* buffer, int size) StringBuilder(char* buffer, int buffer_size)
: buffer_(buffer, size), position_(0) { } : buffer_(buffer, buffer_size), position_(0) { }
~StringBuilder() { if (!is_finalized()) Finalize(); } ~StringBuilder() { if (!is_finalized()) Finalize(); }
@ -218,8 +246,7 @@ class StringBuilder {
// 0-characters; use the Finalize() method to terminate the string // 0-characters; use the Finalize() method to terminate the string
// instead. // instead.
void AddCharacter(char c) { void AddCharacter(char c) {
// I just extract raw data not a cstr so null is fine. ASSERT(c != '\0');
//ASSERT(c != '\0');
ASSERT(!is_finalized() && position_ < buffer_.length()); ASSERT(!is_finalized() && position_ < buffer_.length());
buffer_[position_++] = c; buffer_[position_++] = c;
} }
@ -234,8 +261,7 @@ class StringBuilder {
// builder. The input string must have enough characters. // builder. The input string must have enough characters.
void AddSubstring(const char* s, int n) { void AddSubstring(const char* s, int n) {
ASSERT(!is_finalized() && position_ + n < buffer_.length()); ASSERT(!is_finalized() && position_ + n < buffer_.length());
// I just extract raw data not a cstr so null is fine. ASSERT(static_cast<size_t>(n) <= strlen(s));
//ASSERT(static_cast<size_t>(n) <= strlen(s));
memmove(&buffer_[position_], s, n * kCharSize); memmove(&buffer_[position_], s, n * kCharSize);
position_ += n; position_ += n;
} }
@ -255,8 +281,7 @@ class StringBuilder {
buffer_[position_] = '\0'; buffer_[position_] = '\0';
// Make sure nobody managed to add a 0-character to the // Make sure nobody managed to add a 0-character to the
// buffer while building the string. // buffer while building the string.
// I just extract raw data not a cstr so null is fine. ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));
//ASSERT(strlen(buffer_.start()) == static_cast<size_t>(position_));
position_ = -1; position_ = -1;
ASSERT(is_finalized()); ASSERT(is_finalized());
return buffer_.start(); return buffer_.start();
@ -299,11 +324,8 @@ template <class Dest, class Source>
inline Dest BitCast(const Source& source) { inline Dest BitCast(const Source& source) {
// Compile time assertion: sizeof(Dest) == sizeof(Source) // Compile time assertion: sizeof(Dest) == sizeof(Source)
// A compile error here means your Dest and Source have different sizes. // A compile error here means your Dest and Source have different sizes.
typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1] DOUBLE_CONVERSION_UNUSED
#if __GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 8 typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1];
__attribute__((unused))
#endif
;
Dest dest; Dest dest;
memmove(&dest, &source, sizeof(dest)); memmove(&dest, &source, sizeof(dest));

View File

@ -134,7 +134,7 @@ class OverflowException : public Exception {
template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) { template <unsigned len> inline std::size_t CheckOverflowInternal(uint64_t value) {
UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code."); UTIL_THROW_IF(value > static_cast<uint64_t>(std::numeric_limits<std::size_t>::max()), OverflowException, "Integer overflow detected. This model is too big for 32-bit code.");
return value; return static_cast<std::size_t>(value);
} }
template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) { template <> inline std::size_t CheckOverflowInternal<8>(uint64_t value) {

View File

@ -490,7 +490,7 @@ int
mkstemp_and_unlink(char *tmpl) { mkstemp_and_unlink(char *tmpl) {
int ret = mkstemp(tmpl); int ret = mkstemp(tmpl);
if (ret != -1) { if (ret != -1) {
UTIL_THROW_IF(unlink(tmpl), ErrnoException, "while deleting delete " << tmpl); UTIL_THROW_IF(unlink(tmpl), ErrnoException, "while deleting " << tmpl);
} }
return ret; return ret;
} }

View File

@ -103,7 +103,7 @@ class FilePiece {
if (position_ == position_end_) { if (position_ == position_end_) {
try { try {
Shift(); Shift();
} catch (const util::EndOfFileException &e) { return false; } } catch (const util::EndOfFileException &) { return false; }
// And break out at end of file. // And break out at end of file.
if (position_ == position_end_) return false; if (position_ == position_end_) return false;
} }

View File

@ -142,7 +142,7 @@ void UnmapOrThrow(void *start, size_t length) {
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)
UTIL_THROW_IF(!::UnmapViewOfFile(start), ErrnoException, "Failed to unmap a file"); UTIL_THROW_IF(!::UnmapViewOfFile(start), ErrnoException, "Failed to unmap a file");
#else #else
UTIL_THROW_IF(munmap(start, length), ErrnoException, "munmap failed"); UTIL_THROW_IF(munmap(start, length), ErrnoException, "munmap failed with " << start << " for length " << length);
#endif #endif
} }

View File

@ -30,7 +30,7 @@ class DivMod {
public: public:
explicit DivMod(std::size_t buckets) : buckets_(buckets) {} explicit DivMod(std::size_t buckets) : buckets_(buckets) {}
static std::size_t RoundBuckets(std::size_t from) { static uint64_t RoundBuckets(uint64_t from) {
return from; return from;
} }
@ -58,7 +58,7 @@ class Power2Mod {
} }
// Round up to next power of 2. // Round up to next power of 2.
static std::size_t RoundBuckets(std::size_t from) { static uint64_t RoundBuckets(uint64_t from) {
--from; --from;
from |= from >> 1; from |= from >> 1;
from |= from >> 2; from |= from >> 2;

View File

@ -5,10 +5,9 @@
#include "util/spaces.hh" #include "util/spaces.hh"
#include "util/string_piece.hh" #include "util/string_piece.hh"
#include <boost/iterator/iterator_facade.hpp>
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <iterator>
namespace util { namespace util {
@ -97,12 +96,12 @@ class AnyCharacterLast {
StringPiece chars_; StringPiece chars_;
}; };
template <class Find, bool SkipEmpty = false> class TokenIter : public boost::iterator_facade<TokenIter<Find, SkipEmpty>, const StringPiece, boost::forward_traversal_tag> { template <class Find, bool SkipEmpty = false> class TokenIter : public std::iterator<std::forward_iterator_tag, const StringPiece, std::ptrdiff_t, const StringPiece *, const StringPiece &> {
public: public:
TokenIter() {} TokenIter() {}
template <class Construct> TokenIter(const StringPiece &str, const Construct &construct) : after_(str), finder_(construct) { template <class Construct> TokenIter(const StringPiece &str, const Construct &construct) : after_(str), finder_(construct) {
increment(); ++*this;
} }
bool operator!() const { bool operator!() const {
@ -116,10 +115,15 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
return TokenIter<Find, SkipEmpty>(); return TokenIter<Find, SkipEmpty>();
} }
private: bool operator==(const TokenIter<Find, SkipEmpty> &other) const {
friend class boost::iterator_core_access; return current_.data() == other.current_.data();
}
void increment() { bool operator!=(const TokenIter<Find, SkipEmpty> &other) const {
return !(*this == other);
}
TokenIter<Find, SkipEmpty> &operator++() {
do { do {
StringPiece found(finder_.Find(after_)); StringPiece found(finder_.Find(after_));
current_ = StringPiece(after_.data(), found.data() - after_.data()); current_ = StringPiece(after_.data(), found.data() - after_.data());
@ -129,17 +133,25 @@ template <class Find, bool SkipEmpty = false> class TokenIter : public boost::it
after_ = StringPiece(found.data() + found.size(), after_.data() - found.data() + after_.size() - found.size()); after_ = StringPiece(found.data() + found.size(), after_.data() - found.data() + after_.size() - found.size());
} }
} while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false.
return *this;
} }
bool equal(const TokenIter<Find, SkipEmpty> &other) const { TokenIter<Find, SkipEmpty> &operator++(int) {
return current_.data() == other.current_.data(); TokenIter<Find, SkipEmpty> ret(*this);
++*this;
return ret;
} }
const StringPiece &dereference() const { const StringPiece &operator*() const {
UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens"); UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens");
return current_; return current_;
} }
const StringPiece *operator->() const {
UTIL_THROW_IF(!current_.data(), OutOfTokens, "Ran out of tokens");
return &current_;
}
private:
StringPiece current_; StringPiece current_;
StringPiece after_; StringPiece after_;

View File

@ -16,7 +16,7 @@ struct ModelState {
static constexpr unsigned int BATCH_SIZE = 1; static constexpr unsigned int BATCH_SIZE = 1;
Alphabet alphabet_; Alphabet alphabet_;
std::unique_ptr<Scorer> scorer_; std::shared_ptr<Scorer> scorer_;
unsigned int beam_width_; unsigned int beam_width_;
unsigned int n_steps_; unsigned int n_steps_;
unsigned int n_context_; unsigned int n_context_;

View File

@ -21,7 +21,6 @@ import deepspeech
# rename for backwards compatibility # rename for backwards compatibility
from deepspeech.impl import PrintVersions as printVersions from deepspeech.impl import PrintVersions as printVersions
from deepspeech.impl import FreeStream as freeStream
class Model(object): class Model(object):
""" """
@ -56,127 +55,163 @@ class Model(object):
""" """
return deepspeech.impl.GetModelSampleRate(self._impl) return deepspeech.impl.GetModelSampleRate(self._impl)
def enableDecoderWithLM(self, *args, **kwargs): def enableExternalScorer(self, scorer_path):
""" """
Enable decoding using beam scoring with a KenLM language model. Enable decoding using an external scorer.
:param aLMPath: The path to the language model binary file. :param scorer_path: The path to the external scorer file.
:type aLMPath: str :type scorer_path: str
:param aTriePath: The path to the trie file build from the same vocabulary as the language model binary. :return: Zero on success, non-zero on failure.
:type aTriePath: str
:param aLMAlpha: The alpha hyperparameter of the CTC decoder. Language Model weight.
:type aLMAlpha: float
:param aLMBeta: The beta hyperparameter of the CTC decoder. Word insertion weight.
:type aLMBeta: float
:return: Zero on success, non-zero on failure (invalid arguments).
:type: int :type: int
""" """
return deepspeech.impl.EnableDecoderWithLM(self._impl, *args, **kwargs) return deepspeech.impl.EnableExternalScorer(self._impl, scorer_path)
def stt(self, *args, **kwargs): def disableExternalScorer(self):
"""
Disable decoding using an external scorer.
:return: Zero on success, non-zero on failure.
"""
return deepspeech.impl.DisableExternalScorer(self._impl)
def setScorerAlphaBeta(self, alpha, beta):
"""
Set hyperparameters alpha and beta of the external scorer.
:param alpha: The alpha hyperparameter of the decoder. Language model weight.
:type alpha: float
:param beta: The beta hyperparameter of the decoder. Word insertion weight.
:type beta: float
:return: Zero on success, non-zero on failure.
:type: int
"""
return deepspeech.impl.SetScorerAlphaBeta(self._impl, alpha, beta)
def stt(self, audio_buffer):
""" """
Use the DeepSpeech model to perform Speech-To-Text. Use the DeepSpeech model to perform Speech-To-Text.
:param aBuffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type aBuffer: int array :type audio_buffer: numpy.int16 array
:param aBufferSize: The number of samples in the audio signal.
:type aBufferSize: int
:return: The STT result. :return: The STT result.
:type: str :type: str
""" """
return deepspeech.impl.SpeechToText(self._impl, *args, **kwargs) return deepspeech.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, *args, **kwargs): def sttWithMetadata(self, audio_buffer):
""" """
Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results. Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results.
:param aBuffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type aBuffer: int array :type audio_buffer: numpy.int16 array
:param aBufferSize: The number of samples in the audio signal.
:type aBufferSize: int
:return: Outputs a struct of individual letters along with their timing information. :return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata` :type: :func:`Metadata`
""" """
return deepspeech.impl.SpeechToTextWithMetadata(self._impl, *args, **kwargs) return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer)
def createStream(self): def createStream(self):
""" """
Create a new streaming inference state. The streaming state returned Create a new streaming inference state. The streaming state returned by
by this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`. this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`.
:return: Object holding the stream :return: Stream object representing the newly created stream
:type: :func:`Stream`
:throws: RuntimeError on error :throws: RuntimeError on error
""" """
status, ctx = deepspeech.impl.CreateStream(self._impl) status, ctx = deepspeech.impl.CreateStream(self._impl)
if status != 0: if status != 0:
raise RuntimeError("CreateStream failed with error code {}".format(status)) raise RuntimeError("CreateStream failed with error code {}".format(status))
return ctx return Stream(ctx)
# pylint: disable=no-self-use
def feedAudioContent(self, *args, **kwargs): class Stream(object):
"""
Class wrapping a DeepSpeech stream. The constructor cannot be called directly.
Use :func:`Model.createStream()`
"""
def __init__(self, native_stream):
self._impl = native_stream
def __del__(self):
if self._impl:
self.freeStream()
def feedAudioContent(self, audio_buffer):
""" """
Feed audio samples to an ongoing streaming inference. Feed audio samples to an ongoing streaming inference.
:param aSctx: A streaming state pointer returned by :func:`createStream()`. :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type aSctx: object :type audio_buffer: numpy.int16 array
:param aBuffer: An array of 16-bit, mono raw audio samples at the appropriate sample rate (matching what the model was trained on). :throws: RuntimeError if the stream object is not valid
:type aBuffer: int array
:param aBufferSize: The number of samples in @p aBuffer.
:type aBufferSize: int
""" """
deepspeech.impl.FeedAudioContent(*args, **kwargs) if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to feed an already finished stream?")
deepspeech.impl.FeedAudioContent(self._impl, audio_buffer)
# pylint: disable=no-self-use def intermediateDecode(self):
def intermediateDecode(self, *args, **kwargs):
""" """
Compute the intermediate decoding of an ongoing streaming inference. Compute the intermediate decoding of an ongoing streaming inference.
:param aSctx: A streaming state pointer returned by :func:`createStream()`.
:type aSctx: object
:return: The STT intermediate result. :return: The STT intermediate result.
:type: str :type: str
"""
return deepspeech.impl.IntermediateDecode(*args, **kwargs)
# pylint: disable=no-self-use :throws: RuntimeError if the stream object is not valid
def finishStream(self, *args, **kwargs):
""" """
Signal the end of an audio signal to an ongoing streaming if not self._impl:
inference, returns the STT result over the whole audio signal. raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?")
return deepspeech.impl.IntermediateDecode(self._impl)
:param aSctx: A streaming state pointer returned by :func:`createStream()`. def finishStream(self):
:type aSctx: object """
Signal the end of an audio signal to an ongoing streaming inference,
returns the STT result over the whole audio signal.
:return: The STT result. :return: The STT result.
:type: str :type: str
"""
return deepspeech.impl.FinishStream(*args, **kwargs)
# pylint: disable=no-self-use :throws: RuntimeError if the stream object is not valid
def finishStreamWithMetadata(self, *args, **kwargs):
""" """
Signal the end of an audio signal to an ongoing streaming if not self._impl:
inference, returns per-letter metadata. raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStream(self._impl)
self._impl = None
return result
:param aSctx: A streaming state pointer returned by :func:`createStream()`. def finishStreamWithMetadata(self):
:type aSctx: object """
Signal the end of an audio signal to an ongoing streaming inference,
returns per-letter metadata.
:return: Outputs a struct of individual letters along with their timing information. :return: Outputs a struct of individual letters along with their timing information.
:type: :func:`Metadata` :type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
""" """
return deepspeech.impl.FinishStreamWithMetadata(*args, **kwargs) if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?")
result = deepspeech.impl.FinishStreamWithMetadata(self._impl)
self._impl = None
return result
def freeStream(self):
"""
Destroy a streaming state without decoding the computed logits. This can
be used if you no longer need the result of an ongoing streaming inference.
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError("Stream object is not valid. Trying to free an already finished stream?")
deepspeech.impl.FreeStream(self._impl)
self._impl = None
# This is only for documentation purpose # This is only for documentation purpose
# Metadata and MetadataItem should be in sync with native_client/deepspeech.h # Metadata and MetadataItem should be in sync with native_client/deepspeech.h
@ -189,22 +224,18 @@ class MetadataItem(object):
""" """
The character generated for transcription The character generated for transcription
""" """
# pylint: disable=unnecessary-pass
pass
def timestep(self): def timestep(self):
""" """
Position of the character in units of 20ms Position of the character in units of 20ms
""" """
# pylint: disable=unnecessary-pass
pass
def start_time(self): def start_time(self):
""" """
Position of the character in seconds Position of the character in seconds
""" """
# pylint: disable=unnecessary-pass
pass
class Metadata(object): class Metadata(object):
@ -218,8 +249,7 @@ class Metadata(object):
:return: A list of :func:`MetadataItem` elements :return: A list of :func:`MetadataItem` elements
:type: list :type: list
""" """
# pylint: disable=unnecessary-pass
pass
def num_items(self): def num_items(self):
""" """
@ -228,8 +258,7 @@ class Metadata(object):
:return: Size of the list of items :return: Size of the list of items
:type: int :type: int
""" """
# pylint: disable=unnecessary-pass
pass
def confidence(self): def confidence(self):
""" """
@ -237,5 +266,4 @@ class Metadata(object):
sum of the acoustic model logit values for each timestep/character that sum of the acoustic model logit values for each timestep/character that
contributed to the creation of this transcription. contributed to the creation of this transcription.
""" """
# pylint: disable=unnecessary-pass
pass

View File

@ -72,7 +72,7 @@ def metadata_json_output(metadata):
json_result["words"] = words_from_metadata(metadata) json_result["words"] = words_from_metadata(metadata)
json_result["confidence"] = metadata.confidence json_result["confidence"] = metadata.confidence
return json.dumps(json_result) return json.dumps(json_result)
class VersionAction(argparse.Action): class VersionAction(argparse.Action):
@ -88,18 +88,16 @@ def main():
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
parser.add_argument('--model', required=True, parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)') help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', nargs='?', parser.add_argument('--scorer', required=False,
help='Path to the language model binary file') help='Path to the external scorer file')
parser.add_argument('--trie', nargs='?',
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--audio', required=True, parser.add_argument('--audio', required=True,
help='Path to the audio file to run (WAV format)') help='Path to the audio file to run (WAV format)')
parser.add_argument('--beam_width', type=int, default=500, parser.add_argument('--beam_width', type=int, default=500,
help='Beam width for the CTC decoder') help='Beam width for the CTC decoder')
parser.add_argument('--lm_alpha', type=float, default=0.75, parser.add_argument('--lm_alpha', type=float,
help='Language model weight (lm_alpha)') help='Language model weight (lm_alpha). If not specified, use default from the scorer package.')
parser.add_argument('--lm_beta', type=float, default=1.85, parser.add_argument('--lm_beta', type=float,
help='Word insertion bonus (lm_beta)') help='Word insertion bonus (lm_beta). If not specified, use default from the scorer package.')
parser.add_argument('--version', action=VersionAction, parser.add_argument('--version', action=VersionAction,
help='Print version and exits') help='Print version and exits')
parser.add_argument('--extended', required=False, action='store_true', parser.add_argument('--extended', required=False, action='store_true',
@ -116,12 +114,15 @@ def main():
desired_sample_rate = ds.sampleRate() desired_sample_rate = ds.sampleRate()
if args.lm and args.trie: if args.scorer:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr) print('Loading scorer from files {}'.format(args.scorer), file=sys.stderr)
lm_load_start = timer() scorer_load_start = timer()
ds.enableDecoderWithLM(args.lm, args.trie, args.lm_alpha, args.lm_beta) ds.enableExternalScorer(args.scorer)
lm_load_end = timer() - lm_load_start scorer_load_end = timer() - scorer_load_start
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr) print('Loaded scorer in {:.3}s.'.format(scorer_load_end), file=sys.stderr)
if args.lm_alpha and args.lm_beta:
ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta)
fin = wave.open(args.audio, 'rb') fin = wave.open(args.audio, 'rb')
fs = fin.getframerate() fs = fin.getframerate()

View File

@ -14,21 +14,13 @@ from deepspeech import Model
# Beam width used in the CTC decoder when building candidate transcriptions # Beam width used in the CTC decoder when building candidate transcriptions
BEAM_WIDTH = 500 BEAM_WIDTH = 500
# The alpha hyperparameter of the CTC decoder. Language Model weight
LM_ALPHA = 0.75
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
LM_BETA = 1.85
def main(): def main():
parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') parser = argparse.ArgumentParser(description='Running DeepSpeech inference.')
parser.add_argument('--model', required=True, parser.add_argument('--model', required=True,
help='Path to the model (protocol buffer binary file)') help='Path to the model (protocol buffer binary file)')
parser.add_argument('--lm', nargs='?', parser.add_argument('--scorer', nargs='?',
help='Path to the language model binary file') help='Path to the external scorer file')
parser.add_argument('--trie', nargs='?',
help='Path to the language model trie file created with native_client/generate_trie')
parser.add_argument('--audio1', required=True, parser.add_argument('--audio1', required=True,
help='First audio file to use in interleaved streams') help='First audio file to use in interleaved streams')
parser.add_argument('--audio2', required=True, parser.add_argument('--audio2', required=True,
@ -37,8 +29,8 @@ def main():
ds = Model(args.model, BEAM_WIDTH) ds = Model(args.model, BEAM_WIDTH)
if args.lm and args.trie: if args.scorer:
ds.enableDecoderWithLM(args.lm, args.trie, LM_ALPHA, LM_BETA) ds.enableExternalScorer(args.scorer)
fin = wave.open(args.audio1, 'rb') fin = wave.open(args.audio1, 'rb')
fs1 = fin.getframerate() fs1 = fin.getframerate()
@ -57,11 +49,11 @@ def main():
splits2 = np.array_split(audio2, 10) splits2 = np.array_split(audio2, 10)
for part1, part2 in zip(splits1, splits2): for part1, part2 in zip(splits1, splits2):
ds.feedAudioContent(stream1, part1) stream1.feedAudioContent(part1)
ds.feedAudioContent(stream2, part2) stream2.feedAudioContent(part2)
print(ds.finishStream(stream1)) print(stream1.finishStream())
print(ds.finishStream(stream2)) print(stream2.finishStream())
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -27,9 +27,9 @@ int main(int argc, char** argv)
return err; return err;
} }
Scorer scorer; Scorer scorer;
err = scorer.init(kenlm_path, alphabet);
#ifndef DEBUG #ifndef DEBUG
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet); return err;
#else #else
// Print some info about the FST // Print some info about the FST
using FstType = fst::ConstFst<fst::StdArc>; using FstType = fst::ConstFst<fst::StdArc>;
@ -60,7 +60,6 @@ int main(int argc, char** argv)
// for (int i = 1; i < 10; ++i) { // for (int i = 1; i < 10; ++i) {
// print_states_from(i); // print_states_from(i);
// } // }
#endif // DEBUG
return 0; return 0;
#endif // DEBUG
} }

View File

@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
BAZEL_TARGETS=" BAZEL_TARGETS="
//native_client:libdeepspeech.so //native_client:libdeepspeech.so
//native_client:generate_trie
" "
BAZEL_BUILD_FLAGS="${BAZEL_ARM64_FLAGS} ${BAZEL_EXTRA_FLAGS}" BAZEL_BUILD_FLAGS="${BAZEL_ARM64_FLAGS} ${BAZEL_EXTRA_FLAGS}"

View File

@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
BAZEL_TARGETS=" BAZEL_TARGETS="
//native_client:libdeepspeech.so //native_client:libdeepspeech.so
//native_client:generate_trie
" "
BAZEL_ENV_FLAGS="TF_NEED_CUDA=1 ${TF_CUDA_FLAGS}" BAZEL_ENV_FLAGS="TF_NEED_CUDA=1 ${TF_CUDA_FLAGS}"

View File

@ -30,11 +30,11 @@ then:
image: ${build.docker_image} image: ${build.docker_image}
env: env:
DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.6.0-alpha.15/models.tar.gz" DEEPSPEECH_MODEL: "https://github.com/reuben/DeepSpeech/releases/download/v0.7.0-alpha.1/models.tar.gz"
DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz" DEEPSPEECH_AUDIO: "https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/audio-0.4.1.tar.gz"
PIP_DEFAULT_TIMEOUT: "60" PIP_DEFAULT_TIMEOUT: "60"
EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples" EXAMPLES_CLONE_URL: "https://github.com/mozilla/DeepSpeech-examples"
EXAMPLES_CHECKOUT_TARGET: "master" EXAMPLES_CHECKOUT_TARGET: "4b97ac41d03ca0d23fa92526433db72a90f47d4a"
command: command:
- "/bin/bash" - "/bin/bash"

View File

@ -10,7 +10,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
BAZEL_TARGETS=" BAZEL_TARGETS="
//native_client:libdeepspeech.so //native_client:libdeepspeech.so
//native_client:generate_trie
" "
if [ "${runtime}" = "tflite" ]; then if [ "${runtime}" = "tflite" ]; then

View File

@ -8,7 +8,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
BAZEL_TARGETS=" BAZEL_TARGETS="
//native_client:libdeepspeech.so //native_client:libdeepspeech.so
//native_client:generate_trie
" "
BAZEL_BUILD_FLAGS="${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS}" BAZEL_BUILD_FLAGS="${BAZEL_ARM_FLAGS} ${BAZEL_EXTRA_FLAGS}"

View File

@ -49,7 +49,7 @@ deepspeech --version
pushd ${HOME}/DeepSpeech/ds/ pushd ${HOME}/DeepSpeech/ds/
python bin/import_ldc93s1.py data/smoke_test python bin/import_ldc93s1.py data/smoke_test
python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --lm data/smoke_test/vocab.pruned.lm --trie data/smoke_test/vocab.trie --csv data/smoke_test/ldc93s1.csv python evaluate_tflite.py --model "${TASKCLUSTER_TMP_DIR}/${model_name_mmap}" --scorer data/smoke_test/pruned_lm.scorer --csv data/smoke_test/ldc93s1.csv
popd popd
virtualenv_deactivate "${pyalias}" "${PYENV_NAME}" virtualenv_deactivate "${pyalias}" "${PYENV_NAME}"

View File

@ -378,7 +378,7 @@ run_netframework_inference_tests()
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?" assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
set +e set +e
phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(DeepSpeechConsole.exe --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
set -e set -e
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?" assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?"
} }
@ -401,7 +401,7 @@ run_electronjs_inference_tests()
assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?" assert_working_ldc93s1 "${phrase_pbmodel_nolm}" "$?"
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
set -e set -e
assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?" assert_working_ldc93s1_lm "${phrase_pbmodel_withlm}" "$?"
} }
@ -427,7 +427,7 @@ run_basic_inference_tests()
assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$status" assert_correct_ldc93s1 "${phrase_pbmodel_nolm}" "$status"
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm}" "$status" assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm}" "$status"
@ -444,7 +444,7 @@ run_all_inference_tests()
assert_correct_ldc93s1 "${phrase_pbmodel_nolm_stereo_44k}" "$status" assert_correct_ldc93s1 "${phrase_pbmodel_nolm_stereo_44k}" "$status"
set +e set +e
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status" assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"
@ -457,7 +457,7 @@ run_all_inference_tests()
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}" assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"
set +e set +e
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null) phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e set -e
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}" assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
fi; fi;
@ -470,8 +470,7 @@ run_prod_concurrent_stream_tests()
set +e set +e
output=$(python ${TASKCLUSTER_TMP_DIR}/test_sources/concurrent_streams.py \ output=$(python ${TASKCLUSTER_TMP_DIR}/test_sources/concurrent_streams.py \
--model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} \ --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} \
--lm ${TASKCLUSTER_TMP_DIR}/lm.binary \ --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer \
--trie ${TASKCLUSTER_TMP_DIR}/trie \
--audio1 ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_16000.wav \ --audio1 ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_16000.wav \
--audio2 ${TASKCLUSTER_TMP_DIR}/new-home-in-the-stars-16k.wav 2>${TASKCLUSTER_TMP_DIR}/stderr) --audio2 ${TASKCLUSTER_TMP_DIR}/new-home-in-the-stars-16k.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
@ -489,19 +488,19 @@ run_prod_inference_tests()
local _bitrate=$1 local _bitrate=$1
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
set +e set +e
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}"
@ -509,7 +508,7 @@ run_prod_inference_tests()
# Run down-sampling warning test only when we actually perform downsampling # Run down-sampling warning test only when we actually perform downsampling
if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then
set +e set +e
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null) phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e set -e
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}" assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
fi; fi;
@ -520,19 +519,19 @@ run_prodtflite_inference_tests()
local _bitrate=$1 local _bitrate=$1
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
set +e set +e
phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodtflitemodel "${phrase_pbmodel_withlm}" "$status" "${_bitrate}"
set +e set +e
phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr) phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav 2>${TASKCLUSTER_TMP_DIR}/stderr)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_prodtflitemodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}" assert_correct_ldc93s1_prodtflitemodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status" "${_bitrate}"
@ -540,7 +539,7 @@ run_prodtflite_inference_tests()
# Run down-sampling warning test only when we actually perform downsampling # Run down-sampling warning test only when we actually perform downsampling
if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then if [ "${ldc93s1_sample_filename}" != "LDC93S1_pcms16le_1_8000.wav" ]; then
set +e set +e
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null) phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
set -e set -e
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}" assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
fi; fi;
@ -555,7 +554,7 @@ run_multi_inference_tests()
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}" "$status" assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_nolm}" "$status"
set +e -o pipefail set +e -o pipefail
multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/ 2>${TASKCLUSTER_TMP_DIR}/stderr | tr '\n' '%') multi_phrase_pbmodel_withlm=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/ 2>${TASKCLUSTER_TMP_DIR}/stderr | tr '\n' '%')
status=$? status=$?
set -e +o pipefail set -e +o pipefail
assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status" assert_correct_multi_ldc93s1 "${multi_phrase_pbmodel_withlm}" "$status"
@ -564,7 +563,7 @@ run_multi_inference_tests()
run_cpp_only_inference_tests() run_cpp_only_inference_tests()
{ {
set +e set +e
phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1) phrase_pbmodel_withlm_intermediate_decode=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer --audio ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} --stream 1280 2>${TASKCLUSTER_TMP_DIR}/stderr | tail -n 1)
status=$? status=$?
set -e set -e
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status" assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_intermediate_decode}" "$status"
@ -669,8 +668,7 @@ download_data()
${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source}" ${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source}"
${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source_mmap}" ${WGET} -P "${TASKCLUSTER_TMP_DIR}" "${model_source_mmap}"
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/*.wav ${TASKCLUSTER_TMP_DIR}/ cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/*.wav ${TASKCLUSTER_TMP_DIR}/
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.pruned.lm ${TASKCLUSTER_TMP_DIR}/lm.binary cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/pruned_lm.scorer ${TASKCLUSTER_TMP_DIR}/kenlm.scorer
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/smoke_test/vocab.trie ${TASKCLUSTER_TMP_DIR}/trie
cp -R ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/test ${TASKCLUSTER_TMP_DIR}/test_sources cp -R ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/test ${TASKCLUSTER_TMP_DIR}/test_sources
} }
@ -1562,7 +1560,6 @@ package_native_client()
fi; fi;
${TAR} -cf - \ ${TAR} -cf - \
-C ${tensorflow_dir}/bazel-bin/native_client/ generate_trie${PLATFORM_EXE_SUFFIX} \
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so \ -C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so \
-C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so.if.lib \ -C ${tensorflow_dir}/bazel-bin/native_client/ libdeepspeech.so.if.lib \
-C ${deepspeech_dir}/ LICENSE \ -C ${deepspeech_dir}/ LICENSE \
@ -1767,8 +1764,7 @@ android_setup_apk_data()
adb push \ adb push \
${TASKCLUSTER_TMP_DIR}/${model_name} \ ${TASKCLUSTER_TMP_DIR}/${model_name} \
${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \ ${TASKCLUSTER_TMP_DIR}/${ldc93s1_sample_filename} \
${TASKCLUSTER_TMP_DIR}/lm.binary \ ${TASKCLUSTER_TMP_DIR}/kenlm.scorer \
${TASKCLUSTER_TMP_DIR}/trie \
${ANDROID_TMP_DIR}/test/ ${ANDROID_TMP_DIR}/test/
} }

View File

@ -10,7 +10,6 @@ source ${DS_ROOT_TASK}/DeepSpeech/tf/tc-vars.sh
BAZEL_TARGETS=" BAZEL_TARGETS="
//native_client:libdeepspeech.so //native_client:libdeepspeech.so
//native_client:generate_trie
" "
if [ "${package_option}" = "--cuda" ]; then if [ "${package_option}" = "--cuda" ]; then

Some files were not shown because too many files have changed in this diff Show More