Improve API naming around Metadata objects

This commit is contained in:
Reuben Morais 2020-02-25 12:29:18 +01:00
parent e1fec4e818
commit 69bd032605
6 changed files with 75 additions and 96 deletions

View File

@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
}
std::vector<Output>
DecoderState::decode() const
DecoderState::decode(size_t num_results) const
{
std::vector<PathTrie*> prefixes_copy = prefixes_;
std::unordered_map<const PathTrie*, float> scores;
@ -181,14 +181,12 @@ DecoderState::decode() const
}
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
size_t num_returned = std::min(prefixes_copy.size(), num_results);
std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes,
prefixes_copy.begin() + num_returned,
prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores));
size_t num_returned = std::min(num_prefixes, beam_size_);
std::vector<Output> outputs;
outputs.reserve(num_returned);

View File

@ -61,12 +61,15 @@ public:
int class_dim);
/* Get transcription from current decoder state
*
* Parameters:
* num_results: Number of beams to return.
*
* Return:
* A vector where each element is a pair of score and decoding result,
* in descending order.
*/
std::vector<Output> decode() const;
std::vector<Output> decode(size_t num_results=1) const;
};

View File

@ -60,7 +60,7 @@ using std::vector;
When batch_buffer is full, we do a single step through the acoustic model
and accumulate the intermediate decoding state in the DecoderState structure.
When finishStream() is called, we return the corresponding transcription from
When finishStream() is called, we return the corresponding transcript from
the current decoder state.
*/
struct StreamingState {
@ -80,7 +80,7 @@ struct StreamingState {
char* intermediateDecode() const;
void finalizeStream();
char* finishStream();
Result* finishStreamWithMetadata(unsigned int num_results);
Metadata* finishStreamWithMetadata(unsigned int num_results);
void processAudioWindow(const vector<float>& buf);
void processMfccWindow(const vector<float>& buf);
@ -143,7 +143,7 @@ StreamingState::finishStream()
return model_->decode(decoder_state_);
}
Result*
Metadata*
StreamingState::finishStreamWithMetadata(unsigned int num_results)
{
finalizeStream();
@ -411,11 +411,11 @@ DS_FinishStream(StreamingState* aSctx)
return str;
}
Result*
Metadata*
DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults)
{
Result* result = aSctx->finishStreamWithMetadata(aNumResults);
Metadata* result = aSctx->finishStreamWithMetadata(aNumResults);
DS_FreeStream(aSctx);
return result;
}
@ -443,7 +443,7 @@ DS_SpeechToText(ModelState* aCtx,
return DS_FinishStream(ctx);
}
Result*
Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
@ -463,30 +463,16 @@ void
DS_FreeMetadata(Metadata* m)
{
if (m) {
for (int i = 0; i < m->num_items; ++i) {
free(m->items[i].character);
}
delete[] m->items;
delete m;
}
}
void
DS_FreeResult(Result* r)
{
if (r) {
for (int i = 0; i < r->num_transcriptions; ++i) {
Metadata* m = &r->transcriptions[i];
for (int j = 0; j < m->num_items; ++j) {
free(m->items[j].character);
for (int i = 0; i < m->num_transcripts; ++i) {
for (int j = 0; j < m->transcripts[i].num_tokens; ++j) {
free(m->transcripts[i].tokens[j].text);
}
delete[] m->items;
delete[] m->transcripts[i].tokens;
}
delete[] r->transcriptions;
delete r;
delete[] m->transcripts;
delete m;
}
}

View File

@ -20,43 +20,44 @@ typedef struct ModelState ModelState;
typedef struct StreamingState StreamingState;
/**
* @brief Stores each individual character, along with its timing information
* @brief Stores text of an individual token, along with its timing information
*/
typedef struct MetadataItem {
/** The character generated for transcription */
char* character;
typedef struct TokenMetadata {
/** The text corresponding to this token */
char* text;
/** Position of the character in units of 20ms */
/** Position of the token in units of 20ms */
int timestep;
/** Position of the character in seconds */
/** Position of the token in seconds */
float start_time;
} MetadataItem;
} TokenMetadata;
/**
* @brief Stores the entire CTC output as an array of character metadata objects
* @brief A single transcript computed by the model, including a confidence
* value and the metadata for its constituent tokens.
*/
typedef struct Metadata {
/** List of items */
MetadataItem* items;
/** Size of the list of items */
int num_items;
typedef struct CandidateTranscript {
/** Array of TokenMetadata objects */
TokenMetadata* tokens;
/** Size of the tokens array */
int num_tokens;
/** Approximated confidence value for this transcription. This is roughly the
* sum of the acoustic model logit values for each timestep/character that
* contributed to the creation of this transcription.
*/
double confidence;
} Metadata;
} CandidateTranscript;
/**
* @brief Stores Metadata structs for each alternative transcription
* @brief An array of CandidateTranscript objects computed by the model
*/
typedef struct Result {
/** List of transcriptions */
Metadata* transcriptions;
/** Size of the list of transcriptions */
int num_transcriptions;
} Result;
typedef struct Metadata {
/** Array of CandidateTranscript objects */
CandidateTranscript* transcripts;
/** Size of the transcriptions array */
int num_transcripts;
} Metadata;
enum DeepSpeech_Error_Codes
{
@ -197,16 +198,16 @@ char* DS_SpeechToText(ModelState* aCtx,
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on).
* @param aBufferSize The number of samples in the audio signal.
* @param aNumResults The number of alternative transcriptions to return.
* @param aNumResults The number of candidate transcripts to return.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
*/
DEEPSPEECH_EXPORT
Result* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults);
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int aNumResults);
/**
* @brief Create a new streaming inference state. The streaming state returned
@ -266,7 +267,7 @@ char* DS_FinishStream(StreamingState* aSctx);
* inference, returns per-letter metadata.
*
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
* @param aNumResults The number of alternative transcriptions to return.
* @param aNumResults The number of candidate transcripts to return.
*
* @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
@ -274,8 +275,8 @@ char* DS_FinishStream(StreamingState* aSctx);
* @note This method will free the state pointer (@p aSctx).
*/
DEEPSPEECH_EXPORT
Result* DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults);
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults);
/**
* @brief Destroy a streaming state without decoding the computed logits. This
@ -295,12 +296,6 @@ void DS_FreeStream(StreamingState* aSctx);
DEEPSPEECH_EXPORT
void DS_FreeMetadata(Metadata* m);
/**
* @brief Free memory allocated for result information.
*/
DEEPSPEECH_EXPORT
void DS_FreeResult(Result* r);
/**
* @brief Free a char* string returned by the DeepSpeech API.
*/

