diff --git a/DeepSpeech.py b/DeepSpeech.py index d5ca4c48..9cc84f8d 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -780,12 +780,11 @@ def export(): graph_version = int(file_relative_read('GRAPH_VERSION').strip()) assert graph_version > 0 - # Reshape with dimension [1] required to avoid this error: - # ERROR: Input array not provided for operation 'reshape'. outputs['metadata_version'] = tf.constant([graph_version], name='metadata_version') outputs['metadata_sample_rate'] = tf.constant([FLAGS.audio_sample_rate], name='metadata_sample_rate') outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len') outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step') + outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet') if FLAGS.export_language: outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('ascii')], name='metadata_language') diff --git a/GRAPH_VERSION b/GRAPH_VERSION index b8626c4c..7ed6ff82 100644 --- a/GRAPH_VERSION +++ b/GRAPH_VERSION @@ -1 +1 @@ -4 +5 diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 8bd5b98c..2d60e426 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -36,7 +36,7 @@ public: if (line == " ") { space_label_ = label; } - label_to_str_.push_back(line); + label_to_str_[label] = line; str_to_label_[line] = label; ++label; } @@ -45,9 +45,53 @@ public: return 0; } + int deserialize(const char* buffer, const int buffer_size) { + int offset = 0; + if (buffer_size - offset < sizeof(int16_t)) { + return 1; + } + int16_t size = *(int16_t*)(buffer + offset); + offset += sizeof(int16_t); + size_ = size; + + for (int i = 0; i < size; ++i) { + if (buffer_size - offset < sizeof(int16_t)) { + return 1; + } + int16_t label = *(int16_t*)(buffer + offset); + offset += sizeof(int16_t); + + if (buffer_size - offset < sizeof(int16_t)) { + return 1; + } + int16_t val_len = *(int16_t*)(buffer + offset); + offset += sizeof(int16_t); + + if (buffer_size - offset < val_len) { + return 1; + } + std::string val(buffer+offset, val_len); + offset += val_len; + + label_to_str_[label] = val; + str_to_label_[val] = label; + + if (val == " ") { + space_label_ = label; + } + } + + return 0; + } + const std::string& StringFromLabel(unsigned int label) const { - assert(label < size_); - return label_to_str_[label]; + auto it = label_to_str_.find(label); + if (it != label_to_str_.end()) { + return it->second; + } else { + std::cerr << "Invalid label " << label << std::endl; + abort(); + } } unsigned int LabelFromString(const std::string& string) const { @@ -55,7 +99,7 @@ public: if (it != str_to_label_.end()) { return it->second; } else { - std::cerr << "Invalid label " << string << std::endl; + std::cerr << "Invalid string " << string << std::endl; abort(); } } @@ -84,7 +128,7 @@ public: private: size_t size_; unsigned int space_label_; - std::vector label_to_str_; + std::unordered_map label_to_str_; std::unordered_map str_to_label_; }; diff --git a/native_client/client.cc b/native_client/client.cc index ec2d7162..981b1bda 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -369,7 +369,7 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; - int status = DS_CreateModel(model, alphabet, beam_width, &ctx); + int status = DS_CreateModel(model, beam_width, &ctx); if (status != 0) { fprintf(stderr, "Could not create model.\n"); return 1; diff --git a/native_client/ctcdecode/__init__.py b/native_client/ctcdecode/__init__.py index c000bbaa..41fb3002 100644 --- a/native_client/ctcdecode/__init__.py +++ b/native_client/ctcdecode/__init__.py @@ -18,9 +18,18 @@ class Scorer(swigwrapper.Scorer): def __init__(self, alpha, beta, model_path, trie_path, alphabet): super(Scorer, self).__init__() - err = self.init(alpha, beta, model_path, trie_path, alphabet.config_file()) + serialized = alphabet.serialize() + native_alphabet = swigwrapper.Alphabet() + err = native_alphabet.deserialize(serialized, len(serialized)) if err != 0: - raise ValueError("Scorer initialization failed with error code {}".format(err), err) + raise ValueError("Error when deserializing alphabet.") + + err = self.init(alpha, beta, + model_path.encode('utf-8'), + trie_path.encode('utf-8'), + native_alphabet) + if err != 0: + raise ValueError("Scorer initialization failed with error code {}".format(err), err) def ctc_beam_search_decoder(probs_seq, @@ -35,8 +44,7 @@ def ctc_beam_search_decoder(probs_seq, step, with each element being a list of normalized probabilities over alphabet and blank. :type probs_seq: 2-D list - :param alphabet: alphabet list. - :alphabet: Alphabet + :param alphabet: Alphabet :param beam_size: Width for beam search. :type beam_size: int :param cutoff_prob: Cutoff probability in pruning, @@ -53,8 +61,13 @@ def ctc_beam_search_decoder(probs_seq, results, in descending order of the confidence. :rtype: list """ + serialized = alphabet.serialize() + native_alphabet = swigwrapper.Alphabet() + err = native_alphabet.deserialize(serialized, len(serialized)) + if err != 0: + raise ValueError("Error when deserializing alphabet.") beam_results = swigwrapper.ctc_beam_search_decoder( - probs_seq, alphabet.config_file(), beam_size, cutoff_prob, cutoff_top_n, + probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n, scorer) beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results] return beam_results @@ -95,9 +108,12 @@ def ctc_beam_search_decoder_batch(probs_seq, results, in descending order of the confidence. :rtype: list """ - batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch( - probs_seq, seq_lengths, alphabet.config_file(), beam_size, num_processes, - cutoff_prob, cutoff_top_n, scorer) + serialized = alphabet.serialize() + native_alphabet = swigwrapper.Alphabet() + err = native_alphabet.deserialize(serialized, len(serialized)) + if err != 0: + raise ValueError("Error when deserializing alphabet.") + batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer) batch_beam_results = [ [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results] for beam_results in batch_beam_results diff --git a/native_client/ctcdecode/swigwrapper.i b/native_client/ctcdecode/swigwrapper.i index 582357a2..b7ac8b11 100644 --- a/native_client/ctcdecode/swigwrapper.i +++ b/native_client/ctcdecode/swigwrapper.i @@ -3,6 +3,7 @@ %{ #include "ctc_beam_search_decoder.h" #define SWIG_FILE_WITH_INIT +#define SWIG_PYTHON_STRICT_BYTE_CHAR %} %include "pyabc.i" @@ -16,57 +17,12 @@ import_array(); // Convert NumPy arrays to pointer+lengths %apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)}; -%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)}; +%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)}; %apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)}; -// Add overloads converting char* to Alphabet -%inline %{ -std::vector -ctc_beam_search_decoder(const double *probs, - int time_dim, - int class_dim, - char* alphabet_config_path, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer) -{ - Alphabet a; - if (a.init(alphabet_config_path)) { - std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n"; - } - return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size, - cutoff_prob, cutoff_top_n, ext_scorer); -} - -std::vector> -ctc_beam_search_decoder_batch(const double *probs, - int batch_dim, - int time_dim, - int class_dim, - const int *seq_lengths, - int seq_lengths_size, - char* alphabet_config_path, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer) -{ - Alphabet a; - if (a.init(alphabet_config_path)) { - std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n"; - } - return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim, - seq_lengths, seq_lengths_size, a, beam_size, - num_processes, cutoff_prob, cutoff_top_n, - ext_scorer); -} -%} - - %ignore Scorer::dictionary; +%include "../alphabet.h" %include "output.h" %include "scorer.h" %include "ctc_beam_search_decoder.h" diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 9aee0f8e..716f2267 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -257,7 +257,6 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) int DS_CreateModel(const char* aModelPath, - const char* aAlphabetConfigPath, unsigned int aBeamWidth, ModelState** retval) { @@ -283,7 +282,7 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_FAIL_CREATE_MODEL; } - int err = model->init(aModelPath, aAlphabetConfigPath, aBeamWidth); + int err = model->init(aModelPath, aBeamWidth); if (err != DS_ERR_OK) { return err; } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index ed9d8638..8dd3b574 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -77,8 +77,6 @@ enum DeepSpeech_Error_Codes * @brief An object providing an interface to a trained DeepSpeech model. * * @param aModelPath The path to the frozen model graph. - * @param aAlphabetConfigPath The path to the configuration file specifying - * the alphabet used by the network. See alphabet.h. * @param aBeamWidth The beam width used by the decoder. A larger beam * width generates better results at the cost of decoding * time. @@ -88,7 +86,6 @@ enum DeepSpeech_Error_Codes */ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, - const char* aAlphabetConfigPath, unsigned int aBeamWidth, ModelState** retval); diff --git a/native_client/deepspeech_compat.h b/native_client/deepspeech_compat.h index 57f3d16c..c83bcbc8 100644 --- a/native_client/deepspeech_compat.h +++ b/native_client/deepspeech_compat.h @@ -13,8 +13,7 @@ * @param aModelPath The path to the frozen model graph. * @param aNCep UNUSED, DEPRECATED. * @param aNContext UNUSED, DEPRECATED. - * @param aAlphabetConfigPath The path to the configuration file specifying - * the alphabet used by the network. See alphabet.h. + * @param aAlphabetConfigPath UNUSED, DEPRECATED. * @param aBeamWidth The beam width used by the decoder. A larger beam * width generates better results at the cost of decoding * time. @@ -25,11 +24,11 @@ int DS_CreateModel(const char* aModelPath, unsigned int /*aNCep*/, unsigned int /*aNContext*/, - const char* aAlphabetConfigPath, + const char* /*aAlphabetConfigPath*/, unsigned int aBeamWidth, ModelState** retval) { - return DS_CreateModel(aModelPath, aAlphabetConfigPath, aBeamWidth, retval); + return DS_CreateModel(aModelPath, aBeamWidth, retval); } /** diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 6c8d11c0..c7fc46a0 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -25,12 +25,8 @@ ModelState::~ModelState() int ModelState::init(const char* model_path, - const char* alphabet_path, unsigned int beam_width) { - if (alphabet_.init(alphabet_path)) { - return DS_ERR_INVALID_ALPHABET; - } beam_width_ = beam_width; return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 3532ab92..ff106a62 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -30,9 +30,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const char* model_path, - const char* alphabet_path, - unsigned int beam_width); + virtual int init(const char* model_path, unsigned int beam_width); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 40457213..b8d491ee 100644 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -1,5 +1,6 @@ #include "tflitemodelstate.h" +#include "tensorflow/lite/string_util.h" #include "workspace_status.h" using namespace tflite; @@ -91,10 +92,9 @@ TFLiteModelState::~TFLiteModelState() int TFLiteModelState::init(const char* model_path, - const char* alphabet_path, unsigned int beam_width) { - int err = ModelState::init(model_path, alphabet_path, beam_width); + int err = ModelState::init(model_path, beam_width); if (err != DS_ERR_OK) { return err; } @@ -126,17 +126,17 @@ TFLiteModelState::init(const char* model_path, mfccs_idx_ = get_output_tensor_by_name("mfccs"); int metadata_version_idx = get_output_tensor_by_name("metadata_version"); - // int metadata_language_idx = get_output_tensor_by_name("metadata_language"); int metadata_sample_rate_idx = get_output_tensor_by_name("metadata_sample_rate"); int metadata_feature_win_len_idx = get_output_tensor_by_name("metadata_feature_win_len"); int metadata_feature_win_step_idx = get_output_tensor_by_name("metadata_feature_win_step"); + int metadata_alphabet_idx = get_output_tensor_by_name("metadata_alphabet"); std::vector metadata_exec_plan; metadata_exec_plan.push_back(find_parent_node_ids(metadata_version_idx)[0]); - // metadata_exec_plan.push_back(find_parent_node_ids(metadata_language_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_sample_rate_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_len_idx)[0]); metadata_exec_plan.push_back(find_parent_node_ids(metadata_feature_win_step_idx)[0]); + metadata_exec_plan.push_back(find_parent_node_ids(metadata_alphabet_idx)[0]); for (int i = 0; i < metadata_exec_plan.size(); ++i) { assert(metadata_exec_plan[i] > -1); @@ -200,6 +200,12 @@ TFLiteModelState::init(const char* model_path, audio_win_len_ = sample_rate_ * (*win_len_ms / 1000.0); audio_win_step_ = sample_rate_ * (*win_step_ms / 1000.0); + tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0); + err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len); + if (err != 0) { + return DS_ERR_INVALID_ALPHABET; + } + assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index cd876895..77137751 100644 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -31,7 +31,6 @@ struct TFLiteModelState : public ModelState virtual ~TFLiteModelState(); virtual int init(const char* model_path, - const char* alphabet_path, unsigned int beam_width) override; virtual void compute_mfcc(const std::vector& audio_buffer, diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index c237e2eb..e9b45755 100644 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -25,10 +25,9 @@ TFModelState::~TFModelState() int TFModelState::init(const char* model_path, - const char* alphabet_path, unsigned int beam_width) { - int err = ModelState::init(model_path, alphabet_path, beam_width); + int err = ModelState::init(model_path, beam_width); if (err != DS_ERR_OK) { return err; } @@ -78,20 +77,16 @@ TFModelState::init(const char* model_path, return DS_ERR_FAIL_CREATE_SESS; } - std::vector metadata_outputs; + std::vector version_output; status = session_->Run({}, { - "metadata_version", - // "metadata_language", - "metadata_sample_rate", - "metadata_feature_win_len", - "metadata_feature_win_step" - }, {}, &metadata_outputs); + "metadata_version" + }, {}, &version_output); if (!status.ok()) { - std::cout << "Unable to fetch metadata: " << status << std::endl; + std::cerr << "Unable to fetch graph version: " << status << std::endl; return DS_ERR_MODEL_INCOMPATIBLE; } - int graph_version = metadata_outputs[0].scalar()(); + int graph_version = version_output[0].scalar()(); if (graph_version < ds_graph_version()) { std::cerr << "Specified model file version (" << graph_version << ") is " << "incompatible with minimum version supported by this client (" @@ -101,12 +96,30 @@ TFModelState::init(const char* model_path, return DS_ERR_MODEL_INCOMPATIBLE; } - sample_rate_ = metadata_outputs[1].scalar()(); - int win_len_ms = metadata_outputs[2].scalar()(); - int win_step_ms = metadata_outputs[3].scalar()(); + std::vector metadata_outputs; + status = session_->Run({}, { + "metadata_sample_rate", + "metadata_feature_win_len", + "metadata_feature_win_step", + "metadata_alphabet", + }, {}, &metadata_outputs); + if (!status.ok()) { + std::cout << "Unable to fetch metadata: " << status << std::endl; + return DS_ERR_MODEL_INCOMPATIBLE; + } + + sample_rate_ = metadata_outputs[0].scalar()(); + int win_len_ms = metadata_outputs[1].scalar()(); + int win_step_ms = metadata_outputs[2].scalar()(); audio_win_len_ = sample_rate_ * (win_len_ms / 1000.0); audio_win_step_ = sample_rate_ * (win_step_ms / 1000.0); + string serialized_alphabet = metadata_outputs[3].scalar()(); + err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size()); + if (err != 0) { + return DS_ERR_INVALID_ALPHABET; + } + assert(sample_rate_ > 0); assert(audio_win_len_ > 0); assert(audio_win_step_ > 0); diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index 39acbc59..c3d9ed0f 100644 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -19,7 +19,6 @@ struct TFModelState : public ModelState virtual ~TFModelState(); virtual int init(const char* model_path, - const char* alphabet_path, unsigned int beam_width) override; virtual void infer(const std::vector& mfcc, diff --git a/util/text.py b/util/text.py index cf07977f..0a846897 100644 --- a/util/text.py +++ b/util/text.py @@ -1,27 +1,29 @@ from __future__ import absolute_import, division, print_function import codecs -import re - import numpy as np +import re +import struct +from util.flags import FLAGS from six.moves import range class Alphabet(object): def __init__(self, config_file): self._config_file = config_file - self._label_to_str = [] + self._label_to_str = {} self._str_to_label = {} self._size = 0 - with codecs.open(config_file, 'r', 'utf-8') as fin: - for line in fin: - if line[0:2] == '\\#': - line = '#\n' - elif line[0] == '#': - continue - self._label_to_str += line[:-1] # remove the line ending - self._str_to_label[line[:-1]] = self._size - self._size += 1 + if config_file: + with codecs.open(config_file, 'r', 'utf-8') as fin: + for line in fin: + if line[0:2] == '\\#': + line = '#\n' + elif line[0] == '#': + continue + self._label_to_str[self._size] = line[:-1] # remove the line ending + self._str_to_label[line[:-1]] = self._size + self._size += 1 def _string_from_label(self, label): return self._label_to_str[label] @@ -51,6 +53,35 @@ class Alphabet(object): res += self._string_from_label(label) return res + def serialize(self): + res = bytearray() + res += struct.pack('