diff --git a/doc/DECODER.rst b/doc/DECODER.rst index 5a59ec34..d7ea951a 100644 --- a/doc/DECODER.rst +++ b/doc/DECODER.rst @@ -1,14 +1,14 @@ .. _decoder-docs: -CTC beam search decoder -======================= +Beam search decoder +=================== Introduction ------------ 🐸STT uses the `Connectionist Temporal Classification `_ loss function. For an excellent explanation of CTC and its usage, see this Distill article: `Sequence Modeling with CTC `_. This document assumes the reader is familiar with the concepts described in that article, and describes 🐸STT specific behaviors that developers building systems with 🐸STT should know to avoid problems. -Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. +Note: Documentation for the tooling for creating custom scorer packages is available in :ref:`language-model`. Documentation for the coqui_stt_ctcdecoder Python package used by the training code for decoding is available in :ref:`decoder-api`. The key words "MUST", "MUST NOT", "REQUIRED", "SHALL", "SHALL NOT", "SHOULD", "SHOULD NOT", "RECOMMENDED", "MAY", and "OPTIONAL" in this document are to be interpreted as described in `BCP 14 `_ when, and only when, they appear in all capitals, as shown here. diff --git a/doc/Decoder-API.rst b/doc/Decoder-API.rst new file mode 100644 index 00000000..158cc206 --- /dev/null +++ b/doc/Decoder-API.rst @@ -0,0 +1,7 @@ +.. _decoder-api: + +Decoder API reference +===================== + +.. automodule:: native_client.ctcdecode + :members: diff --git a/doc/conf.py b/doc/conf.py index 27fa54e2..ce7e99bf 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -24,7 +24,8 @@ import sys sys.path.insert(0, os.path.abspath("../")) -autodoc_mock_imports = ["stt"] +autodoc_mock_imports = ["stt", "native_client.ctcdecode.swigwrapper"] +autodoc_member_order = "bysource" # This is in fact only relevant on ReadTheDocs, but we want to run the same way # on our CI as in RTD to avoid regressions on RTD that we would not catch on CI @@ -128,7 +129,6 @@ todo_include_todos = False add_module_names = False - # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for diff --git a/doc/index.rst b/doc/index.rst index 3e763222..6400f39d 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -89,6 +89,17 @@ The fastest way to use a pre-trained 🐸STT model is with the 🐸STT model man playbook/README +.. toctree:: + :maxdepth: 1 + :caption: Advanced topics + + DECODER + + Decoder-API + + PARALLEL_OPTIMIZATION + + Indices and tables ================== diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index 7676e6ad..b2b94371 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import, division, print_function +import enum from . import swigwrapper # pylint: disable=import-self @@ -13,41 +13,12 @@ for symbol in dir(swigwrapper): globals()[symbol] = getattr(swigwrapper, symbol) -class FlashlightDecoderState(swigwrapper.FlashlightDecoderState): - pass - - -class Scorer(swigwrapper.Scorer): - """Wrapper for Scorer. - - :param alpha: Language model weight. - :type alpha: float - :param beta: Word insertion bonus. - :type beta: float - :scorer_path: Path to load scorer from. - :alphabet: Alphabet - :type scorer_path: basestring - """ - - def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None): - super(Scorer, self).__init__() - # Allow bare initialization - if alphabet: - assert alpha is not None, "alpha parameter is required" - assert beta is not None, "beta parameter is required" - assert scorer_path, "scorer_path parameter is required" - - err = self.init(scorer_path.encode("utf-8"), alphabet) - if err != 0: - raise ValueError( - "Scorer initialization failed with error code 0x{:X}".format(err) - ) - - self.reset_params(alpha, beta) - - class Alphabet(swigwrapper.Alphabet): - """Convenience wrapper for Alphabet which calls init in the constructor""" + """An Alphabet is a bidirectional map from tokens (eg. characters) to + internal integer representations used by the underlying acoustic models + and external scorers. It can be created from alphabet configuration file + via the constructor, or from a list of tokens via :py:meth:`Alphabet.InitFromLabels`. + """ def __init__(self, config_path=None): super(Alphabet, self).__init__() @@ -59,6 +30,10 @@ class Alphabet(swigwrapper.Alphabet): ) def InitFromLabels(self, data): + """ + Initialize Alphabet from a list of labels ``data``. Each label gets + associated with an integer value corresponding to its position in the list. + """ return super(Alphabet, self).InitFromLabels([c.encode("utf-8") for c in data]) def CanEncodeSingle(self, input): @@ -87,7 +62,7 @@ class Alphabet(swigwrapper.Alphabet): Encode a sequence of character/output classes into a sequence of labels. Characters are assumed to always take a single Unicode codepoint. Characters must be in the alphabet, this method will assert that. Use - `CanEncode` and `CanEncodeSingle` to test. + ``CanEncode`` and ``CanEncodeSingle`` to test. """ # Convert SWIG's UnsignedIntVec to a Python list res = super(Alphabet, self).Encode(input.encode("utf-8")) @@ -103,57 +78,39 @@ class Alphabet(swigwrapper.Alphabet): return res.decode("utf-8") -class UTF8Alphabet(swigwrapper.UTF8Alphabet): - """Convenience wrapper for Alphabet which calls init in the constructor""" +class Scorer(swigwrapper.Scorer): + """An external scorer is a data structure composed of a language model built + from text data, as well as the vocabulary used in the construction of this + language model and additional parameters related to how the decoding + process uses the external scorer, such as the language model weight + ``alpha`` and the word insertion score ``beta``. - def __init__(self): - super(UTF8Alphabet, self).__init__() - err = self.init(b"") - if err != 0: - raise ValueError( - "UTF8Alphabet initialization failed with error code 0x{:X}".format(err) - ) + :param alpha: Language model weight. + :type alpha: float + :param beta: Word insertion score. + :type beta: float + :param scorer_path: Path to load scorer from. + :type scorer_path: str + :param alphabet: Alphabet object matching the tokens used when creating the + external scorer. + :type alphabet: Alphabet + """ - def CanEncodeSingle(self, input): - """ - Returns true if the single character/output class has a corresponding label - in the alphabet. - """ - return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8")) + def __init__(self, alpha=None, beta=None, scorer_path=None, alphabet=None): + super(Scorer, self).__init__() + # Allow bare initialization + if alphabet: + assert alpha is not None, "alpha parameter is required" + assert beta is not None, "beta parameter is required" + assert scorer_path, "scorer_path parameter is required" - def CanEncode(self, input): - """ - Returns true if the entire string can be encoded into labels in this - alphabet. - """ - return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8")) + err = self.init(scorer_path.encode("utf-8"), alphabet) + if err != 0: + raise ValueError( + "Scorer initialization failed with error code 0x{:X}".format(err) + ) - def EncodeSingle(self, input): - """ - Encode a single character/output class into a label. Character must be in - the alphabet, this method will assert that. Use `CanEncodeSingle` to test. - """ - return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8")) - - def Encode(self, input): - """ - Encode a sequence of character/output classes into a sequence of labels. - Characters are assumed to always take a single Unicode codepoint. - Characters must be in the alphabet, this method will assert that. Use - `CanEncode` and `CanEncodeSingle` to test. - """ - # Convert SWIG's UnsignedIntVec to a Python list - res = super(UTF8Alphabet, self).Encode(input.encode("utf-8")) - return [el for el in res] - - def DecodeSingle(self, input): - res = super(UTF8Alphabet, self).DecodeSingle(input) - return res.decode("utf-8") - - def Decode(self, input): - """Decode a sequence of labels into a string.""" - res = super(UTF8Alphabet, self).Decode(input) - return res.decode("utf-8") + self.reset_params(alpha, beta) def ctc_beam_search_decoder( @@ -186,7 +143,7 @@ def ctc_beam_search_decoder( count or language model. :type scorer: Scorer :param hot_words: Map of words (keys) to their assigned boosts (values) - :type hot_words: map{string:float} + :type hot_words: dict[string, float] :param num_results: Number of beams to return. :type num_results: int :return: List of tuples of confidence and sentence as decoding @@ -245,7 +202,7 @@ def ctc_beam_search_decoder_batch( count or language model. :type scorer: Scorer :param hot_words: Map of words (keys) to their assigned boosts (values) - :type hot_words: map{string:float} + :type hot_words: dict[string, float] :param num_results: Number of beams to return. :type num_results: int :return: List of tuples of confidence and sentence as decoding @@ -271,8 +228,63 @@ def ctc_beam_search_decoder_batch( return batch_beam_results +class FlashlightDecoderState(swigwrapper.FlashlightDecoderState): + """ + This class contains constants used to specify the desired behavior for the + :py:func:`flashlight_beam_search_decoder` and :py:func:`flashlight_beam_search_decoder_batch` + functions. + """ + + class CriterionType(enum.IntEnum): + """Constants used to specify which loss criterion was used by the + acoustic model. This class is a Python :py:class:`enum.IntEnum`. + """ + + #: Decoder mode for handling acoustic models trained with CTC loss + CTC = swigwrapper.FlashlightDecoderState.CTC + + #: Decoder mode for handling acoustic models trained with ASG loss + ASG = swigwrapper.FlashlightDecoderState.ASG + + #: Decoder mode for handling acoustic models trained with Seq2seq loss + #: Note: this criterion type is currently not supported. + S2S = swigwrapper.FlashlightDecoderState.S2S + + class DecoderType(enum.IntEnum): + """Constants used to specify if decoder should operate in lexicon mode, + only predicting words present in a fixed vocabulary, or in lexicon-free + mode, without such restriction. This class is a Python :py:class:`enum.IntEnum`. + """ + + #: Lexicon mode, only predict words in specified vocabulary. + LexiconBased = swigwrapper.FlashlightDecoderState.LexiconBased + + #: Lexicon-free mode, allow prediction of any word. + LexiconFree = swigwrapper.FlashlightDecoderState.LexiconFree + + class TokenType(enum.IntEnum): + """Constants used to specify the granularity of text units used when training + the external scorer in relation to the text units used when training the + acoustic model. For example, you can have an acoustic model predicting + characters and an external scorer trained on words, or an acoustic model + and an external scorer both trained with sub-word units. If the acoustic + model and the scorer were both trained on the same text unit granularity, + use ``TokenType.Single``. Otherwise, if the external scorer was trained + on a sequence of acoustic model text units, use ``TokenType.Aggregate``. + This class is a Python :py:class:`enum.IntEnum`. + """ + + #: Token type for external scorers trained on the same textual units as + #: the acoustic model. + Single = swigwrapper.FlashlightDecoderState.Single + + #: Token type for external scorers trained on a sequence of acoustic model + #: textual units. + Aggregate = swigwrapper.FlashlightDecoderState.Aggregate + + def flashlight_beam_search_decoder( - probs_seq, + logits_seq, alphabet, beam_size, decoder_type, @@ -283,12 +295,67 @@ def flashlight_beam_search_decoder( cutoff_top_n=40, silence_score=0.0, merge_with_log_add=False, - criterion_type=swigwrapper.FlashlightDecoderState.CTC, + criterion_type=FlashlightDecoderState.CriterionType.CTC, transitions=[], num_results=1, ): + """Decode acoustic model emissions for a single sample. Note that unlike + :py:func:`ctc_beam_search_decoder`, this function expects raw outputs + from CTC and ASG acoustic models, without softmaxing them over + timesteps. + + :param logits_seq: 2-D list of acoustic model emissions, dimensions are + time steps x number of output units. + :type logits_seq: 2-D list of floats or numpy array + :param alphabet: Alphabet object matching the tokens used when creating the + acoustic model and external scorer if specified. + :type alphabet: Alphabet + :param beam_size: Width for beam search. + :type beam_size: int + :param decoder_type: Decoding mode, lexicon-constrained or lexicon-free. + :type decoder_type: FlashlightDecoderState.DecoderType + :param token_type: Type of token in the external scorer. + :type token_type: FlashlightDecoderState.TokenType + :param lm_tokens: List of tokens to constrain decoding to when in lexicon-constrained + mode. Must match the token type used in the scorer, ie. + must be a list of characters if scorer is character-based, + or a list of words if scorer is word-based. + :param lm_tokens: list[str] + :param scorer: External scorer. + :type scorer: Scorer + :param beam_threshold: Maximum threshold in beam score from leading beam. Any + newly created candidate beams which lag behind the best + beam so far by more than this value will get pruned. + This is a performance optimization parameter and an + appropriate value should be found empirically using a + validation set. + :type beam_threshold: float + :param cutoff_top_n: Maximum number of tokens to expand per time step during + decoding. Only the highest probability cutoff_top_n + candidates (characters, sub-word units, words) in a given + timestep will be expanded. This is a performance + optimization parameter and an appropriate value should + be found empirically using a validation set. + :type cutoff_top_n: int + :param silence_score: Score to add to beam when encountering a predicted + silence token (eg. the space symbol). + :type silence_score: float + :param merge_with_log_add: Whether to use log-add when merging scores of + new candidate beams equivalent to existing ones + (leading to the same transcription). When disabled, + the maximum score is used. + :type merge_with_log_add: bool + :param criterion_type: Criterion used for training the acoustic model. + :type criterion_type: FlashlightDecoderState.CriterionType + :param transitions: Transition score matrix for ASG acoustic models. + :type transitions: list[float] + :param num_results: Number of beams to return. + :type num_results: int + :return: List of FlashlightOutput structures. + :rtype: list[FlashlightOutput] + """ return swigwrapper.flashlight_beam_search_decoder( - probs_seq, + logits_seq, alphabet, beam_size, beam_threshold, @@ -319,10 +386,17 @@ def flashlight_beam_search_decoder_batch( cutoff_top_n=40, silence_score=0.0, merge_with_log_add=False, - criterion_type=swigwrapper.FlashlightDecoderState.CTC, + criterion_type=FlashlightDecoderState.CriterionType.CTC, transitions=[], num_results=1, ): + """Decode batch acoustic model emissions in parallel. ``num_processes`` + controls how many samples from the batch will be decoded simultaneously. + All the other parameters are forwarded to :py:func:`flashlight_beam_search_decoder`. + + Returns a list of lists of FlashlightOutput structures. + """ + return swigwrapper.flashlight_beam_search_decoder_batch( probs_seq, seq_lengths, @@ -341,3 +415,58 @@ def flashlight_beam_search_decoder_batch( num_results, num_processes, ) + + +class UTF8Alphabet(swigwrapper.UTF8Alphabet): + """Alphabet class representing 255 possible byte values for Bytes Output Mode. + For internal use only. + """ + + def __init__(self): + super(UTF8Alphabet, self).__init__() + err = self.init(b"") + if err != 0: + raise ValueError( + "UTF8Alphabet initialization failed with error code 0x{:X}".format(err) + ) + + def CanEncodeSingle(self, input): + """ + Returns true if the single character/output class has a corresponding label + in the alphabet. + """ + return super(UTF8Alphabet, self).CanEncodeSingle(input.encode("utf-8")) + + def CanEncode(self, input): + """ + Returns true if the entire string can be encoded into labels in this + alphabet. + """ + return super(UTF8Alphabet, self).CanEncode(input.encode("utf-8")) + + def EncodeSingle(self, input): + """ + Encode a single character/output class into a label. Character must be in + the alphabet, this method will assert that. Use ``CanEncodeSingle`` to test. + """ + return super(UTF8Alphabet, self).EncodeSingle(input.encode("utf-8")) + + def Encode(self, input): + """ + Encode a sequence of character/output classes into a sequence of labels. + Characters are assumed to always take a single Unicode codepoint. + Characters must be in the alphabet, this method will assert that. Use + ``CanEncode`` and ``CanEncodeSingle`` to test. + """ + # Convert SWIG's UnsignedIntVec to a Python list + res = super(UTF8Alphabet, self).Encode(input.encode("utf-8")) + return [el for el in res] + + def DecodeSingle(self, input): + res = super(UTF8Alphabet, self).DecodeSingle(input) + return res.decode("utf-8") + + def Decode(self, input): + """Decode a sequence of labels into a string.""" + res = super(UTF8Alphabet, self).Decode(input) + return res.decode("utf-8") diff --git a/native_client/ctcdecode/setup.py b/native_client/ctcdecode/setup.py index 5131f365..231e0df9 100644 --- a/native_client/ctcdecode/setup.py +++ b/native_client/ctcdecode/setup.py @@ -96,7 +96,9 @@ class BuildExtFirst(build): setup( name="coqui_stt_ctcdecoder", version=project_version, - description="""DS CTC decoder""", + description="Coqui STT Python decoder package.", + long_description="Documentation available at `stt.readthedocs.io `_", + long_description_content_type="text/x-rst; charset=UTF-8", cmdclass={"build": BuildExtFirst}, ext_modules=[decoder_module], package_dir={"coqui_stt_ctcdecoder": "."}, diff --git a/training/coqui_stt_training/evaluate_flashlight.py b/training/coqui_stt_training/evaluate_flashlight.py index 6ac96ce0..1c7ae155 100644 --- a/training/coqui_stt_training/evaluate_flashlight.py +++ b/training/coqui_stt_training/evaluate_flashlight.py @@ -142,8 +142,8 @@ def evaluate(test_csvs, create_model): batch_lengths, Config.alphabet, beam_size=Config.export_beam_width, - decoder_type=FlashlightDecoderState.LexiconBased, - token_type=FlashlightDecoderState.Aggregate, + decoder_type=FlashlightDecoderState.DecoderType.LexiconBased, + token_type=FlashlightDecoderState.TokenType.Aggregate, lm_tokens=vocab, num_processes=num_processes, scorer=scorer,