diff --git a/native_client/args.h b/native_client/args.h index 60630651..3af3f54b 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -30,6 +30,8 @@ bool extended_metadata = false; bool json_output = false; +int stream_size = 0; + void PrintHelp(const char* bin) { std::cout << @@ -45,6 +47,7 @@ void PrintHelp(const char* bin) " -t Run in benchmark mode, output mfcc & inference time\n" " --extended Output string from extended metadata\n" " --json Extended output, shows word timings as JSON\n" + " --stream size Run in stream mode, output intermediate results\n" " --help Show help\n" " --version Print version and exits\n"; DS_PrintVersions(); @@ -64,6 +67,7 @@ bool ProcessArgs(int argc, char** argv) {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, {"json", no_argument, nullptr, 'j'}, + {"stream", required_argument, nullptr, 's'}, {"help", no_argument, nullptr, 'h'}, {"version", no_argument, nullptr, 'v'}, {nullptr, no_argument, nullptr, 0} @@ -118,6 +122,10 @@ bool ProcessArgs(int argc, char** argv) json_output = true; break; + case 's': + stream_size = atoi(optarg); + break; + case 'h': // -h or --help case '?': // Unrecognized option default: @@ -136,6 +144,12 @@ bool ProcessArgs(int argc, char** argv) return false; } + if (stream_size < 0 || stream_size % 160 != 0) { + std::cout << + "Stream buffer size must be multiples of 160\n"; + return false; + } + return true; } diff --git a/native_client/client.cc b/native_client/client.cc index 618e9345..f1148ebc 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -70,6 +70,31 @@ LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, Metadata *metadata = DS_SpeechToTextWithMetadata(aCtx, aBuffer, aBufferSize, aSampleRate); res.string = JSONOutput(metadata); DS_FreeMetadata(metadata); + } else if (stream_size > 0) { + StreamingState* ctx; + int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx); + if (status != DS_ERR_OK) { + res.string = strdup(""); + return res; + } + size_t off = 0; + const char *last = nullptr; + while (off < aBufferSize) { + size_t cur = aBufferSize - off > stream_size ? stream_size : aBufferSize - off; + DS_FeedAudioContent(ctx, aBuffer + off, cur); + off += cur; + const char* partial = DS_IntermediateDecode(ctx); + if (last == nullptr || strcmp(last, partial)) { + printf("%s\n", partial); + last = partial; + } else { + DS_FreeString((char *) partial); + } + } + if (last != nullptr) { + DS_FreeString((char *) last); + } + res.string = DS_FinishStream(ctx); } else { res.string = DS_SpeechToText(aCtx, aBuffer, aBufferSize, aSampleRate); }