Add TFLite engine

This commit is contained in:
Alexandre Lissy 2018-11-06 16:37:33 +01:00
parent 4b11736191
commit 69aa316c88
4 changed files with 217 additions and 43 deletions

View File

@ -3,6 +3,9 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", load("@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_shared_object", "if_cuda") "tf_cc_shared_object", "if_cuda")
load("@org_tensorflow//tensorflow/contrib/lite:build_def.bzl",
"tflite_copts", "tflite_linkopts")
genrule( genrule(
name = "ds_git_version", name = "ds_git_version",
outs = ["ds_version.h"], outs = ["ds_version.h"],
@ -55,14 +58,19 @@ tf_cc_shared_object(
DECODER_SOURCES, DECODER_SOURCES,
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself, # -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
# which makes it harder to see our own warnings # which makes it harder to see our own warnings
copts = ["-Wno-sign-compare", "-fvisibility=hidden"], copts = ["-Wno-sign-compare", "-fvisibility=hidden"] + tflite_copts(),
linkopts = select({ linkopts = select({
"//tensorflow:darwin": [], "//tensorflow:darwin": [],
"//tensorflow:linux_x86_64": LINUX_LINKOPTS, "//tensorflow:linux_x86_64": LINUX_LINKOPTS,
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"], "//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"], "//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
}), "//conditions:default": []
deps = [ }) + tflite_linkopts(),
deps = select({
"//tensorflow:android": [
"//tensorflow/contrib/lite/kernels:builtin_ops",
],
"//conditions:default": [
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session", "//tensorflow/core:direct_session",
"//third_party/eigen3", "//third_party/eigen3",
@ -94,7 +102,8 @@ tf_cc_shared_object(
#### Needed by production model produced without "--use_seq_length False" #### Needed by production model produced without "--use_seq_length False"
#"//tensorflow/core/kernels:logging_ops", # Assert #"//tensorflow/core/kernels:logging_ops", # Assert
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence #"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
] + if_cuda([ ],
}) + if_cuda([
"//tensorflow/core:core", "//tensorflow/core:core",
]), ]),
includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES, includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES,

View File

@ -9,12 +9,20 @@
#include "deepspeech.h" #include "deepspeech.h"
#include "alphabet.h" #include "alphabet.h"
#ifndef USE_TFLITE
#include "tensorflow/core/public/version.h" #include "tensorflow/core/public/version.h"
#endif // USE_TFLITE
#include "native_client/ds_version.h" #include "native_client/ds_version.h"
#ifndef USE_TFLITE
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/util/memmapped_file_system.h" #include "tensorflow/core/util/memmapped_file_system.h"
#else // USE_TFLITE
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#endif // USE_TFLITE
#include "c_speech_features.h" #include "c_speech_features.h"
@ -51,7 +59,11 @@ std::array<float, WINDOW_SIZE> calc_hamming_window() {
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window(); std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
#ifndef USE_TFLITE
using namespace tensorflow; using namespace tensorflow;
#else
using namespace tflite;
#endif
using std::vector; using std::vector;
@ -104,9 +116,14 @@ struct StreamingState {
}; };
struct ModelState { struct ModelState {
#ifndef USE_TFLITE
MemmappedEnv* mmap_env; MemmappedEnv* mmap_env;
Session* session; Session* session;
GraphDef graph_def; GraphDef graph_def;
#else // USE_TFLITE
std::unique_ptr<Interpreter> interpreter;
std::unique_ptr<FlatBufferModel> fbmodel;
#endif // USE_TFLITE
unsigned int ncep; unsigned int ncep;
unsigned int ncontext; unsigned int ncontext;
Alphabet* alphabet; Alphabet* alphabet;
@ -116,6 +133,12 @@ struct ModelState {
unsigned int mfcc_feats_per_timestep; unsigned int mfcc_feats_per_timestep;
unsigned int n_context; unsigned int n_context;
#ifdef USE_TFLITE
size_t previous_state_size;
std::unique_ptr<float[]> previous_state_c_;
std::unique_ptr<float[]> previous_state_h_;
#endif
ModelState(); ModelState();
~ModelState(); ~ModelState();
@ -144,8 +167,14 @@ struct ModelState {
}; };
ModelState::ModelState() ModelState::ModelState()
: mmap_env(nullptr) :
#ifndef USE_TFLITE
mmap_env(nullptr)
, session(nullptr) , session(nullptr)
#else // USE_TFLITE
interpreter(nullptr)
, fbmodel(nullptr)
#endif // USE_TFLITE
, ncep(0) , ncep(0)
, ncontext(0) , ncontext(0)
, alphabet(nullptr) , alphabet(nullptr)
@ -154,20 +183,27 @@ ModelState::ModelState()
, n_steps(-1) , n_steps(-1)
, mfcc_feats_per_timestep(-1) , mfcc_feats_per_timestep(-1)
, n_context(-1) , n_context(-1)
#ifdef USE_TFLITE
, previous_state_size(0)
, previous_state_c_(nullptr)
, previous_state_h_(nullptr)
#endif
{ {
} }
ModelState::~ModelState() ModelState::~ModelState()
{ {
#ifndef USE_TFLITE
if (session) { if (session) {
Status status = session->Close(); Status status = session->Close();
if (!status.ok()) { if (!status.ok()) {
std::cerr << "Error closing TensorFlow session: " << status << std::endl; std::cerr << "Error closing TensorFlow session: " << status << std::endl;
} }
} }
delete mmap_env;
#endif // USE_TFLITE
delete scorer; delete scorer;
delete mmap_env;
delete alphabet; delete alphabet;
} }
@ -293,6 +329,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
{ {
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
#ifndef USE_TFLITE
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES})); Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
auto input_mapped = input.flat<float>(); auto input_mapped = input.flat<float>();
@ -322,6 +359,41 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) { for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
logits_output.push_back(logits_mapped(t)); logits_output.push_back(logits_mapped(t));
} }
#else // USE_TFLITE
// Feeding input_node
float* input_node = interpreter->typed_tensor<float>(interpreter->inputs()[0]);
{
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
input_node[i] = aMfcc[i];
}
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
input_node[i] = 0;
}
}
assert(previous_state_size > 0);
// Feeding previous_state_c, previous_state_h
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[1]), previous_state_c_.get(), sizeof(float) * previous_state_size);
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[2]), previous_state_h_.get(), sizeof(float) * previous_state_size);
TfLiteStatus status = interpreter->Invoke();
if (status != kTfLiteOk) {
std::cerr << "Error running session: " << status << "\n";
return;
}
float* outputs = interpreter->typed_tensor<float>(interpreter->outputs()[0]);
// The CTCDecoder works with log-probs.
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
logits_output.push_back(outputs[t]);
}
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[1]), sizeof(float) * previous_state_size);
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[2]), sizeof(float) * previous_state_size);
#endif // USE_TFLITE
} }
char* char*
@ -352,7 +424,9 @@ DS_CreateModel(const char* aModelPath,
ModelState** retval) ModelState** retval)
{ {
std::unique_ptr<ModelState> model(new ModelState()); std::unique_ptr<ModelState> model(new ModelState());
#ifndef USE_TFLITE
model->mmap_env = new MemmappedEnv(Env::Default()); model->mmap_env = new MemmappedEnv(Env::Default());
#endif // USE_TFLITE
model->ncep = aNCep; model->ncep = aNCep;
model->ncontext = aNContext; model->ncontext = aNContext;
model->alphabet = new Alphabet(aAlphabetConfigPath); model->alphabet = new Alphabet(aAlphabetConfigPath);
@ -364,9 +438,14 @@ DS_CreateModel(const char* aModelPath,
if (!aModelPath || strlen(aModelPath) < 1) { if (!aModelPath || strlen(aModelPath) < 1) {
std::cerr << "No model specified, cannot continue." << std::endl; std::cerr << "No model specified, cannot continue." << std::endl;
#ifndef USE_TFLITE
return error::INVALID_ARGUMENT; return error::INVALID_ARGUMENT;
#else // USE_TFLITE
return EINVAL;
#endif // USE_TFLITE
} }
#ifndef USE_TFLITE
Status status; Status status;
SessionOptions options; SessionOptions options;
@ -448,6 +527,62 @@ DS_CreateModel(const char* aModelPath,
*retval = model.release(); *retval = model.release();
return tensorflow::error::OK; return tensorflow::error::OK;
#else // USE_TFLITE
TfLiteStatus status;
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
if (status != kTfLiteOk) {
std::cerr << status << std::endl;
return status;
}
tflite::ops::builtin::BuiltinOpResolver resolver;
status = tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
if (status != kTfLiteOk) {
std::cerr << status << std::endl;
return status;
}
model->interpreter->AllocateTensors();
model->interpreter->SetNumThreads(4);
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->interpreter->inputs()[0])->dims;
model->n_steps = dims_input_node->data[1];
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->interpreter->outputs()[0])->dims;
const int final_dim_size = dims_logits->data[1] - 1;
if (final_dim_size != model->alphabet->GetSize()) {
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
<< "has size " << model->alphabet->GetSize()
<< ", but model has " << final_dim_size
<< " classes in its output. Make sure you're passing an alphabet "
<< "file with the same size as the one used for training."
<< std::endl;
return EINVAL;
}
const int previous_state_c_id = model->interpreter->inputs()[1];
const int previous_state_h_id = model->interpreter->inputs()[2];
TfLiteIntArray* dims_c = model->interpreter->tensor(previous_state_c_id)->dims;
TfLiteIntArray* dims_h = model->interpreter->tensor(previous_state_h_id)->dims;
assert(dims_c->data[1] == dims_h->data[1]);
model->previous_state_size = dims_c->data[1];
model->previous_state_c_.reset(new float[model->previous_state_size]());
model->previous_state_h_.reset(new float[model->previous_state_size]());
// Set initial values for previous_state_c and previous_state_h
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
*retval = model.release();
return kTfLiteOk;
#endif // USE_TFLITE
} }
void void
@ -483,7 +618,11 @@ DS_SpeechToText(ModelState* aCtx,
{ {
StreamingState* ctx; StreamingState* ctx;
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx); int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
#ifndef USE_TFLITE
if (status != tensorflow::error::OK) { if (status != tensorflow::error::OK) {
#else // USE_TFLITE
if (status != kTfLiteOk) {
#endif // USE_TFLITE
return nullptr; return nullptr;
} }
DS_FeedAudioContent(ctx, aBuffer, aBufferSize); DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
@ -498,16 +637,22 @@ DS_SetupStream(ModelState* aCtx,
{ {
*retval = nullptr; *retval = nullptr;
#ifndef USE_TFLITE
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr); Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
if (!status.ok()) { if (!status.ok()) {
std::cerr << "Error running session: " << status << std::endl; std::cerr << "Error running session: " << status << std::endl;
return status.code(); return status.code();
} }
#endif // USE_TFLITE
std::unique_ptr<StreamingState> ctx(new StreamingState()); std::unique_ptr<StreamingState> ctx(new StreamingState());
if (!ctx) { if (!ctx) {
std::cerr << "Could not allocate streaming state." << std::endl; std::cerr << "Could not allocate streaming state." << std::endl;
#ifndef USE_TFLITE
return status.code(); return status.code();
#else // USE_TFLITE
return ENOMEM;
#endif // USE_TFLITE
} }
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
@ -528,7 +673,11 @@ DS_SetupStream(ModelState* aCtx,
ctx->model = aCtx; ctx->model = aCtx;
*retval = ctx.release(); *retval = ctx.release();
#ifndef USE_TFLITE
return tensorflow::error::OK; return tensorflow::error::OK;
#else // USE_TFLITE
return kTfLiteOk;
#endif // USE_TFLITE
} }
void void

View File

@ -10,3 +10,19 @@ rm -rf windows include lm/filter lm/builder util/stream util/getopt.* python
This was done in order to ensure uniqueness of double_conversion: This was done in order to ensure uniqueness of double_conversion:
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g' git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
Please apply this patch to be able to build on Android:
diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc
index d53dc0a..b5e36b2 100644
--- a/native_client/kenlm/util/file.cc
+++ b/native_client/kenlm/util/file.cc
@@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
for (int i=0; vars[i]; ++i) {
char *val =
-#if defined(_GNU_SOURCE)
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
#if __GLIBC_PREREQ(2,17)
secure_getenv
#else // __GLIBC_PREREQ

View File

@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0}; const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
for (int i=0; vars[i]; ++i) { for (int i=0; vars[i]; ++i) {
char *val = char *val =
#if defined(_GNU_SOURCE) #if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
#if __GLIBC_PREREQ(2,17) #if __GLIBC_PREREQ(2,17)
secure_getenv secure_getenv
#else // __GLIBC_PREREQ #else // __GLIBC_PREREQ