From b2ef9cca83b40b18def5d25a091c3828b9774028 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 22 Aug 2019 22:08:31 +0200 Subject: [PATCH] Make Alphabet init fallible and check it in model creation --- native_client/alphabet.h | 15 +++---- native_client/ctcdecode/swigwrapper.i | 56 ++++++++++++++------------- native_client/generate_trie.cpp | 5 ++- native_client/modelstate.cc | 4 +- native_client/trie_load.cc | 5 ++- 5 files changed, 48 insertions(+), 37 deletions(-) diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 9f793c40..8bd5b98c 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -15,15 +15,15 @@ */ class Alphabet { public: - Alphabet() { - } + Alphabet() = default; + Alphabet(const Alphabet&) = default; + Alphabet& operator=(const Alphabet&) = default; - Alphabet(const char *config_file) { - init(config_file); - } - - void init(const char *config_file) { + int init(const char *config_file) { std::ifstream in(config_file, std::ios::in); + if (!in) { + return 1; + } unsigned int label = 0; space_label_ = -2; for (std::string line; std::getline(in, line);) { @@ -42,6 +42,7 @@ public: } size_ = label; in.close(); + return 0; } const std::string& StringFromLabel(unsigned int label) const { diff --git a/native_client/ctcdecode/swigwrapper.i b/native_client/ctcdecode/swigwrapper.i index 9c16e68b..582357a2 100644 --- a/native_client/ctcdecode/swigwrapper.i +++ b/native_client/ctcdecode/swigwrapper.i @@ -19,42 +19,44 @@ import_array(); %apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)}; %apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)}; -// Convert char* to Alphabet -%rename (ctc_beam_search_decoder) mod_decoder; +// Add overloads converting char* to Alphabet %inline %{ std::vector -mod_decoder(const double *probs, - int time_dim, - int class_dim, - char* alphabet_config_path, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer) +ctc_beam_search_decoder(const double *probs, + int time_dim, + int class_dim, + char* alphabet_config_path, + size_t beam_size, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { - Alphabet a(alphabet_config_path); + Alphabet a; + if (a.init(alphabet_config_path)) { + std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n"; + } return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size, cutoff_prob, cutoff_top_n, ext_scorer); } -%} -%rename (ctc_beam_search_decoder_batch) mod_decoder_batch; -%inline %{ std::vector> -mod_decoder_batch(const double *probs, - int batch_dim, - int time_dim, - int class_dim, - const int *seq_lengths, - int seq_lengths_size, - char* alphabet_config_path, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - Scorer *ext_scorer) +ctc_beam_search_decoder_batch(const double *probs, + int batch_dim, + int time_dim, + int class_dim, + const int *seq_lengths, + int seq_lengths_size, + char* alphabet_config_path, + size_t beam_size, + size_t num_processes, + double cutoff_prob, + size_t cutoff_top_n, + Scorer *ext_scorer) { - Alphabet a(alphabet_config_path); + Alphabet a; + if (a.init(alphabet_config_path)) { + std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n"; + } return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim, seq_lengths, seq_lengths_size, a, beam_size, num_processes, cutoff_prob, cutoff_top_n, diff --git a/native_client/generate_trie.cpp b/native_client/generate_trie.cpp index 49944c86..60008804 100644 --- a/native_client/generate_trie.cpp +++ b/native_client/generate_trie.cpp @@ -8,7 +8,10 @@ using namespace std; int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) { - Alphabet alphabet(alphabet_path); + Alphabet alphabet; + if (int err = alphabet.init(alphabet_path)) { + return err; + } Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet); scorer.save_dictionary(trie_path); return 0; diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 8f24b776..bd80367a 100644 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -32,7 +32,9 @@ ModelState::init(const char* model_path, { n_features_ = n_features; n_context_ = n_context; - alphabet_.init(alphabet_path); + if (alphabet_.init(alphabet_path)) { + return DS_ERR_INVALID_ALPHABET; + } beam_width_ = beam_width; return DS_ERR_OK; } diff --git a/native_client/trie_load.cc b/native_client/trie_load.cc index a719a431..82148300 100644 --- a/native_client/trie_load.cc +++ b/native_client/trie_load.cc @@ -16,7 +16,10 @@ int main(int argc, char** argv) printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path); - Alphabet alphabet(alphabet_path); + Alphabet alphabet; + if (int err = alphabet.init(alphabet_path)) { + return err; + } Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet); return 0;