diff --git a/native_client/client.cc b/native_client/client.cc index 8f0a0455..5b7b4aea 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -22,7 +22,7 @@ struct ds_result { // DsSTT() instrumented struct ds_result* -LocalDsSTT(DeepSpeechContext* aCtx, const short* aBuffer, size_t aBufferSize, +LocalDsSTT(DeepSpeech& aCtx, const short* aBuffer, size_t aBufferSize, int aSampleRate) { float* mfcc; @@ -34,11 +34,11 @@ LocalDsSTT(DeepSpeechContext* aCtx, const short* aBuffer, size_t aBufferSize, clock_t ds_start_time = clock(); clock_t ds_end_mfcc = 0, ds_end_infer = 0; - int n_frames = - DsGetMfccFrames(aCtx, aBuffer, aBufferSize, aSampleRate, &mfcc); + int n_frames = 0; + aCtx.getMfccFrames(aBuffer, aBufferSize, aSampleRate, &mfcc, &n_frames); ds_end_mfcc = clock(); - res->string = DsInfer(aCtx, mfcc, n_frames); + res->string = aCtx.infer(mfcc, n_frames); ds_end_infer = clock(); free(mfcc); @@ -66,8 +66,7 @@ main(int argc, char **argv) } // Initialise DeepSpeech - DeepSpeechContext* ctx = DsInit(argv[1], N_CEP, N_CONTEXT); - assert(ctx); + DeepSpeech ctx = DeepSpeech(argv[1], N_CEP, N_CONTEXT); // Initialise SOX assert(sox_init() == SOX_SUCCESS); @@ -178,7 +177,6 @@ main(int argc, char **argv) } // Deinitialise and quit - DsClose(ctx); sox_quit(); return 0; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 9d713cc7..3e6fafb4 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -1,7 +1,7 @@ -#include "tensorflow/core/public/session.h" -#include "tensorflow/core/platform/env.h" #include "deepspeech.h" #include "c_speech_features.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/platform/env.h" #define COEFF 0.97f #define WIN_LEN 0.025f @@ -9,76 +9,69 @@ #define N_FFT 512 #define N_FILTERS 26 #define LOWFREQ 0 -#define N_CEP 26 #define CEP_LIFTER 22 -#define N_CONTEXT 9 using namespace tensorflow; -struct _DeepSpeechContext { +struct _DeepSpeechPrivate { Session* session; GraphDef graph_def; int ncep; int ncontext; }; -DeepSpeechContext* -DsInit(const char* aModelPath, int aNCep, int aNContext) +DeepSpeech::DeepSpeech(const char* aModelPath, int aNCep, int aNContext) { + mPriv = new DeepSpeechPrivate; + if (!aModelPath) { - return NULL; - } - - DeepSpeechContext* ctx = new DeepSpeechContext; - - Status status = NewSession(SessionOptions(), &ctx->session); - if (!status.ok()) { - delete ctx; - return NULL; - } - - status = ReadBinaryProto(Env::Default(), aModelPath, &ctx->graph_def); - if (!status.ok()) { - ctx->session->Close(); - delete ctx; - return NULL; - } - - status = ctx->session->Create(ctx->graph_def); - if (!status.ok()) { - ctx->session->Close(); - delete ctx; - return NULL; - } - - ctx->ncep = aNCep; - ctx->ncontext = aNContext; - - return ctx; -} - -void -DsClose(DeepSpeechContext* aCtx) -{ - if (!aCtx) { return; } - aCtx->session->Close(); - delete aCtx; + Status status = NewSession(SessionOptions(), &mPriv->session); + if (!status.ok()) { + return; + } + + status = ReadBinaryProto(Env::Default(), aModelPath, &mPriv->graph_def); + if (!status.ok()) { + mPriv->session->Close(); + mPriv->session = nullptr; + return; + } + + status = mPriv->session->Create(mPriv->graph_def); + if (!status.ok()) { + mPriv->session->Close(); + mPriv->session = nullptr; + return; + } + + mPriv->ncep = aNCep; + mPriv->ncontext = aNContext; } -int -DsGetMfccFrames(DeepSpeechContext* aCtx, const short* aBuffer, - size_t aBufferSize, int aSampleRate, float** aMfcc) +DeepSpeech::~DeepSpeech() { - const int contextSize = aCtx->ncep * aCtx->ncontext; - const int frameSize = aCtx->ncep + (2 * aCtx->ncep * aCtx->ncontext); + if (mPriv->session) { + mPriv->session->Close(); + } + + delete mPriv; +} + +void +DeepSpeech::getMfccFrames(const short* aBuffer, unsigned int aBufferSize, + int aSampleRate, float** aMfcc, int* aNFrames, + int* aFrameLen) +{ + const int contextSize = mPriv->ncep * mPriv->ncontext; + const int frameSize = mPriv->ncep + (2 * mPriv->ncep * mPriv->ncontext); // Compute MFCC features float* mfcc; int n_frames = csf_mfcc(aBuffer, aBufferSize, aSampleRate, - WIN_LEN, WIN_STEP, aCtx->ncep, N_FILTERS, N_FFT, + WIN_LEN, WIN_STEP, mPriv->ncep, N_FILTERS, N_FFT, LOWFREQ, aSampleRate/2, COEFF, CEP_LIFTER, 1, NULL, &mfcc); @@ -87,30 +80,30 @@ DsGetMfccFrames(DeepSpeechContext* aCtx, const short* aBuffer, // TODO: Use MFCC of silence instead of zero float* ds_input = (float*)calloc(sizeof(float), ds_input_length * frameSize); for (int i = 0, idx = 0, mfcc_idx = 0; i < ds_input_length; - i++, idx += frameSize, mfcc_idx += aCtx->ncep * 2) { + i++, idx += frameSize, mfcc_idx += mPriv->ncep * 2) { // Past context - for (int j = N_CONTEXT; j > 0; j--) { + for (int j = mPriv->ncontext; j > 0; j--) { int frame_index = (i * 2) - (j * 2); if (frame_index < 0) { continue; } - int mfcc_base = frame_index * aCtx->ncep; - int base = (N_CONTEXT - j) * N_CEP; - for (int k = 0; k < N_CEP; k++) { + int mfcc_base = frame_index * mPriv->ncep; + int base = (mPriv->ncontext - j) * mPriv->ncep; + for (int k = 0; k < mPriv->ncep; k++) { ds_input[idx + base + k] = mfcc[mfcc_base + k]; } } // Present context - for (int j = 0; j < N_CEP; j++) { + for (int j = 0; j < mPriv->ncep; j++) { ds_input[idx + j + contextSize] = mfcc[mfcc_idx + j]; } // Future context - for (int j = 1; j <= N_CONTEXT; j++) { + for (int j = 1; j <= mPriv->ncontext; j++) { int frame_index = (i * 2) + (j * 2); if (frame_index >= n_frames) { continue; } - int mfcc_base = frame_index * aCtx->ncep; - int base = contextSize + N_CEP + ((j - 1) * N_CEP); - for (int k = 0; k < N_CEP; k++) { + int mfcc_base = frame_index * mPriv->ncep; + int base = contextSize + mPriv->ncep + ((j - 1) * mPriv->ncep); + for (int k = 0; k < mPriv->ncep; k++) { ds_input[idx + base + k] = mfcc[mfcc_base + k]; } } @@ -136,14 +129,33 @@ DsGetMfccFrames(DeepSpeechContext* aCtx, const short* aBuffer, ds_input[idx] = (float)((ds_input[idx] - mean) / stddev); } - *aMfcc = ds_input; - return ds_input_length; + if (aMfcc) { + *aMfcc = ds_input; + } + if (aNFrames) { + *aNFrames = ds_input_length; + } + if (aFrameLen) { + *aFrameLen = contextSize; + } } char* -DsInfer(DeepSpeechContext* aCtx, float* aMfcc, int aNFrames) +DeepSpeech::infer(float* aMfcc, int aNFrames, int aFrameLen) { - const int frameSize = aCtx->ncep + (2 * aCtx->ncep * aCtx->ncontext); + if (!mPriv->session) { + return nullptr; + } + + const int frameSize = mPriv->ncep + (2 * mPriv->ncep * mPriv->ncontext); + if (aFrameLen == 0) { + aFrameLen = frameSize; + } else if (aFrameLen < frameSize) { + std::cerr << "mfcc features array is too small (expected " << + frameSize << ", got " << aFrameLen << ")\n"; + return nullptr; + } + Tensor input(DT_FLOAT, TensorShape({1, aNFrames, frameSize})); auto input_mapped = input.tensor(); @@ -151,18 +163,19 @@ DsInfer(DeepSpeechContext* aCtx, float* aMfcc, int aNFrames) for (int j = 0; j < frameSize; j++, idx++) { input_mapped(0, i, j) = aMfcc[idx]; } + idx += (aFrameLen - frameSize); } Tensor n_frames(DT_INT32, TensorShape({1})); n_frames.scalar()() = aNFrames; std::vector outputs; - Status status = - aCtx->session->Run({{ "input_node", input }, { "input_lengths", n_frames }}, - {"output_node"}, {}, &outputs); + Status status = mPriv->session->Run( + {{ "input_node", input }, { "input_lengths", n_frames }}, + {"output_node"}, {}, &outputs); if (!status.ok()) { std::cerr << "Error running session: " << status.ToString() << "\n"; - return NULL; + return nullptr; } // Output is an array of shape (1, n_results, result_length). @@ -180,14 +193,14 @@ DsInfer(DeepSpeechContext* aCtx, float* aMfcc, int aNFrames) } char* -DsSTT(DeepSpeechContext* aCtx, const short* aBuffer, size_t aBufferSize, - int aSampleRate) +DeepSpeech::stt(const short* aBuffer, unsigned int aBufferSize, int aSampleRate) { float* mfcc; char* string; - int n_frames = - DsGetMfccFrames(aCtx, aBuffer, aBufferSize, aSampleRate, &mfcc); - string = DsInfer(aCtx, mfcc, n_frames); + int n_frames; + + getMfccFrames(aBuffer, aBufferSize, aSampleRate, &mfcc, &n_frames, nullptr); + string = infer(mfcc, n_frames); free(mfcc); return string; } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 555e52a2..3b6c5cda 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -2,73 +2,85 @@ #ifndef __DEEPSPEECH_H__ #define __DEEPSPEECH_H__ -typedef struct _DeepSpeechContext DeepSpeechContext; +typedef struct _DeepSpeechPrivate DeepSpeechPrivate; -/** - * @brief Initialise a DeepSpeech context. - * - * @param aModelPath The path to the frozen model graph. - * @param aNCep The number of cepstrum the model was trained with. - * @param aNContext The context window the model was trained with. - * - * @return A DeepSpeech context. - */ -DeepSpeechContext* DsInit(const char* aModelPath, int aNCep, int aNContext); +class DeepSpeech { + private: + DeepSpeechPrivate* mPriv; -/** - * @brief De-initialise a DeepSpeech context. - * - * @param aCtx A DeepSpeech context. - */ -void DsClose(DeepSpeechContext* aCtx); + public: + /** + * @brief Initialise a DeepSpeech context. + * + * @param aModelPath The path to the frozen model graph. + * @param aNCep The number of cepstrum the model was trained with. + * @param aNContext The context window the model was trained with. + * + * @return A DeepSpeech context. + */ + DeepSpeech(const char* aModelPath, int aNCep, int aNContext); + ~DeepSpeech(); -/** - * @brief Extract MFCC features from a given audio signal and add context. - * - * Extracts MFCC features from a given audio signal and adds the appropriate - * amount of context to run inference with the given DeepSpeech context. - * - * @param aCtx A DeepSpeech context. - * @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample - * rate. - * @param aBufferSize The sample-length of the audio signal. - * @param aSampleRate The sample-rate of the audio signal. - * @param[out] aMFCC An array containing features, of shape - * (frames, ncep * ncontext). The user is responsible for - * freeing the array. - * - * @return The number of frames in @p aMFCC. - */ -int DsGetMfccFrames(DeepSpeechContext* aCtx, const short* aBuffer, - size_t aBufferSize, int aSampleRate, float** aMfcc); + /** + * @brief Extract MFCC features from a given audio signal and add context. + * + * Extracts MFCC features from a given audio signal and adds the appropriate + * amount of context to run inference with the given DeepSpeech context. + * + * @param aCtx A DeepSpeech context. + * @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample + * rate. + * @param aBufferSize The sample-length of the audio signal. + * @param aSampleRate The sample-rate of the audio signal. + * @param[out] aMFCC An array containing features, of shape + * (@p aNFrames, ncep * ncontext). The user is responsible + * for freeing the array. + * @param[out] aNFrames (optional) The number of frames in @p aMFCC. + * @param[out] aFrameLen (optional) The length of each frame + * (ncep * ncontext) in @p aMFCC. + */ + void getMfccFrames(const short* aBuffer, + unsigned int aBufferSize, + int aSampleRate, + float** aMfcc, + int* aNFrames = nullptr, + int* aFrameLen = nullptr); -/** - * @brief Run inference on the given audio. - * - * Runs inference on the given MFCC audio features with the given DeepSpeech - * context. See DsGetMfccFrames(). - * - * @param aCtx A DeepSpeech context. - * @param aMfcc MFCC features with the appropriate amount of context per frame. - * @param aNFrames The number of frames in @p aMfcc. - * - * @return The resulting string after running inference. The user is - * responsible for freeing this string. - */ -char* DsInfer(DeepSpeechContext* aCtx, float* aMfcc, int aNFrames); + /** + * @brief Run inference on the given audio. + * + * Runs inference on the given MFCC audio features with the given DeepSpeech + * context. See DsGetMfccFrames(). + * + * @param aCtx A DeepSpeech context. + * @param aMfcc MFCC features with the appropriate amount of context per + * frame. + * @param aNFrames The number of frames in @p aMfcc. + * @param aFrameLen (optional) The length of each frame in @p aMfcc. If + * specified, this will be used to verify the array is + * large enough. + * + * @return The resulting string after running inference. The user is + * responsible for freeing this string. + */ + char* infer(float* aMfcc, + int aNFrames, + int aFrameLen = 0); -/** - * @brief Use DeepSpeech to perform Speech-To-Text. - * - * @param aMfcc An MFCC features array. - * @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample - * rate. - * @param aBufferSize The number of samples in the audio signal. - * @param aSampleRate The sample-rate of the audio signal. - * - * @return The STT result. The user is responsible for freeing this string. - */ -char* DsSTT(DeepSpeechContext* aCtx, const short* aBuffer, size_t aBufferSize, - int aSampleRate); + /** + * @brief Use DeepSpeech to perform Speech-To-Text. + * + * @param aMfcc An MFCC features array. + * @param aBuffer A 16-bit, mono raw audio signal at the appropriate sample + * rate. + * @param aBufferSize The number of samples in the audio signal. + * @param aSampleRate The sample-rate of the audio signal. + * + * @return The STT result. The user is responsible for freeing this string. + */ + char* stt(const short* aBuffer, + unsigned int aBufferSize, + int aSampleRate); +}; #endif /* __DEEPSPEECH_H__ */