Add intermediateDecodeExpensive() function and Python binding

This flushes the buffer through the acoustic model before returning
an intermediate decode. Provides an alternative to finishStream()
that doesn't end the stream.
This commit is contained in:
Jeremiah Rose 2021-09-08 15:09:51 +10:00
parent ba24f010eb
commit 3a791d82e8
4 changed files with 63 additions and 0 deletions

View File

@ -326,6 +326,20 @@ Metadata* STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
* @note This method will free the state pointer (@p aSctx).
*/
STT_EXPORT
char* STT_IntermediateDecodeExpensive(StreamingState* aSctx);
/**
* @brief Compute the intermediate decoding of an ongoing streaming inference,
*
* @param aSctx A streaming state pointer returned by {@link STT_CreateStream()}.
* @param aNumResults The number of candidate transcripts to return.
*
* @return Metadata struct containing multiple candidate transcripts. Each transcript
* has per-token metadata including timing information. The user is
* responsible for freeing Metadata by calling {@link STT_FreeMetadata()}.
* Returns NULL on error.
*/
STT_EXPORT
char* STT_FinishStream(StreamingState* aSctx);
/**

View File

@ -281,6 +281,22 @@ class Stream(object):
)
return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results)
def intermediateDecodeExpensive(self):
"""
Compute the intermediate decoding of an ongoing streaming inference, flushing buffers.
This ensures that all data that has been streamed so far are included in the result.
:return: The STT intermediate result.
:type: str
:throws: RuntimeError if the stream object is not valid
"""
if not self._impl:
raise RuntimeError(
"Stream object is not valid. Trying to decode an already finished stream?"
)
return stt.impl.IntermediateDecodeExpensive(self._impl)
def finishStream(self):
"""
Compute the final decoding of an ongoing streaming inference and return

View File

@ -119,6 +119,7 @@ static PyObject *parent_reference() {
%newobject STT_SpeechToText;
%newobject STT_IntermediateDecode;
%newobject STT_IntermediateDecodeExpensive;
%newobject STT_FinishStream;
%newobject STT_Version;
%newobject STT_ErrorCodeToErrorMessage;

View File

@ -79,6 +79,8 @@ struct StreamingState {
void feedAudioContent(const short* buffer, unsigned int buffer_size);
char* intermediateDecode() const;
Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const;
char* intermediateDecodeExpensive();
void flushBuffer();
void finalizeStream();
char* finishStream();
Metadata* finishStreamWithMetadata(unsigned int num_results);
@ -143,6 +145,13 @@ StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const
return model_->decode_metadata(decoder_state_, num_results);
}
char*
StreamingState::intermediateDecodeExpensive()
{
flushBuffer();
return model_->decode(decoder_state_);
}
char*
StreamingState::finishStream()
{
@ -167,6 +176,23 @@ StreamingState::processAudioWindow(const vector<float>& buf)
pushMfccBuffer(mfcc);
}
void
StreamingState::flushBuffer()
{
// Flush audio buffer
processAudioWindow(audio_buffer_);
// Add empty mfcc vectors at end of sample
//for (int i = 0; i < model_->n_context_; ++i) {
// addZeroMfccWindow();
//}
// Process final batch
if (batch_buffer_.size() > 0) {
processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_);
}
}
void
StreamingState::finalizeStream()
{
@ -465,6 +491,12 @@ STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx,
return aSctx->intermediateDecodeWithMetadata(aNumResults);
}
char*
STT_IntermediateDecodeExpensive(StreamingState* aSctx)
{
return aSctx->intermediateDecodeExpensive();
}
char*
STT_FinishStream(StreamingState* aSctx)
{