Expose letter timings on the API
This commit is contained in:
parent
730ef1b5c8
commit
192e17f2d5
@ -115,7 +115,9 @@ struct StreamingState {
|
|||||||
|
|
||||||
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
||||||
char* intermediateDecode();
|
char* intermediateDecode();
|
||||||
|
void finalizeStream();
|
||||||
char* finishStream();
|
char* finishStream();
|
||||||
|
Metadata* finishStreamWithMetadata();
|
||||||
|
|
||||||
void processAudioWindow(const vector<float>& buf);
|
void processAudioWindow(const vector<float>& buf);
|
||||||
void processMfccWindow(const vector<float>& buf);
|
void processMfccWindow(const vector<float>& buf);
|
||||||
@ -170,6 +172,28 @@ struct ModelState {
|
|||||||
*/
|
*/
|
||||||
char* decode(vector<float>& logits);
|
char* decode(vector<float>& logits);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
|
* CTC decoder with KenLM enabled
|
||||||
|
*
|
||||||
|
* @param logits Flat matrix of logits, of size:
|
||||||
|
* n_frames * batch_size * num_classes
|
||||||
|
*
|
||||||
|
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||||
|
*/
|
||||||
|
vector<Output> decode_raw(vector<float>& logits);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return character-level metadata including letter timings.
|
||||||
|
*
|
||||||
|
* @param logits Flat matrix of logits, of size:
|
||||||
|
* n_frames * batch_size * num_classes
|
||||||
|
*
|
||||||
|
* @return Metadata struct containing MetadataItem structs for each character.
|
||||||
|
* The user is responsible for freeing Metadata and Metadata.items.
|
||||||
|
*/
|
||||||
|
Metadata* decode_metadata(vector<float>& logits);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Do a single inference step in the acoustic model, with:
|
* @brief Do a single inference step in the acoustic model, with:
|
||||||
* input=mfcc
|
* input=mfcc
|
||||||
@ -183,6 +207,9 @@ struct ModelState {
|
|||||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
|
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
|
||||||
|
unsigned int aBufferSize, unsigned int aSampleRate);
|
||||||
|
|
||||||
ModelState::ModelState()
|
ModelState::ModelState()
|
||||||
:
|
:
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
@ -260,22 +287,19 @@ StreamingState::intermediateDecode()
|
|||||||
char*
|
char*
|
||||||
StreamingState::finishStream()
|
StreamingState::finishStream()
|
||||||
{
|
{
|
||||||
// Flush audio buffer
|
finalizeStream();
|
||||||
processAudioWindow(audio_buffer);
|
|
||||||
|
|
||||||
// Add empty mfcc vectors at end of sample
|
|
||||||
for (int i = 0; i < model->n_context; ++i) {
|
|
||||||
addZeroMfccWindow();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process final batch
|
|
||||||
if (batch_buffer.size() > 0) {
|
|
||||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
|
||||||
}
|
|
||||||
|
|
||||||
return model->decode(accumulated_logits);
|
return model->decode(accumulated_logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Metadata*
|
||||||
|
StreamingState::finishStreamWithMetadata()
|
||||||
|
{
|
||||||
|
finalizeStream();
|
||||||
|
|
||||||
|
return model->decode_metadata(accumulated_logits);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
StreamingState::processAudioWindow(const vector<float>& buf)
|
StreamingState::processAudioWindow(const vector<float>& buf)
|
||||||
{
|
{
|
||||||
@ -291,6 +315,23 @@ StreamingState::processAudioWindow(const vector<float>& buf)
|
|||||||
free(mfcc);
|
free(mfcc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
StreamingState::finalizeStream()
|
||||||
|
{
|
||||||
|
// Flush audio buffer
|
||||||
|
processAudioWindow(audio_buffer);
|
||||||
|
|
||||||
|
// Add empty mfcc vectors at end of sample
|
||||||
|
for (int i = 0; i < model->n_context; ++i) {
|
||||||
|
addZeroMfccWindow();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process final batch
|
||||||
|
if (batch_buffer.size() > 0) {
|
||||||
|
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
StreamingState::addZeroMfccWindow()
|
StreamingState::addZeroMfccWindow()
|
||||||
{
|
{
|
||||||
@ -415,6 +456,14 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
|
|
||||||
char*
|
char*
|
||||||
ModelState::decode(vector<float>& logits)
|
ModelState::decode(vector<float>& logits)
|
||||||
|
{
|
||||||
|
vector<Output> out = ModelState::decode_raw(logits);
|
||||||
|
|
||||||
|
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Output>
|
||||||
|
ModelState::decode_raw(vector<float>& logits)
|
||||||
{
|
{
|
||||||
const int cutoff_top_n = 40;
|
const int cutoff_top_n = 40;
|
||||||
const double cutoff_prob = 1.0;
|
const double cutoff_prob = 1.0;
|
||||||
@ -429,7 +478,37 @@ ModelState::decode(vector<float>& logits)
|
|||||||
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
|
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
|
||||||
cutoff_prob, cutoff_top_n, scorer);
|
cutoff_prob, cutoff_top_n, scorer);
|
||||||
|
|
||||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
Metadata* ModelState::decode_metadata(vector<float>& logits)
|
||||||
|
{
|
||||||
|
vector<Output> out = decode_raw(logits);
|
||||||
|
|
||||||
|
Metadata* metadata = (Metadata*)malloc(sizeof (Metadata));
|
||||||
|
metadata->num_items = out[0].tokens.size();
|
||||||
|
metadata->items = (MetadataItem*)malloc(sizeof(MetadataItem) * metadata->num_items);
|
||||||
|
|
||||||
|
// Loop through each character
|
||||||
|
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||||
|
char* character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
|
||||||
|
|
||||||
|
// Note: 1 timestep = 20ms
|
||||||
|
float start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
|
||||||
|
|
||||||
|
MetadataItem item;
|
||||||
|
item.character = character;
|
||||||
|
item.timestep = out[0].timesteps[i];
|
||||||
|
item.start_time = start_time;
|
||||||
|
|
||||||
|
if (item.start_time < 0) {
|
||||||
|
item.start_time = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata->items[i] = item;
|
||||||
|
}
|
||||||
|
|
||||||
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
#ifdef USE_TFLITE
|
||||||
@ -660,14 +739,36 @@ DS_SpeechToText(ModelState* aCtx,
|
|||||||
const short* aBuffer,
|
const short* aBuffer,
|
||||||
unsigned int aBufferSize,
|
unsigned int aBufferSize,
|
||||||
unsigned int aSampleRate)
|
unsigned int aSampleRate)
|
||||||
|
{
|
||||||
|
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||||
|
return DS_FinishStream(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Metadata*
|
||||||
|
DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate)
|
||||||
|
{
|
||||||
|
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||||
|
return DS_FinishStreamWithMetadata(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamingState*
|
||||||
|
setupStreamAndFeedAudioContent(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate)
|
||||||
{
|
{
|
||||||
StreamingState* ctx;
|
StreamingState* ctx;
|
||||||
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
||||||
if (status != DS_ERR_OK) {
|
if (status != DS_ERR_OK) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
||||||
return DS_FinishStream(ctx);
|
|
||||||
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
@ -735,6 +836,14 @@ DS_FinishStream(StreamingState* aSctx)
|
|||||||
return str;
|
return str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Metadata*
|
||||||
|
DS_FinishStreamWithMetadata(StreamingState* aSctx)
|
||||||
|
{
|
||||||
|
Metadata* metadata = aSctx->finishStreamWithMetadata();
|
||||||
|
DS_DiscardStream(aSctx);
|
||||||
|
return metadata;
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
DS_DiscardStream(StreamingState* aSctx)
|
DS_DiscardStream(StreamingState* aSctx)
|
||||||
{
|
{
|
||||||
@ -809,6 +918,13 @@ DS_AudioToInputVector(const short* aBuffer,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
DS_FreeMetadata(Metadata* m)
|
||||||
|
{
|
||||||
|
free(m->items);
|
||||||
|
free(m);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
DS_PrintVersions() {
|
DS_PrintVersions() {
|
||||||
std::cerr << "TensorFlow: " << tf_local_git_version() << std::endl;
|
std::cerr << "TensorFlow: " << tf_local_git_version() << std::endl;
|
||||||
|
@ -15,6 +15,19 @@ struct ModelState;
|
|||||||
|
|
||||||
struct StreamingState;
|
struct StreamingState;
|
||||||
|
|
||||||
|
// Stores each individual character, along with its timing information
|
||||||
|
struct MetadataItem {
|
||||||
|
char* character;
|
||||||
|
int timestep; // Position of the character in units of 20ms
|
||||||
|
float start_time; // Position of the character in seconds
|
||||||
|
};
|
||||||
|
|
||||||
|
// Stores the entire CTC output as an array of character metadata objects
|
||||||
|
struct Metadata {
|
||||||
|
MetadataItem* items;
|
||||||
|
int num_items;
|
||||||
|
};
|
||||||
|
|
||||||
enum DeepSpeech_Error_Codes
|
enum DeepSpeech_Error_Codes
|
||||||
{
|
{
|
||||||
// OK
|
// OK
|
||||||
@ -109,6 +122,25 @@ char* DS_SpeechToText(ModelState* aCtx,
|
|||||||
unsigned int aBufferSize,
|
unsigned int aBufferSize,
|
||||||
unsigned int aSampleRate);
|
unsigned int aSampleRate);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
|
||||||
|
* about the results.
|
||||||
|
*
|
||||||
|
* @param aCtx The ModelState pointer for the model to use.
|
||||||
|
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
|
||||||
|
* sample rate.
|
||||||
|
* @param aBufferSize The number of samples in the audio signal.
|
||||||
|
* @param aSampleRate The sample-rate of the audio signal.
|
||||||
|
*
|
||||||
|
* @return Outputs a struct of individual letters along with their timing information.
|
||||||
|
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Create a new streaming inference state. The streaming state returned
|
* @brief Create a new streaming inference state. The streaming state returned
|
||||||
* by this function can then be passed to {@link DS_FeedAudioContent()}
|
* by this function can then be passed to {@link DS_FeedAudioContent()}
|
||||||
@ -170,6 +202,20 @@ char* DS_IntermediateDecode(StreamingState* aSctx);
|
|||||||
DEEPSPEECH_EXPORT
|
DEEPSPEECH_EXPORT
|
||||||
char* DS_FinishStream(StreamingState* aSctx);
|
char* DS_FinishStream(StreamingState* aSctx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Signal the end of an audio signal to an ongoing streaming
|
||||||
|
* inference, returns per-letter metadata.
|
||||||
|
*
|
||||||
|
* @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}.
|
||||||
|
*
|
||||||
|
* @return Outputs a struct of individual letters along with their timing information.
|
||||||
|
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
|
||||||
|
*
|
||||||
|
* @note This method will free the state pointer (@p aSctx).
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Destroy a streaming state without decoding the computed logits. This
|
* @brief Destroy a streaming state without decoding the computed logits. This
|
||||||
* can be used if you no longer need the result of an ongoing streaming
|
* can be used if you no longer need the result of an ongoing streaming
|
||||||
@ -213,6 +259,13 @@ void DS_AudioToInputVector(const short* aBuffer,
|
|||||||
int* aNFrames = NULL,
|
int* aNFrames = NULL,
|
||||||
int* aFrameLen = NULL);
|
int* aFrameLen = NULL);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Free memory allocated for metadata information.
|
||||||
|
*/
|
||||||
|
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
void DS_FreeMetadata(Metadata* m);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Print version of this library and of the linked TensorFlow library.
|
* @brief Print version of this library and of the linked TensorFlow library.
|
||||||
*/
|
*/
|
||||||
|
Loading…
x
Reference in New Issue
Block a user