diff --git a/doc/C-API.rst b/doc/C-API.rst index 2506d9b2..2b0e7e05 100644 --- a/doc/C-API.rst +++ b/doc/C-API.rst @@ -34,6 +34,9 @@ C .. doxygenfunction:: DS_IntermediateDecode :project: deepspeech-c +.. doxygenfunction:: DS_IntermediateDecodeWithMetadata + :project: deepspeech-c + .. doxygenfunction:: DS_FinishStream :project: deepspeech-c diff --git a/doc/DotNet-API.rst b/doc/DotNet-API.rst index 2ba3415f..b4f85dfc 100644 --- a/doc/DotNet-API.rst +++ b/doc/DotNet-API.rst @@ -31,13 +31,20 @@ ErrorCodes Metadata -------- -.. doxygenstruct:: DeepSpeechClient::Structs::Metadata +.. doxygenclass:: DeepSpeechClient::Models::Metadata :project: deepspeech-dotnet - :members: items, num_items, confidence + :members: Transcripts -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenstruct:: DeepSpeechClient::Structs::MetadataItem +.. doxygenclass:: DeepSpeechClient::Models::CandidateTranscript :project: deepspeech-dotnet - :members: character, timestep, start_time + :members: Tokens, Confidence + +TokenMetadata +------------- + +.. doxygenclass:: DeepSpeechClient::Models::TokenMetadata + :project: deepspeech-dotnet + :members: Text, Timestep, StartTime diff --git a/doc/Java-API.rst b/doc/Java-API.rst index a485dc02..2986ca97 100644 --- a/doc/Java-API.rst +++ b/doc/Java-API.rst @@ -13,11 +13,17 @@ Metadata .. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::Metadata :project: deepspeech-java - :members: getItems, getNum_items, getProbability, getItem + :members: getTranscripts, getNum_transcripts, getTranscript -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::MetadataItem +.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::CandidateTranscript :project: deepspeech-java - :members: getCharacter, getTimestep, getStart_time + :members: getTokens, getNum_tokens, getConfidence, getToken + +TokenMetadata +------------- +.. doxygenclass:: org::mozilla::deepspeech::libdeepspeech::TokenMetadata + :project: deepspeech-java + :members: getText, getTimestep, getStart_time diff --git a/doc/NodeJS-API.rst b/doc/NodeJS-API.rst index aaba718c..b6170b5b 100644 --- a/doc/NodeJS-API.rst +++ b/doc/NodeJS-API.rst @@ -30,8 +30,14 @@ Metadata .. js:autoclass:: Metadata :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. js:autoclass:: MetadataItem +.. js:autoclass:: CandidateTranscript + :members: + +TokenMetadata +------------- + +.. js:autoclass:: TokenMetadata :members: diff --git a/doc/Python-API.rst b/doc/Python-API.rst index b2b3567f..9aec57f0 100644 --- a/doc/Python-API.rst +++ b/doc/Python-API.rst @@ -21,8 +21,14 @@ Metadata .. autoclass:: Metadata :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. autoclass:: MetadataItem +.. autoclass:: CandidateTranscript + :members: + +TokenMetadata +------------- + +.. autoclass:: TokenMetadata :members: diff --git a/doc/Structs.rst b/doc/Structs.rst index 713e52e0..5d532277 100644 --- a/doc/Structs.rst +++ b/doc/Structs.rst @@ -8,9 +8,16 @@ Metadata :project: deepspeech-c :members: -MetadataItem ------------- +CandidateTranscript +------------------- -.. doxygenstruct:: MetadataItem +.. doxygenstruct:: CandidateTranscript + :project: deepspeech-c + :members: + +TokenMetadata +------------- + +.. doxygenstruct:: TokenMetadata :project: deepspeech-c :members: diff --git a/doc/doxygen-dotnet.conf b/doc/doxygen-dotnet.conf index ad64cfcb..74c2c5bb 100644 --- a/doc/doxygen-dotnet.conf +++ b/doc/doxygen-dotnet.conf @@ -790,7 +790,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Structs/ +INPUT = native_client/dotnet/DeepSpeechClient/ native_client/dotnet/DeepSpeechClient/Interfaces/ native_client/dotnet/DeepSpeechClient/Enums/ native_client/dotnet/DeepSpeechClient/Models/ # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/native_client/args.h b/native_client/args.h index 33b9b8fe..ca28bfb7 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -34,6 +34,8 @@ bool extended_metadata = false; bool json_output = false; +int json_candidate_transcripts = 3; + int stream_size = 0; void PrintHelp(const char* bin) @@ -43,18 +45,19 @@ void PrintHelp(const char* bin) "\n" "Running DeepSpeech inference.\n" "\n" - "\t--model MODEL\t\tPath to the model (protocol buffer binary file)\n" - "\t--scorer SCORER\t\tPath to the external scorer file\n" - "\t--audio AUDIO\t\tPath to the audio file to run (WAV format)\n" - "\t--beam_width BEAM_WIDTH\tValue for decoder beam width (int)\n" - "\t--lm_alpha LM_ALPHA\tValue for language model alpha param (float)\n" - "\t--lm_beta LM_BETA\tValue for language model beta param (float)\n" - "\t-t\t\t\tRun in benchmark mode, output mfcc & inference time\n" - "\t--extended\t\tOutput string from extended metadata\n" - "\t--json\t\t\tExtended output, shows word timings as JSON\n" - "\t--stream size\t\tRun in stream mode, output intermediate results\n" - "\t--help\t\t\tShow help\n" - "\t--version\t\tPrint version and exits\n"; + "\t--model MODEL\t\t\tPath to the model (protocol buffer binary file)\n" + "\t--scorer SCORER\t\t\tPath to the external scorer file\n" + "\t--audio AUDIO\t\t\tPath to the audio file to run (WAV format)\n" + "\t--beam_width BEAM_WIDTH\t\tValue for decoder beam width (int)\n" + "\t--lm_alpha LM_ALPHA\t\tValue for language model alpha param (float)\n" + "\t--lm_beta LM_BETA\t\tValue for language model beta param (float)\n" + "\t-t\t\t\t\tRun in benchmark mode, output mfcc & inference time\n" + "\t--extended\t\t\tOutput string from extended metadata\n" + "\t--json\t\t\t\tExtended output, shows word timings as JSON\n" + "\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in output\n" + "\t--stream size\t\t\tRun in stream mode, output intermediate results\n" + "\t--help\t\t\t\tShow help\n" + "\t--version\t\t\tPrint version and exits\n"; char* version = DS_Version(); std::cerr << "DeepSpeech " << version << "\n"; DS_FreeString(version); @@ -74,6 +77,7 @@ bool ProcessArgs(int argc, char** argv) {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, {"json", no_argument, nullptr, 'j'}, + {"candidate_transcripts", required_argument, nullptr, 150}, {"stream", required_argument, nullptr, 's'}, {"version", no_argument, nullptr, 'v'}, {"help", no_argument, nullptr, 'h'}, @@ -128,6 +132,10 @@ bool ProcessArgs(int argc, char** argv) json_output = true; break; + case 150: + json_candidate_transcripts = atoi(optarg); + break; + case 's': stream_size = atoi(optarg); break; diff --git a/native_client/client.cc b/native_client/client.cc index abcadd8d..1f7f78eb 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -44,9 +44,115 @@ struct meta_word { float duration; }; -char* metadataToString(Metadata* metadata); -std::vector WordsFromMetadata(Metadata* metadata); -char* JSONOutput(Metadata* metadata); +char* +CandidateTranscriptToString(const CandidateTranscript* transcript) +{ + std::string retval = ""; + for (int i = 0; i < transcript->num_tokens; i++) { + const TokenMetadata& token = transcript->tokens[i]; + retval += token.text; + } + return strdup(retval.c_str()); +} + +std::vector +CandidateTranscriptToWords(const CandidateTranscript* transcript) +{ + std::vector word_list; + + std::string word = ""; + float word_start_time = 0; + + // Loop through each token + for (int i = 0; i < transcript->num_tokens; i++) { + const TokenMetadata& token = transcript->tokens[i]; + + // Append token to word if it's not a space + if (strcmp(token.text, u8" ") != 0) { + // Log the start time of the new word + if (word.length() == 0) { + word_start_time = token.start_time; + } + word.append(token.text); + } + + // Word boundary is either a space or the last token in the array + if (strcmp(token.text, u8" ") == 0 || i == transcript->num_tokens-1) { + float word_duration = token.start_time - word_start_time; + + if (word_duration < 0) { + word_duration = 0; + } + + meta_word w; + w.word = word; + w.start_time = word_start_time; + w.duration = word_duration; + + word_list.push_back(w); + + // Reset + word = ""; + word_start_time = 0; + } + } + + return word_list; +} + +std::string +CandidateTranscriptToJSON(const CandidateTranscript *transcript) +{ + std::ostringstream out_string; + + std::vector words = CandidateTranscriptToWords(transcript); + + out_string << R"("metadata":{"confidence":)" << transcript->confidence << R"(},"words":[)"; + + for (int i = 0; i < words.size(); i++) { + meta_word w = words[i]; + out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}"; + + if (i < words.size() - 1) { + out_string << ","; + } + } + + out_string << "]"; + + return out_string.str(); +} + +char* +MetadataToJSON(Metadata* result) +{ + std::ostringstream out_string; + out_string << "{\n"; + + for (int j=0; j < result->num_transcripts; ++j) { + const CandidateTranscript *transcript = &result->transcripts[j]; + + if (j == 0) { + out_string << CandidateTranscriptToJSON(transcript); + + if (result->num_transcripts > 1) { + out_string << ",\n" << R"("alternatives")" << ":[\n"; + } + } else { + out_string << "{" << CandidateTranscriptToJSON(transcript) << "}"; + + if (j < result->num_transcripts - 1) { + out_string << ",\n"; + } else { + out_string << "\n]"; + } + } + } + + out_string << "\n}\n"; + + return strdup(out_string.str().c_str()); +} ds_result LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, @@ -57,13 +163,13 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, clock_t ds_start_time = clock(); if (extended_output) { - Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize); - res.string = metadataToString(metadata); - DS_FreeMetadata(metadata); + Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, 1); + res.string = CandidateTranscriptToString(&result->transcripts[0]); + DS_FreeMetadata(result); } else if (json_output) { - Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize); - res.string = JSONOutput(metadata); - DS_FreeMetadata(metadata); + Metadata *result = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, json_candidate_transcripts); + res.string = MetadataToJSON(result); + DS_FreeMetadata(result); } else if (stream_size > 0) { StreamingState* ctx; int status = DS_CreateStream(aCtx, &ctx); @@ -278,87 +384,6 @@ ProcessFile(ModelState* context, const char* path, bool show_times) } } -char* -metadataToString(Metadata* metadata) -{ - std::string retval = ""; - for (int i = 0; i < metadata->num_items; i++) { - MetadataItem item = metadata->items[i]; - retval += item.character; - } - return strdup(retval.c_str()); -} - -std::vector -WordsFromMetadata(Metadata* metadata) -{ - std::vector word_list; - - std::string word = ""; - float word_start_time = 0; - - // Loop through each character - for (int i = 0; i < metadata->num_items; i++) { - MetadataItem item = metadata->items[i]; - - // Append character to word if it's not a space - if (strcmp(item.character, u8" ") != 0) { - // Log the start time of the new word - if (word.length() == 0) { - word_start_time = item.start_time; - } - word.append(item.character); - } - - // Word boundary is either a space or the last character in the array - if (strcmp(item.character, " ") == 0 - || strcmp(item.character, u8" ") == 0 - || i == metadata->num_items-1) { - - float word_duration = item.start_time - word_start_time; - - if (word_duration < 0) { - word_duration = 0; - } - - meta_word w; - w.word = word; - w.start_time = word_start_time; - w.duration = word_duration; - - word_list.push_back(w); - - // Reset - word = ""; - word_start_time = 0; - } - } - - return word_list; -} - -char* -JSONOutput(Metadata* metadata) -{ - std::vector words = WordsFromMetadata(metadata); - - std::ostringstream out_string; - out_string << R"({"metadata":{"confidence":)" << metadata->confidence << R"(},"words":[)"; - - for (int i = 0; i < words.size(); i++) { - meta_word w = words[i]; - out_string << R"({"word":")" << w.word << R"(","time":)" << w.start_time << R"(,"duration":)" << w.duration << "}"; - - if (i < words.size() - 1) { - out_string << ","; - } - } - - out_string << "]}\n"; - - return strdup(out_string.str().c_str()); -} - int main(int argc, char **argv) { diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.cpp b/native_client/ctcdecode/ctc_beam_search_decoder.cpp index 5dadd57f..8a072c53 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.cpp +++ b/native_client/ctcdecode/ctc_beam_search_decoder.cpp @@ -157,7 +157,7 @@ DecoderState::next(const double *probs, } std::vector -DecoderState::decode() const +DecoderState::decode(size_t num_results) const { std::vector prefixes_copy = prefixes_; std::unordered_map scores; @@ -181,16 +181,12 @@ DecoderState::decode() const } using namespace std::placeholders; - size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_); + size_t num_returned = std::min(prefixes_copy.size(), num_results); std::partial_sort(prefixes_copy.begin(), - prefixes_copy.begin() + num_prefixes, + prefixes_copy.begin() + num_returned, prefixes_copy.end(), std::bind(prefix_compare_external, _1, _2, scores)); - //TODO: expose this as an API parameter - const size_t top_paths = 1; - size_t num_returned = std::min(num_prefixes, top_paths); - std::vector outputs; outputs.reserve(num_returned); diff --git a/native_client/ctcdecode/ctc_beam_search_decoder.h b/native_client/ctcdecode/ctc_beam_search_decoder.h index a3d5c480..b785e097 100644 --- a/native_client/ctcdecode/ctc_beam_search_decoder.h +++ b/native_client/ctcdecode/ctc_beam_search_decoder.h @@ -60,13 +60,16 @@ public: int time_dim, int class_dim); - /* Get transcription from current decoder state + /* Get up to num_results transcriptions from current decoder state. + * + * Parameters: + * num_results: Number of beams to return. * * Return: * A vector where each element is a pair of score and decoding result, * in descending order. */ - std::vector decode() const; + std::vector decode(size_t num_results=1) const; }; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index dd2a95ea..96989e04 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -60,7 +60,7 @@ using std::vector; When batch_buffer is full, we do a single step through the acoustic model and accumulate the intermediate decoding state in the DecoderState structure. - When finishStream() is called, we return the corresponding transcription from + When finishStream() is called, we return the corresponding transcript from the current decoder state. */ struct StreamingState { @@ -78,9 +78,10 @@ struct StreamingState { void feedAudioContent(const short* buffer, unsigned int buffer_size); char* intermediateDecode() const; + Metadata* intermediateDecodeWithMetadata(unsigned int num_results) const; void finalizeStream(); char* finishStream(); - Metadata* finishStreamWithMetadata(); + Metadata* finishStreamWithMetadata(unsigned int num_results); void processAudioWindow(const vector& buf); void processMfccWindow(const vector& buf); @@ -136,6 +137,12 @@ StreamingState::intermediateDecode() const return model_->decode(decoder_state_); } +Metadata* +StreamingState::intermediateDecodeWithMetadata(unsigned int num_results) const +{ + return model_->decode_metadata(decoder_state_, num_results); +} + char* StreamingState::finishStream() { @@ -144,10 +151,10 @@ StreamingState::finishStream() } Metadata* -StreamingState::finishStreamWithMetadata() +StreamingState::finishStreamWithMetadata(unsigned int num_results) { finalizeStream(); - return model_->decode_metadata(decoder_state_); + return model_->decode_metadata(decoder_state_, num_results); } void @@ -402,6 +409,13 @@ DS_IntermediateDecode(const StreamingState* aSctx) return aSctx->intermediateDecode(); } +Metadata* +DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx, + unsigned int aNumResults) +{ + return aSctx->intermediateDecodeWithMetadata(aNumResults); +} + char* DS_FinishStream(StreamingState* aSctx) { @@ -411,11 +425,12 @@ DS_FinishStream(StreamingState* aSctx) } Metadata* -DS_FinishStreamWithMetadata(StreamingState* aSctx) +DS_FinishStreamWithMetadata(StreamingState* aSctx, + unsigned int aNumResults) { - Metadata* metadata = aSctx->finishStreamWithMetadata(); + Metadata* result = aSctx->finishStreamWithMetadata(aNumResults); DS_FreeStream(aSctx); - return metadata; + return result; } StreamingState* @@ -444,10 +459,11 @@ DS_SpeechToText(ModelState* aCtx, Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, const short* aBuffer, - unsigned int aBufferSize) + unsigned int aBufferSize, + unsigned int aNumResults) { StreamingState* ctx = CreateStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize); - return DS_FinishStreamWithMetadata(ctx); + return DS_FinishStreamWithMetadata(ctx, aNumResults); } void @@ -460,11 +476,16 @@ void DS_FreeMetadata(Metadata* m) { if (m) { - for (int i = 0; i < m->num_items; ++i) { - free(m->items[i].character); + for (int i = 0; i < m->num_transcripts; ++i) { + for (int j = 0; j < m->transcripts[i].num_tokens; ++j) { + free((void*)m->transcripts[i].tokens[j].text); + } + + free((void*)m->transcripts[i].tokens); } - delete[] m->items; - delete m; + + free((void*)m->transcripts); + free(m); } } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 6dad59db..a8c29c93 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -20,32 +20,43 @@ typedef struct ModelState ModelState; typedef struct StreamingState StreamingState; /** - * @brief Stores each individual character, along with its timing information + * @brief Stores text of an individual token, along with its timing information */ -typedef struct MetadataItem { - /** The character generated for transcription */ - char* character; +typedef struct TokenMetadata { + /** The text corresponding to this token */ + const char* const text; - /** Position of the character in units of 20ms */ - int timestep; + /** Position of the token in units of 20ms */ + const unsigned int timestep; - /** Position of the character in seconds */ - float start_time; -} MetadataItem; + /** Position of the token in seconds */ + const float start_time; +} TokenMetadata; /** - * @brief Stores the entire CTC output as an array of character metadata objects + * @brief A single transcript computed by the model, including a confidence + * value and the metadata for its constituent tokens. + */ +typedef struct CandidateTranscript { + /** Array of TokenMetadata objects */ + const TokenMetadata* const tokens; + /** Size of the tokens array */ + const unsigned int num_tokens; + /** Approximated confidence value for this transcript. This is roughly the + * sum of the acoustic model logit values for each timestep/character that + * contributed to the creation of this transcript. + */ + const double confidence; +} CandidateTranscript; + +/** + * @brief An array of CandidateTranscript objects computed by the model. */ typedef struct Metadata { - /** List of items */ - MetadataItem* items; - /** Size of the list of items */ - int num_items; - /** Approximated confidence value for this transcription. This is roughly the - * sum of the acoustic model logit values for each timestep/character that - * contributed to the creation of this transcription. - */ - double confidence; + /** Array of CandidateTranscript objects */ + const CandidateTranscript* const transcripts; + /** Size of the transcripts array */ + const unsigned int num_transcripts; } Metadata; enum DeepSpeech_Error_Codes @@ -164,7 +175,7 @@ int DS_SetScorerAlphaBeta(ModelState* aCtx, float aBeta); /** - * @brief Use the DeepSpeech model to perform Speech-To-Text. + * @brief Use the DeepSpeech model to convert speech to text. * * @param aCtx The ModelState pointer for the model to use. * @param aBuffer A 16-bit, mono raw audio signal at the appropriate @@ -180,21 +191,25 @@ char* DS_SpeechToText(ModelState* aCtx, unsigned int aBufferSize); /** - * @brief Use the DeepSpeech model to perform Speech-To-Text and output metadata - * about the results. + * @brief Use the DeepSpeech model to convert speech to text and output results + * including metadata. * * @param aCtx The ModelState pointer for the model to use. * @param aBuffer A 16-bit, mono raw audio signal at the appropriate * sample rate (matching what the model was trained on). * @param aBufferSize The number of samples in the audio signal. + * @param aNumResults The maximum number of CandidateTranscript structs to return. Returned value might be smaller than this. * - * @return Outputs a struct of individual letters along with their timing information. - * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. + * @return Metadata struct containing multiple CandidateTranscript structs. Each + * transcript has per-token metadata including timing information. The + * user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. */ DEEPSPEECH_EXPORT Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx, const short* aBuffer, - unsigned int aBufferSize); + unsigned int aBufferSize, + unsigned int aNumResults); /** * @brief Create a new streaming inference state. The streaming state returned @@ -236,8 +251,24 @@ DEEPSPEECH_EXPORT char* DS_IntermediateDecode(const StreamingState* aSctx); /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns the STT result over the whole audio signal. + * @brief Compute the intermediate decoding of an ongoing streaming inference, + * return results including metadata. + * + * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. + * @param aNumResults The number of candidate transcripts to return. + * + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. The user is + * responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. + */ +DEEPSPEECH_EXPORT +Metadata* DS_IntermediateDecodeWithMetadata(const StreamingState* aSctx, + unsigned int aNumResults); + +/** + * @brief Compute the final decoding of an ongoing streaming inference and return + * the result. Signals the end of an ongoing streaming inference. * * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. * @@ -250,18 +281,23 @@ DEEPSPEECH_EXPORT char* DS_FinishStream(StreamingState* aSctx); /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns per-letter metadata. + * @brief Compute the final decoding of an ongoing streaming inference and return + * results including metadata. Signals the end of an ongoing streaming + * inference. * * @param aSctx A streaming state pointer returned by {@link DS_CreateStream()}. + * @param aNumResults The number of candidate transcripts to return. * - * @return Outputs a struct of individual letters along with their timing information. - * The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. The user is + * responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. + * Returns NULL on error. * * @note This method will free the state pointer (@p aSctx). */ DEEPSPEECH_EXPORT -Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx); +Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx, + unsigned int aNumResults); /** * @brief Destroy a streaming state without decoding the computed logits. This diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs index 576ed308..3340c9b3 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeech.cs @@ -199,13 +199,14 @@ namespace DeepSpeechClient } /// - /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. + /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata. /// /// Instance of the stream to finish. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata result. - public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream) + public unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults) { - return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer()).PtrToMetadata(); + return NativeImp.DS_FinishStreamWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata(); } /// @@ -218,6 +219,17 @@ namespace DeepSpeechClient return NativeImp.DS_IntermediateDecode(stream.GetNativePointer()).PtrToString(); } + /// + /// Computes the intermediate decoding of an ongoing streaming inference, including metadata. + /// + /// Instance of the stream to decode. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. + /// The STT intermediate result. + public unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults) + { + return NativeImp.DS_IntermediateDecodeWithMetadata(stream.GetNativePointer(), aNumResults).PtrToMetadata(); + } + /// /// Return version of this library. The returned version is a semantic version /// (SemVer 2.0.0). @@ -261,14 +273,15 @@ namespace DeepSpeechClient } /// - /// Use the DeepSpeech model to perform Speech-To-Text. + /// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata. /// /// A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). /// The number of samples in the audio signal. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata. Returns NULL on error. - public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize) + public unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, uint aBufferSize, uint aNumResults) { - return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize).PtrToMetadata(); + return NativeImp.DS_SpeechToTextWithMetadata(_modelStatePP, aBuffer, aBufferSize, aNumResults).PtrToMetadata(); } #endregion diff --git a/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj b/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj index b9077361..0139b3e8 100644 --- a/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj +++ b/native_client/dotnet/DeepSpeechClient/DeepSpeechClient.csproj @@ -50,11 +50,13 @@ - + + - + + diff --git a/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs b/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs index 6b7f4c6a..9325f4b8 100644 --- a/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs +++ b/native_client/dotnet/DeepSpeechClient/Extensions/NativeExtensions.cs @@ -26,35 +26,68 @@ namespace DeepSpeechClient.Extensions } /// - /// Converts a pointer into managed metadata object. + /// Converts a pointer into managed TokenMetadata object. + /// + /// Native pointer. + /// TokenMetadata managed object. + private static Models.TokenMetadata PtrToTokenMetadata(this IntPtr intPtr) + { + var token = Marshal.PtrToStructure(intPtr); + var managedToken = new Models.TokenMetadata + { + Timestep = token.timestep, + StartTime = token.start_time, + Text = token.text.PtrToString(releasePtr: false) + }; + return managedToken; + } + + /// + /// Converts a pointer into managed CandidateTranscript object. + /// + /// Native pointer. + /// CandidateTranscript managed object. + private static Models.CandidateTranscript PtrToCandidateTranscript(this IntPtr intPtr) + { + var managedTranscript = new Models.CandidateTranscript(); + var transcript = Marshal.PtrToStructure(intPtr); + + managedTranscript.Tokens = new Models.TokenMetadata[transcript.num_tokens]; + managedTranscript.Confidence = transcript.confidence; + + //we need to manually read each item from the native ptr using its size + var sizeOfTokenMetadata = Marshal.SizeOf(typeof(TokenMetadata)); + for (int i = 0; i < transcript.num_tokens; i++) + { + managedTranscript.Tokens[i] = transcript.tokens.PtrToTokenMetadata(); + transcript.tokens += sizeOfTokenMetadata; + } + + return managedTranscript; + } + + /// + /// Converts a pointer into managed Metadata object. /// /// Native pointer. /// Metadata managed object. internal static Models.Metadata PtrToMetadata(this IntPtr intPtr) { - var managedMetaObject = new Models.Metadata(); - var metaData = (Metadata)Marshal.PtrToStructure(intPtr, typeof(Metadata)); - - managedMetaObject.Items = new Models.MetadataItem[metaData.num_items]; - managedMetaObject.Confidence = metaData.confidence; + var managedMetadata = new Models.Metadata(); + var metadata = Marshal.PtrToStructure(intPtr); + managedMetadata.Transcripts = new Models.CandidateTranscript[metadata.num_transcripts]; //we need to manually read each item from the native ptr using its size - var sizeOfMetaItem = Marshal.SizeOf(typeof(MetadataItem)); - for (int i = 0; i < metaData.num_items; i++) + var sizeOfCandidateTranscript = Marshal.SizeOf(typeof(CandidateTranscript)); + for (int i = 0; i < metadata.num_transcripts; i++) { - var tempItem = Marshal.PtrToStructure(metaData.items); - managedMetaObject.Items[i] = new Models.MetadataItem - { - Timestep = tempItem.timestep, - StartTime = tempItem.start_time, - Character = tempItem.character.PtrToString(releasePtr: false) - }; - //we keep the offset on each read - metaData.items += sizeOfMetaItem; + managedMetadata.Transcripts[i] = metadata.transcripts.PtrToCandidateTranscript(); + metadata.transcripts += sizeOfCandidateTranscript; } + NativeImp.DS_FreeMetadata(intPtr); - return managedMetaObject; + return managedMetadata; } } } diff --git a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs index 18677abc..37d6ce59 100644 --- a/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs +++ b/native_client/dotnet/DeepSpeechClient/Interfaces/IDeepSpeech.cs @@ -68,13 +68,15 @@ namespace DeepSpeechClient.Interfaces uint aBufferSize); /// - /// Use the DeepSpeech model to perform Speech-To-Text. + /// Use the DeepSpeech model to perform Speech-To-Text, return results including metadata. /// /// A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). /// The number of samples in the audio signal. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata. Returns NULL on error. unsafe Metadata SpeechToTextWithMetadata(short[] aBuffer, - uint aBufferSize); + uint aBufferSize, + uint aNumResults); /// /// Destroy a streaming state without decoding the computed logits. @@ -102,6 +104,14 @@ namespace DeepSpeechClient.Interfaces /// The STT intermediate result. unsafe string IntermediateDecode(DeepSpeechStream stream); + /// + /// Computes the intermediate decoding of an ongoing streaming inference, including metadata. + /// + /// Instance of the stream to decode. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. + /// The extended metadata result. + unsafe Metadata IntermediateDecodeWithMetadata(DeepSpeechStream stream, uint aNumResults); + /// /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. /// @@ -110,10 +120,11 @@ namespace DeepSpeechClient.Interfaces unsafe string FinishStream(DeepSpeechStream stream); /// - /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal. + /// Closes the ongoing streaming inference, returns the STT result over the whole audio signal, including metadata. /// /// Instance of the stream to finish. + /// Maximum number of candidate transcripts to return. Returned list might be smaller than this. /// The extended metadata result. - unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream); + unsafe Metadata FinishStreamWithMetadata(DeepSpeechStream stream, uint aNumResults); } } diff --git a/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs b/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs new file mode 100644 index 00000000..cc6b5d28 --- /dev/null +++ b/native_client/dotnet/DeepSpeechClient/Models/CandidateTranscript.cs @@ -0,0 +1,17 @@ +namespace DeepSpeechClient.Models +{ + /// + /// Stores the entire CTC output as an array of character metadata objects. + /// + public class CandidateTranscript + { + /// + /// Approximated confidence value for this transcription. + /// + public double Confidence { get; set; } + /// + /// List of metada tokens containing text, timestep, and time offset. + /// + public TokenMetadata[] Tokens { get; set; } + } +} \ No newline at end of file diff --git a/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs b/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs index 870eb162..fb6c613d 100644 --- a/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs +++ b/native_client/dotnet/DeepSpeechClient/Models/Metadata.cs @@ -6,12 +6,8 @@ public class Metadata { /// - /// Approximated confidence value for this transcription. + /// List of candidate transcripts. /// - public double Confidence { get; set; } - /// - /// List of metada items containing char, timespet, and time offset. - /// - public MetadataItem[] Items { get; set; } + public CandidateTranscript[] Transcripts { get; set; } } } \ No newline at end of file diff --git a/native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs b/native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs similarity index 89% rename from native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs rename to native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs index e329c6cb..5f2dea56 100644 --- a/native_client/dotnet/DeepSpeechClient/Models/MetadataItem.cs +++ b/native_client/dotnet/DeepSpeechClient/Models/TokenMetadata.cs @@ -3,12 +3,12 @@ /// /// Stores each individual character, along with its timing information. /// - public class MetadataItem + public class TokenMetadata { /// /// Char of the current timestep. /// - public string Character; + public string Text; /// /// Position of the character in units of 20ms. /// diff --git a/native_client/dotnet/DeepSpeechClient/NativeImp.cs b/native_client/dotnet/DeepSpeechClient/NativeImp.cs index 6c3494b6..eabbfe48 100644 --- a/native_client/dotnet/DeepSpeechClient/NativeImp.cs +++ b/native_client/dotnet/DeepSpeechClient/NativeImp.cs @@ -17,45 +17,46 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, - ref IntPtr** pint); + ref IntPtr** pint); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern uint DS_GetModelBeamWidth(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_SetModelBeamWidth(IntPtr** aCtx, - uint aBeamWidth); + uint aBeamWidth); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern ErrorCodes DS_CreateModel(string aModelPath, - uint aBeamWidth, - ref IntPtr** pint); + uint aBeamWidth, + ref IntPtr** pint); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal unsafe static extern int DS_GetModelSampleRate(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_EnableExternalScorer(IntPtr** aCtx, - string aScorerPath); + string aScorerPath); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_DisableExternalScorer(IntPtr** aCtx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern ErrorCodes DS_SetScorerAlphaBeta(IntPtr** aCtx, - float aAlpha, - float aBeta); + float aAlpha, + float aBeta); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern IntPtr DS_SpeechToText(IntPtr** aCtx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, SetLastError = true)] internal static unsafe extern IntPtr DS_SpeechToTextWithMetadata(IntPtr** aCtx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize, + uint aNumResults); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern void DS_FreeModel(IntPtr** aCtx); @@ -76,18 +77,23 @@ namespace DeepSpeechClient [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern void DS_FeedAudioContent(IntPtr** aSctx, - short[] aBuffer, - uint aBufferSize); + short[] aBuffer, + uint aBufferSize); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] internal static unsafe extern IntPtr DS_IntermediateDecode(IntPtr** aSctx); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] + internal static unsafe extern IntPtr DS_IntermediateDecodeWithMetadata(IntPtr** aSctx, + uint aNumResults); + [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl, CharSet = CharSet.Ansi, SetLastError = true)] internal static unsafe extern IntPtr DS_FinishStream(IntPtr** aSctx); [DllImport("libdeepspeech.so", CallingConvention = CallingConvention.Cdecl)] - internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx); + internal static unsafe extern IntPtr DS_FinishStreamWithMetadata(IntPtr** aSctx, + uint aNumResults); #endregion } } diff --git a/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs b/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs new file mode 100644 index 00000000..54581f6f --- /dev/null +++ b/native_client/dotnet/DeepSpeechClient/Structs/CandidateTranscript.cs @@ -0,0 +1,22 @@ +using System; +using System.Runtime.InteropServices; + +namespace DeepSpeechClient.Structs +{ + [StructLayout(LayoutKind.Sequential)] + internal unsafe struct CandidateTranscript + { + /// + /// Native list of tokens. + /// + internal unsafe IntPtr tokens; + /// + /// Count of tokens from the native side. + /// + internal unsafe int num_tokens; + /// + /// Approximated confidence value for this transcription. + /// + internal unsafe double confidence; + } +} diff --git a/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs b/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs index 411da9f2..0a9beddc 100644 --- a/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs +++ b/native_client/dotnet/DeepSpeechClient/Structs/Metadata.cs @@ -7,16 +7,12 @@ namespace DeepSpeechClient.Structs internal unsafe struct Metadata { /// - /// Native list of items. + /// Native list of candidate transcripts. /// - internal unsafe IntPtr items; + internal unsafe IntPtr transcripts; /// - /// Count of items from the native side. + /// Count of transcripts from the native side. /// - internal unsafe int num_items; - /// - /// Approximated confidence value for this transcription. - /// - internal unsafe double confidence; + internal unsafe int num_transcripts; } } diff --git a/native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs b/native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs similarity index 80% rename from native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs rename to native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs index 10092742..1c660c71 100644 --- a/native_client/dotnet/DeepSpeechClient/Structs/MetadataItem.cs +++ b/native_client/dotnet/DeepSpeechClient/Structs/TokenMetadata.cs @@ -4,12 +4,12 @@ using System.Runtime.InteropServices; namespace DeepSpeechClient.Structs { [StructLayout(LayoutKind.Sequential)] - internal unsafe struct MetadataItem + internal unsafe struct TokenMetadata { /// - /// Native character. + /// Native text. /// - internal unsafe IntPtr character; + internal unsafe IntPtr text; /// /// Position of the character in units of 20ms. /// diff --git a/native_client/dotnet/DeepSpeechConsole/Program.cs b/native_client/dotnet/DeepSpeechConsole/Program.cs index b35c7046..a08e44b6 100644 --- a/native_client/dotnet/DeepSpeechConsole/Program.cs +++ b/native_client/dotnet/DeepSpeechConsole/Program.cs @@ -21,14 +21,14 @@ namespace CSharpExamples static string GetArgument(IEnumerable args, string option) => args.SkipWhile(i => i != option).Skip(1).Take(1).FirstOrDefault(); - static string MetadataToString(Metadata meta) + static string MetadataToString(CandidateTranscript transcript) { var nl = Environment.NewLine; string retval = - Environment.NewLine + $"Recognized text: {string.Join("", meta?.Items?.Select(x => x.Character))} {nl}" - + $"Confidence: {meta?.Confidence} {nl}" - + $"Item count: {meta?.Items?.Length} {nl}" - + string.Join(nl, meta?.Items?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Character}")); + Environment.NewLine + $"Recognized text: {string.Join("", transcript?.Tokens?.Select(x => x.Text))} {nl}" + + $"Confidence: {transcript?.Confidence} {nl}" + + $"Item count: {transcript?.Tokens?.Length} {nl}" + + string.Join(nl, transcript?.Tokens?.Select(x => $"Timestep : {x.Timestep} TimeOffset: {x.StartTime} Char: {x.Text}")); return retval; } @@ -75,8 +75,8 @@ namespace CSharpExamples if (extended) { Metadata metaResult = sttClient.SpeechToTextWithMetadata(waveBuffer.ShortBuffer, - Convert.ToUInt32(waveBuffer.MaxSize / 2)); - speechResult = MetadataToString(metaResult); + Convert.ToUInt32(waveBuffer.MaxSize / 2), 1); + speechResult = MetadataToString(metaResult.Transcripts[0]); } else { diff --git a/native_client/java/jni/deepspeech.i b/native_client/java/jni/deepspeech.i index ded18439..c028714c 100644 --- a/native_client/java/jni/deepspeech.i +++ b/native_client/java/jni/deepspeech.i @@ -6,6 +6,8 @@ %} %include "typemaps.i" +%include "enums.swg" +%javaconst(1); %include "arrays_java.i" // apply to DS_FeedAudioContent and DS_SpeechToText @@ -15,21 +17,29 @@ %pointer_functions(ModelState*, modelstatep); %pointer_functions(StreamingState*, streamingstatep); -%typemap(newfree) char* "DS_FreeString($1);"; - -%include "carrays.i" -%array_functions(struct MetadataItem, metadataItem_array); +%extend struct CandidateTranscript { + /** + * Retrieve one TokenMetadata element + * + * @param i Array index of the TokenMetadata to get + * + * @return The TokenMetadata requested or null + */ + const TokenMetadata& getToken(int i) { + return self->tokens[i]; + } +} %extend struct Metadata { /** - * Retrieve one MetadataItem element + * Retrieve one CandidateTranscript element * - * @param i Array index of the MetadataItem to get + * @param i Array index of the CandidateTranscript to get * - * @return The MetadataItem requested or null + * @return The CandidateTranscript requested or null */ - MetadataItem getItem(int i) { - return metadataItem_array_getitem(self->items, i); + const CandidateTranscript& getTranscript(int i) { + return self->transcripts[i]; } ~Metadata() { @@ -37,14 +47,18 @@ } } -%nodefaultdtor Metadata; %nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; +%typemap(newfree) char* "DS_FreeString($1);"; %newobject DS_SpeechToText; %newobject DS_IntermediateDecode; %newobject DS_FinishStream; +%newobject DS_ErrorCodeToErrorMessage; %rename ("%(strip:[DS_])s") ""; diff --git a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java index 2957b2e7..f7eccf00 100644 --- a/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java +++ b/native_client/java/libdeepspeech/src/androidTest/java/org/mozilla/deepspeech/libdeepspeech/test/BasicTest.java @@ -12,7 +12,7 @@ import org.junit.runners.MethodSorters; import static org.junit.Assert.*; import org.mozilla.deepspeech.libdeepspeech.DeepSpeechModel; -import org.mozilla.deepspeech.libdeepspeech.Metadata; +import org.mozilla.deepspeech.libdeepspeech.CandidateTranscript; import java.io.RandomAccessFile; import java.io.FileNotFoundException; @@ -61,10 +61,10 @@ public class BasicTest { m.freeModel(); } - private String metadataToString(Metadata m) { + private String candidateTranscriptToString(CandidateTranscript t) { String retval = ""; - for (int i = 0; i < m.getNum_items(); ++i) { - retval += m.getItem(i).getCharacter(); + for (int i = 0; i < t.getNum_tokens(); ++i) { + retval += t.getToken(i).getText(); } return retval; } @@ -97,7 +97,7 @@ public class BasicTest { ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asShortBuffer().get(shorts); if (extendedMetadata) { - return metadataToString(m.sttWithMetadata(shorts, shorts.length)); + return candidateTranscriptToString(m.sttWithMetadata(shorts, shorts.length, 1).getTranscript(0)); } else { return m.stt(shorts, shorts.length); } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java index 6d0a316b..eafa11e2 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech/DeepSpeechModel.java @@ -11,8 +11,15 @@ public class DeepSpeechModel { } // FIXME: We should have something better than those SWIGTYPE_* - SWIGTYPE_p_p_ModelState _mspp; - SWIGTYPE_p_ModelState _msp; + private SWIGTYPE_p_p_ModelState _mspp; + private SWIGTYPE_p_ModelState _msp; + + private void evaluateErrorCode(int errorCode) { + DeepSpeech_Error_Codes code = DeepSpeech_Error_Codes.swigToEnum(errorCode); + if (code != DeepSpeech_Error_Codes.ERR_OK) { + throw new RuntimeException("Error: " + impl.ErrorCodeToErrorMessage(errorCode) + " (0x" + Integer.toHexString(errorCode) + ")."); + } + } /** * @brief An object providing an interface to a trained DeepSpeech model. @@ -20,10 +27,12 @@ public class DeepSpeechModel { * @constructor * * @param modelPath The path to the frozen model graph. + * + * @throws RuntimeException on failure. */ public DeepSpeechModel(String modelPath) { this._mspp = impl.new_modelstatep(); - impl.CreateModel(modelPath, this._mspp); + evaluateErrorCode(impl.CreateModel(modelPath, this._mspp)); this._msp = impl.modelstatep_value(this._mspp); } @@ -43,10 +52,10 @@ public class DeepSpeechModel { * @param aBeamWidth The beam width used by the model. A larger beam width value * generates better results at the cost of decoding time. * - * @return Zero on success, non-zero on failure. + * @throws RuntimeException on failure. */ - public int setBeamWidth(long beamWidth) { - return impl.SetModelBeamWidth(this._msp, beamWidth); + public void setBeamWidth(long beamWidth) { + evaluateErrorCode(impl.SetModelBeamWidth(this._msp, beamWidth)); } /** @@ -70,19 +79,19 @@ public class DeepSpeechModel { * * @param scorer The path to the external scorer file. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void enableExternalScorer(String scorer) { - impl.EnableExternalScorer(this._msp, scorer); + evaluateErrorCode(impl.EnableExternalScorer(this._msp, scorer)); } /** * @brief Disable decoding using an external scorer. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void disableExternalScorer() { - impl.DisableExternalScorer(this._msp); + evaluateErrorCode(impl.DisableExternalScorer(this._msp)); } /** @@ -91,10 +100,10 @@ public class DeepSpeechModel { * @param alpha The alpha hyperparameter of the decoder. Language model weight. * @param beta The beta hyperparameter of the decoder. Word insertion weight. * - * @return Zero on success, non-zero on failure (invalid arguments). + * @throws RuntimeException on failure. */ public void setScorerAlphaBeta(float alpha, float beta) { - impl.SetScorerAlphaBeta(this._msp, alpha, beta); + evaluateErrorCode(impl.SetScorerAlphaBeta(this._msp, alpha, beta)); } /* @@ -117,11 +126,13 @@ public class DeepSpeechModel { * @param buffer A 16-bit, mono raw audio signal at the appropriate * sample rate (matching what the model was trained on). * @param buffer_size The number of samples in the audio signal. + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. * - * @return Outputs a Metadata object of individual letters along with their timing information. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. */ - public Metadata sttWithMetadata(short[] buffer, int buffer_size) { - return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size); + public Metadata sttWithMetadata(short[] buffer, int buffer_size, int num_results) { + return impl.SpeechToTextWithMetadata(this._msp, buffer, buffer_size, num_results); } /** @@ -130,10 +141,12 @@ public class DeepSpeechModel { * and finishStream(). * * @return An opaque object that represents the streaming state. + * + * @throws RuntimeException on failure. */ public DeepSpeechStreamingState createStream() { SWIGTYPE_p_p_StreamingState ssp = impl.new_streamingstatep(); - impl.CreateStream(this._msp, ssp); + evaluateErrorCode(impl.CreateStream(this._msp, ssp)); return new DeepSpeechStreamingState(impl.streamingstatep_value(ssp)); } @@ -161,8 +174,20 @@ public class DeepSpeechModel { } /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns the STT result over the whole audio signal. + * @brief Compute the intermediate decoding of an ongoing streaming inference. + * + * @param ctx A streaming state pointer returned by createStream(). + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. + * + * @return The STT intermediate result. + */ + public Metadata intermediateDecodeWithMetadata(DeepSpeechStreamingState ctx, int num_results) { + return impl.IntermediateDecodeWithMetadata(ctx.get(), num_results); + } + + /** + * @brief Compute the final decoding of an ongoing streaming inference and return + * the result. Signals the end of an ongoing streaming inference. * * @param ctx A streaming state pointer returned by createStream(). * @@ -175,16 +200,19 @@ public class DeepSpeechModel { } /** - * @brief Signal the end of an audio signal to an ongoing streaming - * inference, returns per-letter metadata. + * @brief Compute the final decoding of an ongoing streaming inference and return + * the results including metadata. Signals the end of an ongoing streaming + * inference. * * @param ctx A streaming state pointer returned by createStream(). + * @param num_results Maximum number of candidate transcripts to return. Returned list might be smaller than this. * - * @return Outputs a Metadata object of individual letters along with their timing information. + * @return Metadata struct containing multiple candidate transcripts. Each transcript + * has per-token metadata including timing information. * * @note This method will free the state pointer (@p ctx). */ - public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx) { - return impl.FinishStreamWithMetadata(ctx.get()); + public Metadata finishStreamWithMetadata(DeepSpeechStreamingState ctx, int num_results) { + return impl.FinishStreamWithMetadata(ctx.get(), num_results); } } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java new file mode 100644 index 00000000..fa13c474 --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/CandidateTranscript.java @@ -0,0 +1,73 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +/** + * A single transcript computed by the model, including a confidence
+ * value and the metadata for its constituent tokens. + */ +public class CandidateTranscript { + private transient long swigCPtr; + protected transient boolean swigCMemOwn; + + protected CandidateTranscript(long cPtr, boolean cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = cPtr; + } + + protected static long getCPtr(CandidateTranscript obj) { + return (obj == null) ? 0 : obj.swigCPtr; + } + + public synchronized void delete() { + if (swigCPtr != 0) { + if (swigCMemOwn) { + swigCMemOwn = false; + throw new UnsupportedOperationException("C++ destructor does not have public access"); + } + swigCPtr = 0; + } + } + + /** + * Array of TokenMetadata objects + */ + public TokenMetadata getTokens() { + long cPtr = implJNI.CandidateTranscript_tokens_get(swigCPtr, this); + return (cPtr == 0) ? null : new TokenMetadata(cPtr, false); + } + + /** + * Size of the tokens array + */ + public long getNum_tokens() { + return implJNI.CandidateTranscript_num_tokens_get(swigCPtr, this); + } + + /** + * Approximated confidence value for this transcript. This is roughly the
+ * sum of the acoustic model logit values for each timestep/character that
+ * contributed to the creation of this transcript. + */ + public double getConfidence() { + return implJNI.CandidateTranscript_confidence_get(swigCPtr, this); + } + + /** + * Retrieve one TokenMetadata element
+ *
+ * @param i Array index of the TokenMetadata to get
+ *
+ * @return The TokenMetadata requested or null + */ + public TokenMetadata getToken(int i) { + return new TokenMetadata(implJNI.CandidateTranscript_getToken(swigCPtr, this, i), false); + } + +} diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java new file mode 100644 index 00000000..ed47183e --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/DeepSpeech_Error_Codes.java @@ -0,0 +1,65 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +public enum DeepSpeech_Error_Codes { + ERR_OK(0x0000), + ERR_NO_MODEL(0x1000), + ERR_INVALID_ALPHABET(0x2000), + ERR_INVALID_SHAPE(0x2001), + ERR_INVALID_SCORER(0x2002), + ERR_MODEL_INCOMPATIBLE(0x2003), + ERR_SCORER_NOT_ENABLED(0x2004), + ERR_FAIL_INIT_MMAP(0x3000), + ERR_FAIL_INIT_SESS(0x3001), + ERR_FAIL_INTERPRETER(0x3002), + ERR_FAIL_RUN_SESS(0x3003), + ERR_FAIL_CREATE_STREAM(0x3004), + ERR_FAIL_READ_PROTOBUF(0x3005), + ERR_FAIL_CREATE_SESS(0x3006), + ERR_FAIL_CREATE_MODEL(0x3007); + + public final int swigValue() { + return swigValue; + } + + public static DeepSpeech_Error_Codes swigToEnum(int swigValue) { + DeepSpeech_Error_Codes[] swigValues = DeepSpeech_Error_Codes.class.getEnumConstants(); + if (swigValue < swigValues.length && swigValue >= 0 && swigValues[swigValue].swigValue == swigValue) + return swigValues[swigValue]; + for (DeepSpeech_Error_Codes swigEnum : swigValues) + if (swigEnum.swigValue == swigValue) + return swigEnum; + throw new IllegalArgumentException("No enum " + DeepSpeech_Error_Codes.class + " with value " + swigValue); + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes() { + this.swigValue = SwigNext.next++; + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes(int swigValue) { + this.swigValue = swigValue; + SwigNext.next = swigValue+1; + } + + @SuppressWarnings("unused") + private DeepSpeech_Error_Codes(DeepSpeech_Error_Codes swigEnum) { + this.swigValue = swigEnum.swigValue; + SwigNext.next = this.swigValue+1; + } + + private final int swigValue; + + private static class SwigNext { + private static int next = 0; + } +} + diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java index 482b7c58..d2831bc4 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/Metadata.java @@ -1,6 +1,6 @@ /* ---------------------------------------------------------------------------- * This file was automatically generated by SWIG (http://www.swig.org). - * Version 4.0.2 + * Version 4.0.1 * * Do not make changes to this file unless you know what you are doing--modify * the SWIG interface file instead. @@ -9,7 +9,7 @@ package org.mozilla.deepspeech.libdeepspeech; /** - * Stores the entire CTC output as an array of character metadata objects + * An array of CandidateTranscript objects computed by the model. */ public class Metadata { private transient long swigCPtr; @@ -40,61 +40,29 @@ public class Metadata { } /** - * List of items + * Array of CandidateTranscript objects */ - public void setItems(MetadataItem value) { - implJNI.Metadata_items_set(swigCPtr, this, MetadataItem.getCPtr(value), value); + public CandidateTranscript getTranscripts() { + long cPtr = implJNI.Metadata_transcripts_get(swigCPtr, this); + return (cPtr == 0) ? null : new CandidateTranscript(cPtr, false); } /** - * List of items + * Size of the transcripts array */ - public MetadataItem getItems() { - long cPtr = implJNI.Metadata_items_get(swigCPtr, this); - return (cPtr == 0) ? null : new MetadataItem(cPtr, false); + public long getNum_transcripts() { + return implJNI.Metadata_num_transcripts_get(swigCPtr, this); } /** - * Size of the list of items - */ - public void setNum_items(int value) { - implJNI.Metadata_num_items_set(swigCPtr, this, value); - } - - /** - * Size of the list of items - */ - public int getNum_items() { - return implJNI.Metadata_num_items_get(swigCPtr, this); - } - - /** - * Approximated confidence value for this transcription. This is roughly the
- * sum of the acoustic model logit values for each timestep/character that
- * contributed to the creation of this transcription. - */ - public void setConfidence(double value) { - implJNI.Metadata_confidence_set(swigCPtr, this, value); - } - - /** - * Approximated confidence value for this transcription. This is roughly the
- * sum of the acoustic model logit values for each timestep/character that
- * contributed to the creation of this transcription. - */ - public double getConfidence() { - return implJNI.Metadata_confidence_get(swigCPtr, this); - } - - /** - * Retrieve one MetadataItem element
+ * Retrieve one CandidateTranscript element
*
- * @param i Array index of the MetadataItem to get
+ * @param i Array index of the CandidateTranscript to get
*
- * @return The MetadataItem requested or null + * @return The CandidateTranscript requested or null */ - public MetadataItem getItem(int i) { - return new MetadataItem(implJNI.Metadata_getItem(swigCPtr, this, i), true); + public CandidateTranscript getTranscript(int i) { + return new CandidateTranscript(implJNI.Metadata_getTranscript(swigCPtr, this, i), false); } } diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst index 1279d717..bd89f9b8 100644 --- a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/README.rst @@ -4,7 +4,7 @@ Javadoc for Sphinx This code is only here for reference for documentation generation. -To update, please build SWIG (4.0 at least) and then run from native_client/java: +To update, please install SWIG (4.0 at least) and then run from native_client/java: .. code-block:: diff --git a/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java new file mode 100644 index 00000000..d14fc161 --- /dev/null +++ b/native_client/java/libdeepspeech/src/main/java/org/mozilla/deepspeech/libdeepspeech_doc/TokenMetadata.java @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + * This file was automatically generated by SWIG (http://www.swig.org). + * Version 4.0.1 + * + * Do not make changes to this file unless you know what you are doing--modify + * the SWIG interface file instead. + * ----------------------------------------------------------------------------- */ + +package org.mozilla.deepspeech.libdeepspeech; + +/** + * Stores text of an individual token, along with its timing information + */ +public class TokenMetadata { + private transient long swigCPtr; + protected transient boolean swigCMemOwn; + + protected TokenMetadata(long cPtr, boolean cMemoryOwn) { + swigCMemOwn = cMemoryOwn; + swigCPtr = cPtr; + } + + protected static long getCPtr(TokenMetadata obj) { + return (obj == null) ? 0 : obj.swigCPtr; + } + + public synchronized void delete() { + if (swigCPtr != 0) { + if (swigCMemOwn) { + swigCMemOwn = false; + throw new UnsupportedOperationException("C++ destructor does not have public access"); + } + swigCPtr = 0; + } + } + + /** + * The text corresponding to this token + */ + public String getText() { + return implJNI.TokenMetadata_text_get(swigCPtr, this); + } + + /** + * Position of the token in units of 20ms + */ + public long getTimestep() { + return implJNI.TokenMetadata_timestep_get(swigCPtr, this); + } + + /** + * Position of the token in seconds + */ + public float getStart_time() { + return implJNI.TokenMetadata_start_time_get(swigCPtr, this); + } + +} diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index abbfe59e..16dd19e8 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -42,12 +42,11 @@ function totalTime(hrtimeValue) { return (hrtimeValue[0] + hrtimeValue[1] / 1000000000).toPrecision(4); } -function metadataToString(metadata) { +function candidateTranscriptToString(transcript) { var retval = "" - for (var i = 0; i < metadata.num_items; ++i) { - retval += metadata.items[i].character; + for (var i = 0; i < transcript.tokens.length; ++i) { + retval += transcript.tokens[i].text; } - Ds.FreeMetadata(metadata); return retval; } @@ -117,7 +116,9 @@ audioStream.on('finish', () => { const audioLength = (audioBuffer.length / 2) * (1 / desired_sample_rate); if (args['extended']) { - console.log(metadataToString(model.sttWithMetadata(audioBuffer))); + let metadata = model.sttWithMetadata(audioBuffer, 1); + console.log(candidateTranscriptToString(metadata.transcripts[0])); + Ds.FreeMetadata(metadata); } else { console.log(model.stt(audioBuffer)); } diff --git a/native_client/javascript/deepspeech.i b/native_client/javascript/deepspeech.i index efbaa360..cb3968c2 100644 --- a/native_client/javascript/deepspeech.i +++ b/native_client/javascript/deepspeech.i @@ -47,8 +47,8 @@ using namespace node; %typemap(argout) ModelState **retval { $result = SWIGV8_ARRAY_NEW(); SWIGV8_AppendOutput($result, SWIG_From_int(result)); - // owned by SWIG, ModelState destructor gets called when the JavaScript object is finalized (see below) - %append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, SWIG_POINTER_OWN)); + // owned by the application. NodeJS does not guarantee the finalizer will be called so applications must call FreeMetadata themselves. + %append_output(SWIG_NewPointerObj(%as_voidptr(*$1), $*1_descriptor, 0)); } @@ -68,27 +68,29 @@ using namespace node; %nodefaultctor ModelState; %nodefaultdtor ModelState; -%typemap(out) MetadataItem* %{ +%typemap(out) TokenMetadata* %{ $result = SWIGV8_ARRAY_NEW(); - for (int i = 0; i < arg1->num_items; ++i) { - SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_MetadataItem, SWIG_POINTER_OWN)); + for (int i = 0; i < arg1->num_tokens; ++i) { + SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_TokenMetadata, 0)); } %} -%nodefaultdtor Metadata; -%nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; - -%extend struct Metadata { - ~Metadata() { - DS_FreeMetadata($self); +%typemap(out) CandidateTranscript* %{ + $result = SWIGV8_ARRAY_NEW(); + for (int i = 0; i < arg1->num_transcripts; ++i) { + SWIGV8_AppendOutput($result, SWIG_NewPointerObj(SWIG_as_voidptr(&result[i]), SWIGTYPE_p_CandidateTranscript, 0)); } -} +%} -%extend struct MetadataItem { - ~MetadataItem() { } -} +%ignore Metadata::num_transcripts; +%ignore CandidateTranscript::num_tokens; + +%nodefaultctor Metadata; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; %rename ("%(strip:[DS_])s") ""; diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js index cca483f1..6ce06c0d 100644 --- a/native_client/javascript/index.js +++ b/native_client/javascript/index.js @@ -115,15 +115,16 @@ Model.prototype.stt = function(aBuffer) { } /** - * Use the DeepSpeech model to perform Speech-To-Text and output metadata - * about the results. + * Use the DeepSpeech model to perform Speech-To-Text and output results including metadata. * * @param {object} aBuffer A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. * - * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. + * @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. */ -Model.prototype.sttWithMetadata = function(aBuffer) { - return binding.SpeechToTextWithMetadata(this._impl, aBuffer); +Model.prototype.sttWithMetadata = function(aBuffer, aNumResults) { + aNumResults = aNumResults || 1; + return binding.SpeechToTextWithMetadata(this._impl, aBuffer, aNumResults); } /** @@ -172,7 +173,19 @@ Stream.prototype.intermediateDecode = function() { } /** - * Signal the end of an audio signal to an ongoing streaming inference, returns the STT result over the whole audio signal. + * Compute the intermediate decoding of an ongoing streaming inference, return results including metadata. + * + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. + * + * @return {object} :js:func:`Metadata` object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. Returns undefined on error. + */ +Stream.prototype.intermediateDecodeWithMetadata = function(aNumResults) { + aNumResults = aNumResults || 1; + return binding.IntermediateDecode(this._impl, aNumResults); +} + +/** + * Compute the final decoding of an ongoing streaming inference and return the result. Signals the end of an ongoing streaming inference. * * @return {string} The STT result. * @@ -185,14 +198,17 @@ Stream.prototype.finishStream = function() { } /** - * Signal the end of an audio signal to an ongoing streaming inference, returns per-letter metadata. + * Compute the final decoding of an ongoing streaming inference and return the results including metadata. Signals the end of an ongoing streaming inference. + * + * @param {number} aNumResults Maximum number of candidate transcripts to return. Returned list might be smaller than this. Default value is 1 if not specified. * * @return {object} Outputs a :js:func:`Metadata` struct of individual letters along with their timing information. The user is responsible for freeing Metadata by calling :js:func:`FreeMetadata`. * * This method will free the stream, it must not be used after this method is called. */ -Stream.prototype.finishStreamWithMetadata = function() { - result = binding.FinishStreamWithMetadata(this._impl); +Stream.prototype.finishStreamWithMetadata = function(aNumResults) { + aNumResults = aNumResults || 1; + result = binding.FinishStreamWithMetadata(this._impl, aNumResults); this._impl = null; return result; } @@ -236,70 +252,80 @@ function Version() { } -//// Metadata and MetadataItem are here only for documentation purposes +//// Metadata, CandidateTranscript and TokenMetadata are here only for documentation purposes /** * @class * - * Stores each individual character, along with its timing information + * Stores text of an individual token, along with its timing information */ -function MetadataItem() {} +function TokenMetadata() {} /** - * The character generated for transcription + * The text corresponding to this token * - * @return {string} The character generated + * @return {string} The text generated */ -MetadataItem.prototype.character = function() {} +TokenMetadata.prototype.text = function() {} /** - * Position of the character in units of 20ms + * Position of the token in units of 20ms * - * @return {int} The position of the character + * @return {int} The position of the token */ -MetadataItem.prototype.timestep = function() {}; +TokenMetadata.prototype.timestep = function() {}; /** - * Position of the character in seconds + * Position of the token in seconds * - * @return {float} The position of the character + * @return {float} The position of the token */ -MetadataItem.prototype.start_time = function() {}; +TokenMetadata.prototype.start_time = function() {}; /** * @class * - * Stores the entire CTC output as an array of character metadata objects + * A single transcript computed by the model, including a confidence value and + * the metadata for its constituent tokens. */ -function Metadata () {} +function CandidateTranscript () {} /** - * List of items + * Array of tokens * - * @return {array} List of :js:func:`MetadataItem` + * @return {array} Array of :js:func:`TokenMetadata` */ -Metadata.prototype.items = function() {} - -/** - * Size of the list of items - * - * @return {int} Number of items - */ -Metadata.prototype.num_items = function() {} +CandidateTranscript.prototype.tokens = function() {} /** * Approximated confidence value for this transcription. This is roughly the - * sum of the acoustic model logit values for each timestep/character that + * sum of the acoustic model logit values for each timestep/token that * contributed to the creation of this transcription. * * @return {float} Confidence value */ -Metadata.prototype.confidence = function() {} +CandidateTranscript.prototype.confidence = function() {} + +/** + * @class + * + * An array of CandidateTranscript objects computed by the model. + */ +function Metadata () {} + +/** + * Array of transcripts + * + * @return {array} Array of :js:func:`CandidateTranscript` objects + */ +Metadata.prototype.transcripts = function() {} + module.exports = { Model: Model, Metadata: Metadata, - MetadataItem: MetadataItem, + CandidateTranscript: CandidateTranscript, + TokenMetadata: TokenMetadata, Version: Version, FreeModel: FreeModel, FreeStream: FreeStream, diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index ea8928bd..3cb06ac2 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -37,27 +37,39 @@ ModelState::decode(const DecoderState& state) const } Metadata* -ModelState::decode_metadata(const DecoderState& state) +ModelState::decode_metadata(const DecoderState& state, + size_t num_results) { - vector out = state.decode(); + vector out = state.decode(num_results); + unsigned int num_returned = out.size(); - std::unique_ptr metadata(new Metadata()); - metadata->num_items = out[0].tokens.size(); - metadata->confidence = out[0].confidence; + CandidateTranscript* transcripts = (CandidateTranscript*)malloc(sizeof(CandidateTranscript)*num_returned); - std::unique_ptr items(new MetadataItem[metadata->num_items]()); + for (int i = 0; i < num_returned; ++i) { + TokenMetadata* tokens = (TokenMetadata*)malloc(sizeof(TokenMetadata)*out[i].tokens.size()); - // Loop through each character - for (int i = 0; i < out[0].tokens.size(); ++i) { - items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str()); - items[i].timestep = out[0].timesteps[i]; - items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_); - - if (items[i].start_time < 0) { - items[i].start_time = 0; + for (int j = 0; j < out[i].tokens.size(); ++j) { + TokenMetadata token { + strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text + static_cast(out[i].timesteps[j]), // timestep + out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time + }; + memcpy(&tokens[j], &token, sizeof(TokenMetadata)); } + + CandidateTranscript transcript { + tokens, // tokens + static_cast(out[i].tokens.size()), // num_tokens + out[i].confidence, // confidence + }; + memcpy(&transcripts[i], &transcript, sizeof(CandidateTranscript)); } - metadata->items = items.release(); - return metadata.release(); + Metadata* ret = (Metadata*)malloc(sizeof(Metadata)); + Metadata metadata { + transcripts, // transcripts + num_returned, // num_transcripts + }; + memcpy(ret, &metadata, sizeof(Metadata)); + return ret; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 25251e15..0dbe108a 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -66,11 +66,14 @@ struct ModelState { * @brief Return character-level metadata including letter timings. * * @param state Decoder state to use when decoding. + * @param num_results Maximum number of candidate results to return. * - * @return Metadata struct containing MetadataItem structs for each character. - * The user is responsible for freeing Metadata by calling DS_FreeMetadata(). + * @return A Metadata struct containing CandidateTranscript structs. + * Each represents an candidate transcript, with the first ranked most probable. + * The user is responsible for freeing Result by calling DS_FreeMetadata(). */ - virtual Metadata* decode_metadata(const DecoderState& state); + virtual Metadata* decode_metadata(const DecoderState& state, + size_t num_results); }; #endif // MODELSTATE_H diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index a6511efe..a44cf05f 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -121,17 +121,20 @@ class Model(object): """ return deepspeech.impl.SpeechToText(self._impl, audio_buffer) - def sttWithMetadata(self, audio_buffer): + def sttWithMetadata(self, audio_buffer, num_results=1): """ - Use the DeepSpeech model to perform Speech-To-Text and output metadata about the results. + Use the DeepSpeech model to perform Speech-To-Text and return results including metadata. :param audio_buffer: A 16-bit, mono raw audio signal at the appropriate sample rate (matching what the model was trained on). :type audio_buffer: numpy.int16 array - :return: Outputs a struct of individual letters along with their timing information. + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. :type: :func:`Metadata` """ - return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer) + return deepspeech.impl.SpeechToTextWithMetadata(self._impl, audio_buffer, num_results) def createStream(self): """ @@ -187,10 +190,27 @@ class Stream(object): raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") return deepspeech.impl.IntermediateDecode(self._impl) + def intermediateDecodeWithMetadata(self, num_results=1): + """ + Compute the intermediate decoding of an ongoing streaming inference and return results including metadata. + + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. + :type: :func:`Metadata` + + :throws: RuntimeError if the stream object is not valid + """ + if not self._impl: + raise RuntimeError("Stream object is not valid. Trying to decode an already finished stream?") + return deepspeech.impl.IntermediateDecodeWithMetadata(self._impl, num_results) + def finishStream(self): """ - Signal the end of an audio signal to an ongoing streaming inference, - returns the STT result over the whole audio signal. + Compute the final decoding of an ongoing streaming inference and return + the result. Signals the end of an ongoing streaming inference. The underlying + stream object must not be used after this method is called. :return: The STT result. :type: str @@ -203,19 +223,24 @@ class Stream(object): self._impl = None return result - def finishStreamWithMetadata(self): + def finishStreamWithMetadata(self, num_results=1): """ - Signal the end of an audio signal to an ongoing streaming inference, - returns per-letter metadata. + Compute the final decoding of an ongoing streaming inference and return + results including metadata. Signals the end of an ongoing streaming + inference. The underlying stream object must not be used after this + method is called. - :return: Outputs a struct of individual letters along with their timing information. + :param num_results: Maximum number of candidate transcripts to return. Returned list might be smaller than this. + :type num_results: int + + :return: Metadata object containing multiple candidate transcripts. Each transcript has per-token metadata including timing information. :type: :func:`Metadata` :throws: RuntimeError if the stream object is not valid """ if not self._impl: raise RuntimeError("Stream object is not valid. Trying to finish an already finished stream?") - result = deepspeech.impl.FinishStreamWithMetadata(self._impl) + result = deepspeech.impl.FinishStreamWithMetadata(self._impl, num_results) self._impl = None return result @@ -233,52 +258,43 @@ class Stream(object): # This is only for documentation purpose -# Metadata and MetadataItem should be in sync with native_client/deepspeech.h -class MetadataItem(object): +# Metadata, CandidateTranscript and TokenMetadata should be in sync with native_client/deepspeech.h +class TokenMetadata(object): """ Stores each individual character, along with its timing information """ - def character(self): + def text(self): """ - The character generated for transcription + The text for this token """ def timestep(self): """ - Position of the character in units of 20ms + Position of the token in units of 20ms """ def start_time(self): """ - Position of the character in seconds + Position of the token in seconds """ -class Metadata(object): +class CandidateTranscript(object): """ Stores the entire CTC output as an array of character metadata objects """ - def items(self): + def tokens(self): """ - List of items + List of tokens - :return: A list of :func:`MetadataItem` elements + :return: A list of :func:`TokenMetadata` elements :type: list """ - def num_items(self): - """ - Size of the list of items - - :return: Size of the list of items - :type: int - """ - - def confidence(self): """ Approximated confidence value for this transcription. This is roughly the @@ -286,3 +302,12 @@ class Metadata(object): contributed to the creation of this transcription. """ + +class Metadata(object): + def transcripts(self): + """ + List of candidate transcripts + + :return: A list of :func:`CandidateTranscript` objects + :type: list + """ diff --git a/native_client/python/client.py b/native_client/python/client.py index 671968b9..00fa2ff6 100644 --- a/native_client/python/client.py +++ b/native_client/python/client.py @@ -18,6 +18,7 @@ try: except ImportError: from pipes import quote + def convert_samplerate(audio_path, desired_sample_rate): sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) try: @@ -31,25 +32,25 @@ def convert_samplerate(audio_path, desired_sample_rate): def metadata_to_string(metadata): - return ''.join(item.character for item in metadata.items) + return ''.join(token.text for token in metadata.tokens) -def words_from_metadata(metadata): + +def words_from_candidate_transcript(metadata): word = "" word_list = [] word_start_time = 0 # Loop through each character - for i in range(0, metadata.num_items): - item = metadata.items[i] + for i, token in enumerate(metadata.tokens): # Append character to word if it's not a space - if item.character != " ": + if token.text != " ": if len(word) == 0: # Log the start time of the new word - word_start_time = item.start_time + word_start_time = token.start_time - word = word + item.character + word = word + token.text # Word boundary is either a space or the last character in the array - if item.character == " " or i == metadata.num_items - 1: - word_duration = item.start_time - word_start_time + if token.text == " " or i == len(metadata.tokens) - 1: + word_duration = token.start_time - word_start_time if word_duration < 0: word_duration = 0 @@ -69,9 +70,11 @@ def words_from_metadata(metadata): def metadata_json_output(metadata): json_result = dict() - json_result["words"] = words_from_metadata(metadata) - json_result["confidence"] = metadata.confidence - return json.dumps(json_result) + json_result["transcripts"] = [{ + "confidence": transcript.confidence, + "words": words_from_candidate_transcript(transcript), + } for transcript in metadata.transcripts] + return json.dumps(json_result, indent=2) @@ -141,9 +144,9 @@ def main(): print('Running inference.', file=sys.stderr) inference_start = timer() if args.extended: - print(metadata_to_string(ds.sttWithMetadata(audio))) + print(metadata_to_string(ds.sttWithMetadata(audio, 1).transcripts[0])) elif args.json: - print(metadata_json_output(ds.sttWithMetadata(audio))) + print(metadata_json_output(ds.sttWithMetadata(audio, 3))) else: print(ds.stt(audio)) inference_end = timer() - inference_start diff --git a/native_client/python/impl.i b/native_client/python/impl.i index d6c7ba19..259a5b5d 100644 --- a/native_client/python/impl.i +++ b/native_client/python/impl.i @@ -38,30 +38,69 @@ import_array(); %append_output(SWIG_NewPointerObj(%as_voidptr($1), $1_descriptor, SWIG_POINTER_OWN)); } -%typemap(out) MetadataItem* %{ - $result = PyList_New(arg1->num_items); - for (int i = 0; i < arg1->num_items; ++i) { - PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->items[i]), SWIGTYPE_p_MetadataItem, 0); +%fragment("parent_reference_init", "init") { + // Thread-safe initialization - initialize during Python module initialization + parent_reference(); +} + +%fragment("parent_reference_function", "header", fragment="parent_reference_init") { + +static PyObject *parent_reference() { + static PyObject *parent_reference_string = SWIG_Python_str_FromChar("__parent_reference"); + return parent_reference_string; +} + +} + +%typemap(out, fragment="parent_reference_function") CandidateTranscript* %{ + $result = PyList_New(arg1->num_transcripts); + for (int i = 0; i < arg1->num_transcripts; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->transcripts[i]), SWIGTYPE_p_CandidateTranscript, 0); + // Add a reference to Metadata in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); PyList_SetItem($result, i, o); } %} -%extend struct MetadataItem { +%typemap(out, fragment="parent_reference_function") TokenMetadata* %{ + $result = PyList_New(arg1->num_tokens); + for (int i = 0; i < arg1->num_tokens; ++i) { + PyObject* o = SWIG_NewPointerObj(SWIG_as_voidptr(&arg1->tokens[i]), SWIGTYPE_p_TokenMetadata, 0); + // Add a reference to CandidateTranscript in the returned elements to avoid premature + // garbage collection + PyObject_SetAttr(o, parent_reference(), $self); + PyList_SetItem($result, i, o); + } +%} + +%extend struct TokenMetadata { %pythoncode %{ def __repr__(self): - return 'MetadataItem(character=\'{}\', timestep={}, start_time={})'.format(self.character, self.timestep, self.start_time) + return 'TokenMetadata(text=\'{}\', timestep={}, start_time={})'.format(self.text, self.timestep, self.start_time) +%} +} + +%extend struct CandidateTranscript { +%pythoncode %{ + def __repr__(self): + tokens_repr = ',\n'.join(repr(i) for i in self.tokens) + tokens_repr = '\n'.join(' ' + l for l in tokens_repr.split('\n')) + return 'CandidateTranscript(confidence={}, tokens=[\n{}\n])'.format(self.confidence, tokens_repr) %} } %extend struct Metadata { %pythoncode %{ def __repr__(self): - items_repr = ', \n'.join(' ' + repr(i) for i in self.items) - return 'Metadata(confidence={}, items=[\n{}\n])'.format(self.confidence, items_repr) + transcripts_repr = ',\n'.join(repr(i) for i in self.transcripts) + transcripts_repr = '\n'.join(' ' + l for l in transcripts_repr.split('\n')) + return 'Metadata(transcripts=[\n{}\n])'.format(transcripts_repr) %} } -%ignore Metadata::num_items; +%ignore Metadata::num_transcripts; +%ignore CandidateTranscript::num_tokens; %extend struct Metadata { ~Metadata() { @@ -69,10 +108,12 @@ import_array(); } } -%nodefaultdtor Metadata; %nodefaultctor Metadata; -%nodefaultctor MetadataItem; -%nodefaultdtor MetadataItem; +%nodefaultdtor Metadata; +%nodefaultctor CandidateTranscript; +%nodefaultdtor CandidateTranscript; +%nodefaultctor TokenMetadata; +%nodefaultdtor TokenMetadata; %typemap(newfree) char* "DS_FreeString($1);";