Refactor TF and TFLite model implementations into their own classes/files
This commit is contained in:
parent
e136b5299a
commit
6f953837fa
@ -70,9 +70,20 @@ tf_cc_shared_object(
|
|||||||
srcs = ["deepspeech.cc",
|
srcs = ["deepspeech.cc",
|
||||||
"deepspeech.h",
|
"deepspeech.h",
|
||||||
"alphabet.h",
|
"alphabet.h",
|
||||||
|
"modelstate.h",
|
||||||
|
"modelstate.cc",
|
||||||
"ds_version.h",
|
"ds_version.h",
|
||||||
"ds_graph_version.h"] +
|
"ds_graph_version.h"] +
|
||||||
DECODER_SOURCES,
|
DECODER_SOURCES +
|
||||||
|
select({
|
||||||
|
"//native_client:tflite": [
|
||||||
|
"tflitemodelstate.h",
|
||||||
|
"tflitemodelstate.cc"
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
|
"tfmodelstate.h",
|
||||||
|
"tfmodelstate.cc"
|
||||||
|
]}),
|
||||||
copts = select({
|
copts = select({
|
||||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||||
"//tensorflow:windows": ["/w"],
|
"//tensorflow:windows": ["/w"],
|
||||||
|
@ -11,17 +11,14 @@
|
|||||||
|
|
||||||
#include "deepspeech.h"
|
#include "deepspeech.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
#include "modelstate.h"
|
||||||
|
|
||||||
#include "native_client/ds_version.h"
|
#include "native_client/ds_version.h"
|
||||||
#include "native_client/ds_graph_version.h"
|
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tfmodelstate.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#else
|
||||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
#include "tflitemodelstate.h"
|
||||||
#else // USE_TFLITE
|
|
||||||
#include "tensorflow/lite/model.h"
|
|
||||||
#include "tensorflow/lite/kernels/register.h"
|
|
||||||
#endif // USE_TFLITE
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||||
@ -36,23 +33,9 @@
|
|||||||
#define LOGE(...)
|
#define LOGE(...)
|
||||||
#endif // __ANDROID__
|
#endif // __ANDROID__
|
||||||
|
|
||||||
//TODO: infer batch size from model/use dynamic batch size
|
|
||||||
constexpr unsigned int BATCH_SIZE = 1;
|
|
||||||
|
|
||||||
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
|
||||||
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
|
||||||
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
using namespace tensorflow;
|
|
||||||
#else
|
|
||||||
using namespace tflite;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
|
||||||
/* This is the actual implementation of the streaming inference API, with the
|
/* This is the implementation of the streaming inference API.
|
||||||
Model class just forwarding the calls to this class.
|
|
||||||
|
|
||||||
The streaming process uses three buffers that are fed eagerly as audio data
|
The streaming process uses three buffers that are fed eagerly as audio data
|
||||||
is fed in. The buffers only hold the minimum amount of data needed to do a
|
is fed in. The buffers only hold the minimum amount of data needed to do a
|
||||||
@ -75,17 +58,20 @@ using std::vector;
|
|||||||
API. When audio_buffer is full, features are computed from it and pushed to
|
API. When audio_buffer is full, features are computed from it and pushed to
|
||||||
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
||||||
When batch_buffer is full, we do a single step through the acoustic model
|
When batch_buffer is full, we do a single step through the acoustic model
|
||||||
and accumulate results in the DecoderState structure.
|
and accumulate the intermediate decoding state in the DecoderState structure.
|
||||||
|
|
||||||
When finishStream() is called, we decode the accumulated logits and return
|
When finishStream() is called, we return the corresponding transcription from
|
||||||
the corresponding transcription.
|
the current decoder state.
|
||||||
*/
|
*/
|
||||||
struct StreamingState {
|
struct StreamingState {
|
||||||
vector<float> audio_buffer;
|
vector<float> audio_buffer_;
|
||||||
vector<float> mfcc_buffer;
|
vector<float> mfcc_buffer_;
|
||||||
vector<float> batch_buffer;
|
vector<float> batch_buffer_;
|
||||||
ModelState* model;
|
ModelState* model_;
|
||||||
std::unique_ptr<DecoderState> decoder_state;
|
std::unique_ptr<DecoderState> decoder_state_;
|
||||||
|
|
||||||
|
StreamingState();
|
||||||
|
~StreamingState();
|
||||||
|
|
||||||
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
||||||
char* intermediateDecode();
|
char* intermediateDecode();
|
||||||
@ -100,133 +86,12 @@ struct StreamingState {
|
|||||||
void processBatch(const vector<float>& buf, unsigned int n_steps);
|
void processBatch(const vector<float>& buf, unsigned int n_steps);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ModelState {
|
StreamingState::StreamingState()
|
||||||
#ifndef USE_TFLITE
|
|
||||||
MemmappedEnv* mmap_env;
|
|
||||||
Session* session;
|
|
||||||
GraphDef graph_def;
|
|
||||||
#else // USE_TFLITE
|
|
||||||
std::unique_ptr<Interpreter> interpreter;
|
|
||||||
std::unique_ptr<FlatBufferModel> fbmodel;
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
unsigned int ncep;
|
|
||||||
unsigned int ncontext;
|
|
||||||
Alphabet* alphabet;
|
|
||||||
Scorer* scorer;
|
|
||||||
unsigned int beam_width;
|
|
||||||
unsigned int n_steps;
|
|
||||||
unsigned int n_context;
|
|
||||||
unsigned int n_features;
|
|
||||||
unsigned int mfcc_feats_per_timestep;
|
|
||||||
unsigned int sample_rate;
|
|
||||||
unsigned int audio_win_len;
|
|
||||||
unsigned int audio_win_step;
|
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
|
||||||
size_t previous_state_size;
|
|
||||||
std::unique_ptr<float[]> previous_state_c_;
|
|
||||||
std::unique_ptr<float[]> previous_state_h_;
|
|
||||||
|
|
||||||
int input_node_idx;
|
|
||||||
int previous_state_c_idx;
|
|
||||||
int previous_state_h_idx;
|
|
||||||
int input_samples_idx;
|
|
||||||
|
|
||||||
int logits_idx;
|
|
||||||
int new_state_c_idx;
|
|
||||||
int new_state_h_idx;
|
|
||||||
int mfccs_idx;
|
|
||||||
|
|
||||||
std::vector<int> acoustic_exec_plan;
|
|
||||||
std::vector<int> mfcc_exec_plan;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ModelState();
|
|
||||||
~ModelState();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
|
||||||
* CTC decoder with KenLM enabled
|
|
||||||
*
|
|
||||||
* @return String representing the decoded text.
|
|
||||||
*/
|
|
||||||
char* decode(DecoderState* state);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
|
||||||
* CTC decoder with KenLM enabled
|
|
||||||
*
|
|
||||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
|
||||||
*/
|
|
||||||
vector<Output> decode_raw(DecoderState* state);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Return character-level metadata including letter timings.
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @return Metadata struct containing MetadataItem structs for each character.
|
|
||||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
|
||||||
*/
|
|
||||||
Metadata* decode_metadata(DecoderState* state);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Do a single inference step in the acoustic model, with:
|
|
||||||
* input=mfcc
|
|
||||||
* input_lengths=[n_frames]
|
|
||||||
*
|
|
||||||
* @param mfcc batch input data
|
|
||||||
* @param n_frames number of timesteps in the data
|
|
||||||
*
|
|
||||||
* @param[out] output_logits Where to store computed logits.
|
|
||||||
*/
|
|
||||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& logits_output);
|
|
||||||
|
|
||||||
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
|
|
||||||
};
|
|
||||||
|
|
||||||
ModelState::ModelState()
|
|
||||||
:
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
mmap_env(nullptr)
|
|
||||||
, session(nullptr)
|
|
||||||
#else // USE_TFLITE
|
|
||||||
interpreter(nullptr)
|
|
||||||
, fbmodel(nullptr)
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
, ncep(0)
|
|
||||||
, ncontext(0)
|
|
||||||
, alphabet(nullptr)
|
|
||||||
, scorer(nullptr)
|
|
||||||
, beam_width(0)
|
|
||||||
, n_steps(-1)
|
|
||||||
, n_context(-1)
|
|
||||||
, n_features(-1)
|
|
||||||
, mfcc_feats_per_timestep(-1)
|
|
||||||
, sample_rate(DEFAULT_SAMPLE_RATE)
|
|
||||||
, audio_win_len(DEFAULT_WINDOW_LENGTH)
|
|
||||||
, audio_win_step(DEFAULT_WINDOW_STEP)
|
|
||||||
#ifdef USE_TFLITE
|
|
||||||
, previous_state_size(0)
|
|
||||||
, previous_state_c_(nullptr)
|
|
||||||
, previous_state_h_(nullptr)
|
|
||||||
#endif
|
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelState::~ModelState()
|
StreamingState::~StreamingState()
|
||||||
{
|
{
|
||||||
#ifndef USE_TFLITE
|
|
||||||
if (session) {
|
|
||||||
Status status = session->Close();
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete mmap_env;
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
|
|
||||||
delete scorer;
|
|
||||||
delete alphabet;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -243,19 +108,19 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
{
|
{
|
||||||
// Consume all the data that was passed in, processing full buffers if needed
|
// Consume all the data that was passed in, processing full buffers if needed
|
||||||
while (buffer_size > 0) {
|
while (buffer_size > 0) {
|
||||||
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
|
while (buffer_size > 0 && audio_buffer_.size() < model_->audio_win_len_) {
|
||||||
// Convert i16 sample into f32
|
// Convert i16 sample into f32
|
||||||
float multiplier = 1.0f / (1 << 15);
|
float multiplier = 1.0f / (1 << 15);
|
||||||
audio_buffer.push_back((float)(*buffer) * multiplier);
|
audio_buffer_.push_back((float)(*buffer) * multiplier);
|
||||||
++buffer;
|
++buffer;
|
||||||
--buffer_size;
|
--buffer_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the buffer is full, process and shift it
|
// If the buffer is full, process and shift it
|
||||||
if (audio_buffer.size() == model->audio_win_len) {
|
if (audio_buffer_.size() == model_->audio_win_len_) {
|
||||||
processAudioWindow(audio_buffer);
|
processAudioWindow(audio_buffer_);
|
||||||
// Shift data by one step
|
// Shift data by one step
|
||||||
shift_buffer_left(audio_buffer, model->audio_win_step);
|
shift_buffer_left(audio_buffer_, model_->audio_win_step_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Repeat until buffer empty
|
// Repeat until buffer empty
|
||||||
@ -265,21 +130,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
char*
|
char*
|
||||||
StreamingState::intermediateDecode()
|
StreamingState::intermediateDecode()
|
||||||
{
|
{
|
||||||
return model->decode(decoder_state.get());
|
return model_->decode(decoder_state_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
StreamingState::finishStream()
|
StreamingState::finishStream()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model->decode(decoder_state.get());
|
return model_->decode(decoder_state_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
StreamingState::finishStreamWithMetadata()
|
StreamingState::finishStreamWithMetadata()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model->decode_metadata(decoder_state.get());
|
return model_->decode_metadata(decoder_state_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -287,8 +152,8 @@ StreamingState::processAudioWindow(const vector<float>& buf)
|
|||||||
{
|
{
|
||||||
// Compute MFCC features
|
// Compute MFCC features
|
||||||
vector<float> mfcc;
|
vector<float> mfcc;
|
||||||
mfcc.reserve(model->n_features);
|
mfcc.reserve(model_->n_features_);
|
||||||
model->compute_mfcc(buf, mfcc);
|
model_->compute_mfcc(buf, mfcc);
|
||||||
pushMfccBuffer(mfcc);
|
pushMfccBuffer(mfcc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -296,23 +161,23 @@ void
|
|||||||
StreamingState::finalizeStream()
|
StreamingState::finalizeStream()
|
||||||
{
|
{
|
||||||
// Flush audio buffer
|
// Flush audio buffer
|
||||||
processAudioWindow(audio_buffer);
|
processAudioWindow(audio_buffer_);
|
||||||
|
|
||||||
// Add empty mfcc vectors at end of sample
|
// Add empty mfcc vectors at end of sample
|
||||||
for (int i = 0; i < model->n_context; ++i) {
|
for (int i = 0; i < model_->n_context_; ++i) {
|
||||||
addZeroMfccWindow();
|
addZeroMfccWindow();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process final batch
|
// Process final batch
|
||||||
if (batch_buffer.size() > 0) {
|
if (batch_buffer_.size() > 0) {
|
||||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
StreamingState::addZeroMfccWindow()
|
StreamingState::addZeroMfccWindow()
|
||||||
{
|
{
|
||||||
vector<float> zero_buffer(model->n_features, 0.f);
|
vector<float> zero_buffer(model_->n_features_, 0.f);
|
||||||
pushMfccBuffer(zero_buffer);
|
pushMfccBuffer(zero_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -332,15 +197,15 @@ StreamingState::pushMfccBuffer(const vector<float>& buf)
|
|||||||
auto end = buf.end();
|
auto end = buf.end();
|
||||||
while (start != end) {
|
while (start != end) {
|
||||||
// Copy from input buffer to mfcc_buffer, stopping if we have a full context window
|
// Copy from input buffer to mfcc_buffer, stopping if we have a full context window
|
||||||
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer),
|
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer_),
|
||||||
model->mfcc_feats_per_timestep - mfcc_buffer.size());
|
model_->mfcc_feats_per_timestep_ - mfcc_buffer_.size());
|
||||||
assert(mfcc_buffer.size() <= model->mfcc_feats_per_timestep);
|
assert(mfcc_buffer_.size() <= model_->mfcc_feats_per_timestep_);
|
||||||
|
|
||||||
// If we have a full context window
|
// If we have a full context window
|
||||||
if (mfcc_buffer.size() == model->mfcc_feats_per_timestep) {
|
if (mfcc_buffer_.size() == model_->mfcc_feats_per_timestep_) {
|
||||||
processMfccWindow(mfcc_buffer);
|
processMfccWindow(mfcc_buffer_);
|
||||||
// Shift data by one step of one mfcc feature vector
|
// Shift data by one step of one mfcc feature vector
|
||||||
shift_buffer_left(mfcc_buffer, model->n_features);
|
shift_buffer_left(mfcc_buffer_, model_->n_features_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -352,14 +217,14 @@ StreamingState::processMfccWindow(const vector<float>& buf)
|
|||||||
auto end = buf.end();
|
auto end = buf.end();
|
||||||
while (start != end) {
|
while (start != end) {
|
||||||
// Copy from input buffer to batch_buffer, stopping if we have a full batch
|
// Copy from input buffer to batch_buffer, stopping if we have a full batch
|
||||||
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer),
|
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer_),
|
||||||
model->n_steps * model->mfcc_feats_per_timestep - batch_buffer.size());
|
model_->n_steps_ * model_->mfcc_feats_per_timestep_ - batch_buffer_.size());
|
||||||
assert(batch_buffer.size() <= model->n_steps * model->mfcc_feats_per_timestep);
|
assert(batch_buffer_.size() <= model_->n_steps_ * model_->mfcc_feats_per_timestep_);
|
||||||
|
|
||||||
// If we have a full batch
|
// If we have a full batch
|
||||||
if (batch_buffer.size() == model->n_steps * model->mfcc_feats_per_timestep) {
|
if (batch_buffer_.size() == model_->n_steps_ * model_->mfcc_feats_per_timestep_) {
|
||||||
processBatch(batch_buffer, model->n_steps);
|
processBatch(batch_buffer_, model_->n_steps_);
|
||||||
batch_buffer.resize(0);
|
batch_buffer_.resize(0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -368,272 +233,27 @@ void
|
|||||||
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||||
{
|
{
|
||||||
vector<float> logits;
|
vector<float> logits;
|
||||||
model->infer(buf.data(), n_steps, logits);
|
model_->infer(buf.data(), n_steps, logits);
|
||||||
|
|
||||||
const int cutoff_top_n = 40;
|
const int cutoff_top_n = 40;
|
||||||
const double cutoff_prob = 1.0;
|
const double cutoff_prob = 1.0;
|
||||||
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
|
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
||||||
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||||
|
|
||||||
// Convert logits to double
|
// Convert logits to double
|
||||||
vector<double> inputs(logits.begin(), logits.end());
|
vector<double> inputs(logits.begin(), logits.end());
|
||||||
|
|
||||||
decoder_next(inputs.data(),
|
decoder_next(inputs.data(),
|
||||||
*model->alphabet,
|
*model_->alphabet_,
|
||||||
decoder_state.get(),
|
decoder_state_.get(),
|
||||||
n_frames,
|
n_frames,
|
||||||
num_classes,
|
num_classes,
|
||||||
cutoff_prob,
|
cutoff_prob,
|
||||||
cutoff_top_n,
|
cutoff_top_n,
|
||||||
model->beam_width,
|
model_->beam_width_,
|
||||||
model->scorer);
|
model_->scorer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
|
||||||
ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
|
||||||
{
|
|
||||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, n_features}));
|
|
||||||
|
|
||||||
auto input_mapped = input.flat<float>();
|
|
||||||
int i;
|
|
||||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
|
||||||
input_mapped(i) = aMfcc[i];
|
|
||||||
}
|
|
||||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
|
||||||
input_mapped(i) = 0.;
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
|
||||||
input_lengths.scalar<int>()() = n_frames;
|
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
|
||||||
Status status = session->Run(
|
|
||||||
{{"input_node", input}, {"input_lengths", input_lengths}},
|
|
||||||
{"logits"}, {}, &outputs);
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto logits_mapped = outputs[0].flat<float>();
|
|
||||||
// The CTCDecoder works with log-probs.
|
|
||||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
|
||||||
logits_output.push_back(logits_mapped(t));
|
|
||||||
}
|
|
||||||
#else // USE_TFLITE
|
|
||||||
// Feeding input_node
|
|
||||||
float* input_node = interpreter->typed_tensor<float>(input_node_idx);
|
|
||||||
{
|
|
||||||
int i;
|
|
||||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
|
||||||
input_node[i] = aMfcc[i];
|
|
||||||
}
|
|
||||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
|
||||||
input_node[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(previous_state_size > 0);
|
|
||||||
|
|
||||||
// Feeding previous_state_c, previous_state_h
|
|
||||||
memcpy(interpreter->typed_tensor<float>(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
|
||||||
memcpy(interpreter->typed_tensor<float>(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
|
||||||
|
|
||||||
interpreter->SetExecutionPlan(acoustic_exec_plan);
|
|
||||||
TfLiteStatus status = interpreter->Invoke();
|
|
||||||
if (status != kTfLiteOk) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
float* outputs = interpreter->typed_tensor<float>(logits_idx);
|
|
||||||
|
|
||||||
// The CTCDecoder works with log-probs.
|
|
||||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
|
||||||
logits_output.push_back(outputs[t]);
|
|
||||||
}
|
|
||||||
|
|
||||||
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(new_state_c_idx), sizeof(float) * previous_state_size);
|
|
||||||
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(new_state_h_idx), sizeof(float) * previous_state_size);
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
}
|
|
||||||
|
|
||||||
void
|
|
||||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
|
||||||
{
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
|
|
||||||
auto input_mapped = input.flat<float>();
|
|
||||||
int i;
|
|
||||||
for (i = 0; i < samples.size(); ++i) {
|
|
||||||
input_mapped(i) = samples[i];
|
|
||||||
}
|
|
||||||
for (; i < audio_win_len; ++i) {
|
|
||||||
input_mapped(i) = 0.f;
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
|
||||||
Status status = session->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
|
||||||
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The feature computation graph is hardcoded to one audio length for now
|
|
||||||
const int n_windows = 1;
|
|
||||||
assert(outputs[0].shape().num_elemements() / n_features == n_windows);
|
|
||||||
|
|
||||||
auto mfcc_mapped = outputs[0].flat<float>();
|
|
||||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
|
||||||
mfcc_output.push_back(mfcc_mapped(i));
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
// Feeding input_node
|
|
||||||
float* input_samples = interpreter->typed_tensor<float>(input_samples_idx);
|
|
||||||
for (int i = 0; i < samples.size(); ++i) {
|
|
||||||
input_samples[i] = samples[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
interpreter->SetExecutionPlan(mfcc_exec_plan);
|
|
||||||
TfLiteStatus status = interpreter->Invoke();
|
|
||||||
if (status != kTfLiteOk) {
|
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The feature computation graph is hardcoded to one audio length for now
|
|
||||||
int n_windows = 1;
|
|
||||||
TfLiteIntArray* out_dims = interpreter->tensor(mfccs_idx)->dims;
|
|
||||||
int num_elements = 1;
|
|
||||||
for (int i = 0; i < out_dims->size; ++i) {
|
|
||||||
num_elements *= out_dims->data[i];
|
|
||||||
}
|
|
||||||
assert(num_elements / n_features == n_windows);
|
|
||||||
|
|
||||||
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
|
||||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
|
||||||
mfcc_output.push_back(outputs[i]);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
char*
|
|
||||||
ModelState::decode(DecoderState* state)
|
|
||||||
{
|
|
||||||
vector<Output> out = ModelState::decode_raw(state);
|
|
||||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
vector<Output>
|
|
||||||
ModelState::decode_raw(DecoderState* state)
|
|
||||||
{
|
|
||||||
vector<Output> out = decoder_decode(state, *alphabet, beam_width, scorer);
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
Metadata*
|
|
||||||
ModelState::decode_metadata(DecoderState* state)
|
|
||||||
{
|
|
||||||
vector<Output> out = decode_raw(state);
|
|
||||||
|
|
||||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
|
||||||
metadata->num_items = out[0].tokens.size();
|
|
||||||
metadata->probability = out[0].probability;
|
|
||||||
|
|
||||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
|
||||||
|
|
||||||
// Loop through each character
|
|
||||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
|
||||||
items[i].character = strdup(alphabet->StringFromLabel(out[0].tokens[i]).c_str());
|
|
||||||
items[i].timestep = out[0].timesteps[i];
|
|
||||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);
|
|
||||||
|
|
||||||
if (items[i].start_time < 0) {
|
|
||||||
items[i].start_time = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
metadata->items = items.release();
|
|
||||||
return metadata.release();
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
|
||||||
int
|
|
||||||
tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
|
|
||||||
{
|
|
||||||
int rv = -1;
|
|
||||||
|
|
||||||
for (int i = 0; i < list.size(); ++i) {
|
|
||||||
const string& node_name = ctx->interpreter->tensor(list[i])->name;
|
|
||||||
if (node_name.compare(string(name)) == 0) {
|
|
||||||
rv = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(rv >= 0);
|
|
||||||
return rv;
|
|
||||||
}
|
|
||||||
|
|
||||||
int
|
|
||||||
tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
|
||||||
{
|
|
||||||
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
|
||||||
}
|
|
||||||
|
|
||||||
int
|
|
||||||
tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
|
||||||
{
|
|
||||||
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
|
||||||
}
|
|
||||||
|
|
||||||
void push_back_if_not_present(std::deque<int>& list, int value) {
|
|
||||||
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
|
||||||
list.push_back(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
|
||||||
// from the frontier list, then for each node which has that tensor id as an
|
|
||||||
// output, add it to the parent list, and add its input tensors to the frontier
|
|
||||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
|
||||||
// the parents list is constructed in reverse, adding elements to its front.
|
|
||||||
std::vector<int>
|
|
||||||
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
|
||||||
{
|
|
||||||
std::deque<int> parents;
|
|
||||||
std::deque<int> frontier;
|
|
||||||
frontier.push_back(tensor_id);
|
|
||||||
while (!frontier.empty()) {
|
|
||||||
int next_tensor_id = frontier.front();
|
|
||||||
frontier.pop_front();
|
|
||||||
// Find all nodes that have next_tensor_id as an output
|
|
||||||
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
|
||||||
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
|
||||||
// Search node outputs for the tensor we're looking for
|
|
||||||
for (int i = 0; i < node.outputs->size; ++i) {
|
|
||||||
if (node.outputs->data[i] == next_tensor_id) {
|
|
||||||
// This node is part of the parent tree, add it to the parent list and
|
|
||||||
// add its input tensors to the frontier list
|
|
||||||
parents.push_front(node_id);
|
|
||||||
for (int j = 0; j < node.inputs->size; ++j) {
|
|
||||||
push_back_if_not_present(frontier, node.inputs->data[j]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::vector<int>(parents.begin(), parents.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
int
|
int
|
||||||
DS_CreateModel(const char* aModelPath,
|
DS_CreateModel(const char* aModelPath,
|
||||||
unsigned int aNCep,
|
unsigned int aNCep,
|
||||||
@ -642,15 +262,6 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
unsigned int aBeamWidth,
|
unsigned int aBeamWidth,
|
||||||
ModelState** retval)
|
ModelState** retval)
|
||||||
{
|
{
|
||||||
std::unique_ptr<ModelState> model(new ModelState());
|
|
||||||
#ifndef USE_TFLITE
|
|
||||||
model->mmap_env = new MemmappedEnv(Env::Default());
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
model->ncep = aNCep;
|
|
||||||
model->ncontext = aNContext;
|
|
||||||
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
|
||||||
model->beam_width = aBeamWidth;
|
|
||||||
|
|
||||||
*retval = nullptr;
|
*retval = nullptr;
|
||||||
|
|
||||||
DS_PrintVersions();
|
DS_PrintVersions();
|
||||||
@ -661,182 +272,23 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
Status status;
|
std::unique_ptr<ModelState> model(new TFModelState());
|
||||||
SessionOptions options;
|
#else
|
||||||
|
std::unique_ptr<ModelState> model(new TFLiteModelState());
|
||||||
bool is_mmap = std::string(aModelPath).find(".pbmm") != std::string::npos;
|
|
||||||
if (!is_mmap) {
|
|
||||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
|
||||||
} else {
|
|
||||||
status = model->mmap_env->InitializeFromFile(aModelPath);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return DS_ERR_FAIL_INIT_MMAP;
|
|
||||||
}
|
|
||||||
|
|
||||||
options.config.mutable_graph_options()
|
|
||||||
->mutable_optimizer_options()
|
|
||||||
->set_opt_level(::OptimizerOptions::L0);
|
|
||||||
options.env = model->mmap_env;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = NewSession(options, &model->session);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return DS_ERR_FAIL_INIT_SESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (is_mmap) {
|
|
||||||
status = ReadBinaryProto(model->mmap_env,
|
|
||||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
|
||||||
&model->graph_def);
|
|
||||||
} else {
|
|
||||||
status = ReadBinaryProto(Env::Default(), aModelPath, &model->graph_def);
|
|
||||||
}
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return DS_ERR_FAIL_READ_PROTOBUF;
|
|
||||||
}
|
|
||||||
|
|
||||||
status = model->session->Create(model->graph_def);
|
|
||||||
if (!status.ok()) {
|
|
||||||
std::cerr << status << std::endl;
|
|
||||||
return DS_ERR_FAIL_CREATE_SESS;
|
|
||||||
}
|
|
||||||
|
|
||||||
int graph_version = model->graph_def.version();
|
|
||||||
if (graph_version < DS_GRAPH_VERSION) {
|
|
||||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
|
||||||
<< "incompatible with minimum version supported by this client ("
|
|
||||||
<< DS_GRAPH_VERSION << "). See "
|
|
||||||
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
|
||||||
<< "for more information" << std::endl;
|
|
||||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < model->graph_def.node_size(); ++i) {
|
|
||||||
NodeDef node = model->graph_def.node(i);
|
|
||||||
if (node.name() == "input_node") {
|
|
||||||
const auto& shape = node.attr().at("shape").shape();
|
|
||||||
model->n_steps = shape.dim(1).size();
|
|
||||||
model->n_context = (shape.dim(2).size()-1)/2;
|
|
||||||
model->n_features = shape.dim(3).size();
|
|
||||||
model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
|
|
||||||
} else if (node.name() == "logits_shape") {
|
|
||||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
|
||||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
|
||||||
if (final_dim_size != model->alphabet->GetSize()) {
|
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
|
||||||
<< "has size " << model->alphabet->GetSize()
|
|
||||||
<< ", but model has " << final_dim_size
|
|
||||||
<< " classes in its output. Make sure you're passing an alphabet "
|
|
||||||
<< "file with the same size as the one used for training."
|
|
||||||
<< std::endl;
|
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
|
||||||
}
|
|
||||||
} else if (node.name() == "model_metadata") {
|
|
||||||
int sample_rate = node.attr().at("sample_rate").i();
|
|
||||||
model->sample_rate = sample_rate;
|
|
||||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
|
||||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
|
||||||
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
|
|
||||||
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (model->n_context == -1 || model->n_features == -1) {
|
|
||||||
std::cerr << "Error: Could not infer input shape from model file. "
|
|
||||||
<< "Make sure input_node is a 4D tensor with shape "
|
|
||||||
<< "[batch_size=1, time, window_size, n_features]."
|
|
||||||
<< std::endl;
|
|
||||||
return DS_ERR_INVALID_SHAPE;
|
|
||||||
}
|
|
||||||
|
|
||||||
*retval = model.release();
|
|
||||||
return DS_ERR_OK;
|
|
||||||
#else // USE_TFLITE
|
|
||||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
|
||||||
if (!model->fbmodel) {
|
|
||||||
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
|
||||||
return DS_ERR_FAIL_INIT_MMAP;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
|
||||||
tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
|
||||||
if (!model->interpreter) {
|
|
||||||
std::cerr << "Error at InterpreterBuilder for model file " << aModelPath << std::endl;
|
|
||||||
return DS_ERR_FAIL_INTERPRETER;
|
|
||||||
}
|
|
||||||
|
|
||||||
model->interpreter->AllocateTensors();
|
|
||||||
model->interpreter->SetNumThreads(4);
|
|
||||||
|
|
||||||
// Query all the index once
|
|
||||||
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
|
|
||||||
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
|
|
||||||
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
|
|
||||||
model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples");
|
|
||||||
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
|
|
||||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
|
||||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
|
||||||
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
|
||||||
|
|
||||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
|
||||||
// which means every time compute_mfcc is called the entire acoustic model is
|
|
||||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
|
||||||
// from the mfccs output tensor to find all the relevant nodes required for
|
|
||||||
// feature computation, building an execution plan that runs just those nodes.
|
|
||||||
auto mfcc_plan = tflite_find_parent_node_ids(model->interpreter.get(), model->mfccs_idx);
|
|
||||||
auto orig_plan = model->interpreter->execution_plan();
|
|
||||||
|
|
||||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
|
||||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
|
||||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
|
||||||
});
|
|
||||||
orig_plan.erase(erase_begin, orig_plan.end());
|
|
||||||
|
|
||||||
model->acoustic_exec_plan = std::move(orig_plan);
|
|
||||||
model->mfcc_exec_plan = std::move(mfcc_plan);
|
|
||||||
|
|
||||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
|
||||||
|
|
||||||
model->n_steps = dims_input_node->data[1];
|
|
||||||
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
|
||||||
model->n_features = dims_input_node->data[3];
|
|
||||||
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
|
||||||
|
|
||||||
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->logits_idx)->dims;
|
|
||||||
const int final_dim_size = dims_logits->data[1] - 1;
|
|
||||||
if (final_dim_size != model->alphabet->GetSize()) {
|
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
|
||||||
<< "has size " << model->alphabet->GetSize()
|
|
||||||
<< ", but model has " << final_dim_size
|
|
||||||
<< " classes in its output. Make sure you're passing an alphabet "
|
|
||||||
<< "file with the same size as the one used for training."
|
|
||||||
<< std::endl;
|
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
|
||||||
}
|
|
||||||
|
|
||||||
TfLiteIntArray* dims_c = model->interpreter->tensor(model->previous_state_c_idx)->dims;
|
|
||||||
TfLiteIntArray* dims_h = model->interpreter->tensor(model->previous_state_h_idx)->dims;
|
|
||||||
assert(dims_c->data[1] == dims_h->data[1]);
|
|
||||||
|
|
||||||
model->previous_state_size = dims_c->data[1];
|
|
||||||
model->previous_state_c_.reset(new float[model->previous_state_size]());
|
|
||||||
model->previous_state_h_.reset(new float[model->previous_state_size]());
|
|
||||||
|
|
||||||
// Set initial values for previous_state_c and previous_state_h
|
|
||||||
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
|
|
||||||
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
|
|
||||||
|
|
||||||
*retval = model.release();
|
|
||||||
return DS_ERR_OK;
|
|
||||||
#endif // USE_TFLITE
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
|
if (!model) {
|
||||||
|
std::cerr << "Could not allocate model state." << std::endl;
|
||||||
|
return DS_ERR_FAIL_CREATE_MODEL;
|
||||||
|
}
|
||||||
|
|
||||||
|
int err = model->init(aModelPath, aNCep, aNContext, aAlphabetConfigPath, aBeamWidth);
|
||||||
|
if (err != DS_ERR_OK) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
*retval = model.release();
|
||||||
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -854,10 +306,10 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
|||||||
float aLMBeta)
|
float aLMBeta)
|
||||||
{
|
{
|
||||||
try {
|
try {
|
||||||
aCtx->scorer = new Scorer(aLMAlpha, aLMBeta,
|
aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta,
|
||||||
aLMPath ? aLMPath : "",
|
aLMPath ? aLMPath : "",
|
||||||
aTriePath ? aTriePath : "",
|
aTriePath ? aTriePath : "",
|
||||||
*aCtx->alphabet);
|
*aCtx->alphabet_);
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
return DS_ERR_INVALID_LM;
|
return DS_ERR_INVALID_LM;
|
||||||
@ -872,13 +324,10 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
{
|
{
|
||||||
*retval = nullptr;
|
*retval = nullptr;
|
||||||
|
|
||||||
#ifndef USE_TFLITE
|
int err = aCtx->initialize_state();
|
||||||
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
if (err != DS_ERR_OK) {
|
||||||
if (!status.ok()) {
|
return err;
|
||||||
std::cerr << "Error running session: " << status << std::endl;
|
|
||||||
return DS_ERR_FAIL_RUN_SESS;
|
|
||||||
}
|
}
|
||||||
#endif // USE_TFLITE
|
|
||||||
|
|
||||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
@ -886,27 +335,20 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
return DS_ERR_FAIL_CREATE_STREAM;
|
return DS_ERR_FAIL_CREATE_STREAM;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
// Default initial allocation = 3 seconds.
|
// Default initial allocation = 3 seconds.
|
||||||
if (aPreAllocFrames == 0) {
|
if (aPreAllocFrames == 0) {
|
||||||
aPreAllocFrames = 150;
|
aPreAllocFrames = 150;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
||||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
|
||||||
|
ctx->model_ = aCtx;
|
||||||
|
|
||||||
ctx->model = aCtx;
|
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
|
||||||
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
|
||||||
memset(ctx->model->previous_state_c_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
|
||||||
memset(ctx->model->previous_state_h_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
|
||||||
#endif // USE_TFLITE
|
|
||||||
|
|
||||||
ctx->decoder_state.reset(decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer));
|
|
||||||
|
|
||||||
*retval = ctx.release();
|
*retval = ctx.release();
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
@ -1012,4 +454,3 @@ DS_PrintVersions() {
|
|||||||
LOGD("DeepSpeech: %s", ds_git_version());
|
LOGD("DeepSpeech: %s", ds_git_version());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,6 +52,7 @@ enum DeepSpeech_Error_Codes
|
|||||||
DS_ERR_FAIL_CREATE_STREAM = 0x3004,
|
DS_ERR_FAIL_CREATE_STREAM = 0x3004,
|
||||||
DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
|
DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
|
||||||
DS_ERR_FAIL_CREATE_SESS = 0x3006,
|
DS_ERR_FAIL_CREATE_SESS = 0x3006,
|
||||||
|
DS_ERR_FAIL_CREATE_MODEL = 0x3007,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
81
native_client/modelstate.cc
Normal file
81
native_client/modelstate.cc
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
|
#include "modelstate.h"
|
||||||
|
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
ModelState::ModelState()
|
||||||
|
: alphabet_(nullptr)
|
||||||
|
, scorer_(nullptr)
|
||||||
|
, beam_width_(-1)
|
||||||
|
, n_steps_(-1)
|
||||||
|
, n_context_(-1)
|
||||||
|
, n_features_(-1)
|
||||||
|
, mfcc_feats_per_timestep_(-1)
|
||||||
|
, sample_rate_(DEFAULT_SAMPLE_RATE)
|
||||||
|
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
|
||||||
|
, audio_win_step_(DEFAULT_WINDOW_STEP)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelState::~ModelState()
|
||||||
|
{
|
||||||
|
delete scorer_;
|
||||||
|
delete alphabet_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
ModelState::init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width)
|
||||||
|
{
|
||||||
|
n_features_ = n_features;
|
||||||
|
n_context_ = n_context;
|
||||||
|
alphabet_ = new Alphabet(alphabet_path);
|
||||||
|
beam_width_ = beam_width;
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Output>
|
||||||
|
ModelState::decode_raw(DecoderState* state)
|
||||||
|
{
|
||||||
|
vector<Output> out = decoder_decode(state, *alphabet_, beam_width_, scorer_);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
char*
|
||||||
|
ModelState::decode(DecoderState* state)
|
||||||
|
{
|
||||||
|
vector<Output> out = decode_raw(state);
|
||||||
|
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Metadata*
|
||||||
|
ModelState::decode_metadata(DecoderState* state)
|
||||||
|
{
|
||||||
|
vector<Output> out = decode_raw(state);
|
||||||
|
|
||||||
|
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||||
|
metadata->num_items = out[0].tokens.size();
|
||||||
|
metadata->probability = out[0].probability;
|
||||||
|
|
||||||
|
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||||
|
|
||||||
|
// Loop through each character
|
||||||
|
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||||
|
items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str());
|
||||||
|
items[i].timestep = out[0].timesteps[i];
|
||||||
|
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||||
|
|
||||||
|
if (items[i].start_time < 0) {
|
||||||
|
items[i].start_time = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata->items = items.release();
|
||||||
|
return metadata.release();
|
||||||
|
}
|
88
native_client/modelstate.h
Normal file
88
native_client/modelstate.h
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
#ifndef MODELSTATE_H
|
||||||
|
#define MODELSTATE_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "deepspeech.h"
|
||||||
|
#include "alphabet.h"
|
||||||
|
|
||||||
|
#include "ctcdecode/scorer.h"
|
||||||
|
#include "ctcdecode/output.h"
|
||||||
|
#include "ctcdecode/decoderstate.h"
|
||||||
|
|
||||||
|
struct ModelState {
|
||||||
|
//TODO: infer batch size from model/use dynamic batch size
|
||||||
|
static constexpr unsigned int BATCH_SIZE = 1;
|
||||||
|
|
||||||
|
static constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||||
|
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||||
|
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||||
|
|
||||||
|
Alphabet* alphabet_;
|
||||||
|
Scorer* scorer_;
|
||||||
|
unsigned int beam_width_;
|
||||||
|
unsigned int n_steps_;
|
||||||
|
unsigned int n_context_;
|
||||||
|
unsigned int n_features_;
|
||||||
|
unsigned int mfcc_feats_per_timestep_;
|
||||||
|
unsigned int sample_rate_;
|
||||||
|
unsigned int audio_win_len_;
|
||||||
|
unsigned int audio_win_step_;
|
||||||
|
|
||||||
|
ModelState();
|
||||||
|
virtual ~ModelState();
|
||||||
|
|
||||||
|
virtual int init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width);
|
||||||
|
|
||||||
|
virtual int initialize_state() = 0;
|
||||||
|
|
||||||
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Do a single inference step in the acoustic model, with:
|
||||||
|
* input=mfcc
|
||||||
|
* input_lengths=[n_frames]
|
||||||
|
*
|
||||||
|
* @param mfcc batch input data
|
||||||
|
* @param n_frames number of timesteps in the data
|
||||||
|
*
|
||||||
|
* @param[out] output_logits Where to store computed logits.
|
||||||
|
*/
|
||||||
|
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
|
* CTC decoder with KenLM enabled
|
||||||
|
*
|
||||||
|
* @param state Decoder state to use when decoding.
|
||||||
|
*
|
||||||
|
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||||
|
*/
|
||||||
|
virtual std::vector<Output> decode_raw(DecoderState* state);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
|
* CTC decoder with KenLM enabled
|
||||||
|
*
|
||||||
|
* @param state Decoder state to use when decoding.
|
||||||
|
*
|
||||||
|
* @return String representing the decoded text.
|
||||||
|
*/
|
||||||
|
virtual char* decode(DecoderState* state);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return character-level metadata including letter timings.
|
||||||
|
*
|
||||||
|
* @param state Decoder state to use when decoding.
|
||||||
|
*
|
||||||
|
* @return Metadata struct containing MetadataItem structs for each character.
|
||||||
|
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||||
|
*/
|
||||||
|
virtual Metadata* decode_metadata(DecoderState* state);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // MODELSTATE_H
|
258
native_client/tflitemodelstate.cc
Normal file
258
native_client/tflitemodelstate.cc
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
#include "tflitemodelstate.h"
|
||||||
|
|
||||||
|
using namespace tflite;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
int
|
||||||
|
tflite_get_tensor_by_name(const Interpreter* interpreter,
|
||||||
|
const vector<int>& list,
|
||||||
|
const char* name)
|
||||||
|
{
|
||||||
|
int rv = -1;
|
||||||
|
|
||||||
|
for (int i = 0; i < list.size(); ++i) {
|
||||||
|
const string& node_name = interpreter->tensor(list[i])->name;
|
||||||
|
if (node_name.compare(string(name)) == 0) {
|
||||||
|
rv = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(rv >= 0);
|
||||||
|
return rv;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||||
|
{
|
||||||
|
int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name);
|
||||||
|
return interpreter->inputs()[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||||
|
{
|
||||||
|
int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name);
|
||||||
|
return interpreter->outputs()[idx];
|
||||||
|
}
|
||||||
|
|
||||||
|
void push_back_if_not_present(std::deque<int>& list, int value)
|
||||||
|
{
|
||||||
|
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
||||||
|
list.push_back(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
||||||
|
// from the frontier list, then for each node which has that tensor id as an
|
||||||
|
// output, add it to the parent list, and add its input tensors to the frontier
|
||||||
|
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||||
|
// the parents list is constructed in reverse, adding elements to its front.
|
||||||
|
std::vector<int>
|
||||||
|
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||||
|
{
|
||||||
|
std::deque<int> parents;
|
||||||
|
std::deque<int> frontier;
|
||||||
|
frontier.push_back(tensor_id);
|
||||||
|
while (!frontier.empty()) {
|
||||||
|
int next_tensor_id = frontier.front();
|
||||||
|
frontier.pop_front();
|
||||||
|
// Find all nodes that have next_tensor_id as an output
|
||||||
|
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
||||||
|
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
||||||
|
// Search node outputs for the tensor we're looking for
|
||||||
|
for (int i = 0; i < node.outputs->size; ++i) {
|
||||||
|
if (node.outputs->data[i] == next_tensor_id) {
|
||||||
|
// This node is part of the parent tree, add it to the parent list and
|
||||||
|
// add its input tensors to the frontier list
|
||||||
|
parents.push_front(node_id);
|
||||||
|
for (int j = 0; j < node.inputs->size; ++j) {
|
||||||
|
push_back_if_not_present(frontier, node.inputs->data[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<int>(parents.begin(), parents.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
TFLiteModelState::TFLiteModelState()
|
||||||
|
: ModelState()
|
||||||
|
, interpreter_(nullptr)
|
||||||
|
, fbmodel_(nullptr)
|
||||||
|
, previous_state_size_(0)
|
||||||
|
, previous_state_c_(nullptr)
|
||||||
|
, previous_state_h_(nullptr)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
TFLiteModelState::init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width)
|
||||||
|
{
|
||||||
|
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||||
|
if (err != DS_ERR_OK) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path);
|
||||||
|
if (!fbmodel_) {
|
||||||
|
std::cerr << "Error at reading model file " << model_path << std::endl;
|
||||||
|
return DS_ERR_FAIL_INIT_MMAP;
|
||||||
|
}
|
||||||
|
|
||||||
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||||
|
tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_);
|
||||||
|
if (!interpreter_) {
|
||||||
|
std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl;
|
||||||
|
return DS_ERR_FAIL_INTERPRETER;
|
||||||
|
}
|
||||||
|
|
||||||
|
interpreter_->AllocateTensors();
|
||||||
|
interpreter_->SetNumThreads(4);
|
||||||
|
|
||||||
|
// Query all the index once
|
||||||
|
input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node");
|
||||||
|
previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c");
|
||||||
|
previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h");
|
||||||
|
input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples");
|
||||||
|
logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits");
|
||||||
|
new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c");
|
||||||
|
new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h");
|
||||||
|
mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs");
|
||||||
|
|
||||||
|
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||||
|
// which means every time compute_mfcc is called the entire acoustic model is
|
||||||
|
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||||
|
// from the mfccs output tensor to find all the relevant nodes required for
|
||||||
|
// feature computation, building an execution plan that runs just those nodes.
|
||||||
|
auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_);
|
||||||
|
auto orig_plan = interpreter_->execution_plan();
|
||||||
|
|
||||||
|
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||||
|
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||||
|
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||||
|
});
|
||||||
|
orig_plan.erase(erase_begin, orig_plan.end());
|
||||||
|
|
||||||
|
acoustic_exec_plan_ = std::move(orig_plan);
|
||||||
|
mfcc_exec_plan_ = std::move(mfcc_plan);
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||||
|
|
||||||
|
n_steps_ = dims_input_node->data[1];
|
||||||
|
n_context_ = (dims_input_node->data[2] - 1) / 2;
|
||||||
|
n_features_ = dims_input_node->data[3];
|
||||||
|
mfcc_feats_per_timestep_ = dims_input_node->data[2] * dims_input_node->data[3];
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
||||||
|
const int final_dim_size = dims_logits->data[1] - 1;
|
||||||
|
if (final_dim_size != alphabet_->GetSize()) {
|
||||||
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
|
<< "has size " << alphabet_->GetSize()
|
||||||
|
<< ", but model has " << final_dim_size
|
||||||
|
<< " classes in its output. Make sure you're passing an alphabet "
|
||||||
|
<< "file with the same size as the one used for training."
|
||||||
|
<< std::endl;
|
||||||
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
|
||||||
|
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
|
||||||
|
assert(dims_c->data[1] == dims_h->data[1]);
|
||||||
|
|
||||||
|
previous_state_size_ = dims_c->data[1];
|
||||||
|
previous_state_c_.reset(new float[previous_state_size_]());
|
||||||
|
previous_state_h_.reset(new float[previous_state_size_]());
|
||||||
|
|
||||||
|
// Set initial values for previous_state_c and previous_state_h
|
||||||
|
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||||
|
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||||
|
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
TFLiteModelState::initialize_state()
|
||||||
|
{
|
||||||
|
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
||||||
|
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||||
|
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||||
|
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||||
|
{
|
||||||
|
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
|
// Feeding input_node
|
||||||
|
float* input_node = interpreter_->typed_tensor<float>(input_node_idx_);
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||||
|
input_node[i] = aMfcc[i];
|
||||||
|
}
|
||||||
|
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||||
|
input_node[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(previous_state_size_ > 0);
|
||||||
|
|
||||||
|
// Feeding previous_state_c, previous_state_h
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_);
|
||||||
|
memcpy(interpreter_->typed_tensor<float>(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_);
|
||||||
|
|
||||||
|
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
|
||||||
|
TfLiteStatus status = interpreter_->Invoke();
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* outputs = interpreter_->typed_tensor<float>(logits_idx_);
|
||||||
|
|
||||||
|
// The CTCDecoder works with log-probs.
|
||||||
|
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||||
|
logits_output.push_back(outputs[t]);
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(previous_state_c_.get(), interpreter_->typed_tensor<float>(new_state_c_idx_), sizeof(float) * previous_state_size_);
|
||||||
|
memcpy(previous_state_h_.get(), interpreter_->typed_tensor<float>(new_state_h_idx_), sizeof(float) * previous_state_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||||
|
{
|
||||||
|
// Feeding input_node
|
||||||
|
float* input_samples = interpreter_->typed_tensor<float>(input_samples_idx_);
|
||||||
|
for (int i = 0; i < samples.size(); ++i) {
|
||||||
|
input_samples[i] = samples[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
interpreter_->SetExecutionPlan(mfcc_exec_plan_);
|
||||||
|
TfLiteStatus status = interpreter_->Invoke();
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The feature computation graph is hardcoded to one audio length for now
|
||||||
|
int n_windows = 1;
|
||||||
|
TfLiteIntArray* out_dims = interpreter_->tensor(mfccs_idx_)->dims;
|
||||||
|
int num_elements = 1;
|
||||||
|
for (int i = 0; i < out_dims->size; ++i) {
|
||||||
|
num_elements *= out_dims->data[i];
|
||||||
|
}
|
||||||
|
assert(num_elements / n_features_ == n_windows);
|
||||||
|
|
||||||
|
float* outputs = interpreter_->typed_tensor<float>(mfccs_idx_);
|
||||||
|
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||||
|
mfcc_output.push_back(outputs[i]);
|
||||||
|
}
|
||||||
|
}
|
51
native_client/tflitemodelstate.h
Normal file
51
native_client/tflitemodelstate.h
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
#ifndef TFLITEMODELSTATE_H
|
||||||
|
#define TFLITEMODELSTATE_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
#include "modelstate.h"
|
||||||
|
|
||||||
|
struct TFLiteModelState : public ModelState
|
||||||
|
{
|
||||||
|
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||||
|
std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
|
||||||
|
|
||||||
|
size_t previous_state_size_;
|
||||||
|
std::unique_ptr<float[]> previous_state_c_;
|
||||||
|
std::unique_ptr<float[]> previous_state_h_;
|
||||||
|
|
||||||
|
int input_node_idx_;
|
||||||
|
int previous_state_c_idx_;
|
||||||
|
int previous_state_h_idx_;
|
||||||
|
int input_samples_idx_;
|
||||||
|
|
||||||
|
int logits_idx_;
|
||||||
|
int new_state_c_idx_;
|
||||||
|
int new_state_h_idx_;
|
||||||
|
int mfccs_idx_;
|
||||||
|
|
||||||
|
std::vector<int> acoustic_exec_plan_;
|
||||||
|
std::vector<int> mfcc_exec_plan_;
|
||||||
|
|
||||||
|
TFLiteModelState();
|
||||||
|
|
||||||
|
virtual int init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width) override;
|
||||||
|
|
||||||
|
virtual int initialize_state() override;
|
||||||
|
|
||||||
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||||
|
std::vector<float>& mfcc_output) override;
|
||||||
|
|
||||||
|
virtual void infer(const float* mfcc, unsigned int n_frames,
|
||||||
|
std::vector<float>& logits_output) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TFLITEMODELSTATE_H
|
214
native_client/tfmodelstate.cc
Normal file
214
native_client/tfmodelstate.cc
Normal file
@ -0,0 +1,214 @@
|
|||||||
|
#include "tfmodelstate.h"
|
||||||
|
|
||||||
|
#include "ds_graph_version.h"
|
||||||
|
|
||||||
|
using namespace tensorflow;
|
||||||
|
using std::vector;
|
||||||
|
|
||||||
|
TFModelState::TFModelState()
|
||||||
|
: ModelState()
|
||||||
|
, mmap_env_(nullptr)
|
||||||
|
, session_(nullptr)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
TFModelState::~TFModelState()
|
||||||
|
{
|
||||||
|
if (session_) {
|
||||||
|
Status status = session_->Close();
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete mmap_env_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
TFModelState::init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width)
|
||||||
|
{
|
||||||
|
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||||
|
if (err != DS_ERR_OK) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status status;
|
||||||
|
SessionOptions options;
|
||||||
|
|
||||||
|
mmap_env_ = new MemmappedEnv(Env::Default());
|
||||||
|
|
||||||
|
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
||||||
|
if (!is_mmap) {
|
||||||
|
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||||
|
} else {
|
||||||
|
status = mmap_env_->InitializeFromFile(model_path);
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return DS_ERR_FAIL_INIT_MMAP;
|
||||||
|
}
|
||||||
|
|
||||||
|
options.config.mutable_graph_options()
|
||||||
|
->mutable_optimizer_options()
|
||||||
|
->set_opt_level(::OptimizerOptions::L0);
|
||||||
|
options.env = mmap_env_;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = NewSession(options, &session_);
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return DS_ERR_FAIL_INIT_SESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_mmap) {
|
||||||
|
status = ReadBinaryProto(mmap_env_,
|
||||||
|
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||||
|
&graph_def_);
|
||||||
|
} else {
|
||||||
|
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
||||||
|
}
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return DS_ERR_FAIL_READ_PROTOBUF;
|
||||||
|
}
|
||||||
|
|
||||||
|
status = session_->Create(graph_def_);
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return DS_ERR_FAIL_CREATE_SESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
int graph_version = graph_def_.version();
|
||||||
|
if (graph_version < DS_GRAPH_VERSION) {
|
||||||
|
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||||
|
<< "incompatible with minimum version supported by this client ("
|
||||||
|
<< DS_GRAPH_VERSION << "). See "
|
||||||
|
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
||||||
|
<< "for more information" << std::endl;
|
||||||
|
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||||
|
NodeDef node = graph_def_.node(i);
|
||||||
|
if (node.name() == "input_node") {
|
||||||
|
const auto& shape = node.attr().at("shape").shape();
|
||||||
|
n_steps_ = shape.dim(1).size();
|
||||||
|
n_context_ = (shape.dim(2).size()-1)/2;
|
||||||
|
n_features_ = shape.dim(3).size();
|
||||||
|
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||||
|
} else if (node.name() == "logits_shape") {
|
||||||
|
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||||
|
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||||
|
if (final_dim_size != alphabet_->GetSize()) {
|
||||||
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
|
<< "has size " << alphabet_->GetSize()
|
||||||
|
<< ", but model has " << final_dim_size
|
||||||
|
<< " classes in its output. Make sure you're passing an alphabet "
|
||||||
|
<< "file with the same size as the one used for training."
|
||||||
|
<< std::endl;
|
||||||
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
|
}
|
||||||
|
} else if (node.name() == "model_metadata") {
|
||||||
|
sample_rate_ = node.attr().at("sample_rate").i();
|
||||||
|
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||||
|
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||||
|
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||||
|
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_context_ == -1 || n_features_ == -1) {
|
||||||
|
std::cerr << "Error: Could not infer input shape from model file. "
|
||||||
|
<< "Make sure input_node is a 4D tensor with shape "
|
||||||
|
<< "[batch_size=1, time, window_size, n_features]."
|
||||||
|
<< std::endl;
|
||||||
|
return DS_ERR_INVALID_SHAPE;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
TFModelState::initialize_state()
|
||||||
|
{
|
||||||
|
Status status = session_->Run({}, {}, {"initialize_state"}, nullptr);
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << "Error running session: " << status << std::endl;
|
||||||
|
return DS_ERR_FAIL_RUN_SESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||||
|
{
|
||||||
|
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
|
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||||
|
|
||||||
|
auto input_mapped = input.flat<float>();
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||||
|
input_mapped(i) = aMfcc[i];
|
||||||
|
}
|
||||||
|
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||||
|
input_mapped(i) = 0.;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||||
|
input_lengths.scalar<int>()() = n_frames;
|
||||||
|
|
||||||
|
vector<Tensor> outputs;
|
||||||
|
Status status = session_->Run(
|
||||||
|
{{"input_node", input}, {"input_lengths", input_lengths}},
|
||||||
|
{"logits"}, {}, &outputs);
|
||||||
|
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto logits_mapped = outputs[0].flat<float>();
|
||||||
|
// The CTCDecoder works with log-probs.
|
||||||
|
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||||
|
logits_output.push_back(logits_mapped(t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||||
|
{
|
||||||
|
Tensor input(DT_FLOAT, TensorShape({audio_win_len_}));
|
||||||
|
auto input_mapped = input.flat<float>();
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < samples.size(); ++i) {
|
||||||
|
input_mapped(i) = samples[i];
|
||||||
|
}
|
||||||
|
for (; i < audio_win_len_; ++i) {
|
||||||
|
input_mapped(i) = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<Tensor> outputs;
|
||||||
|
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||||
|
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The feature computation graph is hardcoded to one audio length for now
|
||||||
|
const int n_windows = 1;
|
||||||
|
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||||
|
|
||||||
|
auto mfcc_mapped = outputs[0].flat<float>();
|
||||||
|
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||||
|
mfcc_output.push_back(mfcc_mapped(i));
|
||||||
|
}
|
||||||
|
}
|
37
native_client/tfmodelstate.h
Normal file
37
native_client/tfmodelstate.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#ifndef TFMODELSTATE_H
|
||||||
|
#define TFMODELSTATE_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||||
|
|
||||||
|
#include "modelstate.h"
|
||||||
|
|
||||||
|
struct TFModelState : public ModelState
|
||||||
|
{
|
||||||
|
tensorflow::MemmappedEnv* mmap_env_;
|
||||||
|
tensorflow::Session* session_;
|
||||||
|
tensorflow::GraphDef graph_def_;
|
||||||
|
|
||||||
|
TFModelState();
|
||||||
|
virtual ~TFModelState();
|
||||||
|
|
||||||
|
virtual int init(const char* model_path,
|
||||||
|
unsigned int n_features,
|
||||||
|
unsigned int n_context,
|
||||||
|
const char* alphabet_path,
|
||||||
|
unsigned int beam_width) override;
|
||||||
|
|
||||||
|
virtual int initialize_state() override;
|
||||||
|
|
||||||
|
virtual void infer(const float* mfcc,
|
||||||
|
unsigned int n_frames,
|
||||||
|
std::vector<float>& logits_output) override;
|
||||||
|
|
||||||
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||||
|
std::vector<float>& mfcc_output) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // TFMODELSTATE_H
|
Loading…
Reference in New Issue
Block a user