Refactor TF and TFLite model implementations into their own classes/files
This commit is contained in:
parent
e136b5299a
commit
6f953837fa
@ -70,9 +70,20 @@ tf_cc_shared_object(
|
||||
srcs = ["deepspeech.cc",
|
||||
"deepspeech.h",
|
||||
"alphabet.h",
|
||||
"modelstate.h",
|
||||
"modelstate.cc",
|
||||
"ds_version.h",
|
||||
"ds_graph_version.h"] +
|
||||
DECODER_SOURCES,
|
||||
DECODER_SOURCES +
|
||||
select({
|
||||
"//native_client:tflite": [
|
||||
"tflitemodelstate.h",
|
||||
"tflitemodelstate.cc"
|
||||
],
|
||||
"//conditions:default": [
|
||||
"tfmodelstate.h",
|
||||
"tfmodelstate.cc"
|
||||
]}),
|
||||
copts = select({
|
||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||
"//tensorflow:windows": ["/w"],
|
||||
|
@ -11,17 +11,14 @@
|
||||
|
||||
#include "deepspeech.h"
|
||||
#include "alphabet.h"
|
||||
#include "modelstate.h"
|
||||
|
||||
#include "native_client/ds_version.h"
|
||||
#include "native_client/ds_graph_version.h"
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#else // USE_TFLITE
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tfmodelstate.h"
|
||||
#else
|
||||
#include "tflitemodelstate.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
@ -36,23 +33,9 @@
|
||||
#define LOGE(...)
|
||||
#endif // __ANDROID__
|
||||
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
using namespace tensorflow;
|
||||
#else
|
||||
using namespace tflite;
|
||||
#endif
|
||||
|
||||
using std::vector;
|
||||
|
||||
/* This is the actual implementation of the streaming inference API, with the
|
||||
Model class just forwarding the calls to this class.
|
||||
/* This is the implementation of the streaming inference API.
|
||||
|
||||
The streaming process uses three buffers that are fed eagerly as audio data
|
||||
is fed in. The buffers only hold the minimum amount of data needed to do a
|
||||
@ -75,17 +58,20 @@ using std::vector;
|
||||
API. When audio_buffer is full, features are computed from it and pushed to
|
||||
mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer.
|
||||
When batch_buffer is full, we do a single step through the acoustic model
|
||||
and accumulate results in the DecoderState structure.
|
||||
and accumulate the intermediate decoding state in the DecoderState structure.
|
||||
|
||||
When finishStream() is called, we decode the accumulated logits and return
|
||||
the corresponding transcription.
|
||||
When finishStream() is called, we return the corresponding transcription from
|
||||
the current decoder state.
|
||||
*/
|
||||
struct StreamingState {
|
||||
vector<float> audio_buffer;
|
||||
vector<float> mfcc_buffer;
|
||||
vector<float> batch_buffer;
|
||||
ModelState* model;
|
||||
std::unique_ptr<DecoderState> decoder_state;
|
||||
vector<float> audio_buffer_;
|
||||
vector<float> mfcc_buffer_;
|
||||
vector<float> batch_buffer_;
|
||||
ModelState* model_;
|
||||
std::unique_ptr<DecoderState> decoder_state_;
|
||||
|
||||
StreamingState();
|
||||
~StreamingState();
|
||||
|
||||
void feedAudioContent(const short* buffer, unsigned int buffer_size);
|
||||
char* intermediateDecode();
|
||||
@ -100,133 +86,12 @@ struct StreamingState {
|
||||
void processBatch(const vector<float>& buf, unsigned int n_steps);
|
||||
};
|
||||
|
||||
struct ModelState {
|
||||
#ifndef USE_TFLITE
|
||||
MemmappedEnv* mmap_env;
|
||||
Session* session;
|
||||
GraphDef graph_def;
|
||||
#else // USE_TFLITE
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
std::unique_ptr<FlatBufferModel> fbmodel;
|
||||
#endif // USE_TFLITE
|
||||
unsigned int ncep;
|
||||
unsigned int ncontext;
|
||||
Alphabet* alphabet;
|
||||
Scorer* scorer;
|
||||
unsigned int beam_width;
|
||||
unsigned int n_steps;
|
||||
unsigned int n_context;
|
||||
unsigned int n_features;
|
||||
unsigned int mfcc_feats_per_timestep;
|
||||
unsigned int sample_rate;
|
||||
unsigned int audio_win_len;
|
||||
unsigned int audio_win_step;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
size_t previous_state_size;
|
||||
std::unique_ptr<float[]> previous_state_c_;
|
||||
std::unique_ptr<float[]> previous_state_h_;
|
||||
|
||||
int input_node_idx;
|
||||
int previous_state_c_idx;
|
||||
int previous_state_h_idx;
|
||||
int input_samples_idx;
|
||||
|
||||
int logits_idx;
|
||||
int new_state_c_idx;
|
||||
int new_state_h_idx;
|
||||
int mfccs_idx;
|
||||
|
||||
std::vector<int> acoustic_exec_plan;
|
||||
std::vector<int> mfcc_exec_plan;
|
||||
#endif
|
||||
|
||||
ModelState();
|
||||
~ModelState();
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @return String representing the decoded text.
|
||||
*/
|
||||
char* decode(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
vector<Output> decode_raw(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
*
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||
*/
|
||||
Metadata* decode_metadata(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Do a single inference step in the acoustic model, with:
|
||||
* input=mfcc
|
||||
* input_lengths=[n_frames]
|
||||
*
|
||||
* @param mfcc batch input data
|
||||
* @param n_frames number of timesteps in the data
|
||||
*
|
||||
* @param[out] output_logits Where to store computed logits.
|
||||
*/
|
||||
void infer(const float* mfcc, unsigned int n_frames, vector<float>& logits_output);
|
||||
|
||||
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
|
||||
};
|
||||
|
||||
ModelState::ModelState()
|
||||
:
|
||||
#ifndef USE_TFLITE
|
||||
mmap_env(nullptr)
|
||||
, session(nullptr)
|
||||
#else // USE_TFLITE
|
||||
interpreter(nullptr)
|
||||
, fbmodel(nullptr)
|
||||
#endif // USE_TFLITE
|
||||
, ncep(0)
|
||||
, ncontext(0)
|
||||
, alphabet(nullptr)
|
||||
, scorer(nullptr)
|
||||
, beam_width(0)
|
||||
, n_steps(-1)
|
||||
, n_context(-1)
|
||||
, n_features(-1)
|
||||
, mfcc_feats_per_timestep(-1)
|
||||
, sample_rate(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step(DEFAULT_WINDOW_STEP)
|
||||
#ifdef USE_TFLITE
|
||||
, previous_state_size(0)
|
||||
, previous_state_c_(nullptr)
|
||||
, previous_state_h_(nullptr)
|
||||
#endif
|
||||
StreamingState::StreamingState()
|
||||
{
|
||||
}
|
||||
|
||||
ModelState::~ModelState()
|
||||
StreamingState::~StreamingState()
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
if (session) {
|
||||
Status status = session->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
delete mmap_env;
|
||||
#endif // USE_TFLITE
|
||||
|
||||
delete scorer;
|
||||
delete alphabet;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
@ -243,19 +108,19 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
{
|
||||
// Consume all the data that was passed in, processing full buffers if needed
|
||||
while (buffer_size > 0) {
|
||||
while (buffer_size > 0 && audio_buffer.size() < model->audio_win_len) {
|
||||
while (buffer_size > 0 && audio_buffer_.size() < model_->audio_win_len_) {
|
||||
// Convert i16 sample into f32
|
||||
float multiplier = 1.0f / (1 << 15);
|
||||
audio_buffer.push_back((float)(*buffer) * multiplier);
|
||||
audio_buffer_.push_back((float)(*buffer) * multiplier);
|
||||
++buffer;
|
||||
--buffer_size;
|
||||
}
|
||||
|
||||
// If the buffer is full, process and shift it
|
||||
if (audio_buffer.size() == model->audio_win_len) {
|
||||
processAudioWindow(audio_buffer);
|
||||
if (audio_buffer_.size() == model_->audio_win_len_) {
|
||||
processAudioWindow(audio_buffer_);
|
||||
// Shift data by one step
|
||||
shift_buffer_left(audio_buffer, model->audio_win_step);
|
||||
shift_buffer_left(audio_buffer_, model_->audio_win_step_);
|
||||
}
|
||||
|
||||
// Repeat until buffer empty
|
||||
@ -265,21 +130,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
||||
char*
|
||||
StreamingState::intermediateDecode()
|
||||
{
|
||||
return model->decode(decoder_state.get());
|
||||
return model_->decode(decoder_state_.get());
|
||||
}
|
||||
|
||||
char*
|
||||
StreamingState::finishStream()
|
||||
{
|
||||
finalizeStream();
|
||||
return model->decode(decoder_state.get());
|
||||
return model_->decode(decoder_state_.get());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
StreamingState::finishStreamWithMetadata()
|
||||
{
|
||||
finalizeStream();
|
||||
return model->decode_metadata(decoder_state.get());
|
||||
return model_->decode_metadata(decoder_state_.get());
|
||||
}
|
||||
|
||||
void
|
||||
@ -287,8 +152,8 @@ StreamingState::processAudioWindow(const vector<float>& buf)
|
||||
{
|
||||
// Compute MFCC features
|
||||
vector<float> mfcc;
|
||||
mfcc.reserve(model->n_features);
|
||||
model->compute_mfcc(buf, mfcc);
|
||||
mfcc.reserve(model_->n_features_);
|
||||
model_->compute_mfcc(buf, mfcc);
|
||||
pushMfccBuffer(mfcc);
|
||||
}
|
||||
|
||||
@ -296,23 +161,23 @@ void
|
||||
StreamingState::finalizeStream()
|
||||
{
|
||||
// Flush audio buffer
|
||||
processAudioWindow(audio_buffer);
|
||||
processAudioWindow(audio_buffer_);
|
||||
|
||||
// Add empty mfcc vectors at end of sample
|
||||
for (int i = 0; i < model->n_context; ++i) {
|
||||
for (int i = 0; i < model_->n_context_; ++i) {
|
||||
addZeroMfccWindow();
|
||||
}
|
||||
|
||||
// Process final batch
|
||||
if (batch_buffer.size() > 0) {
|
||||
processBatch(batch_buffer, batch_buffer.size()/model->mfcc_feats_per_timestep);
|
||||
if (batch_buffer_.size() > 0) {
|
||||
processBatch(batch_buffer_, batch_buffer_.size()/model_->mfcc_feats_per_timestep_);
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
StreamingState::addZeroMfccWindow()
|
||||
{
|
||||
vector<float> zero_buffer(model->n_features, 0.f);
|
||||
vector<float> zero_buffer(model_->n_features_, 0.f);
|
||||
pushMfccBuffer(zero_buffer);
|
||||
}
|
||||
|
||||
@ -332,15 +197,15 @@ StreamingState::pushMfccBuffer(const vector<float>& buf)
|
||||
auto end = buf.end();
|
||||
while (start != end) {
|
||||
// Copy from input buffer to mfcc_buffer, stopping if we have a full context window
|
||||
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer),
|
||||
model->mfcc_feats_per_timestep - mfcc_buffer.size());
|
||||
assert(mfcc_buffer.size() <= model->mfcc_feats_per_timestep);
|
||||
start = copy_up_to_n(start, end, std::back_inserter(mfcc_buffer_),
|
||||
model_->mfcc_feats_per_timestep_ - mfcc_buffer_.size());
|
||||
assert(mfcc_buffer_.size() <= model_->mfcc_feats_per_timestep_);
|
||||
|
||||
// If we have a full context window
|
||||
if (mfcc_buffer.size() == model->mfcc_feats_per_timestep) {
|
||||
processMfccWindow(mfcc_buffer);
|
||||
if (mfcc_buffer_.size() == model_->mfcc_feats_per_timestep_) {
|
||||
processMfccWindow(mfcc_buffer_);
|
||||
// Shift data by one step of one mfcc feature vector
|
||||
shift_buffer_left(mfcc_buffer, model->n_features);
|
||||
shift_buffer_left(mfcc_buffer_, model_->n_features_);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -352,14 +217,14 @@ StreamingState::processMfccWindow(const vector<float>& buf)
|
||||
auto end = buf.end();
|
||||
while (start != end) {
|
||||
// Copy from input buffer to batch_buffer, stopping if we have a full batch
|
||||
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer),
|
||||
model->n_steps * model->mfcc_feats_per_timestep - batch_buffer.size());
|
||||
assert(batch_buffer.size() <= model->n_steps * model->mfcc_feats_per_timestep);
|
||||
start = copy_up_to_n(start, end, std::back_inserter(batch_buffer_),
|
||||
model_->n_steps_ * model_->mfcc_feats_per_timestep_ - batch_buffer_.size());
|
||||
assert(batch_buffer_.size() <= model_->n_steps_ * model_->mfcc_feats_per_timestep_);
|
||||
|
||||
// If we have a full batch
|
||||
if (batch_buffer.size() == model->n_steps * model->mfcc_feats_per_timestep) {
|
||||
processBatch(batch_buffer, model->n_steps);
|
||||
batch_buffer.resize(0);
|
||||
if (batch_buffer_.size() == model_->n_steps_ * model_->mfcc_feats_per_timestep_) {
|
||||
processBatch(batch_buffer_, model_->n_steps_);
|
||||
batch_buffer_.resize(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -368,272 +233,27 @@ void
|
||||
StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
{
|
||||
vector<float> logits;
|
||||
model->infer(buf.data(), n_steps, logits);
|
||||
|
||||
model_->infer(buf.data(), n_steps, logits);
|
||||
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
const size_t num_classes = model->alphabet->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (BATCH_SIZE * num_classes);
|
||||
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||
|
||||
// Convert logits to double
|
||||
vector<double> inputs(logits.begin(), logits.end());
|
||||
|
||||
decoder_next(inputs.data(),
|
||||
*model->alphabet,
|
||||
decoder_state.get(),
|
||||
*model_->alphabet_,
|
||||
decoder_state_.get(),
|
||||
n_frames,
|
||||
num_classes,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
model->beam_width,
|
||||
model->scorer);
|
||||
model_->beam_width_,
|
||||
model_->scorer_);
|
||||
}
|
||||
|
||||
void
|
||||
ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
{
|
||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, n_features}));
|
||||
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||
input_mapped(i) = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||
input_mapped(i) = 0.;
|
||||
}
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session->Run(
|
||||
{{"input_node", input}, {"input_lengths", input_lengths}},
|
||||
{"logits"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto logits_mapped = outputs[0].flat<float>();
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(logits_mapped(t));
|
||||
}
|
||||
#else // USE_TFLITE
|
||||
// Feeding input_node
|
||||
float* input_node = interpreter->typed_tensor<float>(input_node_idx);
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(previous_state_size > 0);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
memcpy(interpreter->typed_tensor<float>(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
||||
memcpy(interpreter->typed_tensor<float>(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
||||
|
||||
interpreter->SetExecutionPlan(acoustic_exec_plan);
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
float* outputs = interpreter->typed_tensor<float>(logits_idx);
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(outputs[t]);
|
||||
}
|
||||
|
||||
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(new_state_c_idx), sizeof(float) * previous_state_size);
|
||||
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(new_state_h_idx), sizeof(float) * previous_state_size);
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
void
|
||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len}));
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < samples.size(); ++i) {
|
||||
input_mapped(i) = samples[i];
|
||||
}
|
||||
for (; i < audio_win_len; ++i) {
|
||||
input_mapped(i) = 0.f;
|
||||
}
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elemements() / n_features == n_windows);
|
||||
|
||||
auto mfcc_mapped = outputs[0].flat<float>();
|
||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||
mfcc_output.push_back(mfcc_mapped(i));
|
||||
}
|
||||
#else
|
||||
// Feeding input_node
|
||||
float* input_samples = interpreter->typed_tensor<float>(input_samples_idx);
|
||||
for (int i = 0; i < samples.size(); ++i) {
|
||||
input_samples[i] = samples[i];
|
||||
}
|
||||
|
||||
interpreter->SetExecutionPlan(mfcc_exec_plan);
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
int n_windows = 1;
|
||||
TfLiteIntArray* out_dims = interpreter->tensor(mfccs_idx)->dims;
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < out_dims->size; ++i) {
|
||||
num_elements *= out_dims->data[i];
|
||||
}
|
||||
assert(num_elements / n_features == n_windows);
|
||||
|
||||
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||
mfcc_output.push_back(outputs[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
char*
|
||||
ModelState::decode(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = ModelState::decode_raw(state);
|
||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decoder_decode(state, *alphabet, beam_width, scorer);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
Metadata*
|
||||
ModelState::decode_metadata(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
|
||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
metadata->probability = out[0].probability;
|
||||
|
||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet->StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].timestep = out[0].timesteps[i];
|
||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step / sample_rate);
|
||||
|
||||
if (items[i].start_time < 0) {
|
||||
items[i].start_time = 0;
|
||||
}
|
||||
}
|
||||
|
||||
metadata->items = items.release();
|
||||
return metadata.release();
|
||||
}
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
int
|
||||
tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
|
||||
{
|
||||
int rv = -1;
|
||||
|
||||
for (int i = 0; i < list.size(); ++i) {
|
||||
const string& node_name = ctx->interpreter->tensor(list[i])->name;
|
||||
if (node_name.compare(string(name)) == 0) {
|
||||
rv = i;
|
||||
}
|
||||
}
|
||||
|
||||
assert(rv >= 0);
|
||||
return rv;
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
||||
{
|
||||
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
||||
{
|
||||
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
||||
}
|
||||
|
||||
void push_back_if_not_present(std::deque<int>& list, int value) {
|
||||
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
||||
list.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
||||
// from the frontier list, then for each node which has that tensor id as an
|
||||
// output, add it to the parent list, and add its input tensors to the frontier
|
||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||
// the parents list is constructed in reverse, adding elements to its front.
|
||||
std::vector<int>
|
||||
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
{
|
||||
std::deque<int> parents;
|
||||
std::deque<int> frontier;
|
||||
frontier.push_back(tensor_id);
|
||||
while (!frontier.empty()) {
|
||||
int next_tensor_id = frontier.front();
|
||||
frontier.pop_front();
|
||||
// Find all nodes that have next_tensor_id as an output
|
||||
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
||||
// Search node outputs for the tensor we're looking for
|
||||
for (int i = 0; i < node.outputs->size; ++i) {
|
||||
if (node.outputs->data[i] == next_tensor_id) {
|
||||
// This node is part of the parent tree, add it to the parent list and
|
||||
// add its input tensors to the frontier list
|
||||
parents.push_front(node_id);
|
||||
for (int j = 0; j < node.inputs->size; ++j) {
|
||||
push_back_if_not_present(frontier, node.inputs->data[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::vector<int>(parents.begin(), parents.end());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
int
|
||||
DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aNCep,
|
||||
@ -642,15 +262,6 @@ DS_CreateModel(const char* aModelPath,
|
||||
unsigned int aBeamWidth,
|
||||
ModelState** retval)
|
||||
{
|
||||
std::unique_ptr<ModelState> model(new ModelState());
|
||||
#ifndef USE_TFLITE
|
||||
model->mmap_env = new MemmappedEnv(Env::Default());
|
||||
#endif // USE_TFLITE
|
||||
model->ncep = aNCep;
|
||||
model->ncontext = aNContext;
|
||||
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||
model->beam_width = aBeamWidth;
|
||||
|
||||
*retval = nullptr;
|
||||
|
||||
DS_PrintVersions();
|
||||
@ -661,182 +272,23 @@ DS_CreateModel(const char* aModelPath,
|
||||
}
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
bool is_mmap = std::string(aModelPath).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = model->mmap_env->InitializeFromFile(aModelPath);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = model->mmap_env;
|
||||
}
|
||||
|
||||
status = NewSession(options, &model->session);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(model->mmap_env,
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&model->graph_def);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), aModelPath, &model->graph_def);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = model->session->Create(model->graph_def);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
int graph_version = model->graph_def.version();
|
||||
if (graph_version < DS_GRAPH_VERSION) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< DS_GRAPH_VERSION << "). See "
|
||||
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
for (int i = 0; i < model->graph_def.node_size(); ++i) {
|
||||
NodeDef node = model->graph_def.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
model->n_steps = shape.dim(1).size();
|
||||
model->n_context = (shape.dim(2).size()-1)/2;
|
||||
model->n_features = shape.dim(3).size();
|
||||
model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != model->alphabet->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << model->alphabet->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
int sample_rate = node.attr().at("sample_rate").i();
|
||||
model->sample_rate = sample_rate;
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
model->audio_win_len = sample_rate * (win_len_ms / 1000.0);
|
||||
model->audio_win_step = sample_rate * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (model->n_context == -1 || model->n_features == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
*retval = model.release();
|
||||
return DS_ERR_OK;
|
||||
#else // USE_TFLITE
|
||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||
if (!model->fbmodel) {
|
||||
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
||||
if (!model->interpreter) {
|
||||
std::cerr << "Error at InterpreterBuilder for model file " << aModelPath << std::endl;
|
||||
return DS_ERR_FAIL_INTERPRETER;
|
||||
}
|
||||
|
||||
model->interpreter->AllocateTensors();
|
||||
model->interpreter->SetNumThreads(4);
|
||||
|
||||
// Query all the index once
|
||||
model->input_node_idx = tflite_get_input_tensor_by_name(model.get(), "input_node");
|
||||
model->previous_state_c_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_c");
|
||||
model->previous_state_h_idx = tflite_get_input_tensor_by_name(model.get(), "previous_state_h");
|
||||
model->input_samples_idx = tflite_get_input_tensor_by_name(model.get(), "input_samples");
|
||||
model->logits_idx = tflite_get_output_tensor_by_name(model.get(), "logits");
|
||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
||||
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
// from the mfccs output tensor to find all the relevant nodes required for
|
||||
// feature computation, building an execution plan that runs just those nodes.
|
||||
auto mfcc_plan = tflite_find_parent_node_ids(model->interpreter.get(), model->mfccs_idx);
|
||||
auto orig_plan = model->interpreter->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||
});
|
||||
orig_plan.erase(erase_begin, orig_plan.end());
|
||||
|
||||
model->acoustic_exec_plan = std::move(orig_plan);
|
||||
model->mfcc_exec_plan = std::move(mfcc_plan);
|
||||
|
||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
||||
|
||||
model->n_steps = dims_input_node->data[1];
|
||||
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
||||
model->n_features = dims_input_node->data[3];
|
||||
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
||||
|
||||
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->logits_idx)->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != model->alphabet->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << model->alphabet->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
TfLiteIntArray* dims_c = model->interpreter->tensor(model->previous_state_c_idx)->dims;
|
||||
TfLiteIntArray* dims_h = model->interpreter->tensor(model->previous_state_h_idx)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
|
||||
model->previous_state_size = dims_c->data[1];
|
||||
model->previous_state_c_.reset(new float[model->previous_state_size]());
|
||||
model->previous_state_h_.reset(new float[model->previous_state_size]());
|
||||
|
||||
// Set initial values for previous_state_c and previous_state_h
|
||||
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
|
||||
*retval = model.release();
|
||||
return DS_ERR_OK;
|
||||
std::unique_ptr<ModelState> model(new TFModelState());
|
||||
#else
|
||||
std::unique_ptr<ModelState> model(new TFLiteModelState());
|
||||
#endif // USE_TFLITE
|
||||
|
||||
if (!model) {
|
||||
std::cerr << "Could not allocate model state." << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_MODEL;
|
||||
}
|
||||
|
||||
int err = model->init(aModelPath, aNCep, aNContext, aAlphabetConfigPath, aBeamWidth);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
*retval = model.release();
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
void
|
||||
@ -854,10 +306,10 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
float aLMBeta)
|
||||
{
|
||||
try {
|
||||
aCtx->scorer = new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
*aCtx->alphabet);
|
||||
aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
*aCtx->alphabet_);
|
||||
return DS_ERR_OK;
|
||||
} catch (...) {
|
||||
return DS_ERR_INVALID_LM;
|
||||
@ -872,13 +324,10 @@ DS_SetupStream(ModelState* aCtx,
|
||||
{
|
||||
*retval = nullptr;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << std::endl;
|
||||
return DS_ERR_FAIL_RUN_SESS;
|
||||
int err = aCtx->initialize_state();
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
#endif // USE_TFLITE
|
||||
|
||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||
if (!ctx) {
|
||||
@ -886,27 +335,20 @@ DS_SetupStream(ModelState* aCtx,
|
||||
return DS_ERR_FAIL_CREATE_STREAM;
|
||||
}
|
||||
|
||||
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
||||
const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
// Default initial allocation = 3 seconds.
|
||||
if (aPreAllocFrames == 0) {
|
||||
aPreAllocFrames = 150;
|
||||
}
|
||||
|
||||
ctx->audio_buffer.reserve(aCtx->audio_win_len);
|
||||
ctx->mfcc_buffer.reserve(aCtx->mfcc_feats_per_timestep);
|
||||
ctx->mfcc_buffer.resize(aCtx->n_features*aCtx->n_context, 0.f);
|
||||
ctx->batch_buffer.reserve(aCtx->n_steps * aCtx->mfcc_feats_per_timestep);
|
||||
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
||||
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||
ctx->batch_buffer_.reserve(aCtx->n_steps_ * aCtx->mfcc_feats_per_timestep_);
|
||||
ctx->model_ = aCtx;
|
||||
|
||||
ctx->model = aCtx;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
||||
memset(ctx->model->previous_state_c_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
||||
memset(ctx->model->previous_state_h_.get(), 0, sizeof(float) * ctx->model->previous_state_size);
|
||||
#endif // USE_TFLITE
|
||||
|
||||
ctx->decoder_state.reset(decoder_init(*aCtx->alphabet, num_classes, aCtx->scorer));
|
||||
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
@ -1012,4 +454,3 @@ DS_PrintVersions() {
|
||||
LOGD("DeepSpeech: %s", ds_git_version());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -52,6 +52,7 @@ enum DeepSpeech_Error_Codes
|
||||
DS_ERR_FAIL_CREATE_STREAM = 0x3004,
|
||||
DS_ERR_FAIL_READ_PROTOBUF = 0x3005,
|
||||
DS_ERR_FAIL_CREATE_SESS = 0x3006,
|
||||
DS_ERR_FAIL_CREATE_MODEL = 0x3007,
|
||||
};
|
||||
|
||||
/**
|
||||
|
81
native_client/modelstate.cc
Normal file
81
native_client/modelstate.cc
Normal file
@ -0,0 +1,81 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ctcdecode/ctc_beam_search_decoder.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
using std::vector;
|
||||
|
||||
ModelState::ModelState()
|
||||
: alphabet_(nullptr)
|
||||
, scorer_(nullptr)
|
||||
, beam_width_(-1)
|
||||
, n_steps_(-1)
|
||||
, n_context_(-1)
|
||||
, n_features_(-1)
|
||||
, mfcc_feats_per_timestep_(-1)
|
||||
, sample_rate_(DEFAULT_SAMPLE_RATE)
|
||||
, audio_win_len_(DEFAULT_WINDOW_LENGTH)
|
||||
, audio_win_step_(DEFAULT_WINDOW_STEP)
|
||||
{
|
||||
}
|
||||
|
||||
ModelState::~ModelState()
|
||||
{
|
||||
delete scorer_;
|
||||
delete alphabet_;
|
||||
}
|
||||
|
||||
int
|
||||
ModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
n_features_ = n_features;
|
||||
n_context_ = n_context;
|
||||
alphabet_ = new Alphabet(alphabet_path);
|
||||
beam_width_ = beam_width;
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
vector<Output>
|
||||
ModelState::decode_raw(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decoder_decode(state, *alphabet_, beam_width_, scorer_);
|
||||
return out;
|
||||
}
|
||||
|
||||
char*
|
||||
ModelState::decode(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
ModelState::decode_metadata(DecoderState* state)
|
||||
{
|
||||
vector<Output> out = decode_raw(state);
|
||||
|
||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||
metadata->num_items = out[0].tokens.size();
|
||||
metadata->probability = out[0].probability;
|
||||
|
||||
std::unique_ptr<MetadataItem[]> items(new MetadataItem[metadata->num_items]());
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].timestep = out[0].timesteps[i];
|
||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||
|
||||
if (items[i].start_time < 0) {
|
||||
items[i].start_time = 0;
|
||||
}
|
||||
}
|
||||
|
||||
metadata->items = items.release();
|
||||
return metadata.release();
|
||||
}
|
88
native_client/modelstate.h
Normal file
88
native_client/modelstate.h
Normal file
@ -0,0 +1,88 @@
|
||||
#ifndef MODELSTATE_H
|
||||
#define MODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "deepspeech.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "ctcdecode/output.h"
|
||||
#include "ctcdecode/decoderstate.h"
|
||||
|
||||
struct ModelState {
|
||||
//TODO: infer batch size from model/use dynamic batch size
|
||||
static constexpr unsigned int BATCH_SIZE = 1;
|
||||
|
||||
static constexpr unsigned int DEFAULT_SAMPLE_RATE = 16000;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
Alphabet* alphabet_;
|
||||
Scorer* scorer_;
|
||||
unsigned int beam_width_;
|
||||
unsigned int n_steps_;
|
||||
unsigned int n_context_;
|
||||
unsigned int n_features_;
|
||||
unsigned int mfcc_feats_per_timestep_;
|
||||
unsigned int sample_rate_;
|
||||
unsigned int audio_win_len_;
|
||||
unsigned int audio_win_step_;
|
||||
|
||||
ModelState();
|
||||
virtual ~ModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width);
|
||||
|
||||
virtual int initialize_state() = 0;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Do a single inference step in the acoustic model, with:
|
||||
* input=mfcc
|
||||
* input_lengths=[n_frames]
|
||||
*
|
||||
* @param mfcc batch input data
|
||||
* @param n_frames number of timesteps in the data
|
||||
*
|
||||
* @param[out] output_logits Where to store computed logits.
|
||||
*/
|
||||
virtual void infer(const float* mfcc, unsigned int n_frames, std::vector<float>& logits_output) = 0;
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
||||
*/
|
||||
virtual std::vector<Output> decode_raw(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||
* CTC decoder with KenLM enabled
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return String representing the decoded text.
|
||||
*/
|
||||
virtual char* decode(DecoderState* state);
|
||||
|
||||
/**
|
||||
* @brief Return character-level metadata including letter timings.
|
||||
*
|
||||
* @param state Decoder state to use when decoding.
|
||||
*
|
||||
* @return Metadata struct containing MetadataItem structs for each character.
|
||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||
*/
|
||||
virtual Metadata* decode_metadata(DecoderState* state);
|
||||
};
|
||||
|
||||
#endif // MODELSTATE_H
|
258
native_client/tflitemodelstate.cc
Normal file
258
native_client/tflitemodelstate.cc
Normal file
@ -0,0 +1,258 @@
|
||||
#include "tflitemodelstate.h"
|
||||
|
||||
using namespace tflite;
|
||||
using std::vector;
|
||||
|
||||
int
|
||||
tflite_get_tensor_by_name(const Interpreter* interpreter,
|
||||
const vector<int>& list,
|
||||
const char* name)
|
||||
{
|
||||
int rv = -1;
|
||||
|
||||
for (int i = 0; i < list.size(); ++i) {
|
||||
const string& node_name = interpreter->tensor(list[i])->name;
|
||||
if (node_name.compare(string(name)) == 0) {
|
||||
rv = i;
|
||||
}
|
||||
}
|
||||
|
||||
assert(rv >= 0);
|
||||
return rv;
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_input_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||
{
|
||||
int idx = tflite_get_tensor_by_name(interpreter, interpreter->inputs(), name);
|
||||
return interpreter->inputs()[idx];
|
||||
}
|
||||
|
||||
int
|
||||
tflite_get_output_tensor_by_name(const Interpreter* interpreter, const char* name)
|
||||
{
|
||||
int idx = tflite_get_tensor_by_name(interpreter, interpreter->outputs(), name);
|
||||
return interpreter->outputs()[idx];
|
||||
}
|
||||
|
||||
void push_back_if_not_present(std::deque<int>& list, int value)
|
||||
{
|
||||
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
||||
list.push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
||||
// from the frontier list, then for each node which has that tensor id as an
|
||||
// output, add it to the parent list, and add its input tensors to the frontier
|
||||
// list. Because we start from the final tensor and work backwards to the inputs,
|
||||
// the parents list is constructed in reverse, adding elements to its front.
|
||||
std::vector<int>
|
||||
tflite_find_parent_node_ids(Interpreter* interpreter, int tensor_id)
|
||||
{
|
||||
std::deque<int> parents;
|
||||
std::deque<int> frontier;
|
||||
frontier.push_back(tensor_id);
|
||||
while (!frontier.empty()) {
|
||||
int next_tensor_id = frontier.front();
|
||||
frontier.pop_front();
|
||||
// Find all nodes that have next_tensor_id as an output
|
||||
for (int node_id = 0; node_id < interpreter->nodes_size(); ++node_id) {
|
||||
TfLiteNode node = interpreter->node_and_registration(node_id)->first;
|
||||
// Search node outputs for the tensor we're looking for
|
||||
for (int i = 0; i < node.outputs->size; ++i) {
|
||||
if (node.outputs->data[i] == next_tensor_id) {
|
||||
// This node is part of the parent tree, add it to the parent list and
|
||||
// add its input tensors to the frontier list
|
||||
parents.push_front(node_id);
|
||||
for (int j = 0; j < node.inputs->size; ++j) {
|
||||
push_back_if_not_present(frontier, node.inputs->data[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::vector<int>(parents.begin(), parents.end());
|
||||
}
|
||||
|
||||
TFLiteModelState::TFLiteModelState()
|
||||
: ModelState()
|
||||
, interpreter_(nullptr)
|
||||
, fbmodel_(nullptr)
|
||||
, previous_state_size_(0)
|
||||
, previous_state_c_(nullptr)
|
||||
, previous_state_h_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path);
|
||||
if (!fbmodel_) {
|
||||
std::cerr << "Error at reading model file " << model_path << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_);
|
||||
if (!interpreter_) {
|
||||
std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl;
|
||||
return DS_ERR_FAIL_INTERPRETER;
|
||||
}
|
||||
|
||||
interpreter_->AllocateTensors();
|
||||
interpreter_->SetNumThreads(4);
|
||||
|
||||
// Query all the index once
|
||||
input_node_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_node");
|
||||
previous_state_c_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_c");
|
||||
previous_state_h_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "previous_state_h");
|
||||
input_samples_idx_ = tflite_get_input_tensor_by_name(interpreter_.get(), "input_samples");
|
||||
logits_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "logits");
|
||||
new_state_c_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_c");
|
||||
new_state_h_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "new_state_h");
|
||||
mfccs_idx_ = tflite_get_output_tensor_by_name(interpreter_.get(), "mfccs");
|
||||
|
||||
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
||||
// which means every time compute_mfcc is called the entire acoustic model is
|
||||
// also executed. To workaround that problem, we walk up the dependency DAG
|
||||
// from the mfccs output tensor to find all the relevant nodes required for
|
||||
// feature computation, building an execution plan that runs just those nodes.
|
||||
auto mfcc_plan = tflite_find_parent_node_ids(interpreter_.get(), mfccs_idx_);
|
||||
auto orig_plan = interpreter_->execution_plan();
|
||||
|
||||
// Remove MFCC nodes from original plan (all nodes) to create the acoustic model plan
|
||||
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan](int elem) {
|
||||
return std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end();
|
||||
});
|
||||
orig_plan.erase(erase_begin, orig_plan.end());
|
||||
|
||||
acoustic_exec_plan_ = std::move(orig_plan);
|
||||
mfcc_exec_plan_ = std::move(mfcc_plan);
|
||||
|
||||
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
||||
|
||||
n_steps_ = dims_input_node->data[1];
|
||||
n_context_ = (dims_input_node->data[2] - 1) / 2;
|
||||
n_features_ = dims_input_node->data[3];
|
||||
mfcc_feats_per_timestep_ = dims_input_node->data[2] * dims_input_node->data[3];
|
||||
|
||||
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
||||
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
|
||||
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
|
||||
previous_state_size_ = dims_c->data[1];
|
||||
previous_state_c_.reset(new float[previous_state_size_]());
|
||||
previous_state_h_.reset(new float[previous_state_size_]());
|
||||
|
||||
// Set initial values for previous_state_c and previous_state_h
|
||||
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
TFLiteModelState::initialize_state()
|
||||
{
|
||||
/* Ensure previous_state_{c,h} are not holding previous stream value */
|
||||
memset(previous_state_c_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
memset(previous_state_h_.get(), 0, sizeof(float) * previous_state_size_);
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
// Feeding input_node
|
||||
float* input_node = interpreter_->typed_tensor<float>(input_node_idx_);
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||
input_node[i] = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||
input_node[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(previous_state_size_ > 0);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
memcpy(interpreter_->typed_tensor<float>(previous_state_c_idx_), previous_state_c_.get(), sizeof(float) * previous_state_size_);
|
||||
memcpy(interpreter_->typed_tensor<float>(previous_state_h_idx_), previous_state_h_.get(), sizeof(float) * previous_state_size_);
|
||||
|
||||
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
float* outputs = interpreter_->typed_tensor<float>(logits_idx_);
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(outputs[t]);
|
||||
}
|
||||
|
||||
memcpy(previous_state_c_.get(), interpreter_->typed_tensor<float>(new_state_c_idx_), sizeof(float) * previous_state_size_);
|
||||
memcpy(previous_state_h_.get(), interpreter_->typed_tensor<float>(new_state_h_idx_), sizeof(float) * previous_state_size_);
|
||||
}
|
||||
|
||||
void
|
||||
TFLiteModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
// Feeding input_node
|
||||
float* input_samples = interpreter_->typed_tensor<float>(input_samples_idx_);
|
||||
for (int i = 0; i < samples.size(); ++i) {
|
||||
input_samples[i] = samples[i];
|
||||
}
|
||||
|
||||
interpreter_->SetExecutionPlan(mfcc_exec_plan_);
|
||||
TfLiteStatus status = interpreter_->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
int n_windows = 1;
|
||||
TfLiteIntArray* out_dims = interpreter_->tensor(mfccs_idx_)->dims;
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < out_dims->size; ++i) {
|
||||
num_elements *= out_dims->data[i];
|
||||
}
|
||||
assert(num_elements / n_features_ == n_windows);
|
||||
|
||||
float* outputs = interpreter_->typed_tensor<float>(mfccs_idx_);
|
||||
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||
mfcc_output.push_back(outputs[i]);
|
||||
}
|
||||
}
|
51
native_client/tflitemodelstate.h
Normal file
51
native_client/tflitemodelstate.h
Normal file
@ -0,0 +1,51 @@
|
||||
#ifndef TFLITEMODELSTATE_H
|
||||
#define TFLITEMODELSTATE_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFLiteModelState : public ModelState
|
||||
{
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
std::unique_ptr<tflite::FlatBufferModel> fbmodel_;
|
||||
|
||||
size_t previous_state_size_;
|
||||
std::unique_ptr<float[]> previous_state_c_;
|
||||
std::unique_ptr<float[]> previous_state_h_;
|
||||
|
||||
int input_node_idx_;
|
||||
int previous_state_c_idx_;
|
||||
int previous_state_h_idx_;
|
||||
int input_samples_idx_;
|
||||
|
||||
int logits_idx_;
|
||||
int new_state_c_idx_;
|
||||
int new_state_h_idx_;
|
||||
int mfccs_idx_;
|
||||
|
||||
std::vector<int> acoustic_exec_plan_;
|
||||
std::vector<int> mfcc_exec_plan_;
|
||||
|
||||
TFLiteModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual int initialize_state() override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
|
||||
virtual void infer(const float* mfcc, unsigned int n_frames,
|
||||
std::vector<float>& logits_output) override;
|
||||
};
|
||||
|
||||
#endif // TFLITEMODELSTATE_H
|
214
native_client/tfmodelstate.cc
Normal file
214
native_client/tfmodelstate.cc
Normal file
@ -0,0 +1,214 @@
|
||||
#include "tfmodelstate.h"
|
||||
|
||||
#include "ds_graph_version.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
using std::vector;
|
||||
|
||||
TFModelState::TFModelState()
|
||||
: ModelState()
|
||||
, mmap_env_(nullptr)
|
||||
, session_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
TFModelState::~TFModelState()
|
||||
{
|
||||
if (session_) {
|
||||
Status status = session_->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
delete mmap_env_;
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width)
|
||||
{
|
||||
int err = ModelState::init(model_path, n_features, n_context, alphabet_path, beam_width);
|
||||
if (err != DS_ERR_OK) {
|
||||
return err;
|
||||
}
|
||||
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
mmap_env_ = new MemmappedEnv(Env::Default());
|
||||
|
||||
bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos;
|
||||
if (!is_mmap) {
|
||||
std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl;
|
||||
} else {
|
||||
status = mmap_env_->InitializeFromFile(model_path);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_MMAP;
|
||||
}
|
||||
|
||||
options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(::OptimizerOptions::L0);
|
||||
options.env = mmap_env_;
|
||||
}
|
||||
|
||||
status = NewSession(options, &session_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_INIT_SESS;
|
||||
}
|
||||
|
||||
if (is_mmap) {
|
||||
status = ReadBinaryProto(mmap_env_,
|
||||
MemmappedFileSystem::kMemmappedPackageDefaultGraphDef,
|
||||
&graph_def_);
|
||||
} else {
|
||||
status = ReadBinaryProto(Env::Default(), model_path, &graph_def_);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_READ_PROTOBUF;
|
||||
}
|
||||
|
||||
status = session_->Create(graph_def_);
|
||||
if (!status.ok()) {
|
||||
std::cerr << status << std::endl;
|
||||
return DS_ERR_FAIL_CREATE_SESS;
|
||||
}
|
||||
|
||||
int graph_version = graph_def_.version();
|
||||
if (graph_version < DS_GRAPH_VERSION) {
|
||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||
<< "incompatible with minimum version supported by this client ("
|
||||
<< DS_GRAPH_VERSION << "). See "
|
||||
<< "https://github.com/mozilla/DeepSpeech/#model-compatibility "
|
||||
<< "for more information" << std::endl;
|
||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||
}
|
||||
|
||||
for (int i = 0; i < graph_def_.node_size(); ++i) {
|
||||
NodeDef node = graph_def_.node(i);
|
||||
if (node.name() == "input_node") {
|
||||
const auto& shape = node.attr().at("shape").shape();
|
||||
n_steps_ = shape.dim(1).size();
|
||||
n_context_ = (shape.dim(2).size()-1)/2;
|
||||
n_features_ = shape.dim(3).size();
|
||||
mfcc_feats_per_timestep_ = shape.dim(2).size() * shape.dim(3).size();
|
||||
} else if (node.name() == "logits_shape") {
|
||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
} else if (node.name() == "model_metadata") {
|
||||
sample_rate_ = node.attr().at("sample_rate").i();
|
||||
int win_len_ms = node.attr().at("feature_win_len").i();
|
||||
int win_step_ms = node.attr().at("feature_win_step").i();
|
||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (n_context_ == -1 || n_features_ == -1) {
|
||||
std::cerr << "Error: Could not infer input shape from model file. "
|
||||
<< "Make sure input_node is a 4D tensor with shape "
|
||||
<< "[batch_size=1, time, window_size, n_features]."
|
||||
<< std::endl;
|
||||
return DS_ERR_INVALID_SHAPE;
|
||||
}
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
TFModelState::initialize_state()
|
||||
{
|
||||
Status status = session_->Run({}, {}, {"initialize_state"}, nullptr);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << std::endl;
|
||||
return DS_ERR_FAIL_RUN_SESS;
|
||||
}
|
||||
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logits_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep_; ++i) {
|
||||
input_mapped(i) = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps_*mfcc_feats_per_timestep_; ++i) {
|
||||
input_mapped(i) = 0.;
|
||||
}
|
||||
|
||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||
input_lengths.scalar<int>()() = n_frames;
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run(
|
||||
{{"input_node", input}, {"input_lengths", input_lengths}},
|
||||
{"logits"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
auto logits_mapped = outputs[0].flat<float>();
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(logits_mapped(t));
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
TFModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||
{
|
||||
Tensor input(DT_FLOAT, TensorShape({audio_win_len_}));
|
||||
auto input_mapped = input.flat<float>();
|
||||
int i;
|
||||
for (i = 0; i < samples.size(); ++i) {
|
||||
input_mapped(i) = samples[i];
|
||||
}
|
||||
for (; i < audio_win_len_; ++i) {
|
||||
input_mapped(i) = 0.f;
|
||||
}
|
||||
|
||||
vector<Tensor> outputs;
|
||||
Status status = session_->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// The feature computation graph is hardcoded to one audio length for now
|
||||
const int n_windows = 1;
|
||||
assert(outputs[0].shape().num_elements() / n_features_ == n_windows);
|
||||
|
||||
auto mfcc_mapped = outputs[0].flat<float>();
|
||||
for (int i = 0; i < n_windows * n_features_; ++i) {
|
||||
mfcc_output.push_back(mfcc_mapped(i));
|
||||
}
|
||||
}
|
37
native_client/tfmodelstate.h
Normal file
37
native_client/tfmodelstate.h
Normal file
@ -0,0 +1,37 @@
|
||||
#ifndef TFMODELSTATE_H
|
||||
#define TFMODELSTATE_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
|
||||
#include "modelstate.h"
|
||||
|
||||
struct TFModelState : public ModelState
|
||||
{
|
||||
tensorflow::MemmappedEnv* mmap_env_;
|
||||
tensorflow::Session* session_;
|
||||
tensorflow::GraphDef graph_def_;
|
||||
|
||||
TFModelState();
|
||||
virtual ~TFModelState();
|
||||
|
||||
virtual int init(const char* model_path,
|
||||
unsigned int n_features,
|
||||
unsigned int n_context,
|
||||
const char* alphabet_path,
|
||||
unsigned int beam_width) override;
|
||||
|
||||
virtual int initialize_state() override;
|
||||
|
||||
virtual void infer(const float* mfcc,
|
||||
unsigned int n_frames,
|
||||
std::vector<float>& logits_output) override;
|
||||
|
||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||
std::vector<float>& mfcc_output) override;
|
||||
};
|
||||
|
||||
#endif // TFMODELSTATE_H
|
Loading…
Reference in New Issue
Block a user