Add intermediateDecodeExpensiveWithMetadata() and Python bindings

- Rename finalizeStream() to flushBuffers().
- Add optional argument addZeroMfccVectors = false.
- Fix doc string applied to wrong function.
This commit is contained in:
Jeremiah Rose 2021-09-09 09:56:44 +10:00
parent fe03974d7d
commit ed656aa487
4 changed files with 70 additions and 25 deletions

View File

@ -315,8 +315,10 @@ Metadata* STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
unsigned int aNumResults);
/**
* @brief Compute the final decoding of an ongoing streaming inference and return
* the result. Signals the end of an ongoing streaming inference.
* @brief Compute the intermediate decoding of an ongoing streaming inference, flushing
buffers first. This ensures that all audio that has been streamed so far is
included in the result, but is more expensive than STT_IntermediateDecode()
because buffers are processed through the acoustic model.
*
* @param aSctx A streaming state pointer returned by {@link STT_CreateStream()}.
*
@ -331,8 +333,9 @@ char* STT_IntermediateDecodeExpensive(StreamingState* aSctx);
/**
* @brief Compute the intermediate decoding of an ongoing streaming inference, flushing
buffers first. This ensures that all audio that has been streamed so far is
included in the result, but is more expensive than STT_IntermediateDecode()
because buffers are processed through the acoustic model.
included in the result, but is more expensive than
STT_IntermediateDecodeWithMetadata() because buffers are processed through
the acoustic model. Return results including metadata.
*
* @param aSctx A streaming state pointer returned by {@link STT_CreateStream()}.
* @param aNumResults The number of candidate transcripts to return.
@ -343,6 +346,21 @@ char* STT_IntermediateDecodeExpensive(StreamingState* aSctx);
* Returns NULL on error.
*/
STT_EXPORT
Metadata* STT_IntermediateDecodeExpensiveWithMetadata(StreamingState* aSctx,
unsigned int aNumResults);
/**
* @brief Compute the final decoding of an ongoing streaming inference and return
* the result. Signals the end of an ongoing streaming inference.
*
* @param aSctx A streaming state pointer returned by {@link STT_CreateStream()}.
*
* @return The STT result. The user is responsible for freeing the string using
* {@link STT_FreeString()}.
*
* @note This method will free the state pointer (@p aSctx).
*/
STT_EXPORT
char* STT_FinishStream(StreamingState* aSctx);
/**

View File

@ -299,6 +299,28 @@ class Stream(object):
)
return stt.impl.IntermediateDecodeExpensive(self._impl)
def intermediateDecodeExpensiveWithMetadata(self, num_results=1):
"""
Compute the intermediate decoding of an ongoing streaming inference, flushing
buffers first. This ensures that all audio that has been streamed so far is
included in the result, but is more expensive than intermediateDecode() because
buffers are processed through the acoustic model. Return results including
metadata.
:param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this.
:type num_results: int
:return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information.
:type: :func:`Metadata`
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def finishStream(self):
"""
Compute the final decoding of an ongoing streaming inference and return

View File

@ -120,6 +120,7 @@ static PyObject *parent_reference() {
%newobject STT_SpeechToText;
%newobject STT_IntermediateDecode;
%newobject STT_IntermediateDecodeExpensive;
%newobject STT_IntermediateDecodeExpensiveWithMetadata;
%newobject STT_FinishStream;
%newobject STT_Version;
%newobject STT_ErrorCodeToErrorMessage;

View File

@ -80,8 +80,8 @@ struct StreamingState {
char* intermediateDecode() const;
Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const;
char* intermediateDecodeExpensive();
void flushBuffer();
void finalizeStream();
Metadata* intermediateDecodeExpensiveWithMetadata(unsigned int num_results);
void flushBuffers(bool addZeroMfccVectors = false);
char* finishStream();
Metadata* finishStreamWithMetadata(unsigned int num_results);
@ -148,21 +148,28 @@ StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const
char*
StreamingState::intermediateDecodeExpensive()
{
flushBuffer();
flushBuffers();
return model_->decode(decoder_state_);
}
Metadata*
StreamingState::intermediateDecodeExpensiveWithMetadata(unsigned int num_results)
{
flushBuffers();
return model_->decode_metadata(decoder_state_, num_results);
}
char*
StreamingState::finishStream()
{
finalizeStream();
flushBuffers(true);
return model_->decode(decoder_state_);
}
Metadata*
StreamingState::finishStreamWithMetadata(unsigned int num_results)
{
finalizeStream();
flushBuffers(true);
return model_->decode_metadata(decoder_state_, num_results);
}
@ -177,27 +184,17 @@ StreamingState::processAudioWindow(const vector<float>& buf)
}
void
StreamingState::flushBuffer()
{
// Flush audio buffer
processAudioWindow(audio_buffer_);
// Process final batch
if (batch_buffer_.size() > 0) {
processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_);
}
}
void
StreamingState::finalizeStream()
StreamingState::flushBuffers(bool addZeroMfccVectors)
{
// Flush audio buffer
processAudioWindow(audio_buffer_);
if (addZeroMfccVectors) {
// Add empty mfcc vectors at end of sample
for (int i = 0; i < model_->n_context_; ++i) {
addZeroMfccWindow();
}
}
// Process final batch
if (batch_buffer_.size() > 0) {
@ -492,6 +489,13 @@ STT_IntermediateDecodeExpensive(StreamingState* aSctx)
return aSctx->intermediateDecodeExpensive();
}
Metadata*
STT_IntermediateDecodeExpensiveWithMetadata(StreamingState* aSctx,
unsigned int aNumResults)
{
return aSctx->intermediateDecodeExpensiveWithMetadata(aNumResults);
}
char*
STT_FinishStream(StreamingState* aSctx)
{