Output word-level metadata from the client with the -e tag

This commit is contained in:
dabinat 2019-03-21 15:53:04 -07:00
parent 79830fe512
commit a3b81d054e
2 changed files with 88 additions and 5 deletions

View File

@ -26,10 +26,12 @@ bool show_times = false;
bool has_versions = false;
bool extended_metadata = false;
void PrintHelp(const char* bin)
{
std::cout <<
"Usage: " << bin << " --model MODEL --alphabet ALPHABET [--lm LM --trie TRIE] --audio AUDIO [-t]\n"
"Usage: " << bin << " --model MODEL --alphabet ALPHABET [--lm LM --trie TRIE] --audio AUDIO [-t] [-e]\n"
"\n"
"Running DeepSpeech inference.\n"
"\n"
@ -39,6 +41,7 @@ void PrintHelp(const char* bin)
" --trie TRIE Path to the language model trie file created with native_client/generate_trie\n"
" --audio AUDIO Path to the audio file to run (WAV format)\n"
" -t Run in benchmark mode, output mfcc & inference time\n"
" -e Extended output, shows word timings as CSV (word, start time, duration)\n"
" --help Show help\n"
" --version Print version and exits\n";
DS_PrintVersions();
@ -47,7 +50,7 @@ void PrintHelp(const char* bin)
bool ProcessArgs(int argc, char** argv)
{
const char* const short_opts = "m:a:l:r:w:thv";
const char* const short_opts = "m:a:l:r:w:tehv";
const option long_opts[] = {
{"model", required_argument, nullptr, 'm'},
{"alphabet", required_argument, nullptr, 'a'},
@ -56,6 +59,7 @@ bool ProcessArgs(int argc, char** argv)
{"audio", required_argument, nullptr, 'w'},
{"run_very_slowly_without_trie_I_really_know_what_Im_doing", no_argument, nullptr, 999},
{"t", no_argument, nullptr, 't'},
{"e", no_argument, nullptr, 'e'},
{"help", no_argument, nullptr, 'h'},
{"version", no_argument, nullptr, 'v'},
{nullptr, no_argument, nullptr, 0}
@ -102,6 +106,10 @@ bool ProcessArgs(int argc, char** argv)
has_versions = true;
break;
case 'e':
extended_metadata = true;
break;
case 'h': // -h or --help
case '?': // Unrecognized option
default:

View File

@ -28,6 +28,7 @@
#include <dirent.h>
#include <unistd.h>
#endif // NO_DIR
#include <vector>
#include "deepspeech.h"
#include "args.h"
@ -43,15 +44,30 @@ typedef struct {
double cpu_time_overall;
} ds_result;
struct meta_word {
std::string word;
float start_time;
float duration;
};
std::vector<meta_word> WordsFromMetadata(Metadata* metadata);
char* CSVOutput(std::vector<meta_word> words);
ds_result
LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize,
int aSampleRate)
int aSampleRate, bool extended_output)
{
ds_result res = {0};
clock_t ds_start_time = clock();
res.string = DS_SpeechToText(aCtx, aBuffer, aBufferSize, aSampleRate);
if (extended_output) {
Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, aSampleRate);
res.string = CSVOutput(WordsFromMetadata(metadata));
DS_FreeMetadata(metadata);
} else {
res.string = DS_SpeechToText(aCtx, aBuffer, aBufferSize, aSampleRate);
}
clock_t ds_end_infer = clock();
@ -224,7 +240,8 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
ds_result result = LocalDsSTT(context,
(const short*)audio.buffer,
audio.buffer_size / 2,
audio.sample_rate);
audio.sample_rate,
extended_metadata);
free(audio.buffer);
if (result.string) {
@ -238,6 +255,64 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
}
}
std::vector<meta_word>
WordsFromMetadata(Metadata* metadata)
{
std::vector<meta_word> word_list;
std::string word = "";
float word_start_time = 0;
// Loop through each character
for (int i=0; i < metadata->num_items; i++) {
MetadataItem item = metadata->items[i];
if (strcmp(item.character," ") != 0) {
word.append(item.character);
}
// Word boundary is either a space or the last character in the array
if (strcmp(item.character," ") == 0 || i == metadata->num_items-1) {
float word_duration = item.start_time - word_start_time;
if (word_duration < 0) {
word_duration = 0;
}
meta_word w;
w.word = word;
w.start_time = word_start_time;
w.duration = word_duration;
word_list.push_back(w);
// Reset
word = "";
word_start_time = 0;
} else {
if (word.length() == 1) {
word_start_time = item.start_time; // Log the start time of the new word
}
}
}
return word_list;
}
char*
CSVOutput(std::vector<meta_word> words)
{
std::ostringstream out_string;
for (int i=0; i < words.size(); i++) {
meta_word w = words[i];
out_string << w.word << "," << std::to_string(w.start_time) << "," << std::to_string(w.duration) << "\n";
}
return strdup(out_string.str().c_str());
}
int
main(int argc, char **argv)
{