Merge pull request #1844 from lissyx/tflite-error-checks

Tflite error checks
This commit is contained in:
lissyx 2019-01-22 17:18:09 +01:00 committed by GitHub
commit 6f7ddd31aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -382,7 +382,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
}
#else // USE_TFLITE
// Feeding input_node
float* input_node = interpreter->typed_tensor<float>(interpreter->inputs()[input_node_idx]);
float* input_node = interpreter->typed_tensor<float>(input_node_idx);
{
int i;
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
@ -396,8 +396,8 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
assert(previous_state_size > 0);
// Feeding previous_state_c, previous_state_h
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[previous_state_c_idx]), previous_state_c_.get(), sizeof(float) * previous_state_size);
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[previous_state_h_idx]), previous_state_h_.get(), sizeof(float) * previous_state_size);
memcpy(interpreter->typed_tensor<float>(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size);
memcpy(interpreter->typed_tensor<float>(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size);
TfLiteStatus status = interpreter->Invoke();
if (status != kTfLiteOk) {
@ -405,15 +405,15 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
return;
}
float* outputs = interpreter->typed_tensor<float>(interpreter->outputs()[logits_idx]);
float* outputs = interpreter->typed_tensor<float>(logits_idx);
// The CTCDecoder works with log-probs.
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
logits_output.push_back(outputs[t]);
}
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[new_state_c_idx]), sizeof(float) * previous_state_size);
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[new_state_h_idx]), sizeof(float) * previous_state_size);
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(new_state_c_idx), sizeof(float) * previous_state_size);
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(new_state_h_idx), sizeof(float) * previous_state_size);
#endif // USE_TFLITE
}
@ -454,12 +454,12 @@ int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, co
int tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
{
return tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name);
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
}
int tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
{
return tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name);
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
}
#endif
@ -579,17 +579,17 @@ DS_CreateModel(const char* aModelPath,
TfLiteStatus status;
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
if (status != kTfLiteOk) {
std::cerr << status << std::endl;
return status;
if (!model->fbmodel) {
std::cerr << "Error at reading model file " << aModelPath << std::endl;
return kTfLiteError;
}
tflite::ops::builtin::BuiltinOpResolver resolver;
status = tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
if (status != kTfLiteOk) {
std::cerr << status << std::endl;
return status;
tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
if (!model->interpreter) {
std::cerr << "Error at InterpreterBuilder for model file " << aModelPath << std::endl;
return kTfLiteError;
}
model->interpreter->AllocateTensors();
@ -603,13 +603,13 @@ DS_CreateModel(const char* aModelPath,
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->interpreter->inputs()[model->input_node_idx])->dims;
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
model->n_steps = dims_input_node->data[1];
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->interpreter->outputs()[model->logits_idx])->dims;
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->logits_idx)->dims;
const int final_dim_size = dims_logits->data[1] - 1;
if (final_dim_size != model->alphabet->GetSize()) {
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
@ -621,11 +621,8 @@ DS_CreateModel(const char* aModelPath,
return EINVAL;
}
const int previous_state_c_id = model->interpreter->inputs()[model->previous_state_c_idx];
const int previous_state_h_id = model->interpreter->inputs()[model->previous_state_c_idx];
TfLiteIntArray* dims_c = model->interpreter->tensor(previous_state_c_id)->dims;
TfLiteIntArray* dims_h = model->interpreter->tensor(previous_state_h_id)->dims;
TfLiteIntArray* dims_c = model->interpreter->tensor(model->previous_state_c_idx)->dims;
TfLiteIntArray* dims_h = model->interpreter->tensor(model->previous_state_h_idx)->dims;
assert(dims_c->data[1] == dims_h->data[1]);
model->previous_state_size = dims_c->data[1];