STT/native_client/ctcdecode/__init__.py
2021-10-31 20:12:37 +01:00

473 lines
18 KiB
Python

import enum
from . import swigwrapper # pylint: disable=import-self
# This module is built with SWIG_PYTHON_STRICT_BYTE_CHAR so we must handle
# string encoding explicitly, here and throughout this file.
__version__ = swigwrapper.__version__.decode("utf-8")
# Hack: import error codes by matching on their names, as SWIG unfortunately
# does not support binding enums to Python in a scoped manner yet.
for symbol in dir(swigwrapper):
if symbol.startswith("STT_ERR_"):
globals()[symbol] = getattr(swigwrapper, symbol)
class Alphabet(swigwrapper.Alphabet):
"""An Alphabet is a bidirectional map from tokens (eg. characters) to
internal integer representations used by the underlying acoustic models
and external scorers. It can be created from alphabet configuration file
via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`.
"""
def __init__(self, config_path=None):
super(Alphabet, self).__init__()
if config_path:
err = self.init(config_path.encode("utf-8"))
if err != 0:
raise ValueError(
"Alphabet initialization failed with error code 0x{:X}".format(err)
)
def InitFromLabels(self, data):
"""
Initialize Alphabet from a list of labels ``data``. Each label gets
associated with an integer value corresponding to its position in the list.
"""
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
"""
return super(Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(Alphabet, self).Decode(input)
return res.decode("utf-8")
class Scorer(swigwrapper.Scorer):
"""An external scorer is a data structure composed of a language model built
from text data, as well as the vocabulary used in the construction of this
language model and additional parameters related to how the decoding
process uses the external scorer, such as the language model weight
``alpha`` and the word insertion score ``beta``.
:param alpha: Language model weight.
:type alpha: float
:param beta: Word insertion score.
:type beta: float
:param scorer_path: Path to load scorer from.
:type scorer_path: str
:param alphabet: Alphabet object matching the tokens used when creating the
external scorer.
:type alphabet: Alphabet
"""
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
super(Scorer, self).__init__()
# Allow bare initialization
if alphabet:
assert alpha is not None, "alpha parameter is required"
assert beta is not None, "beta parameter is required"
assert scorer_path, "scorer_path parameter is required"
err = self.init(scorer_path.encode("utf-8"), alphabet)
if err != 0:
raise ValueError(
"Scorer initialization failed with error code 0x{:X}".format(err)
)
self.reset_params(alpha, beta)
def ctc_beam_search_decoder(
probs_seq,
alphabet,
beam_size,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1,
):
"""Wrapper for the CTC Beam Search Decoder.
:param probs_seq: 2-D list of probability distributions over each time
step, with each element being a list of normalized
probabilities over alphabet and blank.
:type probs_seq: 2-D list
:param alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param cutoff_prob: Cutoff probability in pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in alphabet will be
used in beam search, default 40.
:type cutoff_top_n: int
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence.
:rtype: list
"""
beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq,
alphabet,
beam_size,
cutoff_prob,
cutoff_top_n,
scorer,
hot_words,
num_results,
)
beam_results = [
(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results
]
return beam_results
def ctc_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
num_processes,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1,
):
"""Wrapper for the batched CTC beam search decoder.
:param probs_seq: 3-D list with each element as an instance of 2-D list
of probabilities used by ctc_beam_search_decoder().
:type probs_seq: 3-D list
:param alphabet: alphabet list.
:alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param cutoff_prob: Cutoff probability in alphabet pruning,
default 1.0, no pruning.
:type cutoff_prob: float
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
characters with highest probs in alphabet will be
used in beam search, default 40.
:type cutoff_top_n: int
:param num_processes: Number of parallel processes.
:type num_processes: int
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: dict[string, float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence.
:rtype: list
"""
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
num_processes,
cutoff_prob,
cutoff_top_n,
scorer,
hot_words,
num_results,
)
batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results
]
return batch_beam_results
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
"""
This class contains constants used to specify the desired behavior for the
:py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch`
functions.
"""
class CriterionType(enum.IntEnum):
"""Constants used to specify which loss criterion was used by the
acoustic model. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Decoder mode for handling acoustic models trained with CTC loss
CTC = swigwrapper.FlashlightDecoderState.CTC
#: Decoder mode for handling acoustic models trained with ASG loss
ASG = swigwrapper.FlashlightDecoderState.ASG
#: Decoder mode for handling acoustic models trained with Seq2seq loss
#: Note: this criterion type is currently not supported.
S2S = swigwrapper.FlashlightDecoderState.S2S
class DecoderType(enum.IntEnum):
"""Constants used to specify if decoder should operate in lexicon mode,
only predicting words present in a fixed vocabulary, or in lexicon-free
mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`.
"""
#: Lexicon mode, only predict words in specified vocabulary.
LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased
#: Lexicon-free mode, allow prediction of any word.
LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree
class TokenType(enum.IntEnum):
"""Constants used to specify the granularity of text units used when training
the external scorer in relation to the text units used when training the
acoustic model. For example, you can have an acoustic model predicting
characters and an external scorer trained on words, or an acoustic model
and an external scorer both trained with sub-word units. If the acoustic
model and the scorer were both trained on the same text unit granularity,
use ``TokenType.Single``. Otherwise, if the external scorer was trained
on a sequence of acoustic model text units, use ``TokenType.Aggregate``.
This class is a Python :py:class:`enum.IntEnum`.
"""
#: Token type for external scorers trained on the same textual units as
#: the acoustic model.
Single = swigwrapper.FlashlightDecoderState.Single
#: Token type for external scorers trained on a sequence of acoustic model
#: textual units.
Aggregate = swigwrapper.FlashlightDecoderState.Aggregate
def flashlight_beam_search_decoder(
logits_seq,
alphabet,
beam_size,
decoder_type,
token_type,
lm_tokens,
scorer=None,
beam_threshold=25.0,
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode acoustic model emissions for a single sample. Note that unlike
:py:func:`ctc_beam_search_decoder`, this function expects raw outputs
from CTC and ASG acoustic models, without softmaxing them over
timesteps.
:param logits_seq: 2-D list of acoustic model emissions, dimensions are
time steps x number of output units.
:type logits_seq: 2-D list of floats or numpy array
:param alphabet: Alphabet object matching the tokens used when creating the
acoustic model and external scorer if specified.
:type alphabet: Alphabet
:param beam_size: Width for beam search.
:type beam_size: int
:param decoder_type: Decoding mode, lexicon-constrained or lexicon-free.
:type decoder_type: FlashlightDecoderState.DecoderType
:param token_type: Type of token in the external scorer.
:type token_type: FlashlightDecoderState.TokenType
:param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained
mode. Must match the token type used in the scorer, ie.
must be a list of characters if scorer is character-based,
or a list of words if scorer is word-based.
:param lm_tokens: list[str]
:param scorer: External scorer.
:type scorer: Scorer
:param beam_threshold: Maximum threshold in beam score from leading beam. Any
newly created candidate beams which lag behind the best
beam so far by more than this value will get pruned.
This is a performance optimization parameter and an
appropriate value should be found empirically using a
validation set.
:type beam_threshold: float
:param cutoff_top_n: Maximum number of tokens to expand per time step during
decoding. Only the highest probability cutoff_top_n
candidates (characters, sub-word units, words) in a given
timestep will be expanded. This is a performance
optimization parameter and an appropriate value should
be found empirically using a validation set.
:type cutoff_top_n: int
:param silence_score: Score to add to beam when encountering a predicted
silence token (eg. the space symbol).
:type silence_score: float
:param merge_with_log_add: Whether to use log-add when merging scores of
new candidate beams equivalent to existing ones
(leading to the same transcription). When disabled,
the maximum score is used.
:type merge_with_log_add: bool
:param criterion_type: Criterion used for training the acoustic model.
:type criterion_type: FlashlightDecoderState.CriterionType
:param transitions: Transition score matrix for ASG acoustic models.
:type transitions: list[float]
:param num_results: Number of beams to return.
:type num_results: int
:return: List of FlashlightOutput structures.
:rtype: list[FlashlightOutput]
"""
return swigwrapper.flashlight_beam_search_decoder(
logits_seq,
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
scorer,
token_type,
lm_tokens,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions,
num_results,
)
def flashlight_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
decoder_type,
token_type,
lm_tokens,
num_processes,
scorer=None,
beam_threshold=25.0,
cutoff_top_n=40,
silence_score=0.0,
merge_with_log_add=False,
criterion_type=FlashlightDecoderState.CriterionType.CTC,
transitions=[],
num_results=1,
):
"""Decode batch acoustic model emissions in parallel. ``num_processes``
controls how many samples from the batch will be decoded simultaneously.
All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`.
Returns a list of lists of FlashlightOutput structures.
"""
return swigwrapper.flashlight_beam_search_decoder_batch(
probs_seq,
seq_lengths,
alphabet,
beam_size,
beam_threshold,
cutoff_top_n,
scorer,
token_type,
lm_tokens,
decoder_type,
silence_score,
merge_with_log_add,
criterion_type,
transitions,
num_results,
num_processes,
)
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
"""Alphabet class representing 255 possible byte values for Bytes Output Mode.
For internal use only.
"""
def __init__(self):
super(UTF8Alphabet, self).__init__()
err = self.init(b"")
if err != 0:
raise ValueError(
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
)
def CanEncodeSingle(self, input):
"""
Returns true if the single character/output class has a corresponding label
in the alphabet.
"""
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
def CanEncode(self, input):
"""
Returns true if the entire string can be encoded into labels in this
alphabet.
"""
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
def EncodeSingle(self, input):
"""
Encode a single character/output class into a label. Character must be in
the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test.
"""
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
def Encode(self, input):
"""
Encode a sequence of character/output classes into a sequence of labels.
Characters are assumed to always take a single Unicode codepoint.
Characters must be in the alphabet, this method will assert that. Use
``CanEncode`` and ``CanEncodeSingle`` to test.
"""
# Convert SWIG's UnsignedIntVec to a Python list
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
return [el for el in res]
def DecodeSingle(self, input):
res = super(UTF8Alphabet, self).DecodeSingle(input)
return res.decode("utf-8")
def Decode(self, input):
"""Decode a sequence of labels into a string."""
res = super(UTF8Alphabet, self).Decode(input)
return res.decode("utf-8")