From 3a791d82e83f11610b9e7a927fa298381dbb03a7 Mon Sep 17 00:00:00 2001 From: Jeremiah Rose Date: Wed, 8 Sep 2021 15:09:51 +1000 Subject: [PATCH] Add intermediateDecodeExpensive() function and Python binding This flushes the buffer through the acoustic model before returning an intermediate decode. Provides an alternative to finishStream() that doesn't end the stream. --- native_client/coqui-stt.h | 14 ++++++++++++++ native_client/python/__init__.py | 16 ++++++++++++++++ native_client/python/impl.i | 1 + native_client/stt.cc | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+) diff --git a/native_client/coqui-stt.h b/native_client/coqui-stt.h index 7794bc79..a08e3ab4 100644 --- a/native_client/coqui-stt.h +++ b/native_client/coqui-stt.h @@ -326,6 +326,20 @@ Metadata* STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx, * @note This method will free the state pointer (@p aSctx). */ STT_EXPORT +char* STT_IntermediateDecodeExpensive(StreamingState* aSctx); + +/** + * @brief Compute the intermediate decoding of an ongoing streaming inference, + * + * @param aSctx A streaming state pointer returned by {@link STT_CreateStream()}. + * @param aNumResults The number of candidate transcripts to return. + * + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. The user is + * responsible for freeing Metadata by calling {@link STT_FreeMetadata()}. + * Returns NULL on error. + */ +STT_EXPORT char* STT_FinishStream(StreamingState* aSctx); /** diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index e1d37a70..ee5685bd 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -281,6 +281,22 @@ class Stream(object): ) return stt.impl.IntermediateDecodeWithMetadata(self._impl, num_results) + def intermediateDecodeExpensive(self): + """ + Compute the intermediate decoding of an ongoing streaming inference, flushing buffers. + This ensures that all data that has been streamed so far are included in the result. + + :return: The STT intermediate result. + :type: str + + :throws: RuntimeError if the stream object is not valid + """ + if not self._impl: + raise RuntimeError( + "Stream object is not valid. Trying to decode an already finished stream?" + ) + return stt.impl.IntermediateDecodeExpensive(self._impl) + def finishStream(self): """ Compute the final decoding of an ongoing streaming inference and return diff --git a/native_client/python/impl.i b/native_client/python/impl.i index b9405238..142d2e8d 100644 --- a/native_client/python/impl.i +++ b/native_client/python/impl.i @@ -119,6 +119,7 @@ static PyObject *parent_reference() { %newobject STT_SpeechToText; %newobject STT_IntermediateDecode; +%newobject STT_IntermediateDecodeExpensive; %newobject STT_FinishStream; %newobject STT_Version; %newobject STT_ErrorCodeToErrorMessage; diff --git a/native_client/stt.cc b/native_client/stt.cc index 28715ec5..0ea5bb02 100644 --- a/native_client/stt.cc +++ b/native_client/stt.cc @@ -79,6 +79,8 @@ struct StreamingState { void feedAudioContent(const short* buffer, unsigned int buffer_size); char* intermediateDecode() const; Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const; + char* intermediateDecodeExpensive(); + void flushBuffer(); void finalizeStream(); char* finishStream(); Metadata* finishStreamWithMetadata(unsigned int num_results); @@ -143,6 +145,13 @@ StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const return model_->decode_metadata(decoder_state_, num_results); } +char* +StreamingState::intermediateDecodeExpensive() +{ + flushBuffer(); + return model_->decode(decoder_state_); +} + char* StreamingState::finishStream() { @@ -167,6 +176,23 @@ StreamingState::processAudioWindow(const vector& buf) pushMfccBuffer(mfcc); } +void +StreamingState::flushBuffer() +{ + // Flush audio buffer + processAudioWindow(audio_buffer_); + + // Add empty mfcc vectors at end of sample + //for (int i = 0; i < model_->n_context_; ++i) { + // addZeroMfccWindow(); + //} + + // Process final batch + if (batch_buffer_.size() > 0) { + processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_); + } +} + void StreamingState::finalizeStream() { @@ -465,6 +491,12 @@ STT_IntermediateDecodeWithMetadata(const StreamingState* aSctx, return aSctx->intermediateDecodeWithMetadata(aNumResults); } +char* +STT_IntermediateDecodeExpensive(StreamingState* aSctx) +{ + return aSctx->intermediateDecodeExpensive(); +} + char* STT_FinishStream(StreamingState* aSctx) {