Merge pull request #1844 from lissyx/tflite-error-checks
Tflite error checks
This commit is contained in:
commit
6f7ddd31aa
@ -382,7 +382,7 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
}
|
}
|
||||||
#else // USE_TFLITE
|
#else // USE_TFLITE
|
||||||
// Feeding input_node
|
// Feeding input_node
|
||||||
float* input_node = interpreter->typed_tensor<float>(interpreter->inputs()[input_node_idx]);
|
float* input_node = interpreter->typed_tensor<float>(input_node_idx);
|
||||||
{
|
{
|
||||||
int i;
|
int i;
|
||||||
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||||
@ -396,8 +396,8 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
assert(previous_state_size > 0);
|
assert(previous_state_size > 0);
|
||||||
|
|
||||||
// Feeding previous_state_c, previous_state_h
|
// Feeding previous_state_c, previous_state_h
|
||||||
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[previous_state_c_idx]), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
memcpy(interpreter->typed_tensor<float>(previous_state_c_idx), previous_state_c_.get(), sizeof(float) * previous_state_size);
|
||||||
memcpy(interpreter->typed_tensor<float>(interpreter->inputs()[previous_state_h_idx]), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
memcpy(interpreter->typed_tensor<float>(previous_state_h_idx), previous_state_h_.get(), sizeof(float) * previous_state_size);
|
||||||
|
|
||||||
TfLiteStatus status = interpreter->Invoke();
|
TfLiteStatus status = interpreter->Invoke();
|
||||||
if (status != kTfLiteOk) {
|
if (status != kTfLiteOk) {
|
||||||
@ -405,15 +405,15 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float* outputs = interpreter->typed_tensor<float>(interpreter->outputs()[logits_idx]);
|
float* outputs = interpreter->typed_tensor<float>(logits_idx);
|
||||||
|
|
||||||
// The CTCDecoder works with log-probs.
|
// The CTCDecoder works with log-probs.
|
||||||
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
for (int t = 0; t < n_frames * BATCH_SIZE * num_classes; ++t) {
|
||||||
logits_output.push_back(outputs[t]);
|
logits_output.push_back(outputs[t]);
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[new_state_c_idx]), sizeof(float) * previous_state_size);
|
memcpy(previous_state_c_.get(), interpreter->typed_tensor<float>(new_state_c_idx), sizeof(float) * previous_state_size);
|
||||||
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(interpreter->outputs()[new_state_h_idx]), sizeof(float) * previous_state_size);
|
memcpy(previous_state_h_.get(), interpreter->typed_tensor<float>(new_state_h_idx), sizeof(float) * previous_state_size);
|
||||||
#endif // USE_TFLITE
|
#endif // USE_TFLITE
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,12 +454,12 @@ int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, co
|
|||||||
|
|
||||||
int tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
int tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
||||||
{
|
{
|
||||||
return tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name);
|
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
||||||
}
|
}
|
||||||
|
|
||||||
int tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
int tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
||||||
{
|
{
|
||||||
return tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name);
|
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -579,17 +579,17 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
TfLiteStatus status;
|
TfLiteStatus status;
|
||||||
|
|
||||||
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
model->fbmodel = tflite::FlatBufferModel::BuildFromFile(aModelPath);
|
||||||
if (status != kTfLiteOk) {
|
if (!model->fbmodel) {
|
||||||
std::cerr << status << std::endl;
|
std::cerr << "Error at reading model file " << aModelPath << std::endl;
|
||||||
return status;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
tflite::ops::builtin::BuiltinOpResolver resolver;
|
tflite::ops::builtin::BuiltinOpResolver resolver;
|
||||||
status = tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
tflite::InterpreterBuilder(*model->fbmodel, resolver)(&model->interpreter);
|
||||||
if (status != kTfLiteOk) {
|
if (!model->interpreter) {
|
||||||
std::cerr << status << std::endl;
|
std::cerr << "Error at InterpreterBuilder for model file " << aModelPath << std::endl;
|
||||||
return status;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
model->interpreter->AllocateTensors();
|
model->interpreter->AllocateTensors();
|
||||||
@ -603,13 +603,13 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
||||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
||||||
|
|
||||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->interpreter->inputs()[model->input_node_idx])->dims;
|
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
||||||
|
|
||||||
model->n_steps = dims_input_node->data[1];
|
model->n_steps = dims_input_node->data[1];
|
||||||
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
model->n_context = (dims_input_node->data[2] - 1 ) / 2;
|
||||||
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
model->mfcc_feats_per_timestep = dims_input_node->data[2] * dims_input_node->data[3];
|
||||||
|
|
||||||
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->interpreter->outputs()[model->logits_idx])->dims;
|
TfLiteIntArray* dims_logits = model->interpreter->tensor(model->logits_idx)->dims;
|
||||||
const int final_dim_size = dims_logits->data[1] - 1;
|
const int final_dim_size = dims_logits->data[1] - 1;
|
||||||
if (final_dim_size != model->alphabet->GetSize()) {
|
if (final_dim_size != model->alphabet->GetSize()) {
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
@ -621,11 +621,8 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
return EINVAL;
|
return EINVAL;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int previous_state_c_id = model->interpreter->inputs()[model->previous_state_c_idx];
|
TfLiteIntArray* dims_c = model->interpreter->tensor(model->previous_state_c_idx)->dims;
|
||||||
const int previous_state_h_id = model->interpreter->inputs()[model->previous_state_c_idx];
|
TfLiteIntArray* dims_h = model->interpreter->tensor(model->previous_state_h_idx)->dims;
|
||||||
|
|
||||||
TfLiteIntArray* dims_c = model->interpreter->tensor(previous_state_c_id)->dims;
|
|
||||||
TfLiteIntArray* dims_h = model->interpreter->tensor(previous_state_h_id)->dims;
|
|
||||||
assert(dims_c->data[1] == dims_h->data[1]);
|
assert(dims_c->data[1] == dims_h->data[1]);
|
||||||
|
|
||||||
model->previous_state_size = dims_c->data[1];
|
model->previous_state_size = dims_c->data[1];
|
||||||
|
Loading…
x
Reference in New Issue
Block a user