Add TFLite engine
This commit is contained in:
parent
4b11736191
commit
69aa316c88
|
@ -3,6 +3,9 @@
|
|||
load("@org_tensorflow//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_shared_object", "if_cuda")
|
||||
|
||||
load("@org_tensorflow//tensorflow/contrib/lite:build_def.bzl",
|
||||
"tflite_copts", "tflite_linkopts")
|
||||
|
||||
genrule(
|
||||
name = "ds_git_version",
|
||||
outs = ["ds_version.h"],
|
||||
|
@ -55,14 +58,19 @@ tf_cc_shared_object(
|
|||
DECODER_SOURCES,
|
||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||
# which makes it harder to see our own warnings
|
||||
copts = ["-Wno-sign-compare", "-fvisibility=hidden"],
|
||||
copts = ["-Wno-sign-compare", "-fvisibility=hidden"] + tflite_copts(),
|
||||
linkopts = select({
|
||||
"//tensorflow:darwin": [],
|
||||
"//tensorflow:linux_x86_64": LINUX_LINKOPTS,
|
||||
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||
}),
|
||||
deps = [
|
||||
"//conditions:default": []
|
||||
}) + tflite_linkopts(),
|
||||
deps = select({
|
||||
"//tensorflow:android": [
|
||||
"//tensorflow/contrib/lite/kernels:builtin_ops",
|
||||
],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//third_party/eigen3",
|
||||
|
@ -94,7 +102,8 @@ tf_cc_shared_object(
|
|||
#### Needed by production model produced without "--use_seq_length False"
|
||||
#"//tensorflow/core/kernels:logging_ops", # Assert
|
||||
#"//tensorflow/core/kernels:reverse_sequence_op", # ReverseSequence
|
||||
] + if_cuda([
|
||||
],
|
||||
}) + if_cuda([
|
||||
"//tensorflow/core:core",
|
||||
]),
|
||||
includes = ["c_speech_features", "kiss_fft130"] + DECODER_INCLUDES,
|
||||
|
|
|
@ -9,12 +9,20 @@
|
|||
#include "deepspeech.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
#include "tensorflow/core/public/version.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "native_client/ds_version.h"
|
||||
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#ifndef USE_TFLITE
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/util/memmapped_file_system.h"
|
||||
#else // USE_TFLITE
|
||||
#include "tensorflow/contrib/lite/model.h"
|
||||
#include "tensorflow/contrib/lite/kernels/register.h"
|
||||
#endif // USE_TFLITE
|
||||
|
||||
#include "c_speech_features.h"
|
||||
|
||||
|
@ -51,7 +59,11 @@ std::array<float, WINDOW_SIZE> calc_hamming_window() {
|
|||
|
||||
std::array<float, WINDOW_SIZE> hamming_window = calc_hamming_window();
|
||||
|
||||
using namespace tensorflow;
|
||||
#ifndef USE_TFLITE
|
||||
using namespace tensorflow;
|
||||
#else
|
||||
using namespace tflite;
|
||||
#endif
|
||||
|
||||
using std::vector;
|
||||
|
||||
|
@ -104,9 +116,14 @@ struct StreamingState {
|
|||
};
|
||||
|
||||
struct ModelState {
|
||||
#ifndef USE_TFLITE
|
||||
MemmappedEnv* mmap_env;
|
||||
Session* session;
|
||||
GraphDef graph_def;
|
||||
#else // USE_TFLITE
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
std::unique_ptr<FlatBufferModel> fbmodel;
|
||||
#endif // USE_TFLITE
|
||||
unsigned int ncep;
|
||||
unsigned int ncontext;
|
||||
Alphabet* alphabet;
|
||||
|
@ -116,6 +133,12 @@ struct ModelState {
|
|||
unsigned int mfcc_feats_per_timestep;
|
||||
unsigned int n_context;
|
||||
|
||||
#ifdef USE_TFLITE
|
||||
size_t previous_state_size;
|
||||
std::unique_ptr<float[]> previous_state_c_;
|
||||
std::unique_ptr<float[]> previous_state_h_;
|
||||
#endif
|
||||
|
||||
ModelState();
|
||||
~ModelState();
|
||||
|
||||
|
@ -144,8 +167,14 @@ struct ModelState {
|
|||
};
|
||||
|
||||
ModelState::ModelState()
|
||||
: mmap_env(nullptr)
|
||||
:
|
||||
#ifndef USE_TFLITE
|
||||
mmap_env(nullptr)
|
||||
, session(nullptr)
|
||||
#else // USE_TFLITE
|
||||
interpreter(nullptr)
|
||||
, fbmodel(nullptr)
|
||||
#endif // USE_TFLITE
|
||||
, ncep(0)
|
||||
, ncontext(0)
|
||||
, alphabet(nullptr)
|
||||
|
@ -154,20 +183,27 @@ ModelState::ModelState()
|
|||
, n_steps(-1)
|
||||
, mfcc_feats_per_timestep(-1)
|
||||
, n_context(-1)
|
||||
#ifdef USE_TFLITE
|
||||
, previous_state_size(0)
|
||||
, previous_state_c_(nullptr)
|
||||
, previous_state_h_(nullptr)
|
||||
#endif
|
||||
{
|
||||
}
|
||||
|
||||
ModelState::~ModelState()
|
||||
{
|
||||
#ifndef USE_TFLITE
|
||||
if (session) {
|
||||
Status status = session->Close();
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error closing TensorFlow session: " << status << std::endl;
|
||||
}
|
||||
}
|
||||
delete mmap_env;
|
||||
#endif // USE_TFLITE
|
||||
|
||||
delete scorer;
|
||||
delete mmap_env;
|
||||
delete alphabet;
|
||||
}
|
||||
|
||||
|
@ -293,6 +329,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||
{
|
||||
const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
|
||||
|
||||
auto input_mapped = input.flat<float>();
|
||||
|
@ -322,6 +359,41 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(logits_mapped(t));
|
||||
}
|
||||
#else // USE_TFLITE
|
||||
// Feeding input_node
|
||||
float* input_node = interpreter->typed_tensor<float>(interpreter->inputs()[0]);
|
||||
{
|
||||
int i;
|
||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = aMfcc[i];
|
||||
}
|
||||
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||
input_node[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
assert(previous_state_size > 0);
|
||||
|
||||
// Feeding previous_state_c, previous_state_h
|
||||
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[1]), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
||||
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[2]), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
||||
|
||||
TfLiteStatus status = interpreter->Invoke();
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << "Error running session: " << status << "\n";
|
||||
return;
|
||||
}
|
||||
|
||||
float* outputs = interpreter->typed_tensor<float>(interpreter->outputs()[0]);
|
||||
|
||||
// The CTCDecoder works with log-probs.
|
||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||
logits_output.push_back(outputs[t]);
|
||||
}
|
||||
|
||||
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[1]), sizeof(float) * previous_state_size);
|
||||
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[2]), sizeof(float) * previous_state_size);
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
char*
|
||||
|
@ -352,7 +424,9 @@ DS_CreateModel(const char* aModelPath,
|
|||
ModelState** retval)
|
||||
{
|
||||
std::unique_ptr<ModelState> model(new ModelState());
|
||||
#ifndef USE_TFLITE
|
||||
model->mmap_env = new MemmappedEnv(Env::Default());
|
||||
#endif // USE_TFLITE
|
||||
model->ncep = aNCep;
|
||||
model->ncontext = aNContext;
|
||||
model->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||
|
@ -364,9 +438,14 @@ DS_CreateModel(const char* aModelPath,
|
|||
|
||||
if (!aModelPath || strlen(aModelPath) < 1) {
|
||||
std::cerr << "No model specified, cannot continue." << std::endl;
|
||||
#ifndef USE_TFLITE
|
||||
return error::INVALID_ARGUMENT;
|
||||
#else // USE_TFLITE
|
||||
return EINVAL;
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Status status;
|
||||
SessionOptions options;
|
||||
|
||||
|
@ -448,6 +527,62 @@ DS_CreateModel(const char* aModelPath,
|
|||
|
||||
*retval = model.release();
|
||||
return tensorflow::error::OK;
|
||||
#else // USE_TFLITE
|
||||
TfLiteStatus status;
|
||||
|
||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||
status = tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
||||
if (status != kTfLiteOk) {
|
||||
std::cerr << status << std::endl;
|
||||
return status;
|
||||
}
|
||||
|
||||
model->interpreter->AllocateTensors();
|
||||
model->interpreter->SetNumThreads(4);
|
||||
|
||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->interpreter->inputs()[0])->dims;
|
||||
|
||||
model->n_steps = dims_input_node->data[1];
|
||||
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
||||
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
||||
|
||||
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->interpreter->outputs()[0])->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != model->alphabet->GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << model->alphabet->GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
<< std::endl;
|
||||
return EINVAL;
|
||||
}
|
||||
|
||||
const int previous_state_c_id = model->interpreter->inputs()[1];
|
||||
const int previous_state_h_id = model->interpreter->inputs()[2];
|
||||
|
||||
TfLiteIntArray* dims_c = model->interpreter->tensor(previous_state_c_id)->dims;
|
||||
TfLiteIntArray* dims_h = model->interpreter->tensor(previous_state_h_id)->dims;
|
||||
assert(dims_c->data[1] == dims_h->data[1]);
|
||||
|
||||
model->previous_state_size = dims_c->data[1];
|
||||
model->previous_state_c_.reset(new float[model->previous_state_size]());
|
||||
model->previous_state_h_.reset(new float[model->previous_state_size]());
|
||||
|
||||
// Set initial values for previous_state_c and previous_state_h
|
||||
memset(model->previous_state_c_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
memset(model->previous_state_h_.get(), 0, sizeof(float) * model->previous_state_size);
|
||||
|
||||
*retval = model.release();
|
||||
return kTfLiteOk;
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
void
|
||||
|
@ -483,7 +618,11 @@ DS_SpeechToText(ModelState* aCtx,
|
|||
{
|
||||
StreamingState* ctx;
|
||||
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
||||
#ifndef USE_TFLITE
|
||||
if (status != tensorflow::error::OK) {
|
||||
#else // USE_TFLITE
|
||||
if (status != kTfLiteOk) {
|
||||
#endif // USE_TFLITE
|
||||
return nullptr;
|
||||
}
|
||||
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
||||
|
@ -498,16 +637,22 @@ DS_SetupStream(ModelState* aCtx,
|
|||
{
|
||||
*retval = nullptr;
|
||||
|
||||
#ifndef USE_TFLITE
|
||||
Status status = aCtx->session->Run({}, {}, {"initialize_state"}, nullptr);
|
||||
if (!status.ok()) {
|
||||
std::cerr << "Error running session: " << status << std::endl;
|
||||
return status.code();
|
||||
}
|
||||
#endif // USE_TFLITE
|
||||
|
||||
std::unique_ptr<StreamingState> ctx(new StreamingState());
|
||||
if (!ctx) {
|
||||
std::cerr << "Could not allocate streaming state." << std::endl;
|
||||
#ifndef USE_TFLITE
|
||||
return status.code();
|
||||
#else // USE_TFLITE
|
||||
return ENOMEM;
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank
|
||||
|
@ -528,7 +673,11 @@ DS_SetupStream(ModelState* aCtx,
|
|||
ctx->model = aCtx;
|
||||
|
||||
*retval = ctx.release();
|
||||
#ifndef USE_TFLITE
|
||||
return tensorflow::error::OK;
|
||||
#else // USE_TFLITE
|
||||
return kTfLiteOk;
|
||||
#endif // USE_TFLITE
|
||||
}
|
||||
|
||||
void
|
||||
|
|
|
@ -10,3 +10,19 @@ rm -rf windows include lm/filter lm/builder util/stream util/getopt.* python
|
|||
|
||||
This was done in order to ensure uniqueness of double_conversion:
|
||||
git grep 'double_conversion' | cut -d':' -f1 | sort | uniq | xargs sed -ri 's/double_conversion/kenlm_double_conversion/g'
|
||||
|
||||
Please apply this patch to be able to build on Android:
|
||||
diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc
|
||||
index d53dc0a..b5e36b2 100644
|
||||
--- a/native_client/kenlm/util/file.cc
|
||||
+++ b/native_client/kenlm/util/file.cc
|
||||
@@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
|
||||
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
||||
for (int i=0; vars[i]; ++i) {
|
||||
char *val =
|
||||
-#if defined(_GNU_SOURCE)
|
||||
+#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
|
||||
#if __GLIBC_PREREQ(2,17)
|
||||
secure_getenv
|
||||
#else // __GLIBC_PREREQ
|
||||
|
||||
|
|
|
@ -540,7 +540,7 @@ std::string DefaultTempDirectory() {
|
|||
const char *const vars[] = {"TMPDIR", "TMP", "TEMPDIR", "TEMP", 0};
|
||||
for (int i=0; vars[i]; ++i) {
|
||||
char *val =
|
||||
#if defined(_GNU_SOURCE)
|
||||
#if defined(_GNU_SOURCE) && defined(__GLIBC_PREREQ)
|
||||
#if __GLIBC_PREREQ(2,17)
|
||||
secure_getenv
|
||||
#else // __GLIBC_PREREQ
|
||||
|
|
Loading…
Reference in New Issue