Add TFLite engine
This commit is contained in:
parent
4b11736191
commit
69aa316c88
|
@ -3,6 +3,9 @@
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl",
|
load("@org_tensorflow//tensorflow:tensorflow.bzl",
|
||||||
"tf_cc_shared_object", "if_cuda")
|
"tf_cc_shared_object", "if_cuda")
|
||||||
|
|
||||||
|
load("@org_tensorflow//tensorflow/contrib/lite:build_def.bzl",
|
||||||
|
"tflite_copts", "tflite_linkopts")
|
||||||
|
|
||||||
genrule(
|
genrule(
|
||||||
name = "ds_git_version",
|
name = "ds_git_version",
|
||||||
outs = ["ds_version.h"],
|
outs = ["ds_version.h"],
|
||||||
|
@ -55,14 +58,19 @@ tf_cc_shared_object(
|
||||||
DECODER_SOURCES,
|
DECODER_SOURCES,
|
||||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||||
# which makes it harder to see our own warnings
|
# which makes it harder to see our own warnings
|
||||||
copts = ["-Wno-sign-compare", "-fvisibility=hidden"],
|
copts = ["-Wno-sign-compare", "-fvisibility=hidden"] + tflite_copts(),
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:darwin": [],
|
"//tensorflow:darwin": [],
|
||||||
"//tensorflow:linux_x86_64": LINUX_LINKOPTS,
|
"//tensorflow:linux_x86_64": LINUX_LINKOPTS,
|
||||||
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||||
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||||
}),
|
"//conditions:default": []
|
||||||
deps = [
|
}) + tflite_linkopts(),
|
||||||
|
deps = select({
|
||||||
|
"//tensorflow:android": [
|
||||||
|
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||||
|
],
|
||||||
|
"//conditions:default": [
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:direct_session",
|
"//tensorflow/core:direct_session",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
@ -94,7 +102,8 @@ tf_cc_shared_object(
|
||||||
#### Needed by production model produced without "--use_seq_length False"
|
#### Needed by production model produced without "--use_seq_length False"
|
||||||
#"//tensorflow/core/kernels:logging_ops", # Assert
|
#"//tensorflow/core/kernels:logging_ops", # Assert
|
||||||
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
|
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
|
||||||
] + if_cuda([
|
],
|
||||||
|
}) + if_cuda([
|
||||||
"//tensorflow/core:core",
|
"//tensorflow/core:core",
|
||||||
]),
|
]),
|
||||||
includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES,
|
includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES,
|
||||||
|
|
|
@ -9,12 +9,20 @@
|
||||||
#include "deepspeech.h"
|
#include "deepspeech.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
#include "tensorflow/core/public/version.h"
|
#include "tensorflow/core/public/version.h"
|
||||||
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
#include "native_client/ds_version.h"
|
#include "native_client/ds_version.h"
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||||
|
#else // USE_TFLITE
|
||||||
|
#include "tensorflow/contrib/lite/model.h"
|
||||||
|
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||||
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
#include "c_speech_features.h"
|
#include "c_speech_features.h"
|
||||||
|
|
||||||
|
@ -51,7 +59,11 @@ std::array<float, WINDOW_SIZE> calc_hamming_window() {
|
||||||
|
|
||||||
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
|
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
using namespace tensorflow;
|
using namespace tensorflow;
|
||||||
|
#else
|
||||||
|
using namespace tflite;
|
||||||
|
#endif
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
|
||||||
|
@ -104,9 +116,14 @@ struct StreamingState {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ModelState {
|
struct ModelState {
|
||||||
|
#ifndef USE_TFLITE
|
||||||
MemmappedEnv* mmap_env;
|
MemmappedEnv* mmap_env;
|
||||||
Session* session;
|
Session* session;
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
|
#else // USE_TFLITE
|
||||||
|
std::unique_ptr<Interpreter> interpreter;
|
||||||
|
std::unique_ptr<FlatBufferModel> fbmodel;
|
||||||
|
#endif // USE_TFLITE
|
||||||
unsigned int ncep;
|
unsigned int ncep;
|
||||||
unsigned int ncontext;
|
unsigned int ncontext;
|
||||||
Alphabet* alphabet;
|
Alphabet* alphabet;
|
||||||
|
@ -116,6 +133,12 @@ struct ModelState {
|
||||||
unsigned int mfcc_feats_per_timestep;
|
unsigned int mfcc_feats_per_timestep;
|
||||||
unsigned int n_context;
|
unsigned int n_context;
|
||||||
|
|
||||||
|
#ifdef USE_TFLITE
|
||||||
|
size_t previous_state_size;
|
||||||
|
std::unique_ptr<float[]> previous_state_c_;
|
||||||
|
std::unique_ptr<float[]> previous_state_h_;
|
||||||
|
#endif
|
||||||
|
|
||||||
ModelState();
|
ModelState();
|
||||||
~ModelState();
|
~ModelState();
|
||||||
|
|
||||||
|
@ -144,8 +167,14 @@ struct ModelState {
|
||||||
};
|
};
|
||||||
|
|
||||||
ModelState::ModelState()
|
ModelState::ModelState()
|
||||||
: mmap_env(nullptr)
|
:
|
||||||
|
#ifndef USE_TFLITE
|
||||||
|
mmap_env(nullptr)
|
||||||
, session(nullptr)
|
, session(nullptr)
|
||||||
|
#else // USE_TFLITE
|
||||||
|
interpreter(nullptr)
|
||||||
|
, fbmodel(nullptr)
|
||||||
|
#endif // USE_TFLITE
|
||||||
, ncep(0)
|
, ncep(0)
|
||||||
, ncontext(0)
|
, ncontext(0)
|
||||||
, alphabet(nullptr)
|
, alphabet(nullptr)
|
||||||
|
@ -154,20 +183,27 @@ ModelState::ModelState()
|
||||||
, n_steps(-1)
|
, n_steps(-1)
|
||||||
, mfcc_feats_per_timestep(-1)
|
, mfcc_feats_per_timestep(-1)
|
||||||
, n_context(-1)
|
, n_context(-1)
|
||||||
|
#ifdef USE_TFLITE
|
||||||
|
, previous_state_size(0)
|
||||||
|
, previous_state_c_(nullptr)
|
||||||
|
, previous_state_h_(nullptr)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelState::~ModelState()
|
ModelState::~ModelState()
|
||||||
{
|
{
|
||||||
|
#ifndef USE_TFLITE
|
||||||
if (session) {
|
if (session) {
|
||||||
Status status = session->Close();
|
Status status = session->Close();
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
delete mmap_env;
|
||||||
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
delete scorer;
|
delete scorer;
|
||||||
delete mmap_env;
|
|
||||||
delete alphabet;
|
delete alphabet;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,6 +329,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
||||||
{
|
{
|
||||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
|
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
|
||||||
|
|
||||||
auto input_mapped = input.flat<float>();
|
auto input_mapped = input.flat<float>();
|
||||||
|
@ -322,6 +359,41 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
||||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||||
logits_output.push_back(logits_mapped(t));
|
logits_output.push_back(logits_mapped(t));
|
||||||
}
|
}
|
||||||
|
#else // USE_TFLITE
|
||||||
|
// Feeding input_node
|
||||||
|
float* input_node = interpreter->typed_tensor<float>(interpreter->inputs()[0]);
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||||
|
input_node[i] = aMfcc[i];
|
||||||
|
}
|
||||||
|
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||||
|
input_node[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(previous_state_size > 0);
|
||||||
|
|
||||||
|
// Feeding previous_state_c, previous_state_h
|
||||||
|
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[1]), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
||||||
|
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[2]), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
||||||
|
|
||||||
|
TfLiteStatus status = interpreter->Invoke();
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* outputs = interpreter->typed_tensor<float>(interpreter->outputs()[0]);
|
||||||
|
|
||||||
|
// The CTCDecoder works with log-probs.
|
||||||
|
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||||
|
logits_output.push_back(outputs[t]);
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[1]), sizeof(float) * previous_state_size);
|
||||||
|
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[2]), sizeof(float) * previous_state_size);
|
||||||
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
|
@ -352,7 +424,9 @@ DS_CreateModel(const char* aModelPath,
|
||||||
ModelState** retval)
|
ModelState** retval)
|
||||||
{
|
{
|
||||||
std::unique_ptr<ModelState> model(new ModelState());
|
std::unique_ptr<ModelState> model(new ModelState());
|
||||||
|
#ifndef USE_TFLITE
|
||||||
model->mmap_env = new MemmappedEnv(Env::Default());
|
model->mmap_env = new MemmappedEnv(Env::Default());
|
||||||
|
#endif // USE_TFLITE
|
||||||
model->ncep = aNCep;
|
model->ncep = aNCep;
|
||||||
model->ncontext = aNContext;
|
model->ncontext = aNContext;
|
||||||
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||||
|
@ -364,9 +438,14 @@ DS_CreateModel(const char* aModelPath,
|
||||||
|
|
||||||
if (!aModelPath || strlen(aModelPath) < 1) {
|
if (!aModelPath || strlen(aModelPath) < 1) {
|
||||||
std::cerr << "No model specified, cannot continue." << std::endl;
|
std::cerr << "No model specified, cannot continue." << std::endl;
|
||||||
|
#ifndef USE_TFLITE
|
||||||
return error::INVALID_ARGUMENT;
|
return error::INVALID_ARGUMENT;
|
||||||
|
#else // USE_TFLITE
|
||||||
|
return EINVAL;
|
||||||
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
Status status;
|
Status status;
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
|
|
||||||
|
@ -448,6 +527,62 @@ DS_CreateModel(const char* aModelPath,
|
||||||
|
|
||||||
*retval = model.release();
|
*retval = model.release();
|
||||||
return tensorflow::error::OK;
|
return tensorflow::error::OK;
|
||||||
|
#else // USE_TFLITE
|
||||||
|
TfLiteStatus status;
|
||||||
|
|
||||||
|
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||||
|
status = tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
std::cerr << status << std::endl;
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
model->interpreter->AllocateTensors();
|
||||||
|
model->interpreter->SetNumThreads(4);
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->interpreter->inputs()[0])->dims;
|
||||||
|
|
||||||
|
model->n_steps = dims_input_node->data[1];
|
||||||
|
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
||||||
|
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->interpreter->outputs()[0])->dims;
|
||||||
|
const int final_dim_size = dims_logits->data[1] - 1;
|
||||||
|
if (final_dim_size != model->alphabet->GetSize()) {
|
||||||
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
|
<< "has size " << model->alphabet->GetSize()
|
||||||
|
<< ", but model has " << final_dim_size
|
||||||
|
<< " classes in its output. Make sure you're passing an alphabet "
|
||||||
|
<< "file with the same size as the one used for training."
|
||||||
|
<< std::endl;
|
||||||
|
return EINVAL;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int previous_state_c_id = model->interpreter->inputs()[1];
|
||||||
|
const int previous_state_h_id = model->interpreter->inputs()[2];
|
||||||
|
|
||||||
|
TfLiteIntArray* dims_c = model->interpreter->tensor(previous_state_c_id)->dims;
|
||||||
|
TfLiteIntArray* dims_h = model->interpreter->tensor(previous_state_h_id)->dims;
|
||||||
|
assert(dims_c->data[1] == dims_h->data[1]);
|
||||||
|
|
||||||
|
model->previous_state_size = dims_c->data[1];
|
||||||
|
model->previous_state_c_.reset(new float[model->previous_state_size]());
|
||||||
|
model->previous_state_h_.reset(new float[model->previous_state_size]());
|
||||||
|
|
||||||
|
// Set initial values for previous_state_c and previous_state_h
|
||||||
|
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||||
|
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||||
|
|
||||||
|
*retval = model.release();
|
||||||
|
return kTfLiteOk;
|
||||||
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
@ -483,7 +618,11 @@ DS_SpeechToText(ModelState* aCtx,
|
||||||
{
|
{
|
||||||
StreamingState* ctx;
|
StreamingState* ctx;
|
||||||
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
||||||
|
#ifndef USE_TFLITE
|
||||||
if (status != tensorflow::error::OK) {
|
if (status != tensorflow::error::OK) {
|
||||||
|
#else // USE_TFLITE
|
||||||
|
if (status != kTfLiteOk) {
|
||||||
|
#endif // USE_TFLITE
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
||||||
|
@ -498,16 +637,22 @@ DS_SetupStream(ModelState* aCtx,
|
||||||
{
|
{
|
||||||
*retval = nullptr;
|
*retval = nullptr;
|
||||||
|
|
||||||
|
#ifndef USE_TFLITE
|
||||||
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
std::cerr << "Error running session: " << status << std::endl;
|
std::cerr << "Error running session: " << status << std::endl;
|
||||||
return status.code();
|
return status.code();
|
||||||
}
|
}
|
||||||
|
#endif // USE_TFLITE
|
||||||
|
|
||||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
std::cerr << "Could not allocate streaming state." << std::endl;
|
std::cerr << "Could not allocate streaming state." << std::endl;
|
||||||
|
#ifndef USE_TFLITE
|
||||||
return status.code();
|
return status.code();
|
||||||
|
#else // USE_TFLITE
|
||||||
|
return ENOMEM;
|
||||||
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
||||||
|
@ -528,7 +673,11 @@ DS_SetupStream(ModelState* aCtx,
|
||||||
ctx->model = aCtx;
|
ctx->model = aCtx;
|
||||||
|
|
||||||
*retval = ctx.release();
|
*retval = ctx.release();
|
||||||
|
#ifndef USE_TFLITE
|
||||||
return tensorflow::error::OK;
|
return tensorflow::error::OK;
|
||||||
|
#else // USE_TFLITE
|
||||||
|
return kTfLiteOk;
|
||||||
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
|
|
@ -10,3 +10,19 @@ rm -rf windows include lm/filter lm/builder util/stream util/getopt.* python
|
||||||
|
|
||||||
This was done in order to ensure uniqueness of double_conversion:
|
This was done in order to ensure uniqueness of double_conversion:
|
||||||
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
|
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
|
||||||
|
|
||||||
|
Please apply this patch to be able to build on Android:
|
||||||
|
diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc
|
||||||
|
index d53dc0a..b5e36b2 100644
|
||||||
|
--- a/native_client/kenlm/util/file.cc
|
||||||
|
+++ b/native_client/kenlm/util/file.cc
|
||||||
|
@@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
|
||||||
|
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
||||||
|
for (int i=0; vars[i]; ++i) {
|
||||||
|
char *val =
|
||||||
|
-#if defined(_GNU_SOURCE)
|
||||||
|
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
|
||||||
|
#if __GLIBC_PREREQ(2,17)
|
||||||
|
secure_getenv
|
||||||
|
#else // __GLIBC_PREREQ
|
||||||
|
|
||||||
|
|
|
@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
|
||||||
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
||||||
for (int i=0; vars[i]; ++i) {
|
for (int i=0; vars[i]; ++i) {
|
||||||
char *val =
|
char *val =
|
||||||
#if defined(_GNU_SOURCE)
|
#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
|
||||||
#if __GLIBC_PREREQ(2,17)
|
#if __GLIBC_PREREQ(2,17)
|
||||||
secure_getenv
|
secure_getenv
|
||||||
#else // __GLIBC_PREREQ
|
#else // __GLIBC_PREREQ
|
||||||
|
|
Loading…
Reference in New Issue