Expose multiple transcriptions through the API

This commit is contained in:
dabinat 2020-02-05 07:55:15 +00:00 committed by Reuben Morais
parent b57eaa19d6
commit 32c969c184
5 changed files with 108 additions and 41 deletions

View File

@ -157,7 +157,7 @@ DecoderState::next(const double *probs,
}
std::vector<Output>
DecoderState::decode() const
DecoderState::decode(size_t top_paths) const
{
std::vector<PathTrie*> prefixes_copy = prefixes_;
std::unordered_map<const PathTrie*, float> scores;
@ -167,7 +167,7 @@ DecoderState::decode() const
// score the last word of each prefix that doesn't end with space
if (ext_scorer_) {
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
auto prefix = prefixes_copy[i];
if (!ext_scorer_->is_scoring_boundary(prefix->parent, prefix->character)) {
float score = 0.0;
@ -181,14 +181,12 @@ DecoderState::decode() const
}
using namespace std::placeholders;
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
size_t num_prefixes = std::min(prefixes_copy.size(), top_paths);
std::partial_sort(prefixes_copy.begin(),
prefixes_copy.begin() + num_prefixes,
prefixes_copy.end(),
std::bind(prefix_compare_external, _1, _2, scores));
//TODO: expose this as an API parameter
const size_t top_paths = 1;
size_t num_returned = std::min(num_prefixes, top_paths);
std::vector<Output> outputs;
@ -220,6 +218,7 @@ std::vector<Output> ctc_beam_search_decoder(
int class_dim,
const Alphabet &alphabet,
size_t beam_size,
size_t top_paths,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer)
@ -227,7 +226,7 @@ std::vector<Output> ctc_beam_search_decoder(
DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
state.next(probs, time_dim, class_dim);
return state.decode();
return state.decode(top_paths);
}
std::vector<std::vector<Output>>
@ -240,6 +239,7 @@ ctc_beam_search_decoder_batch(
int seq_lengths_size,
const Alphabet &alphabet,
size_t beam_size,
size_t top_paths,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
@ -259,6 +259,7 @@ ctc_beam_search_decoder_batch(
class_dim,
alphabet,
beam_size,
top_paths,
cutoff_prob,
cutoff_top_n,
ext_scorer));

View File

@ -80,7 +80,7 @@ struct StreamingState {
char* intermediateDecode() const;
void finalizeStream();
char* finishStream();
Metadata* finishStreamWithMetadata();
Result* finishStreamWithMetadata(unsigned int numResults);
void processAudioWindow(const vector<float>& buf);
void processMfccWindow(const vector<float>& buf);
@ -143,11 +143,26 @@ StreamingState::finishStream()
return model_->decode(decoder_state_);
}
Metadata*
StreamingState::finishStreamWithMetadata()
Result*
StreamingState::finishStreamWithMetadata(unsigned int numResults)
{
finalizeStream();
return model_->decode_metadata(decoder_state_);
vector<Metadata*> metadata = model_->decode_metadata(decoder_state_, numResults);
std::unique_ptr<Result> result(new Result());
result->num_transcriptions = metadata.size();
std::unique_ptr<Metadata[]> items(new Metadata[result->num_transcriptions]);
for (int i = 0; i < result->num_transcriptions; ++i) {
std::unique_ptr<Metadata> pointer(new Metadata(*metadata[i]));
items[i] = *pointer.release();
}
result->transcriptions = items.release();
return result.release();
}
void
@ -410,12 +425,13 @@ DS_FinishStream(StreamingState* aSctx)
return str;
}
Metadata*
DS_FinishStreamWithMetadata(StreamingState* aSctx)
Result*
DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int numResults)
{
Metadata* metadata = aSctx->finishStreamWithMetadata();
Result* result = aSctx->finishStreamWithMetadata(numResults);
DS_FreeStream(aSctx);
return metadata;
return result;
}
StreamingState*
@ -441,13 +457,14 @@ DS_SpeechToText(ModelState* aCtx,
return DS_FinishStream(ctx);
}
Metadata*
Result*
DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize)
unsigned int aBufferSize,
unsigned int numResults)
{
StreamingState* ctx = CreateStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize);
return DS_FinishStreamWithMetadata(ctx);
return DS_FinishStreamWithMetadata(ctx, numResults);
}
void
@ -468,6 +485,25 @@ DS_FreeMetadata(Metadata* m)
}
}
void
DS_FreeResult(Result* r)
{
if (r) {
for (int i = 0; i < r->num_transcriptions; ++i) {
Metadata* m = &r->transcriptions[i];
for (int j = 0; j < m->num_items; ++j) {
free(m->items[j].character);
}
delete[] m->items;
}
delete[] r->transcriptions;
delete r;
}
}
void
DS_FreeString(char* str)
{

View File

@ -48,6 +48,16 @@ typedef struct Metadata {
double confidence;
} Metadata;
/**
* @brief Stores Metadata structs for each alternative transcription
*/
typedef struct Result {
/** List of transcriptions */
Metadata* transcriptions;
/** Size of the list of transcriptions */
int num_transcriptions;
} Result;
enum DeepSpeech_Error_Codes
{
// OK
@ -192,9 +202,10 @@ char* DS_SpeechToText(ModelState* aCtx,
* The user is responsible for freeing Metadata by calling {@link DS_FreeMetadata()}. Returns NULL on error.
*/
DEEPSPEECH_EXPORT
Metadata* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize);
Result* DS_SpeechToTextWithMetadata(ModelState* aCtx,
const short* aBuffer,
unsigned int aBufferSize,
unsigned int numResults);
/**
* @brief Create a new streaming inference state. The streaming state returned
@ -261,7 +272,8 @@ char* DS_FinishStream(StreamingState* aSctx);
* @note This method will free the state pointer (@p aSctx).
*/
DEEPSPEECH_EXPORT
Metadata* DS_FinishStreamWithMetadata(StreamingState* aSctx);
Result* DS_FinishStreamWithMetadata(StreamingState* aSctx,
unsigned int numResults);
/**
* @brief Destroy a streaming state without decoding the computed logits. This
@ -281,6 +293,12 @@ void DS_FreeStream(StreamingState* aSctx);
DEEPSPEECH_EXPORT
void DS_FreeMetadata(Metadata* m);
/**
* @brief Free memory allocated for result information.
*/
DEEPSPEECH_EXPORT
void DS_FreeResult(Result* r);
/**
* @brief Free a char* string returned by the DeepSpeech API.
*/

View File

@ -32,32 +32,41 @@ ModelState::init(const char* model_path)
char*
ModelState::decode(const DecoderState& state) const
{
vector<Output> out = state.decode();
vector<Output> out = state.decode(1);
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
}
Metadata*
ModelState::decode_metadata(const DecoderState& state)
vector<Metadata*>
ModelState::decode_metadata(const DecoderState& state,
size_t top_paths)
{
vector<Output> out = state.decode();
vector<Output> out = state.decode(top_paths);
std::unique_ptr<Metadata> metadata(new Metadata());
metadata->num_items = out[0].tokens.size();
metadata->confidence = out[0].confidence;
vector<Metadata*> meta_out;
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
size_t max_results = std::min(top_paths, out.size());
// Loop through each character
for (int i = 0; i < out[0].tokens.size(); ++i) {
items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str());
items[i].timestep = out[0].timesteps[i];
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
for (int j = 0; j < max_results; ++j) {
std::unique_ptr<Metadata> metadata(new Metadata());
metadata->num_items = out[j].tokens.size();
metadata->confidence = out[j].confidence;
if (items[i].start_time < 0) {
items[i].start_time = 0;
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
// Loop through each character
for (int i = 0; i < out[j].tokens.size(); ++i) {
items[i].character = strdup(alphabet_.StringFromLabel(out[j].tokens[i]).c_str());
items[i].timestep = out[j].timesteps[i];
items[i].start_time = out[j].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
if (items[i].start_time < 0) {
items[i].start_time = 0;
}
}
metadata->items = items.release();
meta_out.push_back(metadata.release());
}
metadata->items = items.release();
return metadata.release();
return meta_out;
}

View File

@ -66,11 +66,14 @@ struct ModelState {
* @brief Return character-level metadata including letter timings.
*
* @param state Decoder state to use when decoding.
* @param top_paths Number of alternate results to return.
*
* @return Metadata struct containing MetadataItem structs for each character.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
* @return Vector of Metadata structs containing MetadataItem structs for each character.
* Each represents an alternate transcription, with the first ranked most probable.
* The user is responsible for freeing Metadata by calling DS_FreeMetadata() on each item.
*/
virtual Metadata* decode_metadata(const DecoderState& state);
virtual std::vector<Metadata*> decode_metadata(const DecoderState& state,
size_t top_paths);
};
#endif // MODELSTATE_H