Improve decoder package docs and include in RTD
This commit is contained in:
parent
91f1307de4
commit
90feb63894
@ -1,14 +1,14 @@
|
|||||||
.. _decoder-docs:
|
.. _decoder-docs:
|
||||||
|
|
||||||
CTC beam search decoder
|
Beam search decoder
|
||||||
=======================
|
===================
|
||||||
|
|
||||||
Introduction
|
Introduction
|
||||||
------------
|
------------
|
||||||
|
|
||||||
🐸STT uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems.
|
🐸STT uses the `Connectionist Temporal Classification <http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC <https://distill.pub/2017/ctc/>`_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems.
|
||||||
|
|
||||||
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`.
|
Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. Documentation for the coqui_stt_ctcdecoder Python package used by the training code for decoding is available in :ref:`decoder-api`.
|
||||||
|
|
||||||
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.
|
The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 <https://tools.ietf.org/html/bcp14>`_ when, and only when, they appear in all capitals, as shown here.
|
||||||
|
|
||||||
|
7
doc/Decoder-API.rst
Normal file
7
doc/Decoder-API.rst
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
.. _decoder-api:
|
||||||
|
|
||||||
|
Decoder API reference
|
||||||
|
=====================
|
||||||
|
|
||||||
|
.. automodule:: native_client.ctcdecode
|
||||||
|
:members:
|
@ -24,7 +24,8 @@ import sys
|
|||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../"))
|
sys.path.insert(0, os.path.abspath("../"))
|
||||||
|
|
||||||
autodoc_mock_imports = ["stt"]
|
autodoc_mock_imports = ["stt", "native_client.ctcdecode.swigwrapper"]
|
||||||
|
autodoc_member_order = "bysource"
|
||||||
|
|
||||||
# This is in fact only relevant on ReadTheDocs, but we want to run the same way
|
# This is in fact only relevant on ReadTheDocs, but we want to run the same way
|
||||||
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
|
# on our CI as in RTD to avoid regressions on RTD that we would not catch on CI
|
||||||
@ -128,7 +129,6 @@ todo_include_todos = False
|
|||||||
|
|
||||||
add_module_names = False
|
add_module_names = False
|
||||||
|
|
||||||
|
|
||||||
# -- Options for HTML output ----------------------------------------------
|
# -- Options for HTML output ----------------------------------------------
|
||||||
|
|
||||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
@ -89,6 +89,17 @@ The fastest way to use a pre-trained 🐸STT model is with the 🐸STT model man
|
|||||||
|
|
||||||
playbook/README
|
playbook/README
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: Advanced topics
|
||||||
|
|
||||||
|
DECODER
|
||||||
|
|
||||||
|
Decoder-API
|
||||||
|
|
||||||
|
PARALLEL_OPTIMIZATION
|
||||||
|
|
||||||
|
|
||||||
Indices and tables
|
Indices and tables
|
||||||
==================
|
==================
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
import enum
|
||||||
|
|
||||||
from . import swigwrapper # pylint: disable=import-self
|
from . import swigwrapper # pylint: disable=import-self
|
||||||
|
|
||||||
@ -13,41 +13,12 @@ for symbol in dir(swigwrapper):
|
|||||||
globals()[symbol] = getattr(swigwrapper, symbol)
|
globals()[symbol] = getattr(swigwrapper, symbol)
|
||||||
|
|
||||||
|
|
||||||
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Scorer(swigwrapper.Scorer):
|
|
||||||
"""Wrapper for Scorer.
|
|
||||||
|
|
||||||
:param alpha: Language model weight.
|
|
||||||
:type alpha: float
|
|
||||||
:param beta: Word insertion bonus.
|
|
||||||
:type beta: float
|
|
||||||
:scorer_path: Path to load scorer from.
|
|
||||||
:alphabet: Alphabet
|
|
||||||
:type scorer_path: basestring
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
|
|
||||||
super(Scorer, self).__init__()
|
|
||||||
# Allow bare initialization
|
|
||||||
if alphabet:
|
|
||||||
assert alpha is not None, "alpha parameter is required"
|
|
||||||
assert beta is not None, "beta parameter is required"
|
|
||||||
assert scorer_path, "scorer_path parameter is required"
|
|
||||||
|
|
||||||
err = self.init(scorer_path.encode("utf-8"), alphabet)
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Scorer initialization failed with error code 0x{:X}".format(err)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.reset_params(alpha, beta)
|
|
||||||
|
|
||||||
|
|
||||||
class Alphabet(swigwrapper.Alphabet):
|
class Alphabet(swigwrapper.Alphabet):
|
||||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
"""An Alphabet is a bidirectional map from tokens (eg. characters) to
|
||||||
|
internal integer representations used by the underlying acoustic models
|
||||||
|
and external scorers. It can be created from alphabet configuration file
|
||||||
|
via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config_path=None):
|
def __init__(self, config_path=None):
|
||||||
super(Alphabet, self).__init__()
|
super(Alphabet, self).__init__()
|
||||||
@ -59,6 +30,10 @@ class Alphabet(swigwrapper.Alphabet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def InitFromLabels(self, data):
|
def InitFromLabels(self, data):
|
||||||
|
"""
|
||||||
|
Initialize Alphabet from a list of labels ``data``. Each label gets
|
||||||
|
associated with an integer value corresponding to its position in the list.
|
||||||
|
"""
|
||||||
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
|
return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data])
|
||||||
|
|
||||||
def CanEncodeSingle(self, input):
|
def CanEncodeSingle(self, input):
|
||||||
@ -87,7 +62,7 @@ class Alphabet(swigwrapper.Alphabet):
|
|||||||
Encode a sequence of character/output classes into a sequence of labels.
|
Encode a sequence of character/output classes into a sequence of labels.
|
||||||
Characters are assumed to always take a single Unicode codepoint.
|
Characters are assumed to always take a single Unicode codepoint.
|
||||||
Characters must be in the alphabet, this method will assert that. Use
|
Characters must be in the alphabet, this method will assert that. Use
|
||||||
`CanEncode` and `CanEncodeSingle` to test.
|
``CanEncode`` and ``CanEncodeSingle`` to test.
|
||||||
"""
|
"""
|
||||||
# Convert SWIG's UnsignedIntVec to a Python list
|
# Convert SWIG's UnsignedIntVec to a Python list
|
||||||
res = super(Alphabet, self).Encode(input.encode("utf-8"))
|
res = super(Alphabet, self).Encode(input.encode("utf-8"))
|
||||||
@ -103,57 +78,39 @@ class Alphabet(swigwrapper.Alphabet):
|
|||||||
return res.decode("utf-8")
|
return res.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
|
class Scorer(swigwrapper.Scorer):
|
||||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
"""An external scorer is a data structure composed of a language model built
|
||||||
|
from text data, as well as the vocabulary used in the construction of this
|
||||||
|
language model and additional parameters related to how the decoding
|
||||||
|
process uses the external scorer, such as the language model weight
|
||||||
|
``alpha`` and the word insertion score ``beta``.
|
||||||
|
|
||||||
def __init__(self):
|
:param alpha: Language model weight.
|
||||||
super(UTF8Alphabet, self).__init__()
|
:type alpha: float
|
||||||
err = self.init(b"")
|
:param beta: Word insertion score.
|
||||||
if err != 0:
|
:type beta: float
|
||||||
raise ValueError(
|
:param scorer_path: Path to load scorer from.
|
||||||
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
|
:type scorer_path: str
|
||||||
)
|
:param alphabet: Alphabet object matching the tokens used when creating the
|
||||||
|
external scorer.
|
||||||
|
:type alphabet: Alphabet
|
||||||
|
"""
|
||||||
|
|
||||||
def CanEncodeSingle(self, input):
|
def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None):
|
||||||
"""
|
super(Scorer, self).__init__()
|
||||||
Returns true if the single character/output class has a corresponding label
|
# Allow bare initialization
|
||||||
in the alphabet.
|
if alphabet:
|
||||||
"""
|
assert alpha is not None, "alpha parameter is required"
|
||||||
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
|
assert beta is not None, "beta parameter is required"
|
||||||
|
assert scorer_path, "scorer_path parameter is required"
|
||||||
|
|
||||||
def CanEncode(self, input):
|
err = self.init(scorer_path.encode("utf-8"), alphabet)
|
||||||
"""
|
if err != 0:
|
||||||
Returns true if the entire string can be encoded into labels in this
|
raise ValueError(
|
||||||
alphabet.
|
"Scorer initialization failed with error code 0x{:X}".format(err)
|
||||||
"""
|
)
|
||||||
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
|
|
||||||
|
|
||||||
def EncodeSingle(self, input):
|
self.reset_params(alpha, beta)
|
||||||
"""
|
|
||||||
Encode a single character/output class into a label. Character must be in
|
|
||||||
the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
|
|
||||||
"""
|
|
||||||
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
|
|
||||||
|
|
||||||
def Encode(self, input):
|
|
||||||
"""
|
|
||||||
Encode a sequence of character/output classes into a sequence of labels.
|
|
||||||
Characters are assumed to always take a single Unicode codepoint.
|
|
||||||
Characters must be in the alphabet, this method will assert that. Use
|
|
||||||
`CanEncode` and `CanEncodeSingle` to test.
|
|
||||||
"""
|
|
||||||
# Convert SWIG's UnsignedIntVec to a Python list
|
|
||||||
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
|
|
||||||
return [el for el in res]
|
|
||||||
|
|
||||||
def DecodeSingle(self, input):
|
|
||||||
res = super(UTF8Alphabet, self).DecodeSingle(input)
|
|
||||||
return res.decode("utf-8")
|
|
||||||
|
|
||||||
def Decode(self, input):
|
|
||||||
"""Decode a sequence of labels into a string."""
|
|
||||||
res = super(UTF8Alphabet, self).Decode(input)
|
|
||||||
return res.decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(
|
def ctc_beam_search_decoder(
|
||||||
@ -186,7 +143,7 @@ def ctc_beam_search_decoder(
|
|||||||
count or language model.
|
count or language model.
|
||||||
:type scorer: Scorer
|
:type scorer: Scorer
|
||||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||||
:type hot_words: map{string:float}
|
:type hot_words: dict[string, float]
|
||||||
:param num_results: Number of beams to return.
|
:param num_results: Number of beams to return.
|
||||||
:type num_results: int
|
:type num_results: int
|
||||||
:return: List of tuples of confidence and sentence as decoding
|
:return: List of tuples of confidence and sentence as decoding
|
||||||
@ -245,7 +202,7 @@ def ctc_beam_search_decoder_batch(
|
|||||||
count or language model.
|
count or language model.
|
||||||
:type scorer: Scorer
|
:type scorer: Scorer
|
||||||
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
:param hot_words: Map of words (keys) to their assigned boosts (values)
|
||||||
:type hot_words: map{string:float}
|
:type hot_words: dict[string, float]
|
||||||
:param num_results: Number of beams to return.
|
:param num_results: Number of beams to return.
|
||||||
:type num_results: int
|
:type num_results: int
|
||||||
:return: List of tuples of confidence and sentence as decoding
|
:return: List of tuples of confidence and sentence as decoding
|
||||||
@ -271,8 +228,63 @@ def ctc_beam_search_decoder_batch(
|
|||||||
return batch_beam_results
|
return batch_beam_results
|
||||||
|
|
||||||
|
|
||||||
|
class FlashlightDecoderState(swigwrapper.FlashlightDecoderState):
|
||||||
|
"""
|
||||||
|
This class contains constants used to specify the desired behavior for the
|
||||||
|
:py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch`
|
||||||
|
functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class CriterionType(enum.IntEnum):
|
||||||
|
"""Constants used to specify which loss criterion was used by the
|
||||||
|
acoustic model. This class is a Python :py:class:`enum.IntEnum`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#: Decoder mode for handling acoustic models trained with CTC loss
|
||||||
|
CTC = swigwrapper.FlashlightDecoderState.CTC
|
||||||
|
|
||||||
|
#: Decoder mode for handling acoustic models trained with ASG loss
|
||||||
|
ASG = swigwrapper.FlashlightDecoderState.ASG
|
||||||
|
|
||||||
|
#: Decoder mode for handling acoustic models trained with Seq2seq loss
|
||||||
|
#: Note: this criterion type is currently not supported.
|
||||||
|
S2S = swigwrapper.FlashlightDecoderState.S2S
|
||||||
|
|
||||||
|
class DecoderType(enum.IntEnum):
|
||||||
|
"""Constants used to specify if decoder should operate in lexicon mode,
|
||||||
|
only predicting words present in a fixed vocabulary, or in lexicon-free
|
||||||
|
mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#: Lexicon mode, only predict words in specified vocabulary.
|
||||||
|
LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased
|
||||||
|
|
||||||
|
#: Lexicon-free mode, allow prediction of any word.
|
||||||
|
LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree
|
||||||
|
|
||||||
|
class TokenType(enum.IntEnum):
|
||||||
|
"""Constants used to specify the granularity of text units used when training
|
||||||
|
the external scorer in relation to the text units used when training the
|
||||||
|
acoustic model. For example, you can have an acoustic model predicting
|
||||||
|
characters and an external scorer trained on words, or an acoustic model
|
||||||
|
and an external scorer both trained with sub-word units. If the acoustic
|
||||||
|
model and the scorer were both trained on the same text unit granularity,
|
||||||
|
use ``TokenType.Single``. Otherwise, if the external scorer was trained
|
||||||
|
on a sequence of acoustic model text units, use ``TokenType.Aggregate``.
|
||||||
|
This class is a Python :py:class:`enum.IntEnum`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#: Token type for external scorers trained on the same textual units as
|
||||||
|
#: the acoustic model.
|
||||||
|
Single = swigwrapper.FlashlightDecoderState.Single
|
||||||
|
|
||||||
|
#: Token type for external scorers trained on a sequence of acoustic model
|
||||||
|
#: textual units.
|
||||||
|
Aggregate = swigwrapper.FlashlightDecoderState.Aggregate
|
||||||
|
|
||||||
|
|
||||||
def flashlight_beam_search_decoder(
|
def flashlight_beam_search_decoder(
|
||||||
probs_seq,
|
logits_seq,
|
||||||
alphabet,
|
alphabet,
|
||||||
beam_size,
|
beam_size,
|
||||||
decoder_type,
|
decoder_type,
|
||||||
@ -283,12 +295,67 @@ def flashlight_beam_search_decoder(
|
|||||||
cutoff_top_n=40,
|
cutoff_top_n=40,
|
||||||
silence_score=0.0,
|
silence_score=0.0,
|
||||||
merge_with_log_add=False,
|
merge_with_log_add=False,
|
||||||
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
|
criterion_type=FlashlightDecoderState.CriterionType.CTC,
|
||||||
transitions=[],
|
transitions=[],
|
||||||
num_results=1,
|
num_results=1,
|
||||||
):
|
):
|
||||||
|
"""Decode acoustic model emissions for a single sample. Note that unlike
|
||||||
|
:py:func:`ctc_beam_search_decoder`, this function expects raw outputs
|
||||||
|
from CTC and ASG acoustic models, without softmaxing them over
|
||||||
|
timesteps.
|
||||||
|
|
||||||
|
:param logits_seq: 2-D list of acoustic model emissions, dimensions are
|
||||||
|
time steps x number of output units.
|
||||||
|
:type logits_seq: 2-D list of floats or numpy array
|
||||||
|
:param alphabet: Alphabet object matching the tokens used when creating the
|
||||||
|
acoustic model and external scorer if specified.
|
||||||
|
:type alphabet: Alphabet
|
||||||
|
:param beam_size: Width for beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param decoder_type: Decoding mode, lexicon-constrained or lexicon-free.
|
||||||
|
:type decoder_type: FlashlightDecoderState.DecoderType
|
||||||
|
:param token_type: Type of token in the external scorer.
|
||||||
|
:type token_type: FlashlightDecoderState.TokenType
|
||||||
|
:param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained
|
||||||
|
mode. Must match the token type used in the scorer, ie.
|
||||||
|
must be a list of characters if scorer is character-based,
|
||||||
|
or a list of words if scorer is word-based.
|
||||||
|
:param lm_tokens: list[str]
|
||||||
|
:param scorer: External scorer.
|
||||||
|
:type scorer: Scorer
|
||||||
|
:param beam_threshold: Maximum threshold in beam score from leading beam. Any
|
||||||
|
newly created candidate beams which lag behind the best
|
||||||
|
beam so far by more than this value will get pruned.
|
||||||
|
This is a performance optimization parameter and an
|
||||||
|
appropriate value should be found empirically using a
|
||||||
|
validation set.
|
||||||
|
:type beam_threshold: float
|
||||||
|
:param cutoff_top_n: Maximum number of tokens to expand per time step during
|
||||||
|
decoding. Only the highest probability cutoff_top_n
|
||||||
|
candidates (characters, sub-word units, words) in a given
|
||||||
|
timestep will be expanded. This is a performance
|
||||||
|
optimization parameter and an appropriate value should
|
||||||
|
be found empirically using a validation set.
|
||||||
|
:type cutoff_top_n: int
|
||||||
|
:param silence_score: Score to add to beam when encountering a predicted
|
||||||
|
silence token (eg. the space symbol).
|
||||||
|
:type silence_score: float
|
||||||
|
:param merge_with_log_add: Whether to use log-add when merging scores of
|
||||||
|
new candidate beams equivalent to existing ones
|
||||||
|
(leading to the same transcription). When disabled,
|
||||||
|
the maximum score is used.
|
||||||
|
:type merge_with_log_add: bool
|
||||||
|
:param criterion_type: Criterion used for training the acoustic model.
|
||||||
|
:type criterion_type: FlashlightDecoderState.CriterionType
|
||||||
|
:param transitions: Transition score matrix for ASG acoustic models.
|
||||||
|
:type transitions: list[float]
|
||||||
|
:param num_results: Number of beams to return.
|
||||||
|
:type num_results: int
|
||||||
|
:return: List of FlashlightOutput structures.
|
||||||
|
:rtype: list[FlashlightOutput]
|
||||||
|
"""
|
||||||
return swigwrapper.flashlight_beam_search_decoder(
|
return swigwrapper.flashlight_beam_search_decoder(
|
||||||
probs_seq,
|
logits_seq,
|
||||||
alphabet,
|
alphabet,
|
||||||
beam_size,
|
beam_size,
|
||||||
beam_threshold,
|
beam_threshold,
|
||||||
@ -319,10 +386,17 @@ def flashlight_beam_search_decoder_batch(
|
|||||||
cutoff_top_n=40,
|
cutoff_top_n=40,
|
||||||
silence_score=0.0,
|
silence_score=0.0,
|
||||||
merge_with_log_add=False,
|
merge_with_log_add=False,
|
||||||
criterion_type=swigwrapper.FlashlightDecoderState.CTC,
|
criterion_type=FlashlightDecoderState.CriterionType.CTC,
|
||||||
transitions=[],
|
transitions=[],
|
||||||
num_results=1,
|
num_results=1,
|
||||||
):
|
):
|
||||||
|
"""Decode batch acoustic model emissions in parallel. ``num_processes``
|
||||||
|
controls how many samples from the batch will be decoded simultaneously.
|
||||||
|
All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`.
|
||||||
|
|
||||||
|
Returns a list of lists of FlashlightOutput structures.
|
||||||
|
"""
|
||||||
|
|
||||||
return swigwrapper.flashlight_beam_search_decoder_batch(
|
return swigwrapper.flashlight_beam_search_decoder_batch(
|
||||||
probs_seq,
|
probs_seq,
|
||||||
seq_lengths,
|
seq_lengths,
|
||||||
@ -341,3 +415,58 @@ def flashlight_beam_search_decoder_batch(
|
|||||||
num_results,
|
num_results,
|
||||||
num_processes,
|
num_processes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UTF8Alphabet(swigwrapper.UTF8Alphabet):
|
||||||
|
"""Alphabet class representing 255 possible byte values for Bytes Output Mode.
|
||||||
|
For internal use only.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(UTF8Alphabet, self).__init__()
|
||||||
|
err = self.init(b"")
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"UTF8Alphabet initialization failed with error code 0x{:X}".format(err)
|
||||||
|
)
|
||||||
|
|
||||||
|
def CanEncodeSingle(self, input):
|
||||||
|
"""
|
||||||
|
Returns true if the single character/output class has a corresponding label
|
||||||
|
in the alphabet.
|
||||||
|
"""
|
||||||
|
return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8"))
|
||||||
|
|
||||||
|
def CanEncode(self, input):
|
||||||
|
"""
|
||||||
|
Returns true if the entire string can be encoded into labels in this
|
||||||
|
alphabet.
|
||||||
|
"""
|
||||||
|
return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8"))
|
||||||
|
|
||||||
|
def EncodeSingle(self, input):
|
||||||
|
"""
|
||||||
|
Encode a single character/output class into a label. Character must be in
|
||||||
|
the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test.
|
||||||
|
"""
|
||||||
|
return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8"))
|
||||||
|
|
||||||
|
def Encode(self, input):
|
||||||
|
"""
|
||||||
|
Encode a sequence of character/output classes into a sequence of labels.
|
||||||
|
Characters are assumed to always take a single Unicode codepoint.
|
||||||
|
Characters must be in the alphabet, this method will assert that. Use
|
||||||
|
``CanEncode`` and ``CanEncodeSingle`` to test.
|
||||||
|
"""
|
||||||
|
# Convert SWIG's UnsignedIntVec to a Python list
|
||||||
|
res = super(UTF8Alphabet, self).Encode(input.encode("utf-8"))
|
||||||
|
return [el for el in res]
|
||||||
|
|
||||||
|
def DecodeSingle(self, input):
|
||||||
|
res = super(UTF8Alphabet, self).DecodeSingle(input)
|
||||||
|
return res.decode("utf-8")
|
||||||
|
|
||||||
|
def Decode(self, input):
|
||||||
|
"""Decode a sequence of labels into a string."""
|
||||||
|
res = super(UTF8Alphabet, self).Decode(input)
|
||||||
|
return res.decode("utf-8")
|
||||||
|
@ -96,7 +96,9 @@ class BuildExtFirst(build):
|
|||||||
setup(
|
setup(
|
||||||
name="coqui_stt_ctcdecoder",
|
name="coqui_stt_ctcdecoder",
|
||||||
version=project_version,
|
version=project_version,
|
||||||
description="""DS CTC decoder""",
|
description="Coqui STT Python decoder package.",
|
||||||
|
long_description="Documentation available at `stt.readthedocs.io <https://stt.readthedocs.io/en/latest/Decoder-API.html>`_",
|
||||||
|
long_description_content_type="text/x-rst; charset=UTF-8",
|
||||||
cmdclass={"build": BuildExtFirst},
|
cmdclass={"build": BuildExtFirst},
|
||||||
ext_modules=[decoder_module],
|
ext_modules=[decoder_module],
|
||||||
package_dir={"coqui_stt_ctcdecoder": "."},
|
package_dir={"coqui_stt_ctcdecoder": "."},
|
||||||
|
@ -142,8 +142,8 @@ def evaluate(test_csvs, create_model):
|
|||||||
batch_lengths,
|
batch_lengths,
|
||||||
Config.alphabet,
|
Config.alphabet,
|
||||||
beam_size=Config.export_beam_width,
|
beam_size=Config.export_beam_width,
|
||||||
decoder_type=FlashlightDecoderState.LexiconBased,
|
decoder_type=FlashlightDecoderState.DecoderType.LexiconBased,
|
||||||
token_type=FlashlightDecoderState.Aggregate,
|
token_type=FlashlightDecoderState.TokenType.Aggregate,
|
||||||
lm_tokens=vocab,
|
lm_tokens=vocab,
|
||||||
num_processes=num_processes,
|
num_processes=num_processes,
|
||||||
scorer=scorer,
|
scorer=scorer,
|
||||||
|
Loading…
Reference in New Issue
Block a user