Make Alphabet init fallible and check it in model creation

This commit is contained in:
Reuben Morais 2019-08-22 22:08:31 +02:00
parent 4dabd248bc
commit b2ef9cca83
5 changed files with 48 additions and 37 deletions

View File

@ -15,15 +15,15 @@
*/ */
class Alphabet { class Alphabet {
public: public:
Alphabet() { Alphabet() = default;
} Alphabet(const Alphabet&) = default;
Alphabet& operator=(const Alphabet&) = default;
Alphabet(const char *config_file) { int init(const char *config_file) {
init(config_file);
}
void init(const char *config_file) {
std::ifstream in(config_file, std::ios::in); std::ifstream in(config_file, std::ios::in);
if (!in) {
return 1;
}
unsigned int label = 0; unsigned int label = 0;
space_label_ = -2; space_label_ = -2;
for (std::string line; std::getline(in, line);) { for (std::string line; std::getline(in, line);) {
@ -42,6 +42,7 @@ public:
} }
size_ = label; size_ = label;
in.close(); in.close();
return 0;
} }
const std::string& StringFromLabel(unsigned int label) const { const std::string& StringFromLabel(unsigned int label) const {

View File

@ -19,11 +19,10 @@ import_array();
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)}; %apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)}; %apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
// Convert char* to Alphabet // Add overloads converting char* to Alphabet
%rename (ctc_beam_search_decoder) mod_decoder;
%inline %{ %inline %{
std::vector<Output> std::vector<Output>
mod_decoder(const double *probs, ctc_beam_search_decoder(const double *probs,
int time_dim, int time_dim,
int class_dim, int class_dim,
char* alphabet_config_path, char* alphabet_config_path,
@ -32,16 +31,16 @@ mod_decoder(const double *probs,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) Scorer *ext_scorer)
{ {
Alphabet a(alphabet_config_path); Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size, return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
cutoff_prob, cutoff_top_n, ext_scorer); cutoff_prob, cutoff_top_n, ext_scorer);
} }
%}
%rename (ctc_beam_search_decoder_batch) mod_decoder_batch;
%inline %{
std::vector<std::vector<Output>> std::vector<std::vector<Output>>
mod_decoder_batch(const double *probs, ctc_beam_search_decoder_batch(const double *probs,
int batch_dim, int batch_dim,
int time_dim, int time_dim,
int class_dim, int class_dim,
@ -54,7 +53,10 @@ mod_decoder_batch(const double *probs,
size_t cutoff_top_n, size_t cutoff_top_n,
Scorer *ext_scorer) Scorer *ext_scorer)
{ {
Alphabet a(alphabet_config_path); Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim, return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
seq_lengths, seq_lengths_size, a, beam_size, seq_lengths, seq_lengths_size, a, beam_size,
num_processes, cutoff_prob, cutoff_top_n, num_processes, cutoff_prob, cutoff_top_n,

View File

@ -8,7 +8,10 @@
using namespace std; using namespace std;
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) { int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
Alphabet alphabet(alphabet_path); Alphabet alphabet;
if (int err = alphabet.init(alphabet_path)) {
return err;
}
Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet); Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet);
scorer.save_dictionary(trie_path); scorer.save_dictionary(trie_path);
return 0; return 0;

View File

@ -32,7 +32,9 @@ ModelState::init(const char* model_path,
{ {
n_features_ = n_features; n_features_ = n_features;
n_context_ = n_context; n_context_ = n_context;
alphabet_.init(alphabet_path); if (alphabet_.init(alphabet_path)) {
return DS_ERR_INVALID_ALPHABET;
}
beam_width_ = beam_width; beam_width_ = beam_width;
return DS_ERR_OK; return DS_ERR_OK;
} }

View File

@ -16,7 +16,10 @@ int main(int argc, char** argv)
printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path); printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
Alphabet alphabet(alphabet_path); Alphabet alphabet;
if (int err = alphabet.init(alphabet_path)) {
return err;
}
Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet); Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet);
return 0; return 0;