Improve decoder package docs and include in RTD

This commit is contained in:
Reuben Morais 2021-10-31 19:37:40 +01:00
parent 91f1307de4
commit 90feb63894
7 changed files with 246 additions and 97 deletions

View File

@ -1,14 +1,14 @@
.. _decoder-docs:
CTC beam search decoder
=======================
Beam search decoder
===================
Introduction
------------
🐸STT uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems.
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`.
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. Documentation for the coqui_stt_ctcdecoder Python package used by the training code for decoding is available in :ref:`decoder-api`.
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.

7
doc/Decoder-API.rst Normal file
View File

@ -0,0 +1,7 @@
.. _decoder-api:
Decoder API reference
=====================
.. automodule:: native_client.ctcdecode
:members:

View File

@ -24,7 +24,8 @@ import sys
sys.path.insert(0, os.path.abspath("../"))
autodoc_mock_imports = ["stt"]
autodoc_mock_imports = ["stt", "native_client.ctcdecode.swigwrapper"]
autodoc_member_order = "bysource"
# This is in fact only relevant on ReadTheDocs, but we want to run the same way
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
@ -128,7 +129,6 @@ todo_include_todos = False
add_module_names = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for

View File

@ -89,6 +89,17 @@ The fastest way to use a pre-trained 🐸STT model is with the 🐸STT model man
playbook/README
.. toctree::
:maxdepth: 1
:caption: Advanced topics
DECODER
Decoder-API
PARALLEL_OPTIMIZATION
Indices and tables
==================

View File

@ -1,4 +1,4 @@
from __future__ import absolute_import, division, print_function
import enum
from . import swigwrapper # pylint: disable=import-self
@ -13,41 +13,12 @@ for symbol in dir(swigwrapper):
globals()[symbol] = getattr(swigwrapper, symbol)
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
pass
class Scorer(swigwrapper.Scorer):
"""Wrapper for Scorer.
:param alpha: Language model weight.
:type alpha: float
:param beta: Word insertion bonus.
:type beta: float
:scorer_path: Path to load scorer from.
:alphabet: Alphabet
:type scorer_path: basestring
"""
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__()
# Allow bare initialization
if alphabet:
assert alpha is not None, "alpha parameter is required"
assert beta is not None, "beta parameter is required"
assert scorer_path, "scorer_path parameter is required"
err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0:
raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
self.reset_params(alpha, beta)
class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
"""An Alphabet is a bidirectional map from tokens (eg. characters) to
internal integer representations used by the underlying acoustic models
and external scorers. It can be created from alphabet configuration file
via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`.
"""
def __init__(self, config_path=None):
super(Alphabet, self).__init__()
@ -59,6 +30,10 @@ class Alphabet(swigwrapper.Alphabet):
)
def InitFromLabels(self, data):
"""
Initialize Alphabet from a list of labels ``data``. Each label gets
associated with an integer value corresponding to its position in the list.
"""
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
def CanEncodeSingle(self, input):
@ -87,7 +62,7 @@ class Alphabet(swigwrapper.Alphabet):
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test.
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode("utf-8"))
@ -103,57 +78,39 @@ class Alphabet(swigwrapper.Alphabet):
return res.decode("utf-8")
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
class Scorer(swigwrapper.Scorer):
"""An external scorer is a data structure composed of a language model built
from text data, as well as the vocabulary used in the construction of this
language model and additional parameters related to how the decoding
process uses the external scorer, such as the language model weight
``alpha`` and the word insertion score ``beta``.
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b"")
if err != 0:
raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
:param alpha: Language model weight.
:type alpha: float
:param beta: Word insertion score.
:type beta: float
:param scorer_path: Path to load scorer from.
:type scorer_path: str
:param alphabet: Alphabet object matching the tokens used when creating the
external scorer.
:type alphabet: Alphabet
"""
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__()
# Allow bare initialization
if alphabet:
assert alpha is not None, "alpha parameter is required"
assert beta is not None, "beta parameter is required"
assert scorer_path, "scorer_path parameter is required"
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0:
raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
"""
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
`CanEncode` and `CanEncodeSingle` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input)
return res.decode("utf-8")
self.reset_params(alpha, beta)
def ctc_beam_search_decoder(
@ -186,7 +143,7 @@ def ctc_beam_search_decoder(
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
@ -245,7 +202,7 @@ def ctc_beam_search_decoder_batch(
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
@ -271,8 +228,63 @@ def ctc_beam_search_decoder_batch(
return batch_beam_results
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
"""
This class contains constants used to specify the desired behavior for the
:py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch`
functions.
"""
class CriterionType(enum.IntEnum):
"""Constants used to specify which loss criterion was used by the
acoustic model. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Decoder mode for handling acoustic models trained with CTC loss
CTC = swigwrapper.FlashlightDecoderState.CTC
#: Decoder mode for handling acoustic models trained with ASG loss
ASG = swigwrapper.FlashlightDecoderState.ASG
#: Decoder mode for handling acoustic models trained with Seq2seq loss
#: Note: this criterion type is currently not supported.
S2S = swigwrapper.FlashlightDecoderState.S2S
class DecoderType(enum.IntEnum):
"""Constants used to specify if decoder should operate in lexicon mode,
only predicting words present in a fixed vocabulary, or in lexicon-free
mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Lexicon mode, only predict words in specified vocabulary.
LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased
#: Lexicon-free mode, allow prediction of any word.
LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree
class TokenType(enum.IntEnum):
"""Constants used to specify the granularity of text units used when training
the external scorer in relation to the text units used when training the
acoustic model. For example, you can have an acoustic model predicting
characters and an external scorer trained on words, or an acoustic model
and an external scorer both trained with sub-word units. If the acoustic
model and the scorer were both trained on the same text unit granularity,
use ``TokenType.Single``. Otherwise, if the external scorer was trained
on a sequence of acoustic model text units, use ``TokenType.Aggregate``.
This class is a Python :py:class:`enum.IntEnum`.
"""
#: Token type for external scorers trained on the same textual units as
#: the acoustic model.
Single = swigwrapper.FlashlightDecoderState.Single
#: Token type for external scorers trained on a sequence of acoustic model
#: textual units.
Aggregate = swigwrapper.FlashlightDecoderState.Aggregate
def flashlight_beam_search_decoder(
probs_seq,
logits_seq,
alphabet,
beam_size,
decoder_type,
@ -283,12 +295,67 @@ def flashlight_beam_search_decoder(
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode acoustic model emissions for a single sample. Note that unlike
:py:func:`ctc_beam_search_decoder`, this function expects raw outputs
from CTC and ASG acoustic models, without softmaxing them over
timesteps.
:param logits_seq: 2-D list of acoustic model emissions, dimensions are
time steps x number of output units.
:type logits_seq: 2-D list of floats or numpy array
:param alphabet: Alphabet object matching the tokens used when creating the
acoustic model and external scorer if specified.
:type alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param decoder_type: Decoding mode, lexicon-constrained or lexicon-free.
:type decoder_type: FlashlightDecoderState.DecoderType
:param token_type: Type of token in the external scorer.
:type token_type: FlashlightDecoderState.TokenType
:param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained
mode. Must match the token type used in the scorer, ie.
must be a list of characters if scorer is character-based,
or a list of words if scorer is word-based.
:param lm_tokens: list[str]
:param scorer: External scorer.
:type scorer: Scorer
:param beam_threshold: Maximum threshold in beam score from leading beam. Any
newly created candidate beams which lag behind the best
beam so far by more than this value will get pruned.
This is a performance optimization parameter and an
appropriate value should be found empirically using a
validation set.
:type beam_threshold: float
:param cutoff_top_n: Maximum number of tokens to expand per time step during
decoding. Only the highest probability cutoff_top_n
candidates (characters, sub-word units, words) in a given
timestep will be expanded. This is a performance
optimization parameter and an appropriate value should
be found empirically using a validation set.
:type cutoff_top_n: int
:param silence_score: Score to add to beam when encountering a predicted
silence token (eg. the space symbol).
:type silence_score: float
:param merge_with_log_add: Whether to use log-add when merging scores of
new candidate beams equivalent to existing ones
(leading to the same transcription). When disabled,
the maximum score is used.
:type merge_with_log_add: bool
:param criterion_type: Criterion used for training the acoustic model.
:type criterion_type: FlashlightDecoderState.CriterionType
:param transitions: Transition score matrix for ASG acoustic models.
:type transitions: list[float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of FlashlightOutput structures.
:rtype: list[FlashlightOutput]
"""
return swigwrapper.flashlight_beam_search_decoder(
probs_seq,
logits_seq,
alphabet,
beam_size,
beam_threshold,
@ -319,10 +386,17 @@ def flashlight_beam_search_decoder_batch(
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode batch acoustic model emissions in parallel. ``num_processes``
controls how many samples from the batch will be decoded simultaneously.
All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`.
Returns a list of lists of FlashlightOutput structures.
"""
return swigwrapper.flashlight_beam_search_decoder_batch(
probs_seq,
seq_lengths,
@ -341,3 +415,58 @@ def flashlight_beam_search_decoder_batch(
num_results,
num_processes,
)
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Alphabet class representing 255 possible byte values for Bytes Output Mode.
For internal use only.
"""
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b"")
if err != 0:
raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test.
"""
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input)
return res.decode("utf-8")

View File

@ -96,7 +96,9 @@ class BuildExtFirst(build):
setup(
name="coqui_stt_ctcdecoder",
version=project_version,
description="""DS CTC decoder""",
description="Coqui STT Python decoder package.",
long_description="Documentation available at `stt.readthedocs.io <https://stt.readthedocs.io/en/latest/Decoder-API.html>`_",
long_description_content_type="text/x-rst; charset=UTF-8",
cmdclass={"build": BuildExtFirst},
ext_modules=[decoder_module],
package_dir={"coqui_stt_ctcdecoder": "."},

View File

@ -142,8 +142,8 @@ def evaluate(test_csvs, create_model):
batch_lengths,
Config.alphabet,
beam_size=Config.export_beam_width,
decoder_type=FlashlightDecoderState.LexiconBased,
token_type=FlashlightDecoderState.Aggregate,
decoder_type=FlashlightDecoderState.DecoderType.LexiconBased,
token_type=FlashlightDecoderState.TokenType.Aggregate,
lm_tokens=vocab,
num_processes=num_processes,
scorer=scorer,