Embed alphabet directly in model
This commit is contained in:
parent
493aaed151
commit
8c82081779
@ -780,12 +780,11 @@ def export():
|
|||||||
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
graph_version = int(file_relative_read('GRAPH_VERSION').strip())
|
||||||
assert graph_version > 0
|
assert graph_version > 0
|
||||||
|
|
||||||
# Reshape with dimension [1] required to avoid this error:
|
|
||||||
# ERROR: Input array not provided for operation 'reshape'.
|
|
||||||
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version')
|
||||||
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate')
|
||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language')
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
4
|
5
|
||||||
|
|||||||
@ -36,7 +36,7 @@ public:
|
|||||||
if (line == " ") {
|
if (line == " ") {
|
||||||
space_label_ = label;
|
space_label_ = label;
|
||||||
}
|
}
|
||||||
label_to_str_.push_back(line);
|
label_to_str_[label] = line;
|
||||||
str_to_label_[line] = label;
|
str_to_label_[line] = label;
|
||||||
++label;
|
++label;
|
||||||
}
|
}
|
||||||
@ -45,9 +45,53 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int deserialize(const char* buffer, const int buffer_size) {
|
||||||
|
int offset = 0;
|
||||||
|
if (buffer_size - offset < sizeof(int16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
int16_t size = *(int16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(int16_t);
|
||||||
|
size_ = size;
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (buffer_size - offset < sizeof(int16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
int16_t label = *(int16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(int16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < sizeof(int16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
int16_t val_len = *(int16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(int16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < val_len) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
std::string val(buffer+offset, val_len);
|
||||||
|
offset += val_len;
|
||||||
|
|
||||||
|
label_to_str_[label] = val;
|
||||||
|
str_to_label_[val] = label;
|
||||||
|
|
||||||
|
if (val == " ") {
|
||||||
|
space_label_ = label;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
const std::string& StringFromLabel(unsigned int label) const {
|
const std::string& StringFromLabel(unsigned int label) const {
|
||||||
assert(label < size_);
|
auto it = label_to_str_.find(label);
|
||||||
return label_to_str_[label];
|
if (it != label_to_str_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Invalid label " << label << std::endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned int LabelFromString(const std::string& string) const {
|
unsigned int LabelFromString(const std::string& string) const {
|
||||||
@ -55,7 +99,7 @@ public:
|
|||||||
if (it != str_to_label_.end()) {
|
if (it != str_to_label_.end()) {
|
||||||
return it->second;
|
return it->second;
|
||||||
} else {
|
} else {
|
||||||
std::cerr << "Invalid label " << string << std::endl;
|
std::cerr << "Invalid string " << string << std::endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,7 +128,7 @@ public:
|
|||||||
private:
|
private:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
unsigned int space_label_;
|
unsigned int space_label_;
|
||||||
std::vector<std::string> label_to_str_;
|
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||||
std::unordered_map<std::string, unsigned int> str_to_label_;
|
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -369,7 +369,7 @@ main(int argc, char **argv)
|
|||||||
|
|
||||||
// Initialise DeepSpeech
|
// Initialise DeepSpeech
|
||||||
ModelState* ctx;
|
ModelState* ctx;
|
||||||
int status = DS_CreateModel(model, alphabet, beam_width, &ctx);
|
int status = DS_CreateModel(model, beam_width, &ctx);
|
||||||
if (status != 0) {
|
if (status != 0) {
|
||||||
fprintf(stderr, "Could not create model.\n");
|
fprintf(stderr, "Could not create model.\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
|||||||
@ -18,9 +18,18 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
|
|
||||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
||||||
super(Scorer, self).__init__()
|
super(Scorer, self).__init__()
|
||||||
err = self.init(alpha, beta, model_path, trie_path, alphabet.config_file())
|
serialized = alphabet.serialize()
|
||||||
|
native_alphabet = swigwrapper.Alphabet()
|
||||||
|
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
raise ValueError("Error when deserializing alphabet.")
|
||||||
|
|
||||||
|
err = self.init(alpha, beta,
|
||||||
|
model_path.encode('utf-8'),
|
||||||
|
trie_path.encode('utf-8'),
|
||||||
|
native_alphabet)
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
@ -35,8 +44,7 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
step, with each element being a list of normalized
|
step, with each element being a list of normalized
|
||||||
probabilities over alphabet and blank.
|
probabilities over alphabet and blank.
|
||||||
:type probs_seq: 2-D list
|
:type probs_seq: 2-D list
|
||||||
:param alphabet: alphabet list.
|
:param alphabet: Alphabet
|
||||||
:alphabet: Alphabet
|
|
||||||
:param beam_size: Width for beam search.
|
:param beam_size: Width for beam search.
|
||||||
:type beam_size: int
|
:type beam_size: int
|
||||||
:param cutoff_prob: Cutoff probability in pruning,
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
@ -53,8 +61,13 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
|
serialized = alphabet.serialize()
|
||||||
|
native_alphabet = swigwrapper.Alphabet()
|
||||||
|
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError("Error when deserializing alphabet.")
|
||||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||||
probs_seq, alphabet.config_file(), beam_size, cutoff_prob, cutoff_top_n,
|
probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||||
scorer)
|
scorer)
|
||||||
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
||||||
return beam_results
|
return beam_results
|
||||||
@ -95,9 +108,12 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(
|
serialized = alphabet.serialize()
|
||||||
probs_seq, seq_lengths, alphabet.config_file(), beam_size, num_processes,
|
native_alphabet = swigwrapper.Alphabet()
|
||||||
cutoff_prob, cutoff_top_n, scorer)
|
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError("Error when deserializing alphabet.")
|
||||||
|
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||||
batch_beam_results = [
|
batch_beam_results = [
|
||||||
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
||||||
for beam_results in batch_beam_results
|
for beam_results in batch_beam_results
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
%{
|
%{
|
||||||
#include "ctc_beam_search_decoder.h"
|
#include "ctc_beam_search_decoder.h"
|
||||||
#define SWIG_FILE_WITH_INIT
|
#define SWIG_FILE_WITH_INIT
|
||||||
|
#define SWIG_PYTHON_STRICT_BYTE_CHAR
|
||||||
%}
|
%}
|
||||||
|
|
||||||
%include "pyabc.i"
|
%include "pyabc.i"
|
||||||
@ -16,57 +17,12 @@ import_array();
|
|||||||
|
|
||||||
// Convert NumPy arrays to pointer+lengths
|
// Convert NumPy arrays to pointer+lengths
|
||||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
||||||
|
|
||||||
// Add overloads converting char* to Alphabet
|
|
||||||
%inline %{
|
|
||||||
std::vector<Output>
|
|
||||||
ctc_beam_search_decoder(const double *probs,
|
|
||||||
int time_dim,
|
|
||||||
int class_dim,
|
|
||||||
char* alphabet_config_path,
|
|
||||||
size_t beam_size,
|
|
||||||
double cutoff_prob,
|
|
||||||
size_t cutoff_top_n,
|
|
||||||
Scorer *ext_scorer)
|
|
||||||
{
|
|
||||||
Alphabet a;
|
|
||||||
if (a.init(alphabet_config_path)) {
|
|
||||||
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
|
|
||||||
}
|
|
||||||
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
|
|
||||||
cutoff_prob, cutoff_top_n, ext_scorer);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<Output>>
|
|
||||||
ctc_beam_search_decoder_batch(const double *probs,
|
|
||||||
int batch_dim,
|
|
||||||
int time_dim,
|
|
||||||
int class_dim,
|
|
||||||
const int *seq_lengths,
|
|
||||||
int seq_lengths_size,
|
|
||||||
char* alphabet_config_path,
|
|
||||||
size_t beam_size,
|
|
||||||
size_t num_processes,
|
|
||||||
double cutoff_prob,
|
|
||||||
size_t cutoff_top_n,
|
|
||||||
Scorer *ext_scorer)
|
|
||||||
{
|
|
||||||
Alphabet a;
|
|
||||||
if (a.init(alphabet_config_path)) {
|
|
||||||
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
|
|
||||||
}
|
|
||||||
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
|
|
||||||
seq_lengths, seq_lengths_size, a, beam_size,
|
|
||||||
num_processes, cutoff_prob, cutoff_top_n,
|
|
||||||
ext_scorer);
|
|
||||||
}
|
|
||||||
%}
|
|
||||||
|
|
||||||
|
|
||||||
%ignore Scorer::dictionary;
|
%ignore Scorer::dictionary;
|
||||||
|
|
||||||
|
%include "../alphabet.h"
|
||||||
%include "output.h"
|
%include "output.h"
|
||||||
%include "scorer.h"
|
%include "scorer.h"
|
||||||
%include "ctc_beam_search_decoder.h"
|
%include "ctc_beam_search_decoder.h"
|
||||||
|
|||||||
@ -257,7 +257,6 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
|||||||
|
|
||||||
int
|
int
|
||||||
DS_CreateModel(const char* aModelPath,
|
DS_CreateModel(const char* aModelPath,
|
||||||
const char* aAlphabetConfigPath,
|
|
||||||
unsigned int aBeamWidth,
|
unsigned int aBeamWidth,
|
||||||
ModelState** retval)
|
ModelState** retval)
|
||||||
{
|
{
|
||||||
@ -283,7 +282,7 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
return DS_ERR_FAIL_CREATE_MODEL;
|
return DS_ERR_FAIL_CREATE_MODEL;
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = model->init(aModelPath, aAlphabetConfigPath, aBeamWidth);
|
int err = model->init(aModelPath, aBeamWidth);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -77,8 +77,6 @@ enum DeepSpeech_Error_Codes
|
|||||||
* @brief An object providing an interface to a trained DeepSpeech model.
|
* @brief An object providing an interface to a trained DeepSpeech model.
|
||||||
*
|
*
|
||||||
* @param aModelPath The path to the frozen model graph.
|
* @param aModelPath The path to the frozen model graph.
|
||||||
* @param aAlphabetConfigPath The path to the configuration file specifying
|
|
||||||
* the alphabet used by the network. See alphabet.h.
|
|
||||||
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
||||||
* width generates better results at the cost of decoding
|
* width generates better results at the cost of decoding
|
||||||
* time.
|
* time.
|
||||||
@ -88,7 +86,6 @@ enum DeepSpeech_Error_Codes
|
|||||||
*/
|
*/
|
||||||
DEEPSPEECH_EXPORT
|
DEEPSPEECH_EXPORT
|
||||||
int DS_CreateModel(const char* aModelPath,
|
int DS_CreateModel(const char* aModelPath,
|
||||||
const char* aAlphabetConfigPath,
|
|
||||||
unsigned int aBeamWidth,
|
unsigned int aBeamWidth,
|
||||||
ModelState** retval);
|
ModelState** retval);
|
||||||
|
|
||||||
|
|||||||
@ -13,8 +13,7 @@
|
|||||||
* @param aModelPath The path to the frozen model graph.
|
* @param aModelPath The path to the frozen model graph.
|
||||||
* @param aNCep UNUSED, DEPRECATED.
|
* @param aNCep UNUSED, DEPRECATED.
|
||||||
* @param aNContext UNUSED, DEPRECATED.
|
* @param aNContext UNUSED, DEPRECATED.
|
||||||
* @param aAlphabetConfigPath The path to the configuration file specifying
|
* @param aAlphabetConfigPath UNUSED, DEPRECATED.
|
||||||
* the alphabet used by the network. See alphabet.h.
|
|
||||||
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
* @param aBeamWidth The beam width used by the decoder. A larger beam
|
||||||
* width generates better results at the cost of decoding
|
* width generates better results at the cost of decoding
|
||||||
* time.
|
* time.
|
||||||
@ -25,11 +24,11 @@
|
|||||||
int DS_CreateModel(const char* aModelPath,
|
int DS_CreateModel(const char* aModelPath,
|
||||||
unsigned int /*aNCep*/,
|
unsigned int /*aNCep*/,
|
||||||
unsigned int /*aNContext*/,
|
unsigned int /*aNContext*/,
|
||||||
const char* aAlphabetConfigPath,
|
const char* /*aAlphabetConfigPath*/,
|
||||||
unsigned int aBeamWidth,
|
unsigned int aBeamWidth,
|
||||||
ModelState** retval)
|
ModelState** retval)
|
||||||
{
|
{
|
||||||
return DS_CreateModel(aModelPath, aAlphabetConfigPath, aBeamWidth, retval);
|
return DS_CreateModel(aModelPath, aBeamWidth, retval);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -25,12 +25,8 @@ ModelState::~ModelState()
|
|||||||
|
|
||||||
int
|
int
|
||||||
ModelState::init(const char* model_path,
|
ModelState::init(const char* model_path,
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width)
|
unsigned int beam_width)
|
||||||
{
|
{
|
||||||
if (alphabet_.init(alphabet_path)) {
|
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
|
||||||
}
|
|
||||||
beam_width_ = beam_width;
|
beam_width_ = beam_width;
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,9 +30,7 @@ struct ModelState {
|
|||||||
ModelState();
|
ModelState();
|
||||||
virtual ~ModelState();
|
virtual ~ModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path,
|
virtual int init(const char* model_path, unsigned int beam_width);
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width);
|
|
||||||
|
|
||||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer, std::vector<float>& mfcc_output) = 0;
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
#include "tflitemodelstate.h"
|
#include "tflitemodelstate.h"
|
||||||
|
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
#include "workspace_status.h"
|
#include "workspace_status.h"
|
||||||
|
|
||||||
using namespace tflite;
|
using namespace tflite;
|
||||||
@ -91,10 +92,9 @@ TFLiteModelState::~TFLiteModelState()
|
|||||||
|
|
||||||
int
|
int
|
||||||
TFLiteModelState::init(const char* model_path,
|
TFLiteModelState::init(const char* model_path,
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width)
|
unsigned int beam_width)
|
||||||
{
|
{
|
||||||
int err = ModelState::init(model_path, alphabet_path, beam_width);
|
int err = ModelState::init(model_path, beam_width);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
@ -126,17 +126,17 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
mfccs_idx_ = get_output_tensor_by_name("mfccs");
|
||||||
|
|
||||||
int metadata_version_idx = get_output_tensor_by_name("metadata_version");
|
int metadata_version_idx = get_output_tensor_by_name("metadata_version");
|
||||||
// int metadata_language_idx = get_output_tensor_by_name("metadata_language");
|
|
||||||
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
|
int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate");
|
||||||
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len");
|
||||||
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step");
|
||||||
|
int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet");
|
||||||
|
|
||||||
std::vector<int> metadata_exec_plan;
|
std::vector<int> metadata_exec_plan;
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]);
|
||||||
// metadata_exec_plan.push_back(find_parent_node_ids(metadata_language_idx)[0]);
|
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]);
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]);
|
||||||
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]);
|
||||||
|
metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]);
|
||||||
|
|
||||||
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
for (int i = 0; i < metadata_exec_plan.size(); ++i) {
|
||||||
assert(metadata_exec_plan[i] > -1);
|
assert(metadata_exec_plan[i] > -1);
|
||||||
@ -200,6 +200,12 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
|
audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0);
|
||||||
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
|
audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0);
|
||||||
|
|
||||||
|
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
||||||
|
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||||
|
if (err != 0) {
|
||||||
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
|
}
|
||||||
|
|
||||||
assert(sample_rate_ > 0);
|
assert(sample_rate_ > 0);
|
||||||
assert(audio_win_len_ > 0);
|
assert(audio_win_len_ > 0);
|
||||||
assert(audio_win_step_ > 0);
|
assert(audio_win_step_ > 0);
|
||||||
|
|||||||
@ -31,7 +31,6 @@ struct TFLiteModelState : public ModelState
|
|||||||
virtual ~TFLiteModelState();
|
virtual ~TFLiteModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path,
|
virtual int init(const char* model_path,
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width) override;
|
unsigned int beam_width) override;
|
||||||
|
|
||||||
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
virtual void compute_mfcc(const std::vector<float>& audio_buffer,
|
||||||
|
|||||||
@ -25,10 +25,9 @@ TFModelState::~TFModelState()
|
|||||||
|
|
||||||
int
|
int
|
||||||
TFModelState::init(const char* model_path,
|
TFModelState::init(const char* model_path,
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width)
|
unsigned int beam_width)
|
||||||
{
|
{
|
||||||
int err = ModelState::init(model_path, alphabet_path, beam_width);
|
int err = ModelState::init(model_path, beam_width);
|
||||||
if (err != DS_ERR_OK) {
|
if (err != DS_ERR_OK) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
@ -78,20 +77,16 @@ TFModelState::init(const char* model_path,
|
|||||||
return DS_ERR_FAIL_CREATE_SESS;
|
return DS_ERR_FAIL_CREATE_SESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<tensorflow::Tensor> metadata_outputs;
|
std::vector<tensorflow::Tensor> version_output;
|
||||||
status = session_->Run({}, {
|
status = session_->Run({}, {
|
||||||
"metadata_version",
|
"metadata_version"
|
||||||
// "metadata_language",
|
}, {}, &version_output);
|
||||||
"metadata_sample_rate",
|
|
||||||
"metadata_feature_win_len",
|
|
||||||
"metadata_feature_win_step"
|
|
||||||
}, {}, &metadata_outputs);
|
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
std::cerr << "Unable to fetch graph version: " << status << std::endl;
|
||||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
int graph_version = metadata_outputs[0].scalar<int>()();
|
int graph_version = version_output[0].scalar<int>()();
|
||||||
if (graph_version < ds_graph_version()) {
|
if (graph_version < ds_graph_version()) {
|
||||||
std::cerr << "Specified model file version (" << graph_version << ") is "
|
std::cerr << "Specified model file version (" << graph_version << ") is "
|
||||||
<< "incompatible with minimum version supported by this client ("
|
<< "incompatible with minimum version supported by this client ("
|
||||||
@ -101,12 +96,30 @@ TFModelState::init(const char* model_path,
|
|||||||
return DS_ERR_MODEL_INCOMPATIBLE;
|
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
sample_rate_ = metadata_outputs[1].scalar<int>()();
|
std::vector<tensorflow::Tensor> metadata_outputs;
|
||||||
int win_len_ms = metadata_outputs[2].scalar<int>()();
|
status = session_->Run({}, {
|
||||||
int win_step_ms = metadata_outputs[3].scalar<int>()();
|
"metadata_sample_rate",
|
||||||
|
"metadata_feature_win_len",
|
||||||
|
"metadata_feature_win_step",
|
||||||
|
"metadata_alphabet",
|
||||||
|
}, {}, &metadata_outputs);
|
||||||
|
if (!status.ok()) {
|
||||||
|
std::cout << "Unable to fetch metadata: " << status << std::endl;
|
||||||
|
return DS_ERR_MODEL_INCOMPATIBLE;
|
||||||
|
}
|
||||||
|
|
||||||
|
sample_rate_ = metadata_outputs[0].scalar<int>()();
|
||||||
|
int win_len_ms = metadata_outputs[1].scalar<int>()();
|
||||||
|
int win_step_ms = metadata_outputs[2].scalar<int>()();
|
||||||
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0);
|
||||||
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0);
|
||||||
|
|
||||||
|
string serialized_alphabet = metadata_outputs[3].scalar<string>()();
|
||||||
|
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||||
|
if (err != 0) {
|
||||||
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
|
}
|
||||||
|
|
||||||
assert(sample_rate_ > 0);
|
assert(sample_rate_ > 0);
|
||||||
assert(audio_win_len_ > 0);
|
assert(audio_win_len_ > 0);
|
||||||
assert(audio_win_step_ > 0);
|
assert(audio_win_step_ > 0);
|
||||||
|
|||||||
@ -19,7 +19,6 @@ struct TFModelState : public ModelState
|
|||||||
virtual ~TFModelState();
|
virtual ~TFModelState();
|
||||||
|
|
||||||
virtual int init(const char* model_path,
|
virtual int init(const char* model_path,
|
||||||
const char* alphabet_path,
|
|
||||||
unsigned int beam_width) override;
|
unsigned int beam_width) override;
|
||||||
|
|
||||||
virtual void infer(const std::vector<float>& mfcc,
|
virtual void infer(const std::vector<float>& mfcc,
|
||||||
|
|||||||
55
util/text.py
55
util/text.py
@ -1,27 +1,29 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import re
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import re
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from util.flags import FLAGS
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
|
|
||||||
class Alphabet(object):
|
class Alphabet(object):
|
||||||
def __init__(self, config_file):
|
def __init__(self, config_file):
|
||||||
self._config_file = config_file
|
self._config_file = config_file
|
||||||
self._label_to_str = []
|
self._label_to_str = {}
|
||||||
self._str_to_label = {}
|
self._str_to_label = {}
|
||||||
self._size = 0
|
self._size = 0
|
||||||
with codecs.open(config_file, 'r', 'utf-8') as fin:
|
if config_file:
|
||||||
for line in fin:
|
with codecs.open(config_file, 'r', 'utf-8') as fin:
|
||||||
if line[0:2] == '\\#':
|
for line in fin:
|
||||||
line = '#\n'
|
if line[0:2] == '\\#':
|
||||||
elif line[0] == '#':
|
line = '#\n'
|
||||||
continue
|
elif line[0] == '#':
|
||||||
self._label_to_str += line[:-1] # remove the line ending
|
continue
|
||||||
self._str_to_label[line[:-1]] = self._size
|
self._label_to_str[self._size] = line[:-1] # remove the line ending
|
||||||
self._size += 1
|
self._str_to_label[line[:-1]] = self._size
|
||||||
|
self._size += 1
|
||||||
|
|
||||||
def _string_from_label(self, label):
|
def _string_from_label(self, label):
|
||||||
return self._label_to_str[label]
|
return self._label_to_str[label]
|
||||||
@ -51,6 +53,35 @@ class Alphabet(object):
|
|||||||
res += self._string_from_label(label)
|
res += self._string_from_label(label)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
res = bytearray()
|
||||||
|
res += struct.pack('<h', self._size)
|
||||||
|
for key, value in self._label_to_str.items():
|
||||||
|
value = value.encode('utf-8')
|
||||||
|
res += struct.pack('<hh{}s'.format(len(value)), key, len(value), value)
|
||||||
|
return bytes(res)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def deserialize(buf):
|
||||||
|
#pylint: disable=protected-access
|
||||||
|
res = Alphabet(config_file=None)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
def unpack_and_fwd(fmt, buf):
|
||||||
|
nonlocal offset
|
||||||
|
result = struct.unpack_from(fmt, buf, offset)
|
||||||
|
offset += struct.calcsize(fmt)
|
||||||
|
return result
|
||||||
|
|
||||||
|
res.size = unpack_and_fwd('<h', buf)[0]
|
||||||
|
for _ in range(res.size):
|
||||||
|
label, val_len = unpack_and_fwd('<hh', buf)
|
||||||
|
val = unpack_and_fwd('<{}s'.format(val_len), buf)[0].decode('utf-8')
|
||||||
|
res._label_to_str[label] = val
|
||||||
|
res._str_to_label[val] = label
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def size(self):
|
def size(self):
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user