diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 3039d47c..8a072c53 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -157,7 +157,7 @@ DecoderState::next(const double *probs, } std::vector -DecoderState::decode() const +DecoderState::decode(size_t num_results) const { std::vector prefixes_copy = prefixes_; std::unordered_map scores; @@ -181,14 +181,12 @@ DecoderState::decode() const } using namespace std::placeholders; - size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); + size_t num_returned = std::min(prefixes_copy.size(), num_results); std::partial_sort(prefixes_copy.begin(), - prefixes_copy.begin() + num_prefixes, + prefixes_copy.begin() + num_returned, prefixes_copy.end(), std::bind(prefix_compare_external, _1, _2, scores)); - size_t num_returned = std::min(num_prefixes, beam_size_); - std::vector outputs; outputs.reserve(num_returned); diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index a3d5c480..78871b2a 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -61,12 +61,15 @@ public: int class_dim); /* Get transcription from current decoder state + * + * Parameters: + * num_results: Number of beams to return. * * Return: * A vector where each element is a pair of score and decoding result, * in descending order. */ - std::vector decode() const; + std::vector decode(size_t num_results=1) const; }; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index ffc10a13..adaa0445 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -60,7 +60,7 @@ using std::vector; When batch_buffer is full, we do a single step through the acoustic model and accumulate the intermediate decoding state in the DecoderState structure. - When finishStream() is called, we return the corresponding transcription from + When finishStream() is called, we return the corresponding transcript from the current decoder state. */ struct StreamingState { @@ -80,7 +80,7 @@ struct StreamingState { char* intermediateDecode() const; void finalizeStream(); char* finishStream(); - Result* finishStreamWithMetadata(unsigned int num_results); + Metadata* finishStreamWithMetadata(unsigned int num_results); void processAudioWindow(const vector& buf); void processMfccWindow(const vector& buf); @@ -143,7 +143,7 @@ StreamingState::finishStream() return model_->decode(decoder_state_); } -Result* +Metadata* StreamingState::finishStreamWithMetadata(unsigned int num_results) { finalizeStream(); @@ -411,11 +411,11 @@ DS_FinishStream(StreamingState* aSctx) return str; } -Result* +Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx, unsigned int aNumResults) { - Result* result = aSctx->finishStreamWithMetadata(aNumResults); + Metadata* result = aSctx->finishStreamWithMetadata(aNumResults); DS_FreeStream(aSctx); return result; } @@ -443,7 +443,7 @@ DS_SpeechToText(ModelState* aCtx, return DS_FinishStream(ctx); } -Result* +Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, const short* aBuffer, unsigned int aBufferSize, @@ -463,30 +463,16 @@ void DS_FreeMetadata(Metadata* m) { if (m) { - for (int i = 0; i < m->num_items; ++i) { - free(m->items[i].character); - } - delete[] m->items; - delete m; - } -} - -void -DS_FreeResult(Result* r) -{ - if (r) { - for (int i = 0; i < r->num_transcriptions; ++i) { - Metadata* m = &r->transcriptions[i]; - - for (int j = 0; j < m->num_items; ++j) { - free(m->items[j].character); + for (int i = 0; i < m->num_transcripts; ++i) { + for (int j = 0; j < m->transcripts[i].num_tokens; ++j) { + free(m->transcripts[i].tokens[j].text); } - delete[] m->items; + delete[] m->transcripts[i].tokens; } - delete[] r->transcriptions; - delete r; + delete[] m->transcripts; + delete m; } } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 53f1954f..7aee1048 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -20,43 +20,44 @@ typedef struct ModelState ModelState; typedef struct StreamingState StreamingState; /** - * @brief Stores each individual character, along with its timing information + * @brief Stores text of an individual token, along with its timing information */ -typedef struct MetadataItem { - /** The character generated for transcription */ - char* character; +typedef struct TokenMetadata { + /** The text corresponding to this token */ + char* text; - /** Position of the character in units of 20ms */ + /** Position of the token in units of 20ms */ int timestep; - /** Position of the character in seconds */ + /** Position of the token in seconds */ float start_time; -} MetadataItem; +} TokenMetadata; /** - * @brief Stores the entire CTC output as an array of character metadata objects + * @brief A single transcript computed by the model, including a confidence + * value and the metadata for its constituent tokens. */ -typedef struct Metadata { - /** List of items */ - MetadataItem* items; - /** Size of the list of items */ - int num_items; +typedef struct CandidateTranscript { + /** Array of TokenMetadata objects */ + TokenMetadata* tokens; + /** Size of the tokens array */ + int num_tokens; /** Approximated confidence value for this transcription. This is roughly the * sum of the acoustic model logit values for each timestep/character that * contributed to the creation of this transcription. */ double confidence; -} Metadata; +} CandidateTranscript; /** - * @brief Stores Metadata structs for each alternative transcription + * @brief An array of CandidateTranscript objects computed by the model */ -typedef struct Result { - /** List of transcriptions */ - Metadata* transcriptions; - /** Size of the list of transcriptions */ - int num_transcriptions; -} Result; +typedef struct Metadata { + /** Array of CandidateTranscript objects */ + CandidateTranscript* transcripts; + /** Size of the transcriptions array */ + int num_transcripts; +} Metadata; enum DeepSpeech_Error_Codes { @@ -197,16 +198,16 @@ char* DS_SpeechToText(ModelState* aCtx, * @param aBuffer A 16-bit, mono raw audio signal at the appropriate * sample rate (matching what the model was trained on). * @param aBufferSize The number of samples in the audio signal. - * @param aNumResults The number of alternative transcriptions to return. + * @param aNumResults The number of candidate transcripts to return. * * @return Outputs a struct of individual letters along with their timing information. * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. */ DEEPSPEECH_EXPORT -Result* DS_SpeechToTextWithMetadata(ModelState* aCtx, - const short* aBuffer, - unsigned int aBufferSize, - unsigned int aNumResults); +Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, + const short* aBuffer, + unsigned int aBufferSize, + unsigned int aNumResults); /** * @brief Create a new streaming inference state. The streaming state returned @@ -266,7 +267,7 @@ char* DS_FinishStream(StreamingState* aSctx); * inference, returns per-letter metadata. * * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. - * @param aNumResults The number of alternative transcriptions to return. + * @param aNumResults The number of candidate transcripts to return. * * @return Outputs a struct of individual letters along with their timing information. * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. @@ -274,8 +275,8 @@ char* DS_FinishStream(StreamingState* aSctx); * @note This method will free the state pointer (@p aSctx). */ DEEPSPEECH_EXPORT -Result* DS_FinishStreamWithMetadata(StreamingState* aSctx, - unsigned int aNumResults); +Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx, + unsigned int aNumResults); /** * @brief Destroy a streaming state without decoding the computed logits. This @@ -295,12 +296,6 @@ void DS_FreeStream(StreamingState* aSctx); DEEPSPEECH_EXPORT void DS_FreeMetadata(Metadata* m); -/** - * @brief Free memory allocated for result information. - */ -DEEPSPEECH_EXPORT -void DS_FreeResult(Result* r); - /** * @brief Free a char* string returned by the DeepSpeech API. */ diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 5a8afae3..d4f16636 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -36,41 +36,38 @@ ModelState::decode(const DecoderState& state) const return strdup(alphabet_.LabelsToString(out[0].tokens).c_str()); } -Result* +Metadata* ModelState::decode_metadata(const DecoderState& state, size_t num_results) { - vector out = state.decode(); + vector out = state.decode(num_results); + size_t num_returned = out.size(); - size_t max_results = std::min(num_results, out.size()); + std::unique_ptr metadata(new Metadata); + metadata->num_transcripts = num_returned; - std::unique_ptr result(new Result()); - result->num_transcriptions = max_results; + std::unique_ptr transcripts(new CandidateTranscript[num_returned]); - std::unique_ptr transcripts(new Metadata[max_results]()); + for (int i = 0; i < num_returned; ++i) { + transcripts[i].num_tokens = out[i].tokens.size(); + transcripts[i].confidence = out[i].confidence; - for (int j = 0; j < max_results; ++j) { - Metadata* metadata = &transcripts[j]; - metadata->num_items = out[j].tokens.size(); - metadata->confidence = out[j].confidence; + std::unique_ptr tokens(new TokenMetadata[transcripts[i].num_tokens]); - std::unique_ptr items(new MetadataItem[metadata->num_items]()); + // Loop through each token + for (int j = 0; j < out[i].tokens.size(); ++j) { + tokens[j].text = strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()); + tokens[j].timestep = out[i].timesteps[j]; + tokens[j].start_time = out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_); - // Loop through each character - for (int i = 0; i < out[j].tokens.size(); ++i) { - items[i].character = strdup(alphabet_.StringFromLabel(out[j].tokens[i]).c_str()); - items[i].timestep = out[j].timesteps[i]; - items[i].start_time = out[j].timesteps[i] * ((float)audio_win_step_ / sample_rate_); - - if (items[i].start_time < 0) { - items[i].start_time = 0; + if (tokens[j].start_time < 0) { + tokens[j].start_time = 0; } } - metadata->items = items.release(); + transcripts[i].tokens = tokens.release(); } - result->transcriptions = transcripts.release(); - - return result.release(); + metadata->transcripts = transcripts.release(); + return metadata.release(); } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 8ea7ad99..43eef970 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -66,14 +66,14 @@ struct ModelState { * @brief Return character-level metadata including letter timings. * * @param state Decoder state to use when decoding. - * @param num_results Number of alternate results to return. + * @param num_results Number of candidate results to return. * - * @return A Result struct containing Metadata structs. - * Each represents an alternate transcription, with the first ranked most probable. - * The user is responsible for freeing Result by calling DS_FreeResult(). + * @return A Metadata struct containing CandidateTranscript structs. + * Each represents an candidate transcript, with the first ranked most probable. + * The user is responsible for freeing Result by calling DS_FreeMetadata(). */ - virtual Result* decode_metadata(const DecoderState& state, - size_t num_results); + virtual Metadata* decode_metadata(const DecoderState& state, + size_t num_results); }; #endif // MODELSTATE_H