Make Alphabet init fallible and check it in model creation
This commit is contained in:
		
							parent
							
								
									4dabd248bc
								
							
						
					
					
						commit
						b2ef9cca83
					
				@ -15,15 +15,15 @@
 | 
				
			|||||||
 */
 | 
					 */
 | 
				
			||||||
class Alphabet {
 | 
					class Alphabet {
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
  Alphabet() {
 | 
					  Alphabet() = default;
 | 
				
			||||||
  }
 | 
					  Alphabet(const Alphabet&) = default;
 | 
				
			||||||
 | 
					  Alphabet& operator=(const Alphabet&) = default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Alphabet(const char *config_file) {
 | 
					  int init(const char *config_file) {
 | 
				
			||||||
    init(config_file);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  void init(const char *config_file) {
 | 
					 | 
				
			||||||
    std::ifstream in(config_file, std::ios::in);
 | 
					    std::ifstream in(config_file, std::ios::in);
 | 
				
			||||||
 | 
					    if (!in) {
 | 
				
			||||||
 | 
					      return 1;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    unsigned int label = 0;
 | 
					    unsigned int label = 0;
 | 
				
			||||||
    space_label_ = -2;
 | 
					    space_label_ = -2;
 | 
				
			||||||
    for (std::string line; std::getline(in, line);) {
 | 
					    for (std::string line; std::getline(in, line);) {
 | 
				
			||||||
@ -42,6 +42,7 @@ public:
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    size_ = label;
 | 
					    size_ = label;
 | 
				
			||||||
    in.close();
 | 
					    in.close();
 | 
				
			||||||
 | 
					    return 0;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const std::string& StringFromLabel(unsigned int label) const {
 | 
					  const std::string& StringFromLabel(unsigned int label) const {
 | 
				
			||||||
 | 
				
			|||||||
@ -19,11 +19,10 @@ import_array();
 | 
				
			|||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
 | 
					%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
 | 
				
			||||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
 | 
					%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Convert char* to Alphabet
 | 
					// Add overloads converting char* to Alphabet
 | 
				
			||||||
%rename (ctc_beam_search_decoder) mod_decoder;
 | 
					 | 
				
			||||||
%inline %{
 | 
					%inline %{
 | 
				
			||||||
std::vector<Output>
 | 
					std::vector<Output>
 | 
				
			||||||
mod_decoder(const double *probs,
 | 
					ctc_beam_search_decoder(const double *probs,
 | 
				
			||||||
                        int time_dim,
 | 
					                        int time_dim,
 | 
				
			||||||
                        int class_dim,
 | 
					                        int class_dim,
 | 
				
			||||||
                        char* alphabet_config_path,
 | 
					                        char* alphabet_config_path,
 | 
				
			||||||
@ -32,16 +31,16 @@ mod_decoder(const double *probs,
 | 
				
			|||||||
                        size_t cutoff_top_n,
 | 
					                        size_t cutoff_top_n,
 | 
				
			||||||
                        Scorer *ext_scorer)
 | 
					                        Scorer *ext_scorer)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    Alphabet a(alphabet_config_path);
 | 
					    Alphabet a;
 | 
				
			||||||
 | 
					    if (a.init(alphabet_config_path)) {
 | 
				
			||||||
 | 
					      std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
 | 
					    return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
 | 
				
			||||||
                                   cutoff_prob, cutoff_top_n, ext_scorer);
 | 
					                                   cutoff_prob, cutoff_top_n, ext_scorer);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
%}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
%rename (ctc_beam_search_decoder_batch) mod_decoder_batch;
 | 
					 | 
				
			||||||
%inline %{
 | 
					 | 
				
			||||||
std::vector<std::vector<Output>>
 | 
					std::vector<std::vector<Output>>
 | 
				
			||||||
mod_decoder_batch(const double *probs,
 | 
					ctc_beam_search_decoder_batch(const double *probs,
 | 
				
			||||||
                              int batch_dim,
 | 
					                              int batch_dim,
 | 
				
			||||||
                              int time_dim,
 | 
					                              int time_dim,
 | 
				
			||||||
                              int class_dim,
 | 
					                              int class_dim,
 | 
				
			||||||
@ -54,7 +53,10 @@ mod_decoder_batch(const double *probs,
 | 
				
			|||||||
                              size_t cutoff_top_n,
 | 
					                              size_t cutoff_top_n,
 | 
				
			||||||
                              Scorer *ext_scorer)
 | 
					                              Scorer *ext_scorer)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    Alphabet a(alphabet_config_path);
 | 
					    Alphabet a;
 | 
				
			||||||
 | 
					    if (a.init(alphabet_config_path)) {
 | 
				
			||||||
 | 
					      std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
    return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
 | 
					    return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
 | 
				
			||||||
                                         seq_lengths, seq_lengths_size, a, beam_size,
 | 
					                                         seq_lengths, seq_lengths_size, a, beam_size,
 | 
				
			||||||
                                         num_processes, cutoff_prob, cutoff_top_n,
 | 
					                                         num_processes, cutoff_prob, cutoff_top_n,
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,10 @@
 | 
				
			|||||||
using namespace std;
 | 
					using namespace std;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
 | 
					int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
 | 
				
			||||||
  Alphabet alphabet(alphabet_path);
 | 
					  Alphabet alphabet;
 | 
				
			||||||
 | 
					  if (int err = alphabet.init(alphabet_path)) {
 | 
				
			||||||
 | 
					    return err;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet);
 | 
					  Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet);
 | 
				
			||||||
  scorer.save_dictionary(trie_path);
 | 
					  scorer.save_dictionary(trie_path);
 | 
				
			||||||
  return 0;
 | 
					  return 0;
 | 
				
			||||||
 | 
				
			|||||||
@ -32,7 +32,9 @@ ModelState::init(const char* model_path,
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  n_features_ = n_features;
 | 
					  n_features_ = n_features;
 | 
				
			||||||
  n_context_ = n_context;
 | 
					  n_context_ = n_context;
 | 
				
			||||||
  alphabet_.init(alphabet_path);
 | 
					  if (alphabet_.init(alphabet_path)) {
 | 
				
			||||||
 | 
					    return DS_ERR_INVALID_ALPHABET;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  beam_width_ = beam_width;
 | 
					  beam_width_ = beam_width;
 | 
				
			||||||
  return DS_ERR_OK;
 | 
					  return DS_ERR_OK;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -16,7 +16,10 @@ int main(int argc, char** argv)
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
 | 
					  printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Alphabet alphabet(alphabet_path);
 | 
					  Alphabet alphabet;
 | 
				
			||||||
 | 
					  if (int err = alphabet.init(alphabet_path)) {
 | 
				
			||||||
 | 
					    return err;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
  Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet);
 | 
					  Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return 0;
 | 
					  return 0;
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user