STT/native_client/python/__init__.py

393 lines
13 KiB
Python

import os
import platform
# The API is not snake case which triggers linter errors
# pylint: disable=invalid-name
if platform.system().lower() == "windows":
dslib_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "lib")
# On Windows, we can't rely on RPATH being set to $ORIGIN/lib/ or on
# @loader_path/lib
if hasattr(os, "add_dll_directory"):
# Starting with Python 3.8 this properly handles the problem
os.add_dll_directory(dslib_path)
else:
# Before Pythin 3.8 we need to change the PATH to include the proper
# directory for the dynamic linker
os.environ["PATH"] = dslib_path + ";" + os.environ["PATH"]
import stt
# rename for backwards compatibility
from stt.impl import Version as version
class Model(object):
"""
Class holding a Coqui STT model
:param aModelPath: Path to model file to load
:type aModelPath: str
"""
def __init__(self, model_path):
# make sure the attribute is there if CreateModel fails
self._impl = None
status, impl = stt.impl.CreateModel(model_path)
if status != 0:
raise RuntimeError(
"CreateModel failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
self._impl = impl
def __del__(self):
if self._impl:
stt.impl.FreeModel(self._impl)
self._impl = None
def beamWidth(self):
"""
Get beam width value used by the model. If setModelBeamWidth was not
called before, will return the default value loaded from the model file.
:return: Beam width value used by the model.
:type: int
"""
return stt.impl.GetModelBeamWidth(self._impl)
def setBeamWidth(self, beam_width):
"""
Set beam width value used by the model.
:param beam_width: The beam width used by the model. A larger beam width value generates better results at the cost of decoding time.
:type beam_width: int
:return: Zero on success, non-zero on failure.
:type: int
"""
return stt.impl.SetModelBeamWidth(self._impl, beam_width)
def sampleRate(self):
"""
Return the sample rate expected by the model.
:return: Sample rate.
:type: int
"""
return stt.impl.GetModelSampleRate(self._impl)
def enableExternalScorer(self, scorer_path):
"""
Enable decoding using an external scorer.
:param scorer_path: The path to the external scorer file.
:type scorer_path: str
:throws: RuntimeError on error
"""
status = stt.impl.EnableExternalScorer(self._impl, scorer_path)
if status != 0:
raise RuntimeError(
"EnableExternalScorer failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def disableExternalScorer(self):
"""
Disable decoding using an external scorer.
:return: Zero on success, non-zero on failure.
"""
return stt.impl.DisableExternalScorer(self._impl)
def addHotWord(self, word, boost):
"""
Add a word and its boost for decoding.
Words that don't occur in the scorer (e.g. proper nouns) or strings that contain spaces won't be taken into account.
:param word: the hot-word
:type word: str
:param boost: Positive boost value increases and negative reduces chance of a word occuring in a transcription. Excessive positive boost might lead to splitting up of letters of the word following the hot-word.
:type boost: float
:throws: RuntimeError on error
"""
status = stt.impl.AddHotWord(self._impl, word, boost)
if status != 0:
raise RuntimeError(
"AddHotWord failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def eraseHotWord(self, word):
"""
Remove entry for word from hot-words dict.
:param word: the hot-word
:type word: str
:throws: RuntimeError on error
"""
status = stt.impl.EraseHotWord(self._impl, word)
if status != 0:
raise RuntimeError(
"EraseHotWord failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def clearHotWords(self):
"""
Remove all entries from hot-words dict.
:throws: RuntimeError on error
"""
status = stt.impl.ClearHotWords(self._impl)
if status != 0:
raise RuntimeError(
"ClearHotWords failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
def setScorerAlphaBeta(self, alpha, beta):
"""
Set hyperparameters alpha and beta of the external scorer.
:param alpha: The alpha hyperparameter of the decoder. Language model weight.
:type alpha: float
:param beta: The beta hyperparameter of the decoder. Word insertion weight.
:type beta: float
:return: Zero on success, non-zero on failure.
:type: int
"""
return stt.impl.SetScorerAlphaBeta(self._impl, alpha, beta)
def stt(self, audio_buffer):
"""
Use the Coqui STT model to perform Speech-To-Text.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:return: The STT result.
:type: str
"""
return stt.impl.SpeechToText(self._impl, audio_buffer)
def sttWithMetadata(self, audio_buffer, num_results=1):
"""
Use the Coqui STT model to perform Speech-To-Text and return results including metadata.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
"""
return stt.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results)
def createStream(self):
"""
Create a new streaming inference state. The streaming state returned by
this function can then be passed to :func:`feedAudioContent()` and :func:`finishStream()`.
:return: Stream object representing the newly created stream
:type: :func:`Stream`
:throws: RuntimeError on error
"""
status, ctx = stt.impl.CreateStream(self._impl)
if status != 0:
raise RuntimeError(
"CreateStream failed with '{}' (0x{:X})".format(
stt.impl.ErrorCodeToErrorMessage(status), status
)
)
return Stream(ctx)
class Stream(object):
"""
Class wrapping a stt stream. The constructor cannot be called directly.
Use :func:`Model.createStream()`
"""
def __init__(self, native_stream):
self._impl = native_stream
def __del__(self):
if self._impl:
self.freeStream()
def feedAudioContent(self, audio_buffer):
"""
Feed audio samples to an ongoing streaming inference.
:param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on).
:type audio_buffer: numpy.int16 array
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to feed an already finished stream?"
)
stt.impl.FeedAudioContent(self._impl, audio_buffer)
def intermediateDecode(self):
"""
Compute the intermediate decoding of an ongoing streaming inference.
:return: The STT intermediate result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecode(self._impl)
def intermediateDecodeWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference and return results including metadata.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self):
"""
Compute the final decoding of an ongoing streaming inference and return
the result. Signals the end of an ongoing streaming inference. The underlying
stream object must not be used after this method is called.
:return: The STT result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to finish an already finished stream?"
)
result = stt.impl.FinishStream(self._impl)
self._impl = None
return result
def finishStreamWithMetadata(self, num_results=1):
"""
Compute the final decoding of an ongoing streaming inference and return
results including metadata. Signals the end of an ongoing streaming
inference. The underlying stream object must not be used after this
method is called.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to finish an already finished stream?"
)
result = stt.impl.FinishStreamWithMetadata(self._impl, num_results)
self._impl = None
return result
def freeStream(self):
"""
Destroy a streaming state without decoding the computed logits. This can
be used if you no longer need the result of an ongoing streaming inference.
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to free an already finished stream?"
)
stt.impl.FreeStream(self._impl)
self._impl = None
# This is only for documentation purpose
# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/coqui-stt.h
class TokenMetadata(object):
"""
Stores each individual character, along with its timing information
"""
def text(self):
"""
The text for this token
"""
def timestep(self):
"""
Position of the token in units of 20ms
"""
def start_time(self):
"""
Position of the token in seconds
"""
class CandidateTranscript(object):
"""
Stores the entire CTC output as an array of character metadata objects
"""
def tokens(self):
"""
List of tokens
:return: A list of :func:`TokenMetadata` elements
:type: list
"""
def confidence(self):
"""
Approximated confidence value for this transcription. This is roughly the
sum of the acoustic model logit values for each timestep/character that
contributed to the creation of this transcription.
"""
class Metadata(object):
def transcripts(self):
"""
List of candidate transcripts
:return: A list of :func:`CandidateTranscript` objects
:type: list
"""