Improve API naming around Metadata objects

This commit is contained in:
Reuben Morais 2020-02-25 12:29:18 +01:00
parent e1fec4e818
commit 69bd032605
6 changed files with 75 additions and 96 deletions

View File

@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
} }
std::vector<Output> std::vector<Output>
DecoderState::decode() const DecoderState::decode(size_t num_results) const
{ {
std::vector<PathTrie*> prefixes_copy = prefixes_; std::vector<PathTrie*> prefixes_copy = prefixes_;
std::unordered_map<const PathTrie*, float> scores; std::unordered_map<const PathTrie*, float> scores;
@ -181,14 +181,12 @@ DecoderState::decode() const
} }
using namespace std::placeholders; using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); size_t num_returned = std::min(prefixes_copy.size(), num_results);
std::partial_sort(prefixes_copy.begin(), std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes, prefixes_copy.begin() + num_returned,
prefixes_copy.end(), prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores)); std::bind(prefix_compare_external, _1, _2, scores));
size_t num_returned = std::min(num_prefixes, beam_size_);
std::vector<Output> outputs; std::vector<Output> outputs;
outputs.reserve(num_returned); outputs.reserve(num_returned);

View File

@ -61,12 +61,15 @@ public:
int class_dim); int class_dim);
/* Get transcription from current decoder state /* Get transcription from current decoder state
*
* Parameters:
* num_results: Number of beams to return.
* *
* Return: * Return:
* A vector where each element is a pair of score and decoding result, * A vector where each element is a pair of score and decoding result,
* in descending order. * in descending order.
*/ */
std::vector<Output> decode() const; std::vector<Output> decode(size_t num_results=1) const;
}; };

View File

