Fix TFLite bug in feature computation graph and clean up deepspeech.cc a bit
This commit is contained in:
parent
a7cda8e761
commit
232df740db
@ -548,11 +548,10 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
batch_size = batch_size if batch_size > 0 else None
|
batch_size = batch_size if batch_size > 0 else None
|
||||||
|
|
||||||
# Create feature computation graph
|
# Create feature computation graph
|
||||||
input_samples = tf.placeholder(tf.float32, [None], 'input_samples')
|
input_samples = tf.placeholder(tf.float32, [512], 'input_samples')
|
||||||
samples = tf.expand_dims(input_samples, -1)
|
samples = tf.expand_dims(input_samples, -1)
|
||||||
mfccs, mfccs_len = samples_to_mfccs(samples, 16000)
|
mfccs, _ = samples_to_mfccs(samples, 16000)
|
||||||
mfccs = tf.identity(mfccs, name='mfccs')
|
mfccs = tf.identity(mfccs, name='mfccs')
|
||||||
mfccs_len = tf.identity(mfccs_len, name='mfccs_len')
|
|
||||||
|
|
||||||
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
# This shape is read by the native_client in DS_CreateModel to know the
|
# This shape is read by the native_client in DS_CreateModel to know the
|
||||||
@ -633,7 +632,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
'outputs': logits,
|
'outputs': logits,
|
||||||
'initialize_state': initialize_state,
|
'initialize_state': initialize_state,
|
||||||
'mfccs': mfccs,
|
'mfccs': mfccs,
|
||||||
'mfccs_len': mfccs_len,
|
|
||||||
},
|
},
|
||||||
layers
|
layers
|
||||||
)
|
)
|
||||||
@ -659,7 +657,6 @@ def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
|
|||||||
'new_state_c': new_state_c,
|
'new_state_c': new_state_c,
|
||||||
'new_state_h': new_state_h,
|
'new_state_h': new_state_h,
|
||||||
'mfccs': mfccs,
|
'mfccs': mfccs,
|
||||||
'mfccs_len': mfccs_len,
|
|
||||||
},
|
},
|
||||||
layers
|
layers
|
||||||
)
|
)
|
||||||
|
|||||||
@ -149,7 +149,6 @@ struct ModelState {
|
|||||||
int new_state_c_idx;
|
int new_state_c_idx;
|
||||||
int new_state_h_idx;
|
int new_state_h_idx;
|
||||||
int mfccs_idx;
|
int mfccs_idx;
|
||||||
int mfccs_len_idx;
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ModelState();
|
ModelState();
|
||||||
@ -164,7 +163,7 @@ struct ModelState {
|
|||||||
*
|
*
|
||||||
* @return String representing the decoded text.
|
* @return String representing the decoded text.
|
||||||
*/
|
*/
|
||||||
char* decode(vector<float>& logits);
|
char* decode(const vector<float>& logits);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
@ -186,7 +185,7 @@ struct ModelState {
|
|||||||
* @return Metadata struct containing MetadataItem structs for each character.
|
* @return Metadata struct containing MetadataItem structs for each character.
|
||||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||||
*/
|
*/
|
||||||
Metadata* decode_metadata(vector<float>& logits);
|
Metadata* decode_metadata(const vector<float>& logits);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Do a single inference step in the acoustic model, with:
|
* @brief Do a single inference step in the acoustic model, with:
|
||||||
@ -203,9 +202,6 @@ struct ModelState {
|
|||||||
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
|
void compute_mfcc(const vector<float>& audio_buffer, vector<float>& mfcc_output);
|
||||||
};
|
};
|
||||||
|
|
||||||
StreamingState* SetupStreamAndFeedAudioContent(ModelState* aCtx, const short* aBuffer,
|
|
||||||
unsigned int aBufferSize, unsigned int aSampleRate);
|
|
||||||
|
|
||||||
ModelState::ModelState()
|
ModelState::ModelState()
|
||||||
:
|
:
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
@ -465,22 +461,27 @@ void
|
|||||||
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_output)
|
||||||
{
|
{
|
||||||
#ifndef USE_TFLITE
|
#ifndef USE_TFLITE
|
||||||
Tensor input(DT_FLOAT, TensorShape({static_cast<long long>(samples.size())}));
|
Tensor input(DT_FLOAT, TensorShape({AUDIO_WIN_LEN_SAMPLES}));
|
||||||
auto input_mapped = input.flat<float>();
|
auto input_mapped = input.flat<float>();
|
||||||
for (int i = 0; i < samples.size(); ++i) {
|
int i;
|
||||||
|
for (i = 0; i < samples.size(); ++i) {
|
||||||
input_mapped(i) = samples[i];
|
input_mapped(i) = samples[i];
|
||||||
}
|
}
|
||||||
|
for (; i < AUDIO_WIN_LEN_SAMPLES; ++i) {
|
||||||
|
input_mapped(i) = 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
vector<Tensor> outputs;
|
vector<Tensor> outputs;
|
||||||
Status status = session->Run({{"input_samples", input}}, {"mfccs", "mfccs_len"}, {}, &outputs);
|
Status status = session->Run({{"input_samples", input}}, {"mfccs"}, {}, &outputs);
|
||||||
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
std::cerr << "Error running session: " << status << "\n";
|
std::cerr << "Error running session: " << status << "\n";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mfcc_len_mapped = outputs[1].flat<int32>();
|
// The feature computation graph is hardcoded to one audio length for now
|
||||||
int n_windows = mfcc_len_mapped(0);
|
const int n_windows = 1;
|
||||||
|
assert(outputs[0].shape().num_elemements() / n_features == n_windows);
|
||||||
|
|
||||||
auto mfcc_mapped = outputs[0].flat<float>();
|
auto mfcc_mapped = outputs[0].flat<float>();
|
||||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||||
@ -499,7 +500,14 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_windows = *interpreter->typed_tensor<float>(mfccs_len_idx);
|
// The feature computation graph is hardcoded to one audio length for now
|
||||||
|
int n_windows = 1;
|
||||||
|
TfLiteIntArray* out_dims = interpreter->tensor(mfccs_idx)->dims;
|
||||||
|
int num_elements = 1;
|
||||||
|
for (int i = 0; i < out_dims->size; ++i) {
|
||||||
|
num_elements *= out_dims->data[i];
|
||||||
|
}
|
||||||
|
assert(num_elements / n_features == n_windows);
|
||||||
|
|
||||||
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
float* outputs = interpreter->typed_tensor<float>(mfccs_idx);
|
||||||
for (int i = 0; i < n_windows * n_features; ++i) {
|
for (int i = 0; i < n_windows * n_features; ++i) {
|
||||||
@ -509,10 +517,9 @@ ModelState::compute_mfcc(const vector<float>& samples, vector<float>& mfcc_outpu
|
|||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
ModelState::decode(vector<float>& logits)
|
ModelState::decode(const vector<float>& logits)
|
||||||
{
|
{
|
||||||
vector<Output> out = ModelState::decode_raw(logits);
|
vector<Output> out = ModelState::decode_raw(logits);
|
||||||
|
|
||||||
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet->LabelsToString(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -535,7 +542,8 @@ ModelState::decode_raw(const vector<float>& logits)
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata* ModelState::decode_metadata(vector<float>& logits)
|
Metadata*
|
||||||
|
ModelState::decode_metadata(const vector<float>& logits)
|
||||||
{
|
{
|
||||||
vector<Output> out = decode_raw(logits);
|
vector<Output> out = decode_raw(logits);
|
||||||
|
|
||||||
@ -559,7 +567,8 @@ Metadata* ModelState::decode_metadata(vector<float>& logits)
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef USE_TFLITE
|
#ifdef USE_TFLITE
|
||||||
int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
|
int
|
||||||
|
tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, const char* name)
|
||||||
{
|
{
|
||||||
int rv = -1;
|
int rv = -1;
|
||||||
|
|
||||||
@ -574,12 +583,14 @@ int tflite_get_tensor_by_name(const ModelState* ctx, const vector<int>& list, co
|
|||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
int tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
int
|
||||||
|
tflite_get_input_tensor_by_name(const ModelState* ctx, const char* name)
|
||||||
{
|
{
|
||||||
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
return ctx->interpreter->inputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->inputs(), name)];
|
||||||
}
|
}
|
||||||
|
|
||||||
int tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
int
|
||||||
|
tflite_get_output_tensor_by_name(const ModelState* ctx, const char* name)
|
||||||
{
|
{
|
||||||
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
return ctx->interpreter->outputs()[tflite_get_tensor_by_name(ctx, ctx->interpreter->outputs(), name)];
|
||||||
}
|
}
|
||||||
@ -729,7 +740,6 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
model->new_state_c_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_c");
|
||||||
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
model->new_state_h_idx = tflite_get_output_tensor_by_name(model.get(), "new_state_h");
|
||||||
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
model->mfccs_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs");
|
||||||
model->mfccs_len_idx = tflite_get_output_tensor_by_name(model.get(), "mfccs_len");
|
|
||||||
|
|
||||||
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
TfLiteIntArray* dims_input_node = model->interpreter->tensor(model->input_node_idx)->dims;
|
||||||
|
|
||||||
@ -792,41 +802,6 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
|
||||||
DS_SpeechToText(ModelState* aCtx,
|
|
||||||
const short* aBuffer,
|
|
||||||
unsigned int aBufferSize,
|
|
||||||
unsigned int aSampleRate)
|
|
||||||
{
|
|
||||||
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
|
||||||
return DS_FinishStream(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
Metadata*
|
|
||||||
DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
|
||||||
const short* aBuffer,
|
|
||||||
unsigned int aBufferSize,
|
|
||||||
unsigned int aSampleRate)
|
|
||||||
{
|
|
||||||
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
|
||||||
return DS_FinishStreamWithMetadata(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
StreamingState*
|
|
||||||
SetupStreamAndFeedAudioContent(ModelState* aCtx,
|
|
||||||
const short* aBuffer,
|
|
||||||
unsigned int aBufferSize,
|
|
||||||
unsigned int aSampleRate)
|
|
||||||
{
|
|
||||||
StreamingState* ctx;
|
|
||||||
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
|
||||||
if (status != DS_ERR_OK) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
|
||||||
return ctx;
|
|
||||||
}
|
|
||||||
|
|
||||||
int
|
int
|
||||||
DS_SetupStream(ModelState* aCtx,
|
DS_SetupStream(ModelState* aCtx,
|
||||||
unsigned int aPreAllocFrames,
|
unsigned int aPreAllocFrames,
|
||||||
@ -899,6 +874,41 @@ DS_FinishStreamWithMetadata(StreamingState* aSctx)
|
|||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StreamingState*
|
||||||
|
SetupStreamAndFeedAudioContent(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate)
|
||||||
|
{
|
||||||
|
StreamingState* ctx;
|
||||||
|
int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx);
|
||||||
|
if (status != DS_ERR_OK) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
DS_FeedAudioContent(ctx, aBuffer, aBufferSize);
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
char*
|
||||||
|
DS_SpeechToText(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate)
|
||||||
|
{
|
||||||
|
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||||
|
return DS_FinishStream(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Metadata*
|
||||||
|
DS_SpeechToTextWithMetadata(ModelState* aCtx,
|
||||||
|
const short* aBuffer,
|
||||||
|
unsigned int aBufferSize,
|
||||||
|
unsigned int aSampleRate)
|
||||||
|
{
|
||||||
|
StreamingState* ctx = SetupStreamAndFeedAudioContent(aCtx, aBuffer, aBufferSize, aSampleRate);
|
||||||
|
return DS_FinishStreamWithMetadata(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
DS_DiscardStream(StreamingState* aSctx)
|
DS_DiscardStream(StreamingState* aSctx)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -167,6 +167,11 @@ assert_shows_something()
|
|||||||
fi;
|
fi;
|
||||||
|
|
||||||
case "${stderr}" in
|
case "${stderr}" in
|
||||||
|
*"incompatible with minimum version"*)
|
||||||
|
echo "Prod model too old for client, skipping test."
|
||||||
|
return 0
|
||||||
|
;;
|
||||||
|
|
||||||
*${expected}*)
|
*${expected}*)
|
||||||
echo "Proper output has been produced:"
|
echo "Proper output has been produced:"
|
||||||
echo "${stderr}"
|
echo "${stderr}"
|
||||||
@ -342,10 +347,14 @@ run_all_inference_tests()
|
|||||||
set -e
|
set -e
|
||||||
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"
|
assert_correct_ldc93s1_lm "${phrase_pbmodel_withlm_stereo_44k}" "$status"
|
||||||
|
|
||||||
|
set +e
|
||||||
phrase_pbmodel_nolm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
phrase_pbmodel_nolm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||||
|
set -e
|
||||||
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"
|
assert_correct_warning_upsampling "${phrase_pbmodel_nolm_mono_8k}"
|
||||||
|
|
||||||
|
set +e
|
||||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||||
|
set -e
|
||||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,7 +378,9 @@ run_prod_inference_tests()
|
|||||||
set -e
|
set -e
|
||||||
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status"
|
assert_correct_ldc93s1_prodmodel_stereo_44k "${phrase_pbmodel_withlm_stereo_44k}" "$status"
|
||||||
|
|
||||||
|
set +e
|
||||||
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null)
|
||||||
|
set -e
|
||||||
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user