Make Alphabet copyable and default-constructable and avoid pointers
This commit is contained in:
parent
4d882a8aec
commit
4dabd248bc
@ -15,7 +15,14 @@
|
||||
*/
|
||||
class Alphabet {
|
||||
public:
|
||||
Alphabet() {
|
||||
}
|
||||
|
||||
Alphabet(const char *config_file) {
|
||||
init(config_file);
|
||||
}
|
||||
|
||||
void init(const char *config_file) {
|
||||
std::ifstream in(config_file, std::ios::in);
|
||||
unsigned int label = 0;
|
||||
space_label_ = -2;
|
||||
|
||||
@ -30,18 +30,16 @@ static const int32_t MAGIC = 'TRIE';
|
||||
static const int32_t FILE_VERSION = 4;
|
||||
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
const Alphabet& alphabet)
|
||||
: dictionary()
|
||||
, language_model_()
|
||||
, is_character_based_(true)
|
||||
double beta)
|
||||
: is_character_based_(true)
|
||||
, max_order_(0)
|
||||
, alphabet_(alphabet)
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
}
|
||||
|
||||
void
|
||||
Scorer::init(const std::string& lm_path, const std::string& trie_path)
|
||||
{
|
||||
char_map_.clear();
|
||||
|
||||
SPACE_ID_ = alphabet_.GetSpaceLabel();
|
||||
@ -60,9 +58,22 @@ Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
const std::string& alphabet_config_path)
|
||||
: Scorer(alpha, beta, lm_path, trie_path, Alphabet(alphabet_config_path.c_str()))
|
||||
const Alphabet& alphabet)
|
||||
: Scorer(alpha, beta)
|
||||
{
|
||||
alphabet_ = alphabet;
|
||||
init(lm_path, trie_path);
|
||||
}
|
||||
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
const std::string& alphabet_config_path)
|
||||
: Scorer(alpha, beta)
|
||||
{
|
||||
alphabet_.init(alphabet_config_path.c_str());
|
||||
init(lm_path, trie_path);
|
||||
}
|
||||
|
||||
Scorer::~Scorer()
|
||||
|
||||
@ -87,6 +87,11 @@ public:
|
||||
std::unique_ptr<FstType> dictionary;
|
||||
|
||||
protected:
|
||||
Scorer(double alpha,
|
||||
double beta);
|
||||
|
||||
void init(const std::string& lm_path, const std::string& trie_path);
|
||||
|
||||
// necessary setup: load language model, fill FST's dictionary
|
||||
void setup(const std::string &lm_path, const std::string &trie_path);
|
||||
|
||||
|
||||
@ -244,7 +244,7 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
||||
previous_state_c_,
|
||||
previous_state_h_);
|
||||
|
||||
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
||||
const size_t num_classes = model_->alphabet_.GetSize() + 1; // +1 for blank
|
||||
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||
|
||||
// Convert logits to double
|
||||
@ -309,10 +309,10 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
float aLMBeta)
|
||||
{
|
||||
try {
|
||||
aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
*aCtx->alphabet_);
|
||||
aCtx->scorer_.reset(new Scorer(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
aCtx->alphabet_));
|
||||
return DS_ERR_OK;
|
||||
} catch (...) {
|
||||
return DS_ERR_INVALID_LM;
|
||||
@ -343,11 +343,11 @@ DS_SetupStream(ModelState* aCtx,
|
||||
const int cutoff_top_n = 40;
|
||||
const double cutoff_prob = 1.0;
|
||||
|
||||
ctx->decoder_state_.init(*aCtx->alphabet_,
|
||||
ctx->decoder_state_.init(aCtx->alphabet_,
|
||||
aCtx->beam_width_,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
aCtx->scorer_);
|
||||
aCtx->scorer_.get());
|
||||
|
||||
*retval = ctx.release();
|
||||
return DS_ERR_OK;
|
||||
|
||||
@ -7,9 +7,7 @@
|
||||
using std::vector;
|
||||
|
||||
ModelState::ModelState()
|
||||
: alphabet_(nullptr)
|
||||
, scorer_(nullptr)
|
||||
, beam_width_(-1)
|
||||
: beam_width_(-1)
|
||||
, n_steps_(-1)
|
||||
, n_context_(-1)
|
||||
, n_features_(-1)
|
||||
@ -23,8 +21,6 @@ ModelState::ModelState()
|
||||
|
||||
ModelState::~ModelState()
|
||||
{
|
||||
delete scorer_;
|
||||
delete alphabet_;
|
||||
}
|
||||
|
||||
int
|
||||
@ -36,7 +32,7 @@ ModelState::init(const char* model_path,
|
||||
{
|
||||
n_features_ = n_features;
|
||||
n_context_ = n_context;
|
||||
alphabet_ = new Alphabet(alphabet_path);
|
||||
alphabet_.init(alphabet_path);
|
||||
beam_width_ = beam_width;
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
@ -45,7 +41,7 @@ char*
|
||||
ModelState::decode(const DecoderState& state)
|
||||
{
|
||||
vector<Output> out = state.decode();
|
||||
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
@ -61,7 +57,7 @@ ModelState::decode_metadata(const DecoderState& state)
|
||||
|
||||
// Loop through each character
|
||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||
items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str());
|
||||
items[i].timestep = out[0].timesteps[i];
|
||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@ struct ModelState {
|
||||
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||
|
||||
Alphabet* alphabet_;
|
||||
Scorer* scorer_;
|
||||
Alphabet alphabet_;
|
||||
std::unique_ptr<Scorer> scorer_;
|
||||
unsigned int beam_width_;
|
||||
unsigned int n_steps_;
|
||||
unsigned int n_context_;
|
||||
|
||||
@ -151,9 +151,9 @@ TFLiteModelState::init(const char* model_path,
|
||||
|
||||
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
||||
const int final_dim_size = dims_logits->data[1] - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
if (final_dim_size != alphabet_.GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< "has size " << alphabet_.GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
@ -208,7 +208,7 @@ TFLiteModelState::infer(const vector<float>& mfcc,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||
|
||||
// Feeding input_node
|
||||
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
||||
|
||||
@ -108,9 +108,9 @@ TFModelState::init(const char* model_path,
|
||||
}
|
||||
|
||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||
if (final_dim_size != alphabet_->GetSize()) {
|
||||
if (final_dim_size != alphabet_.GetSize()) {
|
||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||
<< "has size " << alphabet_->GetSize()
|
||||
<< "has size " << alphabet_.GetSize()
|
||||
<< ", but model has " << final_dim_size
|
||||
<< " classes in its output. Make sure you're passing an alphabet "
|
||||
<< "file with the same size as the one used for training."
|
||||
@ -173,7 +173,7 @@ TFModelState::infer(const std::vector<float>& mfcc,
|
||||
vector<float>& state_c_output,
|
||||
vector<float>& state_h_output)
|
||||
{
|
||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
||||
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||
|
||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user