Improve API naming around Metadata objects
This commit is contained in:
parent
e1fec4e818
commit
69bd032605
@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
|
||||
}
|
||||
|
||||
std::vector<Output>
|
||||
DecoderState::decode() const
|
||||
DecoderState::decode(size_t num_results) const
|
||||
{
|
||||
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
||||
std::unordered_map<const PathTrie*, float> scores;
|
||||
@ -181,14 +181,12 @@ DecoderState::decode() const
|
||||
}
|
||||
|
||||
using namespace std::placeholders;
|
||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||
size_t num_returned = std::min(prefixes_copy.size(), num_results);
|
||||
std::partial_sort(prefixes_copy.begin(),
|
||||
prefixes_copy.begin() + num_prefixes,
|
||||
prefixes_copy.begin() + num_returned,
|
||||
prefixes_copy.end(),
|
||||
std::bind(prefix_compare_external, _1, _2, scores));
|
||||
|
||||
size_t num_returned = std::min(num_prefixes, beam_size_);
|
||||
|
||||
std::vector<Output> outputs;
|
||||
outputs.reserve(num_returned);
|
||||
|
||||
|
@ -61,12 +61,15 @@ public:
|
||||
int class_dim);
|
||||
|
||||
/* Get transcription from current decoder state
|
||||
*
|
||||
* Parameters:
|
||||
* num_results: Number of beams to return.
|
||||
*
|
||||
* Return:
|
||||
* A vector where each element is a pair of score and decoding result,
|
||||
* in descending order.
|
||||
*/
|
||||
std::vector<Output> decode() const;
|
||||
std::vector<Output> decode(size_t num_results=1) const;
|
||||
};
|
||||
|
||||
|
||||
|
@ -60,7 +60,7 @@ using std::vector;
|
||||
When batch_buffer is full, we do a single step through the acoustic model
|
||||
and accumulate the intermediate decoding state in the DecoderState structure.
|
||||
|
||||
When finishStream() is called, we return the corresponding transcription from
|
||||
When finishStream() is called, we return the corresponding transcript from
|
||||
the current decoder state.
|
||||
*/
|
||||
struct StreamingState {
|
||||
@ -80,7 +80,7 @@ struct StreamingState {
|
||||
char* intermediateDecode() const;
|
||||
void finalizeStream();
|
||||
char* finishStream();
|
||||
Result* finishStreamWithMetadata(unsigned int num_results);
|
||||
Metadata* finishStreamWithMetadata(unsigned int num_results);
|
||||
|
||||
void processAudioWindow(const vector<float>& buf);
|
||||
void processMfccWindow(const vector<float>& buf);
|
||||
@ -143,7 +143,7 @@ StreamingState::finishStream()
|
||||
return model_->decode(decoder_state_);
|
||||
}
|
||||
|
||||
Result*
|
||||
Metadata*
|
||||
StreamingState::finishStreamWithMetadata(unsigned int num_results)
|
||||
{
|
||||
finalizeStream();
|
||||
@ -411,11 +411,11 @@ DS_FinishStream(StreamingState* aSctx)
|
||||
return str;
|
||||
}
|
||||
|
||||
Result*
|
||||
Metadata*
|
||||
DS_FinishStreamWithMetadata(StreamingState* aSctx,
|
||||
unsigned int aNumResults)
|
||||
{
|
||||
Result* result = aSctx->finishStreamWithMetadata(aNumResults);
|
||||
Metadata* result = aSctx->finishStreamWithMetadata(aNumResults);
|
||||
DS_FreeStream(aSctx);
|
||||
return result;
|
||||
}
|
||||
@ -443,7 +443,7 @@ DS_SpeechToText(ModelState* aCtx,
|
||||
return DS_FinishStream(ctx);
|
||||
}
|
||||
|
||||
Result*
|
||||
Metadata*
|
||||
DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
@ -463,30 +463,16 @@ void
|
||||
DS_FreeMetadata(Metadata* m)
|
||||
{
|
||||
if (m) {
|
||||
for (int i = 0; i < m->num_items; ++i) {
|
||||
free(m->items[i].character);
|
||||
}
|
||||
delete[] m->items;
|
||||
delete m;
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
DS_FreeResult(Result* r)
|
||||
{
|
||||
if (r) {
|
||||
for (int i = 0; i < r->num_transcriptions; ++i) {
|
||||
Metadata* m = &r->transcriptions[i];
|
||||
|
||||
for (int j = 0; j < m->num_items; ++j) {
|
||||
free(m->items[j].character);
|
||||
for (int i = 0; i < m->num_transcripts; ++i) {
|
||||
for (int j = 0; j < m->transcripts[i].num_tokens; ++j) {
|
||||
free(m->transcripts[i].tokens[j].text);
|
||||
}
|
||||
|
||||
delete[] m->items;
|
||||
delete[] m->transcripts[i].tokens;
|
||||
}
|
||||
|
||||
delete[] r->transcriptions;
|
||||
delete r;
|
||||
delete[] m->transcripts;
|
||||
delete m;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,43 +20,44 @@ typedef struct ModelState ModelState;
|
||||
typedef struct StreamingState StreamingState;
|
||||
|
||||
/**
|
||||
* @brief Stores each individual character, along with its timing information
|
||||
* @brief Stores text of an individual token, along with its timing information
|
||||
*/
|
||||
typedef struct MetadataItem {
|
||||
/** The character generated for transcription */
|
||||
char* character;
|
||||
typedef struct TokenMetadata {
|
||||
/** The text corresponding to this token */
|
||||
char* text;
|
||||
|
||||
/** Position of the character in units of 20ms */
|
||||
/** Position of the token in units of 20ms */
|
||||
int timestep;
|
||||
|
||||
/** Position of the character in seconds */
|
||||
/** Position of the token in seconds */
|
||||
float start_time;
|
||||
} MetadataItem;
|
||||
} TokenMetadata;
|
||||
|
||||
/**
|
||||
* @brief Stores the entire CTC output as an array of character metadata objects
|
||||
* @brief A single transcript computed by the model, including a confidence
|
||||
* value and the metadata for its constituent tokens.
|
||||
*/
|
||||
typedef struct Metadata {
|
||||
/** List of items */
|
||||
MetadataItem* items;
|
||||
/** Size of the list of items */
|
||||
int num_items;
|
||||
typedef struct CandidateTranscript {
|
||||
/** Array of TokenMetadata objects */
|
||||
TokenMetadata* tokens;
|
||||
/** Size of the tokens array */
|
||||
int num_tokens;
|
||||
/** Approximated confidence value for this transcription. This is roughly the
|
||||
* sum of the acoustic model logit values for each timestep/character that
|
||||
* contributed to the creation of this transcription.
|
||||
*/
|
||||
double confidence;
|
||||
} Metadata;
|
||||
} CandidateTranscript;
|
||||
|
||||
/**
|
||||
* @brief Stores Metadata structs for each alternative transcription
|
||||
* @brief An array of CandidateTranscript objects computed by the model
|
||||
*/
|
||||
typedef struct Result {
|
||||
/** List of transcriptions */
|
||||
Metadata* transcriptions;
|
||||
/** Size of the list of transcriptions */
|
||||
int num_transcriptions;
|
||||
} Result;
|
||||
typedef struct Metadata {
|
||||
/** Array of CandidateTranscript objects */
|
||||
CandidateTranscript* transcripts;
|
||||
/** Size of the transcriptions array */
|
||||
int num_transcripts;
|
||||
} Metadata;
|
||||
|
||||
enum DeepSpeech_Error_Codes
|
||||
{
|
||||
@ -197,16 +198,16 @@ char* DS_SpeechToText(ModelState* aCtx,
|
||||
* @param aBuffer A 16-bit, mono raw audio signal at the appropriate
|
||||
* sample rate (matching what the model was trained on).
|
||||
* @param aBufferSize The number of samples in the audio signal.
|
||||
* @param aNumResults The number of alternative transcriptions to return.
|
||||
* @param aNumResults The number of candidate transcripts to return.
|
||||
*
|
||||
* @return Outputs a struct of individual letters along with their timing information.
|
||||
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
Result* DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aNumResults);
|
||||
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||
const short* aBuffer,
|
||||
unsigned int aBufferSize,
|
||||
unsigned int aNumResults);
|
||||
|
||||
/**
|
||||
* @brief Create a new streaming inference state. The streaming state returned
|
||||
@ -266,7 +267,7 @@ char* DS_FinishStream(StreamingState* aSctx);
|
||||
* inference, returns per-letter metadata.
|
||||
*
|
||||
* @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}.
|
||||
* @param aNumResults The number of alternative transcriptions to return.
|
||||
* @param aNumResults The number of candidate transcripts to return.
|
||||
*
|
||||
* @return Outputs a struct of individual letters along with their timing information.
|
||||
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
|
||||
@ -274,8 +275,8 @@ char* DS_FinishStream(StreamingState* aSctx);
|
||||
* @note This method will free the state pointer (@p aSctx).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
Result* DS_FinishStreamWithMetadata(StreamingState* aSctx,
|
||||
unsigned int aNumResults);
|
||||
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx,
|
||||
unsigned int aNumResults);
|
||||
|
||||
/**
|
||||
* @brief Destroy a streaming state without decoding the computed logits. This
|
||||
@ -295,12 +296,6 @@ void DS_FreeStream(StreamingState* aSctx);
|
||||
DEEPSPEECH_EXPORT
|
||||
void DS_FreeMetadata(Metadata* m);
|
||||
|
||||
/**
|
||||
* @brief Free memory allocated for result information.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
void DS_FreeResult(Result* r);
|
||||
|
||||
/**
|
||||
* @brief Free a char* string returned by the DeepSpeech API.
|
||||
*/
|
||||
|
@ -36,41 +36,38 @@ ModelState::decode(const DecoderState& state) const
|
||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Result*
|
||||
Metadata*
|
||||
ModelState::decode_metadata(const DecoderState& state,
|
||||
size_t num_results)
|
||||
{
|
||||
vector<Output> out = state.decode();
|
||||
vector<Output> out = state.decode(num_results);
|
||||
size_t num_returned = out.size();
|
||||
|
||||
size_t max_results = std::min(num_results, out.size());
|
||||
std::unique_ptr<Metadata> metadata(new Metadata);
|
||||
metadata->num_transcripts = num_returned;
|
||||
|
||||
std::unique_ptr<Result> result(new Result());
|
||||
result->num_transcriptions = max_results;
|
||||
std::unique_ptr<CandidateTranscript[]> transcripts(new CandidateTranscript[num_returned]);
|
||||
|
||||
std::unique_ptr<Metadata[]> transcripts(new Metadata[max_results]());
|
||||
for (int i = 0; i < num_returned; ++i) {
|
||||
transcripts[i].num_tokens = out[i].tokens.size();
|
||||
transcripts[i].confidence = out[i].confidence;
|
||||
|
||||
for (int j = 0; j < max_results; ++j) {
|
||||
Metadata* metadata = &transcripts[j];
|
||||
metadata->num_items = out[j].tokens.size();
|
||||
metadata->confidence = out[j].confidence;
|
||||
std::unique_ptr<TokenMetadata[]> tokens(new TokenMetadata[transcripts[i].num_tokens]);
|
||||
|
||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||
// Loop through each token
|
||||
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
||||
tokens[j].text = strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str());
|
||||
tokens[j].timestep = out[i].timesteps[j];
|
||||
tokens[j].start_time = out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_);
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[j].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet_.StringFromLabel(out[j].tokens[i]).c_str());
|
||||
items[i].timestep = out[j].timesteps[i];
|
||||
items[i].start_time = out[j].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||
|
||||
if (items[i].start_time < 0) {
|
||||
items[i].start_time = 0;
|
||||
if (tokens[j].start_time < 0) {
|
||||
tokens[j].start_time = 0;
|
||||
}
|
||||
}
|
||||
|
||||
metadata->items = items.release();
|
||||
transcripts[i].tokens = tokens.release();
|
||||
}
|
||||
|
||||
result->transcriptions = transcripts.release();
|
||||
|
||||
return result.release();
|
||||
metadata->transcripts = transcripts.release();
|
||||
return metadata.release();
|
||||
}
|
||||
|
@ -66,14 +66,14 @@ struct ModelState {
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
* @param num_results Number of alternate results to return.
|
||||
* @param num_results Number of candidate results to return.
|
||||
*
|
||||
* @return A Result struct containing Metadata structs.
|
||||
* Each represents an alternate transcription, with the first ranked most probable.
|
||||
* The user is responsible for freeing Result by calling DS_FreeResult().
|
||||
* @return A Metadata struct containing CandidateTranscript structs.
|
||||
* Each represents an candidate transcript, with the first ranked most probable.
|
||||
* The user is responsible for freeing Result by calling DS_FreeMetadata().
|
||||
*/
|
||||
virtual Result* decode_metadata(const DecoderState& state,
|
||||
size_t num_results);
|
||||
virtual Metadata* decode_metadata(const DecoderState& state,
|
||||
size_t num_results);
|
||||
};
|
||||
|
||||
#endif // MODELSTATE_H
|
||||
|
Loading…
Reference in New Issue
Block a user