View File

@ -36,41 +36,38 @@ ModelState::decode(const DecoderState& state) const
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
}
Result*
Metadata*
ModelState::decode_metadata(const DecoderState& state,
size_t num_results)
{
vector<Output> out = state.decode();
vector<Output> out = state.decode(num_results);
size_t num_returned = out.size();
size_t max_results = std::min(num_results, out.size());
std::unique_ptr<Metadata> metadata(new Metadata);
metadata->num_transcripts = num_returned;
std::unique_ptr<Result> result(new Result());
result->num_transcriptions = max_results;
std::unique_ptr<CandidateTranscript[]> transcripts(new CandidateTranscript[num_returned]);
std::unique_ptr<Metadata[]> transcripts(new Metadata[max_results]());
for (int i = 0; i < num_returned; ++i) {
transcripts[i].num_tokens = out[i].tokens.size();
transcripts[i].confidence = out[i].confidence;
for (int j = 0; j < max_results; ++j) {
Metadata* metadata = &transcripts[j];
metadata->num_items = out[j].tokens.size();
metadata->confidence = out[j].confidence;
std::unique_ptr<TokenMetadata[]> tokens(new TokenMetadata[transcripts[i].num_tokens]);
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
// Loop through each token
for (int j = 0; j < out[i].tokens.size(); ++j) {
tokens[j].text = strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str());
tokens[j].timestep = out[i].timesteps[j];
tokens[j].start_time = out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_);
// Loop through each character
for (int i = 0; i < out[j].tokens.size(); ++i) {
items[i].character = strdup(alphabet_.StringFromLabel(out[j].tokens[i]).c_str());
items[i].timestep = out[j].timesteps[i];
items[i].start_time = out[j].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
if (items[i].start_time < 0) {
items[i].start_time = 0;
if (tokens[j].start_time < 0) {
tokens[j].start_time = 0;
}
}
metadata->items = items.release();
transcripts[i].tokens = tokens.release();
}
result->transcriptions = transcripts.release();
return result.release();
metadata->transcripts = transcripts.release();
return metadata.release();
}

View File

@ -66,14 +66,14 @@ struct ModelState {
* @brief Return character-level metadata including letter timings.
*
* @param state Decoder state to use when decoding.
* @param num_results Number of alternate results to return.
* @param num_results Number of candidate results to return.
*
* @return A Result struct containing Metadata structs.
* Each represents an alternate transcription, with the first ranked most probable.
* The user is responsible for freeing Result by calling DS_FreeResult().
* @return A Metadata struct containing CandidateTranscript structs.
* Each represents an candidate transcript, with the first ranked most probable.
* The user is responsible for freeing Result by calling DS_FreeMetadata().
*/
virtual Result* decode_metadata(const DecoderState& state,
size_t num_results);
virtual Metadata* decode_metadata(const DecoderState& state,
size_t num_results);
};
#endif // MODELSTATE_H