Make Alphabet init fallible and check it in model creation

This commit is contained in:
Reuben Morais 2019-08-22 22:08:31 +02:00
parent 4dabd248bc
commit b2ef9cca83
5 changed files with 48 additions and 37 deletions

View File

@ -15,15 +15,15 @@
*/
class Alphabet {
public:
Alphabet() {
}
Alphabet() = default;
Alphabet(const Alphabet&) = default;
Alphabet& operator=(const Alphabet&) = default;
Alphabet(const char *config_file) {
init(config_file);
}
void init(const char *config_file) {
int init(const char *config_file) {
std::ifstream in(config_file, std::ios::in);
if (!in) {
return 1;
}
unsigned int label = 0;
space_label_ = -2;
for (std::string line; std::getline(in, line);) {
@ -42,6 +42,7 @@ public:
}
size_ = label;
in.close();
return 0;
}
const std::string& StringFromLabel(unsigned int label) const {

View File

@ -19,42 +19,44 @@ import_array();
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
// Convert char* to Alphabet
%rename (ctc_beam_search_decoder) mod_decoder;
// Add overloads converting char* to Alphabet
%inline %{
std::vector<Output>
mod_decoder(const double *probs,
int time_dim,
int class_dim,
char* alphabet_config_path,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
ctc_beam_search_decoder(const double *probs,
int time_dim,
int class_dim,
char* alphabet_config_path,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
{
Alphabet a(alphabet_config_path);
Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
cutoff_prob, cutoff_top_n, ext_scorer);
}
%}
%rename (ctc_beam_search_decoder_batch) mod_decoder_batch;
%inline %{
std::vector<std::vector<Output>>
mod_decoder_batch(const double *probs,
int batch_dim,
int time_dim,
int class_dim,
const int *seq_lengths,
int seq_lengths_size,
char* alphabet_config_path,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
ctc_beam_search_decoder_batch(const double *probs,
int batch_dim,
int time_dim,
int class_dim,
const int *seq_lengths,
int seq_lengths_size,
char* alphabet_config_path,
size_t beam_size,
size_t num_processes,
double cutoff_prob,
size_t cutoff_top_n,
Scorer *ext_scorer)
{
Alphabet a(alphabet_config_path);
Alphabet a;
if (a.init(alphabet_config_path)) {
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
}
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
seq_lengths, seq_lengths_size, a, beam_size,
num_processes, cutoff_prob, cutoff_top_n,

View File

@ -8,7 +8,10 @@
using namespace std;
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
Alphabet alphabet(alphabet_path);
Alphabet alphabet;
if (int err = alphabet.init(alphabet_path)) {
return err;
}
Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet);
scorer.save_dictionary(trie_path);
return 0;

View File

@ -32,7 +32,9 @@ ModelState::init(const char* model_path,
{
n_features_ = n_features;
n_context_ = n_context;
alphabet_.init(alphabet_path);
if (alphabet_.init(alphabet_path)) {
return DS_ERR_INVALID_ALPHABET;
}
beam_width_ = beam_width;
return DS_ERR_OK;
}

View File

@ -16,7 +16,10 @@ int main(int argc, char** argv)
printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
Alphabet alphabet(alphabet_path);
Alphabet alphabet;
if (int err = alphabet.init(alphabet_path)) {
return err;
}
Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet);
return 0;