423 lines
15 KiB
C++
423 lines
15 KiB
C++
#include "tflitemodelstate.h"
|
|
#include "workspace_status.h"
|
|
|
|
#include "tensorflow/lite/string_util.h"
|
|
#include "tensorflow/lite/tools/evaluation/utils.h"
|
|
|
|
#ifdef __ANDROID__
|
|
#include <android/log.h>
|
|
#define LOG_TAG "libstt"
|
|
#define LOGD(...) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__)
|
|
#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__)
|
|
#else
|
|
#define LOGD(...)
|
|
#define LOGE(...)
|
|
#endif // __ANDROID__
|
|
|
|
using namespace tflite;
|
|
using std::vector;
|
|
|
|
int
|
|
TFLiteModelState::get_tensor_by_name(const vector<int>& list,
|
|
const char* name)
|
|
{
|
|
int rv = -1;
|
|
|
|
for (int i = 0; i < list.size(); ++i) {
|
|
const string& node_name = interpreter_->tensor(list[i])->name;
|
|
if (node_name.compare(string(name)) == 0) {
|
|
rv = i;
|
|
}
|
|
}
|
|
|
|
assert(rv >= 0);
|
|
return rv;
|
|
}
|
|
|
|
int
|
|
TFLiteModelState::get_input_tensor_by_name(const char* name)
|
|
{
|
|
int idx = get_tensor_by_name(interpreter_->inputs(), name);
|
|
return interpreter_->inputs()[idx];
|
|
}
|
|
|
|
int
|
|
TFLiteModelState::get_output_tensor_by_name(const char* name)
|
|
{
|
|
int idx = get_tensor_by_name(interpreter_->outputs(), name);
|
|
return interpreter_->outputs()[idx];
|
|
}
|
|
|
|
void
|
|
push_back_if_not_present(std::deque<int>& list, int value)
|
|
{
|
|
if (std::find(list.begin(), list.end(), value) == list.end()) {
|
|
list.push_back(value);
|
|
}
|
|
}
|
|
|
|
// Backwards BFS on the node DAG. At each iteration we get the next tensor id
|
|
// from the frontier list, then for each node which has that tensor id as an
|
|
// output, add it to the parent list, and add its input tensors to the frontier
|
|
// list. Because we start from the final tensor and work backwards to the inputs,
|
|
// the parents list is constructed in reverse, adding elements to its front.
|
|
vector<int>
|
|
TFLiteModelState::find_parent_node_ids(int tensor_id)
|
|
{
|
|
std::deque<int> parents;
|
|
std::deque<int> frontier;
|
|
frontier.push_back(tensor_id);
|
|
while (!frontier.empty()) {
|
|
int next_tensor_id = frontier.front();
|
|
frontier.pop_front();
|
|
// Find all nodes that have next_tensor_id as an output
|
|
for (int node_id = 0; node_id < interpreter_->nodes_size(); ++node_id) {
|
|
TfLiteNode node = interpreter_->node_and_registration(node_id)->first;
|
|
// Search node outputs for the tensor we're looking for
|
|
for (int i = 0; i < node.outputs->size; ++i) {
|
|
if (node.outputs->data[i] == next_tensor_id) {
|
|
// This node is part of the parent tree, add it to the parent list and
|
|
// add its input tensors to the frontier list
|
|
parents.push_front(node_id);
|
|
for (int j = 0; j < node.inputs->size; ++j) {
|
|
push_back_if_not_present(frontier, node.inputs->data[j]);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return vector<int>(parents.begin(), parents.end());
|
|
}
|
|
|
|
TFLiteModelState::TFLiteModelState()
|
|
: ModelState()
|
|
, interpreter_(nullptr)
|
|
, fbmodel_(nullptr)
|
|
{
|
|
}
|
|
|
|
TFLiteModelState::~TFLiteModelState()
|
|
{
|
|
}
|
|
|
|
std::map<std::string, tflite::Interpreter::TfLiteDelegatePtr>
|
|
getTfliteDelegates()
|
|
{
|
|
std::map<std::string, tflite::Interpreter::TfLiteDelegatePtr> delegates;
|
|
|
|
const char* env_delegate_c = std::getenv("STT_TFLITE_DELEGATE");
|
|
std::string env_delegate = (env_delegate_c != nullptr) ? env_delegate_c : "";
|
|
|
|
#ifdef __ANDROID__
|
|
if (env_delegate == std::string("gpu")) {
|
|
LOGD("Trying to get GPU delegate ...");
|
|
// Try to get GPU delegate
|
|
{
|
|
tflite::Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateGPUDelegate();
|
|
if (!delegate) {
|
|
LOGD("GPU delegation not supported");
|
|
} else {
|
|
LOGD("GPU delegation supported");
|
|
delegates.emplace("GPU", std::move(delegate));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (env_delegate == std::string("nnapi")) {
|
|
LOGD("Trying to get NNAPI delegate ...");
|
|
// Try to get Android NNAPI delegate
|
|
{
|
|
tflite::Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateNNAPIDelegate();
|
|
if (!delegate) {
|
|
LOGD("NNAPI delegation not supported");
|
|
} else {
|
|
LOGD("NNAPI delegation supported");
|
|
delegates.emplace("NNAPI", std::move(delegate));
|
|
}
|
|
}
|
|
}
|
|
|
|
if (env_delegate == std::string("hexagon")) {
|
|
LOGD("Trying to get Hexagon delegate ...");
|
|
// Try to get Android Hexagon delegate
|
|
{
|
|
const std::string libhexagon_path("/data/local/tmp");
|
|
tflite::Interpreter::TfLiteDelegatePtr delegate = evaluation::CreateHexagonDelegate(libhexagon_path, /* profiler */ false);
|
|
if (!delegate) {
|
|
LOGD("Hexagon delegation not supported");
|
|
} else {
|
|
LOGD("Hexagon delegation supported");
|
|
delegates.emplace("Hexagon", std::move(delegate));
|
|
}
|
|
}
|
|
}
|
|
#endif // __ANDROID__
|
|
|
|
return delegates;
|
|
}
|
|
|
|
int
|
|
TFLiteModelState::init(const char* model_path)
|
|
{
|
|
int err = ModelState::init(model_path);
|
|
if (err != STT_ERR_OK) {
|
|
return err;
|
|
}
|
|
|
|
fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path);
|
|
if (!fbmodel_) {
|
|
std::cerr << "Error at reading model file " << model_path << std::endl;
|
|
return STT_ERR_FAIL_INIT_MMAP;
|
|
}
|
|
|
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
|
tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_);
|
|
if (!interpreter_) {
|
|
std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl;
|
|
return STT_ERR_FAIL_INTERPRETER;
|
|
}
|
|
|
|
LOGD("Trying to detect delegates ...");
|
|
std::map<std::string, tflite::Interpreter::TfLiteDelegatePtr> delegates = getTfliteDelegates();
|
|
LOGD("Finished enumerating delegates ...");
|
|
|
|
interpreter_->AllocateTensors();
|
|
interpreter_->SetNumThreads(4);
|
|
|
|
LOGD("Trying to use delegates ...");
|
|
for (const auto& delegate : delegates) {
|
|
LOGD("Trying to apply delegate %s", delegate.first.c_str());
|
|
if (interpreter_->ModifyGraphWithDelegate(delegate.second.get()) != kTfLiteOk) {
|
|
LOGD("FAILED to apply delegate %s to the graph", delegate.first.c_str());
|
|
}
|
|
}
|
|
|
|
// Query all the index once
|
|
input_node_idx_ = get_input_tensor_by_name("input_node");
|
|
previous_state_c_idx_ = get_input_tensor_by_name("previous_state_c");
|
|
previous_state_h_idx_ = get_input_tensor_by_name("previous_state_h");
|
|
input_samples_idx_ = get_input_tensor_by_name("input_samples");
|
|
logits_idx_ = get_output_tensor_by_name("logits");
|
|
new_state_c_idx_ = get_output_tensor_by_name("new_state_c");
|
|
new_state_h_idx_ = get_output_tensor_by_name("new_state_h");
|
|
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
|
|
|
int metadata_version_idx = get_output_tensor_by_name("metadata_version");
|
|
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
|
|
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
|
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
|
int metadata_beam_width_idx = get_output_tensor_by_name("metadata_beam_width");
|
|
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
|
|
|
|
std::vector<int> metadata_exec_plan;
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_beam_width_idx)[0]);
|
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
|
|
|
|
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
|
assert(metadata_exec_plan[i] > -1);
|
|
}
|
|
|
|
// When we call Interpreter::Invoke, the whole graph is executed by default,
|
|
// which means every time compute_mfcc is called the entire acoustic model is
|
|
// also executed. To workaround that problem, we walk up the dependency DAG
|
|
// from the mfccs output tensor to find all the relevant nodes required for
|
|
// feature computation, building an execution plan that runs just those nodes.
|
|
auto mfcc_plan = find_parent_node_ids(mfccs_idx_);
|
|
auto orig_plan = interpreter_->execution_plan();
|
|
|
|
// Remove MFCC and Metatda nodes from original plan (all nodes) to create the acoustic model plan
|
|
auto erase_begin = std::remove_if(orig_plan.begin(), orig_plan.end(), [&mfcc_plan, &metadata_exec_plan](int elem) {
|
|
return (std::find(mfcc_plan.begin(), mfcc_plan.end(), elem) != mfcc_plan.end()
|
|
|| std::find(metadata_exec_plan.begin(), metadata_exec_plan.end(), elem) != metadata_exec_plan.end());
|
|
});
|
|
orig_plan.erase(erase_begin, orig_plan.end());
|
|
|
|
acoustic_exec_plan_ = std::move(orig_plan);
|
|
mfcc_exec_plan_ = std::move(mfcc_plan);
|
|
|
|
interpreter_->SetExecutionPlan(metadata_exec_plan);
|
|
TfLiteStatus status = interpreter_->Invoke();
|
|
if (status != kTfLiteOk) {
|
|
std::cerr << "Error running session: " << status << "\n";
|
|
return STT_ERR_FAIL_INTERPRETER;
|
|
}
|
|
|
|
int* const graph_version = interpreter_->typed_tensor<int>(metadata_version_idx);
|
|
if (graph_version == nullptr) {
|
|
std::cerr << "Unable to read model file version." << std::endl;
|
|
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
}
|
|
|
|
if (*graph_version < ds_graph_version()) {
|
|
std::cerr << "Specified model file version (" << *graph_version << ") is "
|
|
<< "incompatible with minimum version supported by this client ("
|
|
<< ds_graph_version() << "). See "
|
|
<< "https://stt.readthedocs.io/en/latest/USING.html#model-compatibility "
|
|
<< "for more information" << std::endl;
|
|
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
}
|
|
|
|
int* const model_sample_rate = interpreter_->typed_tensor<int>(metadata_sample_rate_idx);
|
|
if (model_sample_rate == nullptr) {
|
|
std::cerr << "Unable to read model sample rate." << std::endl;
|
|
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
}
|
|
|
|
sample_rate_ = *model_sample_rate;
|
|
|
|
int* const win_len_ms = interpreter_->typed_tensor<int>(metadata_feature_win_len_idx);
|
|
int* const win_step_ms = interpreter_->typed_tensor<int>(metadata_feature_win_step_idx);
|
|
if (win_len_ms == nullptr || win_step_ms == nullptr) {
|
|
std::cerr << "Unable to read model feature window informations." << std::endl;
|
|
return STT_ERR_MODEL_INCOMPATIBLE;
|
|
}
|
|
|
|
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
|
|
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
|
|
|
|
int* const beam_width = interpreter_->typed_tensor<int>(metadata_beam_width_idx);
|
|
beam_width_ = (unsigned int)(*beam_width);
|
|
|
|
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
|
err = alphabet_.Deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
|
if (err != 0) {
|
|
return STT_ERR_INVALID_ALPHABET;
|
|
}
|
|
|
|
assert(sample_rate_ > 0);
|
|
assert(audio_win_len_ > 0);
|
|
assert(audio_win_step_ > 0);
|
|
assert(beam_width_ > 0);
|
|
assert(alphabet_.GetSize() > 0);
|
|
|
|
TfLiteIntArray* dims_input_node = interpreter_->tensor(input_node_idx_)->dims;
|
|
|
|
n_steps_ = dims_input_node->data[1];
|
|
n_context_ = (dims_input_node->data[2] - 1) / 2;
|
|
n_features_ = dims_input_node->data[3];
|
|
mfcc_feats_per_timestep_ = dims_input_node->data[2] * dims_input_node->data[3];
|
|
|
|
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
|
const int final_dim_size = dims_logits->data[1] - 1;
|
|
if (final_dim_size != alphabet_.GetSize()) {
|
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
|
<< "has size " << alphabet_.GetSize()
|
|
<< ", but model has " << final_dim_size
|
|
<< " classes in its output. Make sure you're passing an alphabet "
|
|
<< "file with the same size as the one used for training."
|
|
<< std::endl;
|
|
return STT_ERR_INVALID_ALPHABET;
|
|
}
|
|
|
|
TfLiteIntArray* dims_c = interpreter_->tensor(previous_state_c_idx_)->dims;
|
|
TfLiteIntArray* dims_h = interpreter_->tensor(previous_state_h_idx_)->dims;
|
|
assert(dims_c->data[1] == dims_h->data[1]);
|
|
assert(state_size_ > 0);
|
|
state_size_ = dims_c->data[1];
|
|
|
|
return STT_ERR_OK;
|
|
}
|
|
|
|
// Copy contents of vec into the tensor with index tensor_idx.
|
|
// If vec.size() < num_elements, set the remainder of the tensor values to zero.
|
|
void
|
|
TFLiteModelState::copy_vector_to_tensor(const vector<float>& vec,
|
|
int tensor_idx,
|
|
int num_elements)
|
|
{
|
|
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
|
int i;
|
|
for (i = 0; i < vec.size(); ++i) {
|
|
tensor[i] = vec[i];
|
|
}
|
|
for (; i < num_elements; ++i) {
|
|
tensor[i] = 0.f;
|
|
}
|
|
}
|
|
|
|
// Copy num_elements elements from the tensor with index tensor_idx into vec
|
|
void
|
|
TFLiteModelState::copy_tensor_to_vector(int tensor_idx,
|
|
int num_elements,
|
|
vector<float>& vec)
|
|
{
|
|
float* tensor = interpreter_->typed_tensor<float>(tensor_idx);
|
|
for (int i = 0; i < num_elements; ++i) {
|
|
vec.push_back(tensor[i]);
|
|
}
|
|
}
|
|
|
|
void
|
|
TFLiteModelState::infer(const vector<float>& mfcc,
|
|
unsigned int n_frames,
|
|
const vector<float>& previous_state_c,
|
|
const vector<float>& previous_state_h,
|
|
vector<float>& logits_output,
|
|
vector<float>& state_c_output,
|
|
vector<float>& state_h_output)
|
|
{
|
|
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
|
|
|
// Feeding input_node
|
|
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
|
|
|
// Feeding previous_state_c, previous_state_h
|
|
assert(previous_state_c.size() == state_size_);
|
|
copy_vector_to_tensor(previous_state_c, previous_state_c_idx_, state_size_);
|
|
assert(previous_state_h.size() == state_size_);
|
|
copy_vector_to_tensor(previous_state_h, previous_state_h_idx_, state_size_);
|
|
|
|
interpreter_->SetExecutionPlan(acoustic_exec_plan_);
|
|
TfLiteStatus status = interpreter_->Invoke();
|
|
if (status != kTfLiteOk) {
|
|
std::cerr << "Error running session: " << status << "\n";
|
|
return;
|
|
}
|
|
|
|
copy_tensor_to_vector(logits_idx_, n_frames * BATCH_SIZE * num_classes, logits_output);
|
|
|
|
state_c_output.clear();
|
|
state_c_output.reserve(state_size_);
|
|
copy_tensor_to_vector(new_state_c_idx_, state_size_, state_c_output);
|
|
|
|
state_h_output.clear();
|
|
state_h_output.reserve(state_size_);
|
|
copy_tensor_to_vector(new_state_h_idx_, state_size_, state_h_output);
|
|
}
|
|
|
|
void
|
|
TFLiteModelState::compute_mfcc(const vector<float>& samples,
|
|
vector<float>& mfcc_output)
|
|
{
|
|
// Feeding input_node
|
|
copy_vector_to_tensor(samples, input_samples_idx_, samples.size());
|
|
|
|
TfLiteStatus status = interpreter_->SetExecutionPlan(mfcc_exec_plan_);
|
|
if (status != kTfLiteOk) {
|
|
std::cerr << "Error setting execution plan: " << status << "\n";
|
|
return;
|
|
}
|
|
|
|
status = interpreter_->Invoke();
|
|
if (status != kTfLiteOk) {
|
|
std::cerr << "Error running session: " << status << "\n";
|
|
return;
|
|
}
|
|
|
|
// The feature computation graph is hardcoded to one audio length for now
|
|
int n_windows = 1;
|
|
TfLiteIntArray* out_dims = interpreter_->tensor(mfccs_idx_)->dims;
|
|
int num_elements = 1;
|
|
for (int i = 0; i < out_dims->size; ++i) {
|
|
num_elements *= out_dims->data[i];
|
|
}
|
|
assert(num_elements / n_features_ == n_windows);
|
|
|
|
copy_tensor_to_vector(mfccs_idx_, n_windows * n_features_, mfcc_output);
|
|
}
|