Expose letter timings on the API
This commit is contained in:
parent
730ef1b5c8
commit
192e17f2d5
@ -115,7 +115,9 @@ struct StreamingState {
|
||||
|
||||
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
||||
char* intermediateDecode();
|
||||
void finalizeStream();
|
||||
char* finishStream();
|
||||
Metadata* finishStreamWithMetadata();
|
||||
|
||||
void processAudioWindow(const vector<float>& buf);
|
||||
void processMfccWindow(const vector<float>& buf);
|
||||
@ -170,6 +172,28 @@ struct ModelState {
|
||||
*/
|
||||
char* decode(vector<float>& logits);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param logits Flat matrix of logits, of size:
|
||||
* n_frames * batch_size * num_classes
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
vector<Output> decode_raw(vector<float>& logits);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
* @param logits Flat matrix of logits, of size:
|
||||
* n_frames * batch_size * num_classes
|
||||
*
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata and Metadata.items.
|
||||
*/
|
||||
Metadata* decode_metadata(vector<float>& logits);
|
||||
|
||||
/**
|
||||
* @brief Do a single inference step in the acoustic model, with:
|
||||
* input=mfcc
|
||||
@ -183,6 +207,9 @@ struct ModelState {
|
||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& output_logits);
|
||||
};
|
||||
|
||||
StreamingState* setupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
|
||||
unsigned int aBufferSize, unsigned int aSampleRate);
|
||||
|
||||
ModelState::ModelState()
|
||||
:
|
||||
#ifndef USE_TFLITE
|
||||
@ -260,22 +287,19 @@ StreamingState::intermediateDecode()
|
||||
char*
|
||||
StreamingState::finishStream()
|
||||
{
|
||||
// Flush audio buffer
|
||||
processAudioWindow(audio_buffer);
|
||||
|
||||
// Add empty mfcc vectors at end of sample
|
||||
for (int i = 0; i < model->n_context; ++i) {
|
||||
addZeroMfccWindow();
|
||||
}
|
||||
|
||||
// Process final batch
|
||||
if (batch_buffer.size() > 0) {
|
||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
||||
}
|
||||
finalizeStream();
|
||||
|
||||
return model->decode(accumulated_logits);
|
||||
}
|
||||
|
||||
Metadata*
|
||||
StreamingState::finishStreamWithMetadata()
|
||||
{
|
||||
finalizeStream();
|
||||
|
||||
return model->decode_metadata(accumulated_logits);
|
||||
}
|
||||
|
||||
void
|
||||
StreamingState::processAudioWindow(const vector<float>& buf)
|
||||
{
|
||||
@ -291,6 +315,23 @@ StreamingState::processAudioWindow(const vector<float>& buf)
|
||||
free(mfcc);
|
||||
}
|
||||
|
||||
void
|
||||
StreamingState::finalizeStream()
|
||||
{
|
||||
// Flush audio buffer
|
||||
processAudioWindow(audio_buffer);
|
||||
|
||||
// Add empty mfcc vectors at end of sample
|
||||
for (int i = 0; i < model->n_context; ++i) {
|
||||
addZeroMfccWindow();
|
||||
}
|
||||
|
||||
// Process final batch
|
||||
if (batch_buffer.size() > 0) {
|
||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamingState::addZeroMfccWindow()
|
||||
{
|
||||
@ -415,6 +456,14 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
||||
|
||||
char*
|
||||
ModelState::decode(vector<float>& logits)
|
||||
{
|
||||
vector<Output> out = ModelState::decode_raw(logits);
|
||||
|
||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(vector<float>& logits)
|
||||
{
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
@ -429,7 +478,37 @@ ModelState::decode(vector<float>& logits)
|
||||
inputs.data(), n_frames, num_classes, *alphabet, beam_width,
|
||||
cutoff_prob, cutoff_top_n, scorer);
|
||||
|
||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||
return out;
|
||||
}
|
||||
|
||||
Metadata* ModelState::decode_metadata(vector<float>& logits)
|
||||
{
|
||||
vector<Output> out = decode_raw(logits);
|
||||
|
||||
Metadata* metadata = (Metadata*)malloc(sizeof (Metadata));
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
metadata->items = (MetadataItem*)malloc(sizeof(MetadataItem) * metadata->num_items);
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
char* character = (char*)alphabet->StringFromLabel(out[0].tokens[i]).c_str();
|
||||
|
||||
// Note: 1 timestep = 20ms
|
||||
float start_time = static_cast<float>(out[0].timesteps[i] * AUDIO_WIN_STEP);
|
||||
|
||||
MetadataItem item;
|
||||
item.character = character;
|
||||
item.timestep = out[0].timesteps[i];
|
||||
item.start_time = start_time;
|
||||
|
||||
if (item.start_time < 0) {
|
||||
item.start_time = 0;
|
||||
}
|
||||
|
||||
metadata->items[i] = item;
|
||||
}
|
||||
|
||||
return metadata;
|
||||
}
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
@ -660,14 +739,36 @@ DS_SpeechToText(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aSampleRate)
|
||||
{
|
||||
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||
return DS_FinishStream(ctx);
|
||||
}
|
||||
|
||||
Metadata*
|
||||
DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aSampleRate)
|
||||
{
|
||||
StreamingState* ctx = setupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||
return DS_FinishStreamWithMetadata(ctx);
|
||||
}
|
||||
|
||||
StreamingState*
|
||||
setupStreamAndFeedAudioContent(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aSampleRate)
|
||||
{
|
||||
StreamingState* ctx;
|
||||
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
||||
if (status != DS_ERR_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
||||
return DS_FinishStream(ctx);
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
int
|
||||
@ -735,6 +836,14 @@ DS_FinishStream(StreamingState* aSctx)
|
||||
return str;
|
||||
}
|
||||
|
||||
Metadata*
|
||||
DS_FinishStreamWithMetadata(StreamingState* aSctx)
|
||||
{
|
||||
Metadata* metadata = aSctx->finishStreamWithMetadata();
|
||||
DS_DiscardStream(aSctx);
|
||||
return metadata;
|
||||
}
|
||||
|
||||
void
|
||||
DS_DiscardStream(StreamingState* aSctx)
|
||||
{
|
||||
@ -809,6 +918,13 @@ DS_AudioToInputVector(const short* aBuffer,
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
DS_FreeMetadata(Metadata* m)
|
||||
{
|
||||
free(m->items);
|
||||
free(m);
|
||||
}
|
||||
|
||||
void
|
||||
DS_PrintVersions() {
|
||||
std::cerr << "TensorFlow: " << tf_local_git_version() << std::endl;
|
||||
|
@ -15,6 +15,19 @@ struct ModelState;
|
||||
|
||||
struct StreamingState;
|
||||
|
||||
// Stores each individual character, along with its timing information
|
||||
struct MetadataItem {
|
||||
char* character;
|
||||
int timestep; // Position of the character in units of 20ms
|
||||
float start_time; // Position of the character in seconds
|
||||
};
|
||||
|
||||
// Stores the entire CTC output as an array of character metadata objects
|
||||
struct Metadata {
|
||||
MetadataItem* items;
|
||||
int num_items;
|
||||
};
|
||||
|
||||
enum DeepSpeech_Error_Codes
|
||||
{
|
||||
// OK
|
||||
@ -109,6 +122,25 @@ char* DS_SpeechToText(ModelState* aCtx,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aSampleRate);
|
||||
|
||||
/**
|
||||
* @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata
|
||||
* about the results.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model to use.
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
|
||||
* sample rate.
|
||||
* @param aBufferSize The number of samples in the audio signal.
|
||||
* @param aSampleRate The sample-rate of the audio signal.
|
||||
*
|
||||
* @return Outputs a struct of individual letters along with their timing information.
|
||||
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aSampleRate);
|
||||
|
||||
/**
|
||||
* @brief Create a new streaming inference state. The streaming state returned
|
||||
* by this function can then be passed to {@link DS_FeedAudioContent()}
|
||||
@ -170,6 +202,20 @@ char* DS_IntermediateDecode(StreamingState* aSctx);
|
||||
DEEPSPEECH_EXPORT
|
||||
char* DS_FinishStream(StreamingState* aSctx);
|
||||
|
||||
/**
|
||||
* @brief Signal the end of an audio signal to an ongoing streaming
|
||||
* inference, returns per-letter metadata.
|
||||
*
|
||||
* @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}.
|
||||
*
|
||||
* @return Outputs a struct of individual letters along with their timing information.
|
||||
* The user is responsible for freeing Metadata and Metadata.items. Returns NULL on error.
|
||||
*
|
||||
* @note This method will free the state pointer (@p aSctx).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx);
|
||||
|
||||
/**
|
||||
* @brief Destroy a streaming state without decoding the computed logits. This
|
||||
* can be used if you no longer need the result of an ongoing streaming
|
||||
@ -213,6 +259,13 @@ void DS_AudioToInputVector(const short* aBuffer,
|
||||
int* aNFrames = NULL,
|
||||
int* aFrameLen = NULL);
|
||||
|
||||
/**
|
||||
* @brief Free memory allocated for metadata information.
|
||||
*/
|
||||
|
||||
DEEPSPEECH_EXPORT
|
||||
void DS_FreeMetadata(Metadata* m);
|
||||
|
||||
/**
|
||||
* @brief Print version of this library and of the linked TensorFlow library.
|
||||
*/
|
||||
|
Loading…
x
Reference in New Issue
Block a user