Expose letter timings on the API

This commit is contained in:
dabinat 2019-03-21 15:50:02 -07:00
parent 730ef1b5c8
commit 192e17f2d5
2 changed files with 183 additions and 14 deletions

View File

@ -115,7 +115,9 @@ struct StreamingState {
void feedAudioContent(const short* buffer, unsigned int buffer_size);
char* intermediateDecode();
void finalizeStream();
char* finishStream();
Metadata* finishStreamWithMetadata();
void processAudioWindow(const vector<float>& buf);
void processMfccWindow(const vector<float>& buf);
@ -170,6 +172,28 @@ struct ModelState {
*/
char* decode(vector<float>& logits);
/**
* @brief Perform decoding of the logits, using basic CTC decoder or
* CTC decoder with KenLM enabled
*
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return Vector of Output structs directly from the CTC decoder for additional processing.
*/
vector<Output> decode_raw(vector<float>& logits);
/**
* @brief Return character-level metadata including letter timings.
*
* @param logits Flat matrix of logits, of size:
* n_frames * batch_size * num_classes
*
* @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata and Metadata.items.
*/
Metadata* decode_metadata(vector<float>& logits);
/**
* @brief Do a single inference step in the acoustic model, with:
* input=mfcc
@ -183,6 +207,9 @@ struct ModelState {
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
};
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
unsigned int aBufferSize, unsigned int aSampleRate);
ModelState::ModelState()
:
#ifndef USE_TFLITE
@ -260,22 +287,19 @@ StreamingState::intermediateDecode()
char*
StreamingState::finishStream()
{
// Flush audio buffer
processAudioWindow(audio_buffer);
// Add empty mfcc vectors at end of sample
for (int i = 0; i < model->n_context; ++i) {
addZeroMfccWindow();
}
// Process final batch
if (batch_buffer.size() > 0) {
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
}
finalizeStream();
return model->decode(accumulated_logits);
}
Metadata*
StreamingState::finishStreamWithMetadata()
{
finalizeStream();
return model->decode_metadata(accumulated_logits);
}
void
StreamingState::processAudioWindow(const vector<float>& buf)
{
@ -291,6 +315,23 @@ StreamingState::processAudioWindow(const vector<float>& buf)
free(mfcc);
}
void
StreamingState::finalizeStream()
{
// Flush audio buffer
processAudioWindow(audio_buffer);
// Add empty mfcc vectors at end of sample
for (int i = 0; i < model->n_context; ++i) {
addZeroMfccWindow();
}
// Process final batch
if (batch_buffer.size() > 0) {
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
}
}
void
StreamingState::addZeroMfccWindow()
{
@ -415,6 +456,14 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
char*
ModelState::decode(vector<float>& logits)
{
vector<Output> out = ModelState::decode_raw(logits);
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
}
vector<Output>
ModelState::decode_raw(vector<float>& logits)
{
const int cutoff_top_n = 40;
const double cutoff_prob = 1.0;
@ -429,7 +478,37 @@ ModelState::decode(vector<float>& logits)
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
cutoff_prob, cutoff_top_n, scorer);
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
return out;
}
Metadata* ModelState::decode_metadata(vector<float>& logits)
{
vector<Output> out = decode_raw(logits);
Metadata* metadata = (Metadata*)malloc(sizeof (Metadata));
metadata->num_items = out[0].tokens.size();
metadata->items = (MetadataItem*)malloc(sizeof(MetadataItem) * metadata->num_items);
// Loop through each character
for (int i = 0; i < out[0].tokens.size(); ++i) {
char* character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
// Note: 1 timestep = 20ms
float start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
MetadataItem item;
item.character = character;
item.timestep = out[0].timesteps[i];
item.start_time = start_time;
if (item.start_time < 0) {
item.start_time = 0;
}
metadata->items[i] = item;
}
return metadata;
}
#ifdef USE_TFLITE
@ -660,14 +739,36 @@ DS_SpeechToText(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStream(ctx);
}
Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
return DS_FinishStreamWithMetadata(ctx);
}
StreamingState*
setupStreamAndFeedAudioContent(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate)
{
StreamingState* ctx;
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
if (status != DS_ERR_OK) {
return nullptr;
}
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
return DS_FinishStream(ctx);
return ctx;
}
int
@ -735,6 +836,14 @@ DS_FinishStream(StreamingState* aSctx)
return str;
}
Metadata*
DS_FinishStreamWithMetadata(StreamingState* aSctx)
{
Metadata* metadata = aSctx->finishStreamWithMetadata();
DS_DiscardStream(aSctx);
return metadata;
}
void
DS_DiscardStream(StreamingState* aSctx)
{
@ -809,6 +918,13 @@ DS_AudioToInputVector(const short* aBuffer,
}
}
void
DS_FreeMetadata(Metadata* m)
{
free(m->items);
free(m);
}
void
DS_PrintVersions() {
std::cerr << "TensorFlow: " << tf_local_git_version() << std::endl;

View File

@ -15,6 +15,19 @@ struct ModelState;
struct StreamingState;
// Stores each individual character, along with its timing information
struct MetadataItem {
char* character;
int timestep; // Position of the character in units of 20ms
float start_time; // Position of the character in seconds
};
// Stores the entire CTC output as an array of character metadata objects
struct Metadata {
MetadataItem* items;
int num_items;
};
enum DeepSpeech_Error_Codes
{
// OK
@ -109,6 +122,25 @@ char* DS_SpeechToText(ModelState* aCtx,
unsigned int aBufferSize,
unsigned int aSampleRate);
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
* about the results.
*
* @param aCtx The ModelState pointer for the model to use.
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate.
* @param aBufferSize The number of samples in the audio signal.
* @param aSampleRate The sample-rate of the audio signal.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
*/
DEEPSPEECH_EXPORT
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aSampleRate);
/**
* @brief Create a new streaming inference state. The streaming state returned
* by this function can then be passed to {@link DS_FeedAudioContent()}
@ -170,6 +202,20 @@ char* DS_IntermediateDecode(StreamingState* aSctx);
DEEPSPEECH_EXPORT
char* DS_FinishStream(StreamingState* aSctx);
/**
* @brief Signal the end of an audio signal to an ongoing streaming
* inference, returns per-letter metadata.
*
* @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
*
* @note This method will free the state pointer (@p aSctx).
*/
DEEPSPEECH_EXPORT
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx);
/**
* @brief Destroy a streaming state without decoding the computed logits. This
* can be used if you no longer need the result of an ongoing streaming
@ -213,6 +259,13 @@ void DS_AudioToInputVector(const short* aBuffer,
int* aNFrames = NULL,
int* aFrameLen = NULL);
/**
* @brief Free memory allocated for metadata information.
*/
DEEPSPEECH_EXPORT
void DS_FreeMetadata(Metadata* m);
/**
* @brief Print version of this library and of the linked TensorFlow library.
*/