@ -60,7 +60,7 @@ using std::vector;
When batch_buffer is full, we do a single step through the acoustic model When batch_buffer is full, we do a single step through the acoustic model
and accumulate the intermediate decoding state in the DecoderState structure. and accumulate the intermediate decoding state in the DecoderState structure.
When finishStream() is called, we return the corresponding transcription from When finishStream() is called, we return the corresponding transcript from
the current decoder state. the current decoder state.
*/ */
struct StreamingState { struct StreamingState {
@ -80,7 +80,7 @@ struct StreamingState {
char* intermediateDecode() const; char* intermediateDecode() const;
void finalizeStream(); void finalizeStream();
char* finishStream(); char* finishStream();
Result* finishStreamWithMetadata(unsigned int num_results); Metadata* finishStreamWithMetadata(unsigned int num_results);
void processAudioWindow(const vector<float>& buf); void processAudioWindow(const vector<float>& buf);
void processMfccWindow(const vector<float>& buf); void processMfccWindow(const vector<float>& buf);
@ -143,7 +143,7 @@ StreamingState::finishStream()
return model_->decode(decoder_state_); return model_->decode(decoder_state_);
} }
Result* Metadata*
StreamingState::finishStreamWithMetadata(unsigned int num_results) StreamingState::finishStreamWithMetadata(unsigned int num_results)
{ {
finalizeStream(); finalizeStream();
@ -411,11 +411,11 @@ DS_FinishStream(StreamingState* aSctx)
return str; return str;
} }
Result* Metadata*
DS_FinishStreamWithMetadata(StreamingState* aSctx, DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults) unsigned int aNumResults)
{ {
Result* result = aSctx->finishStreamWithMetadata(aNumResults); Metadata* result = aSctx->finishStreamWithMetadata(aNumResults);
DS_FreeStream(aSctx); DS_FreeStream(aSctx);
return result; return result;
} }
@ -443,7 +443,7 @@ DS_SpeechToText(ModelState* aCtx,
return DS_FinishStream(ctx); return DS_FinishStream(ctx);
} }
Result* Metadata*
DS_SpeechToTextWithMetadata(ModelState* aCtx, DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer, const short* aBuffer,
unsigned int aBufferSize, unsigned int aBufferSize,
@ -463,30 +463,16 @@ void
DS_FreeMetadata(Metadata* m) DS_FreeMetadata(Metadata* m)
{ {
if (m) { if (m) {
for (int i = 0; i < m->num_items; ++i) { for (int i = 0; i < m->num_transcripts; ++i) {
free(m->items[i].character); for (int j = 0; j < m->transcripts[i].num_tokens; ++j) {
} free(m->transcripts[i].tokens[j].text);
delete[] m->items;
delete m;
}
}
void
DS_FreeResult(Result* r)
{
if (r) {
for (int i = 0; i < r->num_transcriptions; ++i) {
Metadata* m = &r->transcriptions[i];
for (int j = 0; j < m->num_items; ++j) {
free(m->items[j].character);
} }
delete[] m->items; delete[] m->transcripts[i].tokens;
} }
delete[] r->transcriptions; delete[] m->transcripts;
delete r; delete m;
} }
} }

View File

@ -20,43 +20,44 @@ typedef struct ModelState ModelState;
typedef struct StreamingState StreamingState; typedef struct StreamingState StreamingState;
/** /**
* @brief Stores each individual character, along with its timing information * @brief Stores text of an individual token, along with its timing information
*/ */
typedef struct MetadataItem { typedef struct TokenMetadata {
/** The character generated for transcription */ /** The text corresponding to this token */
char* character; char* text;
/** Position of the character in units of 20ms */ /** Position of the token in units of 20ms */
int timestep; int timestep;
/** Position of the character in seconds */ /** Position of the token in seconds */
float start_time; float start_time;
} MetadataItem; } TokenMetadata;
/** /**
* @brief Stores the entire CTC output as an array of character metadata objects * @brief A single transcript computed by the model, including a confidence
* value and the metadata for its constituent tokens.
*/ */
typedef struct Metadata { typedef struct CandidateTranscript {
/** List of items */ /** Array of TokenMetadata objects */
MetadataItem* items; TokenMetadata* tokens;
/** Size of the list of items */ /** Size of the tokens array */
int num_items; int num_tokens;
/** Approximated confidence value for this transcription. This is roughly the /** Approximated confidence value for this transcription. This is roughly the
* sum of the acoustic model logit values for each timestep/character that * sum of the acoustic model logit values for each timestep/character that
* contributed to the creation of this transcription. * contributed to the creation of this transcription.
*/ */
double confidence; double confidence;
} Metadata; } CandidateTranscript;
/** /**
* @brief Stores Metadata structs for each alternative transcription * @brief An array of CandidateTranscript objects computed by the model
*/ */
typedef struct Result { typedef struct Metadata {
/** List of transcriptions */ /** Array of CandidateTranscript objects */
Metadata* transcriptions; CandidateTranscript* transcripts;
/** Size of the list of transcriptions */ /** Size of the transcriptions array */
int num_transcriptions; int num_transcripts;
} Result; } Metadata;
enum DeepSpeech_Error_Codes enum DeepSpeech_Error_Codes
{ {
@ -197,16 +198,16 @@ char* DS_SpeechToText(ModelState* aCtx,
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate * @param aBuffer A 16-bit, mono raw audio signal at the appropriate
* sample rate (matching what the model was trained on). * sample rate (matching what the model was trained on).
* @param aBufferSize The number of samples in the audio signal. * @param aBufferSize The number of samples in the audio signal.
* @param aNumResults The number of alternative transcriptions to return. * @param aNumResults The number of candidate transcripts to return.
* *
* @return Outputs a struct of individual letters along with their timing information. * @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
*/ */
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
Result* DS_SpeechToTextWithMetadata(ModelState* aCtx, Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer, const short* aBuffer,
unsigned int aBufferSize, unsigned int aBufferSize,
unsigned int aNumResults); unsigned int aNumResults);
/** /**
* @brief Create a new streaming inference state. The streaming state returned * @brief Create a new streaming inference state. The streaming state returned
@ -266,7 +267,7 @@ char* DS_FinishStream(StreamingState* aSctx);
* inference, returns per-letter metadata. * inference, returns per-letter metadata.
* *
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
* @param aNumResults The number of alternative transcriptions to return. * @param aNumResults The number of candidate transcripts to return.
* *
* @return Outputs a struct of individual letters along with their timing information. * @return Outputs a struct of individual letters along with their timing information.
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
@ -274,8 +275,8 @@ char* DS_FinishStream(StreamingState* aSctx);
* @note This method will free the state pointer (@p aSctx). * @note This method will free the state pointer (@p aSctx).
*/ */
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
Result* DS_FinishStreamWithMetadata(StreamingState* aSctx, Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int aNumResults); unsigned int aNumResults);
/** /**
* @brief Destroy a streaming state without decoding the computed logits. This * @brief Destroy a streaming state without decoding the computed logits. This
@ -295,12 +296,6 @@ void DS_FreeStream(StreamingState* aSctx);
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
void DS_FreeMetadata(Metadata* m); void DS_FreeMetadata(Metadata* m);
/**
* @brief Free memory allocated for result information.
*/
DEEPSPEECH_EXPORT
void DS_FreeResult(Result* r);
/** /**
* @brief Free a char* string returned by the DeepSpeech API. * @brief Free a char* string returned by the DeepSpeech API.
*/ */

View File

@ -36,41 +36,38 @@ ModelState::decode(const DecoderState& state) const
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str()); return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
} }
Result* Metadata*
ModelState::decode_metadata(const DecoderState& state, ModelState::decode_metadata(const DecoderState& state,
size_t num_results) size_t num_results)
{ {
vector<Output> out = state.decode(); vector<Output> out = state.decode(num_results);
size_t num_returned = out.size();
size_t max_results = std::min(num_results, out.size()); std::unique_ptr<Metadata> metadata(new Metadata);
metadata->num_transcripts = num_returned;
std::unique_ptr<Result> result(new Result()); std::unique_ptr<CandidateTranscript[]> transcripts(new CandidateTranscript[num_returned]);
result->num_transcriptions = max_results;
std::unique_ptr<Metadata[]> transcripts(new Metadata[max_results]()); for (int i = 0; i < num_returned; ++i) {
transcripts[i].num_tokens = out[i].tokens.size();
transcripts[i].confidence = out[i].confidence;
for (int j = 0; j < max_results; ++j) { std::unique_ptr<TokenMetadata[]> tokens(new TokenMetadata[transcripts[i].num_tokens]);
Metadata* metadata = &transcripts[j];
metadata->num_items = out[j].tokens.size();
metadata->confidence = out[j].confidence;
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]()); // Loop through each token
for (int j = 0; j < out[i].tokens.size(); ++j) {
tokens[j].text = strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str());
tokens[j].timestep = out[i].timesteps[j];
tokens[j].start_time = out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_);
// Loop through each character if (tokens[j].start_time < 0) {
for (int i = 0; i < out[j].tokens.size(); ++i) { tokens[j].start_time = 0;
items[i].character = strdup(alphabet_.StringFromLabel(out[j].tokens[i]).c_str());
items[i].timestep = out[j].timesteps[i];
items[i].start_time = out[j].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
if (items[i].start_time < 0) {
items[i].start_time = 0;
} }
} }
metadata->items = items.release(); transcripts[i].tokens = tokens.release();
} }
result->transcriptions = transcripts.release(); metadata->transcripts = transcripts.release();
return metadata.release();
return result.release();
} }

View File

@ -66,14 +66,14 @@ struct ModelState {
* @brief Return character-level metadata including letter timings. * @brief Return character-level metadata including letter timings.
* *
* @param state Decoder state to use when decoding. * @param state Decoder state to use when decoding.
* @param num_results Number of alternate results to return. * @param num_results Number of candidate results to return.
* *
* @return A Result struct containing Metadata structs. * @return A Metadata struct containing CandidateTranscript structs.
* Each represents an alternate transcription, with the first ranked most probable. * Each represents an candidate transcript, with the first ranked most probable.
* The user is responsible for freeing Result by calling DS_FreeResult(). * The user is responsible for freeing Result by calling DS_FreeMetadata().
*/ */
virtual Result* decode_metadata(const DecoderState& state, virtual Metadata* decode_metadata(const DecoderState& state,
size_t num_results); size_t num_results);
}; };
#endif // MODELSTATE_H #endif // MODELSTATE_H