From 4dabd248bcec90ee4b7e3d4770563c1eaa32e2dd Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 22 Aug 2019 21:17:23 +0200 Subject: [PATCH] Make Alphabet copyable and default-constructable and avoid pointers --- native_client/alphabet.h | 7 +++++++ native_client/ctcdecode/scorer.cpp | 31 ++++++++++++++++++++---------- native_client/ctcdecode/scorer.h | 5 +++++ native_client/deepspeech.cc | 14 +++++++------- native_client/modelstate.cc | 12 ++++-------- native_client/modelstate.h | 4 ++-- native_client/tflitemodelstate.cc | 6 +++--- native_client/tfmodelstate.cc | 6 +++--- 8 files changed, 52 insertions(+), 33 deletions(-) diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 70239acb..9f793c40 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -15,7 +15,14 @@ */ class Alphabet { public: + Alphabet() { + } + Alphabet(const char *config_file) { + init(config_file); + } + + void init(const char *config_file) { std::ifstream in(config_file, std::ios::in); unsigned int label = 0; space_label_ = -2; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index f769a030..a2979b6f 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -30,18 +30,16 @@ static const int32_t MAGIC = 'TRIE'; static const int32_t FILE_VERSION = 4; Scorer::Scorer(double alpha, - double beta, - const std::string& lm_path, - const std::string& trie_path, - const Alphabet& alphabet) - : dictionary() - , language_model_() - , is_character_based_(true) + double beta) + : is_character_based_(true) , max_order_(0) - , alphabet_(alphabet) { reset_params(alpha, beta); +} +void +Scorer::init(const std::string& lm_path, const std::string& trie_path) +{ char_map_.clear(); SPACE_ID_ = alphabet_.GetSpaceLabel(); @@ -60,9 +58,22 @@ Scorer::Scorer(double alpha, double beta, const std::string& lm_path, const std::string& trie_path, - const std::string& alphabet_config_path) - : Scorer(alpha, beta, lm_path, trie_path, Alphabet(alphabet_config_path.c_str())) + const Alphabet& alphabet) + : Scorer(alpha, beta) { + alphabet_ = alphabet; + init(lm_path, trie_path); +} + +Scorer::Scorer(double alpha, + double beta, + const std::string& lm_path, + const std::string& trie_path, + const std::string& alphabet_config_path) + : Scorer(alpha, beta) +{ + alphabet_.init(alphabet_config_path.c_str()); + init(lm_path, trie_path); } Scorer::~Scorer() diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 2e881077..2753d74f 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -87,6 +87,11 @@ public: std::unique_ptr dictionary; protected: + Scorer(double alpha, + double beta); + + void init(const std::string& lm_path, const std::string& trie_path); + // necessary setup: load language model, fill FST's dictionary void setup(const std::string &lm_path, const std::string &trie_path); diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 8ddd2c72..dbac4d25 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -244,7 +244,7 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) previous_state_c_, previous_state_h_); - const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank + const size_t num_classes = model_->alphabet_.GetSize() + 1; // +1 for blank const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes); // Convert logits to double @@ -309,10 +309,10 @@ DS_EnableDecoderWithLM(ModelState* aCtx, float aLMBeta) { try { - aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta, - aLMPath ? aLMPath : "", - aTriePath ? aTriePath : "", - *aCtx->alphabet_); + aCtx->scorer_.reset(new Scorer(aLMAlpha, aLMBeta, + aLMPath ? aLMPath : "", + aTriePath ? aTriePath : "", + aCtx->alphabet_)); return DS_ERR_OK; } catch (...) { return DS_ERR_INVALID_LM; @@ -343,11 +343,11 @@ DS_SetupStream(ModelState* aCtx, const int cutoff_top_n = 40; const double cutoff_prob = 1.0; - ctx->decoder_state_.init(*aCtx->alphabet_, + ctx->decoder_state_.init(aCtx->alphabet_, aCtx->beam_width_, cutoff_prob, cutoff_top_n, - aCtx->scorer_); + aCtx->scorer_.get()); *retval = ctx.release(); return DS_ERR_OK; diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 48c34eee..8f24b776 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -7,9 +7,7 @@ using std::vector; ModelState::ModelState() - : alphabet_(nullptr) - , scorer_(nullptr) - , beam_width_(-1) + : beam_width_(-1) , n_steps_(-1) , n_context_(-1) , n_features_(-1) @@ -23,8 +21,6 @@ ModelState::ModelState() ModelState::~ModelState() { - delete scorer_; - delete alphabet_; } int @@ -36,7 +32,7 @@ ModelState::init(const char* model_path, { n_features_ = n_features; n_context_ = n_context; - alphabet_ = new Alphabet(alphabet_path); + alphabet_.init(alphabet_path); beam_width_ = beam_width; return DS_ERR_OK; } @@ -45,7 +41,7 @@ char* ModelState::decode(const DecoderState& state) { vector out = state.decode(); - return strdup(alphabet_->LabelsToString(out[0].tokens).c_str()); + return strdup(alphabet_.LabelsToString(out[0].tokens).c_str()); } Metadata* @@ -61,7 +57,7 @@ ModelState::decode_metadata(const DecoderState& state) // Loop through each character for (int i = 0; i < out[0].tokens.size(); ++i) { - items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str()); + items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str()); items[i].timestep = out[0].timesteps[i]; items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_); diff --git a/native_client/modelstate.h b/native_client/modelstate.h index cb7c7d34..4b60f838 100644 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -19,8 +19,8 @@ struct ModelState { static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032; static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02; - Alphabet* alphabet_; - Scorer* scorer_; + Alphabet alphabet_; + std::unique_ptr scorer_; unsigned int beam_width_; unsigned int n_steps_; unsigned int n_context_; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 8c61a83f..026b333d 100644 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -151,9 +151,9 @@ TFLiteModelState::init(const char* model_path, TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims; const int final_dim_size = dims_logits->data[1] - 1; - if (final_dim_size != alphabet_->GetSize()) { + if (final_dim_size != alphabet_.GetSize()) { std::cerr << "Error: Alphabet size does not match loaded model: alphabet " - << "has size " << alphabet_->GetSize() + << "has size " << alphabet_.GetSize() << ", but model has " << final_dim_size << " classes in its output. Make sure you're passing an alphabet " << "file with the same size as the one used for training." @@ -208,7 +208,7 @@ TFLiteModelState::infer(const vector& mfcc, vector& state_c_output, vector& state_h_output) { - const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank + const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank // Feeding input_node copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_); diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index fe1cc887..2f26b142 100644 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -108,9 +108,9 @@ TFModelState::init(const char* model_path, } int final_dim_size = logits_shape.vec()(2) - 1; - if (final_dim_size != alphabet_->GetSize()) { + if (final_dim_size != alphabet_.GetSize()) { std::cerr << "Error: Alphabet size does not match loaded model: alphabet " - << "has size " << alphabet_->GetSize() + << "has size " << alphabet_.GetSize() << ", but model has " << final_dim_size << " classes in its output. Make sure you're passing an alphabet " << "file with the same size as the one used for training." @@ -173,7 +173,7 @@ TFModelState::infer(const std::vector& mfcc, vector& state_c_output, vector& state_h_output) { - const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank + const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_})); Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));