Improve decoder package docs and include in RTD
This commit is contained in:
parent
91f1307de4
commit
90feb63894
@ -1,14 +1,14 @@
|
||||
.. _decoder-docs:
|
||||
|
||||
CTC beam search decoder
|
||||
=======================
|
||||
Beam search decoder
|
||||
===================
|
||||
|
||||
Introduction
|
||||
------------
|
||||
|
||||
🐸STT uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems.
|
||||
|
||||
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`.
|
||||
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. Documentation for the coqui_stt_ctcdecoder Python package used by the training code for decoding is available in :ref:`decoder-api`.
|
||||
|
||||
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.
|
||||
|
||||
|
7
doc/Decoder-API.rst
Normal file
7
doc/Decoder-API.rst
Normal file
@ -0,0 +1,7 @@
|
||||
.. _decoder-api:
|
||||
|
||||
Decoder API reference
|
||||
=====================
|
||||
|
||||
.. automodule:: native_client.ctcdecode
|
||||
:members:
|
@ -24,7 +24,8 @@ import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../"))
|
||||
|
||||
autodoc_mock_imports = ["stt"]
|
||||
autodoc_mock_imports = ["stt", "native_client.ctcdecode.swigwrapper"]
|
||||
autodoc_member_order = "bysource"
|
||||
|
||||
# This is in fact only relevant on ReadTheDocs, but we want to run the same way
|
||||
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
|
||||
@ -128,7 +129,6 @@ todo_include_todos = False
|
||||
|
||||
add_module_names = False
|
||||
|
||||
|
||||
# -- Options for HTML output ----------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
|
@ -89,6 +89,17 @@ The fastest way to use a pre-trained 🐸STT model is with the 🐸STT model man
|
||||
|
||||
playbook/README
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Advanced topics
|
||||
|
||||
DECODER
|
||||
|
||||
Decoder-API
|
||||
|
||||
PARALLEL_OPTIMIZATION
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import enum
|
||||
|
||||
from . import swigwrapper # pylint: disable=import-self
|
||||
|
||||
@ -13,41 +13,12 @@ for symbol in dir(swigwrapper):
|
||||
globals()[symbol] = getattr(swigwrapper, symbol)
|
||||
|
||||
|
||||
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
|
||||
pass
|
||||
|
||||
|
||||
class Scorer(swigwrapper.Scorer):
|
||||
"""Wrapper for Scorer.
|
||||
|
||||
:param alpha: Language model weight.
|
||||
:type alpha: float
|
||||
:param beta: Word insertion bonus.
|
||||
:type beta: float
|
||||
:scorer_path: Path to load scorer from.
|
||||
:alphabet: Alphabet
|
||||
:type scorer_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
|
||||
super(Scorer, self).__init__()
|
||||
# Allow bare initialization
|
||||
if alphabet:
|
||||
assert alpha is not None, "alpha parameter is required"
|
||||
assert beta is not None, "beta parameter is required"
|
||||
assert scorer_path, "scorer_path parameter is required"
|
||||
|
||||
err = self.init(scorer_path.encode("utf-8"), alphabet)
|
||||
if err != 0:
|
||||
raise ValueError(
|
||||
"Scorer initialization failed with error code 0x{:X}".format(err)
|
||||
)
|
||||
|
||||
self.reset_params(alpha, beta)
|
||||
|
||||
|
||||
class Alphabet(swigwrapper.Alphabet):
|
||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||
"""An Alphabet is a bidirectional map from tokens (eg. characters) to
|
||||
internal integer representations used by the underlying acoustic models
|
||||
and external scorers. It can be created from alphabet configuration file
|
||||
via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`.
|
||||
"""
|
||||
|
||||
def __init__(self, config_path=None):
|
||||
super(Alphabet, self).__init__()
|
||||
@ -59,6 +30,10 @@ class Alphabet(swigwrapper.Alphabet):
|
||||
)
|
||||
|
||||
def InitFromLabels(self, data):
|
||||
"""
|
||||
Initialize Alphabet from a list of labels ``data``. Each label gets
|
||||
associated with an integer value corresponding to its position in the list.
|
||||
"""
|
||||
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
|
||||
|
||||
def CanEncodeSingle(self, input):
|
||||
@ -87,7 +62,7 @@ class Alphabet(swigwrapper.Alphabet):
|
||||
Encode a sequence of character/output classes into a sequence of labels.
|
||||
Characters are assumed to always take a single Unicode codepoint.
|
||||
Characters must be in the alphabet, this method will assert that. Use
|
||||
`CanEncode` and `CanEncodeSingle` to test.
|
||||
``CanEncode`` and ``CanEncodeSingle`` to test.
|
||||
"""
|
||||
# Convert SWIG's UnsignedIntVec to a Python list
|
||||
res = super(Alphabet, self).Encode(input.encode("utf-8"))
|
||||
@ -103,57 +78,39 @@ class Alphabet(swigwrapper.Alphabet):
|
||||
return res.decode("utf-8")
|
||||
|
||||
|
||||
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
|
||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||
class Scorer(swigwrapper.Scorer):
|
||||
"""An external scorer is a data structure composed of a language model built
|
||||
from text data, as well as the vocabulary used in the construction of this
|
||||
language model and additional parameters related to how the decoding
|
||||
process uses the external scorer, such as the language model weight
|
||||
``alpha`` and the word insertion score ``beta``.
|
||||
|
||||
def __init__(self):
|
||||
super(UTF8Alphabet, self).__init__()
|
||||
err = self.init(b"")
|
||||
:param alpha: Language model weight.
|
||||
:type alpha: float
|
||||
:param beta: Word insertion score.
|
||||
:type beta: float
|
||||
:param scorer_path: Path to load scorer from.
|
||||
:type scorer_path: str
|
||||
:param alphabet: Alphabet object matching the tokens used when creating the
|
||||
external scorer.
|
||||
:type alphabet: Alphabet
|
||||
"""
|
||||
|
||||
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
|
||||
super(Scorer, self).__init__()
|
||||
# Allow bare initialization
|
||||
if alphabet:
|
||||
assert alpha is not None, "alpha parameter is required"
|
||||
assert beta is not None, "beta parameter is required"
|
||||
assert scorer_path, "scorer_path parameter is required"
|
||||
|
||||
err = self.init(scorer_path.encode("utf-8"), alphabet)
|
||||
if err != 0:
|
||||
raise ValueError(
|
||||
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||
"Scorer initialization failed with error code 0x{:X}".format(err)
|
||||
)
|
||||
|
||||
def CanEncodeSingle(self, input):
|
||||
"""
|
||||
Returns true if the single character/output class has a corresponding label
|
||||
in the alphabet.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
|
||||
|
||||
def CanEncode(self, input):
|
||||
"""
|
||||
Returns true if the entire string can be encoded into labels in this
|
||||
alphabet.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
|
||||
|
||||
def EncodeSingle(self, input):
|
||||
"""
|
||||
Encode a single character/output class into a label. Character must be in
|
||||
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
|
||||
|
||||
def Encode(self, input):
|
||||
"""
|
||||
Encode a sequence of character/output classes into a sequence of labels.
|
||||
Characters are assumed to always take a single Unicode codepoint.
|
||||
Characters must be in the alphabet, this method will assert that. Use
|
||||
`CanEncode` and `CanEncodeSingle` to test.
|
||||
"""
|
||||
# Convert SWIG's UnsignedIntVec to a Python list
|
||||
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
|
||||
return [el for el in res]
|
||||
|
||||
def DecodeSingle(self, input):
|
||||
res = super(UTF8Alphabet, self).DecodeSingle(input)
|
||||
return res.decode("utf-8")
|
||||
|
||||
def Decode(self, input):
|
||||
"""Decode a sequence of labels into a string."""
|
||||
res = super(UTF8Alphabet, self).Decode(input)
|
||||
return res.decode("utf-8")
|
||||
self.reset_params(alpha, beta)
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(
|
||||
@ -186,7 +143,7 @@ def ctc_beam_search_decoder(
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||
:type hot_words: map{string:float}
|
||||
:type hot_words: dict[string, float]
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
@ -245,7 +202,7 @@ def ctc_beam_search_decoder_batch(
|
||||
count or language model.
|
||||
:type scorer: Scorer
|
||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||
:type hot_words: map{string:float}
|
||||
:type hot_words: dict[string, float]
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of tuples of confidence and sentence as decoding
|
||||
@ -271,8 +228,63 @@ def ctc_beam_search_decoder_batch(
|
||||
return batch_beam_results
|
||||
|
||||
|
||||
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
|
||||
"""
|
||||
This class contains constants used to specify the desired behavior for the
|
||||
:py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch`
|
||||
functions.
|
||||
"""
|
||||
|
||||
class CriterionType(enum.IntEnum):
|
||||
"""Constants used to specify which loss criterion was used by the
|
||||
acoustic model. This class is a Python :py:class:`enum.IntEnum`.
|
||||
"""
|
||||
|
||||
#: Decoder mode for handling acoustic models trained with CTC loss
|
||||
CTC = swigwrapper.FlashlightDecoderState.CTC
|
||||
|
||||
#: Decoder mode for handling acoustic models trained with ASG loss
|
||||
ASG = swigwrapper.FlashlightDecoderState.ASG
|
||||
|
||||
#: Decoder mode for handling acoustic models trained with Seq2seq loss
|
||||
#: Note: this criterion type is currently not supported.
|
||||
S2S = swigwrapper.FlashlightDecoderState.S2S
|
||||
|
||||
class DecoderType(enum.IntEnum):
|
||||
"""Constants used to specify if decoder should operate in lexicon mode,
|
||||
only predicting words present in a fixed vocabulary, or in lexicon-free
|
||||
mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`.
|
||||
"""
|
||||
|
||||
#: Lexicon mode, only predict words in specified vocabulary.
|
||||
LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased
|
||||
|
||||
#: Lexicon-free mode, allow prediction of any word.
|
||||
LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree
|
||||
|
||||
class TokenType(enum.IntEnum):
|
||||
"""Constants used to specify the granularity of text units used when training
|
||||
the external scorer in relation to the text units used when training the
|
||||
acoustic model. For example, you can have an acoustic model predicting
|
||||
characters and an external scorer trained on words, or an acoustic model
|
||||
and an external scorer both trained with sub-word units. If the acoustic
|
||||
model and the scorer were both trained on the same text unit granularity,
|
||||
use ``TokenType.Single``. Otherwise, if the external scorer was trained
|
||||
on a sequence of acoustic model text units, use ``TokenType.Aggregate``.
|
||||
This class is a Python :py:class:`enum.IntEnum`.
|
||||
"""
|
||||
|
||||
#: Token type for external scorers trained on the same textual units as
|
||||
#: the acoustic model.
|
||||
Single = swigwrapper.FlashlightDecoderState.Single
|
||||
|
||||
#: Token type for external scorers trained on a sequence of acoustic model
|
||||
#: textual units.
|
||||
Aggregate = swigwrapper.FlashlightDecoderState.Aggregate
|
||||
|
||||
|
||||
def flashlight_beam_search_decoder(
|
||||
probs_seq,
|
||||
logits_seq,
|
||||
alphabet,
|
||||
beam_size,
|
||||
decoder_type,
|
||||
@ -283,12 +295,67 @@ def flashlight_beam_search_decoder(
|
||||
cutoff_top_n=40,
|
||||
silence_score=0.0,
|
||||
merge_with_log_add=False,
|
||||
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
|
||||
criterion_type=FlashlightDecoderState.CriterionType.CTC,
|
||||
transitions=[],
|
||||
num_results=1,
|
||||
):
|
||||
"""Decode acoustic model emissions for a single sample. Note that unlike
|
||||
:py:func:`ctc_beam_search_decoder`, this function expects raw outputs
|
||||
from CTC and ASG acoustic models, without softmaxing them over
|
||||
timesteps.
|
||||
|
||||
:param logits_seq: 2-D list of acoustic model emissions, dimensions are
|
||||
time steps x number of output units.
|
||||
:type logits_seq: 2-D list of floats or numpy array
|
||||
:param alphabet: Alphabet object matching the tokens used when creating the
|
||||
acoustic model and external scorer if specified.
|
||||
:type alphabet: Alphabet
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param decoder_type: Decoding mode, lexicon-constrained or lexicon-free.
|
||||
:type decoder_type: FlashlightDecoderState.DecoderType
|
||||
:param token_type: Type of token in the external scorer.
|
||||
:type token_type: FlashlightDecoderState.TokenType
|
||||
:param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained
|
||||
mode. Must match the token type used in the scorer, ie.
|
||||
must be a list of characters if scorer is character-based,
|
||||
or a list of words if scorer is word-based.
|
||||
:param lm_tokens: list[str]
|
||||
:param scorer: External scorer.
|
||||
:type scorer: Scorer
|
||||
:param beam_threshold: Maximum threshold in beam score from leading beam. Any
|
||||
newly created candidate beams which lag behind the best
|
||||
beam so far by more than this value will get pruned.
|
||||
This is a performance optimization parameter and an
|
||||
appropriate value should be found empirically using a
|
||||
validation set.
|
||||
:type beam_threshold: float
|
||||
:param cutoff_top_n: Maximum number of tokens to expand per time step during
|
||||
decoding. Only the highest probability cutoff_top_n
|
||||
candidates (characters, sub-word units, words) in a given
|
||||
timestep will be expanded. This is a performance
|
||||
optimization parameter and an appropriate value should
|
||||
be found empirically using a validation set.
|
||||
:type cutoff_top_n: int
|
||||
:param silence_score: Score to add to beam when encountering a predicted
|
||||
silence token (eg. the space symbol).
|
||||
:type silence_score: float
|
||||
:param merge_with_log_add: Whether to use log-add when merging scores of
|
||||
new candidate beams equivalent to existing ones
|
||||
(leading to the same transcription). When disabled,
|
||||
the maximum score is used.
|
||||
:type merge_with_log_add: bool
|
||||
:param criterion_type: Criterion used for training the acoustic model.
|
||||
:type criterion_type: FlashlightDecoderState.CriterionType
|
||||
:param transitions: Transition score matrix for ASG acoustic models.
|
||||
:type transitions: list[float]
|
||||
:param num_results: Number of beams to return.
|
||||
:type num_results: int
|
||||
:return: List of FlashlightOutput structures.
|
||||
:rtype: list[FlashlightOutput]
|
||||
"""
|
||||
return swigwrapper.flashlight_beam_search_decoder(
|
||||
probs_seq,
|
||||
logits_seq,
|
||||
alphabet,
|
||||
beam_size,
|
||||
beam_threshold,
|
||||
@ -319,10 +386,17 @@ def flashlight_beam_search_decoder_batch(
|
||||
cutoff_top_n=40,
|
||||
silence_score=0.0,
|
||||
merge_with_log_add=False,
|
||||
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
|
||||
criterion_type=FlashlightDecoderState.CriterionType.CTC,
|
||||
transitions=[],
|
||||
num_results=1,
|
||||
):
|
||||
"""Decode batch acoustic model emissions in parallel. ``num_processes``
|
||||
controls how many samples from the batch will be decoded simultaneously.
|
||||
All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`.
|
||||
|
||||
Returns a list of lists of FlashlightOutput structures.
|
||||
"""
|
||||
|
||||
return swigwrapper.flashlight_beam_search_decoder_batch(
|
||||
probs_seq,
|
||||
seq_lengths,
|
||||
@ -341,3 +415,58 @@ def flashlight_beam_search_decoder_batch(
|
||||
num_results,
|
||||
num_processes,
|
||||
)
|
||||
|
||||
|
||||
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
|
||||
"""Alphabet class representing 255 possible byte values for Bytes Output Mode.
|
||||
For internal use only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(UTF8Alphabet, self).__init__()
|
||||
err = self.init(b"")
|
||||
if err != 0:
|
||||
raise ValueError(
|
||||
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||
)
|
||||
|
||||
def CanEncodeSingle(self, input):
|
||||
"""
|
||||
Returns true if the single character/output class has a corresponding label
|
||||
in the alphabet.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
|
||||
|
||||
def CanEncode(self, input):
|
||||
"""
|
||||
Returns true if the entire string can be encoded into labels in this
|
||||
alphabet.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
|
||||
|
||||
def EncodeSingle(self, input):
|
||||
"""
|
||||
Encode a single character/output class into a label. Character must be in
|
||||
the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test.
|
||||
"""
|
||||
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
|
||||
|
||||
def Encode(self, input):
|
||||
"""
|
||||
Encode a sequence of character/output classes into a sequence of labels.
|
||||
Characters are assumed to always take a single Unicode codepoint.
|
||||
Characters must be in the alphabet, this method will assert that. Use
|
||||
``CanEncode`` and ``CanEncodeSingle`` to test.
|
||||
"""
|
||||
# Convert SWIG's UnsignedIntVec to a Python list
|
||||
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
|
||||
return [el for el in res]
|
||||
|
||||
def DecodeSingle(self, input):
|
||||
res = super(UTF8Alphabet, self).DecodeSingle(input)
|
||||
return res.decode("utf-8")
|
||||
|
||||
def Decode(self, input):
|
||||
"""Decode a sequence of labels into a string."""
|
||||
res = super(UTF8Alphabet, self).Decode(input)
|
||||
return res.decode("utf-8")
|
||||
|
@ -96,7 +96,9 @@ class BuildExtFirst(build):
|
||||
setup(
|
||||
name="coqui_stt_ctcdecoder",
|
||||
version=project_version,
|
||||
description="""DS CTC decoder""",
|
||||
description="Coqui STT Python decoder package.",
|
||||
long_description="Documentation available at `stt.readthedocs.io <https://stt.readthedocs.io/en/latest/Decoder-API.html>`_",
|
||||
long_description_content_type="text/x-rst; charset=UTF-8",
|
||||
cmdclass={"build": BuildExtFirst},
|
||||
ext_modules=[decoder_module],
|
||||
package_dir={"coqui_stt_ctcdecoder": "."},
|
||||
|
@ -142,8 +142,8 @@ def evaluate(test_csvs, create_model):
|
||||
batch_lengths,
|
||||
Config.alphabet,
|
||||
beam_size=Config.export_beam_width,
|
||||
decoder_type=FlashlightDecoderState.LexiconBased,
|
||||
token_type=FlashlightDecoderState.Aggregate,
|
||||
decoder_type=FlashlightDecoderState.DecoderType.LexiconBased,
|
||||
token_type=FlashlightDecoderState.TokenType.Aggregate,
|
||||
lm_tokens=vocab,
|
||||
num_processes=num_processes,
|
||||
scorer=scorer,
|
||||
|
Loading…
Reference in New Issue
Block a user