Change decoder API
This commit is contained in:
parent
16d5632d6f
commit
ab08f5ee5a
|
@ -181,21 +181,6 @@ genrule(
|
|||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "generate_trie",
|
||||
srcs = [
|
||||
"alphabet.h",
|
||||
"generate_trie.cpp",
|
||||
],
|
||||
copts = ["-std=c++11"],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
],
|
||||
deps = [":decoder"],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "trie_load",
|
||||
srcs = [
|
||||
|
|
|
@ -24,39 +24,29 @@
|
|||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
static const int32_t MAGIC = 'TRIE';
|
||||
static const int32_t FILE_VERSION = 6;
|
||||
|
||||
int
|
||||
Scorer::init(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
Scorer::init(const std::string& lm_path,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
alphabet_ = alphabet;
|
||||
setup_char_map();
|
||||
load_lm(lm_path, trie_path);
|
||||
load_lm(lm_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int
|
||||
Scorer::init(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::string& trie_path,
|
||||
Scorer::init(const std::string& lm_path,
|
||||
const std::string& alphabet_config_path)
|
||||
{
|
||||
reset_params(alpha, beta);
|
||||
int err = alphabet_.init(alphabet_config_path.c_str());
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
setup_char_map();
|
||||
load_lm(lm_path, trie_path);
|
||||
load_lm(lm_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -82,7 +72,7 @@ void Scorer::setup_char_map()
|
|||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
||||
void Scorer::load_lm(const std::string& lm_path)
|
||||
{
|
||||
// load language model
|
||||
const char* filename = lm_path.c_str();
|
||||
|
|
|
@ -50,16 +50,10 @@ public:
|
|||
Scorer(const Scorer&) = delete;
|
||||
Scorer& operator=(const Scorer&) = delete;
|
||||
|
||||
int init(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::string &trie_path,
|
||||
int init(const std::string &lm_path,
|
||||
const Alphabet &alphabet);
|
||||
|
||||
int init(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::string &trie_path,
|
||||
int init(const std::string &lm_path,
|
||||
const std::string &alphabet_config_path);
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words,
|
||||
|
@ -104,7 +98,7 @@ public:
|
|||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path, const std::string &trie_path);
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// language model weight
|
||||
double alpha = 0.;
|
||||
|
|
|
@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
|
|||
}
|
||||
|
||||
int
|
||||
DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
const char* aLMPath,
|
||||
const char* aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta)
|
||||
DS_EnableExternalScorer(ModelState* aCtx,
|
||||
const char* aScorerPath)
|
||||
{
|
||||
aCtx->scorer_.reset(new Scorer());
|
||||
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta,
|
||||
aLMPath ? aLMPath : "",
|
||||
aTriePath ? aTriePath : "",
|
||||
aCtx->alphabet_);
|
||||
int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
|
||||
if (err != 0) {
|
||||
return DS_ERR_INVALID_LM;
|
||||
}
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
|
||||
int
|
||||
DS_DisableExternalScorer(ModelState* aCtx)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
aCtx->scorer_.reset(nullptr);
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||
float aAlpha,
|
||||
float aBeta)
|
||||
{
|
||||
if (aCtx->scorer_) {
|
||||
aCtx->scorer_->reset_params(aAlpha, aBeta);
|
||||
return DS_ERR_OK;
|
||||
}
|
||||
return DS_ERR_SCORER_NOT_ENABLED;
|
||||
}
|
||||
|
||||
int
|
||||
DS_CreateStream(ModelState* aCtx,
|
||||
StreamingState** retval)
|
||||
|
|
|
@ -61,6 +61,7 @@ enum DeepSpeech_Error_Codes
|
|||
DS_ERR_INVALID_SHAPE = 0x2001,
|
||||
DS_ERR_INVALID_LM = 0x2002,
|
||||
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
|
||||
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
|
||||
|
||||
// Runtime failures
|
||||
DS_ERR_FAIL_INIT_MMAP = 0x3000,
|
||||
|
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
|
|||
void DS_FreeModel(ModelState* ctx);
|
||||
|
||||
/**
|
||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
||||
* @brief Enable decoding using an external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param aLMPath The path to the language model binary file.
|
||||
* @param aTriePath The path to the trie file build from the same vocabu-
|
||||
* lary as the language model binary.
|
||||
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
|
||||
weight.
|
||||
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
|
||||
weight.
|
||||
* @param aScorerPath The path to the external scorer file.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_EnableDecoderWithLM(ModelState* aCtx,
|
||||
const char* aLMPath,
|
||||
const char* aTriePath,
|
||||
float aLMAlpha,
|
||||
float aLMBeta);
|
||||
int DS_EnableExternalScorer(ModelState* aCtx,
|
||||
const char* aScorerPath);
|
||||
|
||||
/**
|
||||
* @brief Disable decoding using an external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_DisableExternalScorer(ModelState* aCtx);
|
||||
|
||||
/**
|
||||
* @brief Set hyperparameters alpha and beta of a KenLM external scorer.
|
||||
*
|
||||
* @param aCtx The ModelState pointer for the model being changed.
|
||||
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
|
||||
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
|
||||
*
|
||||
* @return Zero on success, non-zero on failure.
|
||||
*/
|
||||
DEEPSPEECH_EXPORT
|
||||
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||
float aAlpha,
|
||||
float aBeta);
|
||||
|
||||
/**
|
||||
* @brief Use the DeepSpeech model to perform Speech-To-Text.
|
||||
|
|
|
@ -1,32 +0,0 @@
|
|||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ctcdecode/scorer.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
|
||||
Alphabet alphabet;
|
||||
int err = alphabet.init(alphabet_path);
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
Scorer scorer;
|
||||
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
|
||||
if (err != 0) {
|
||||
return err;
|
||||
}
|
||||
scorer.save_dictionary(trie_path);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
if (argc != 4) {
|
||||
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
return generate_trie(argv[1], argv[2], argv[3]);
|
||||
}
|
|
@ -27,9 +27,9 @@ int main(int argc, char** argv)
|
|||
return err;
|
||||
}
|
||||
Scorer scorer;
|
||||
|
||||
err = scorer.init(kenlm_path, alphabet);
|
||||
#ifndef DEBUG
|
||||
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet);
|
||||
return err;
|
||||
#else
|
||||
// Print some info about the FST
|
||||
using FstType = fst::ConstFst<fst::StdArc>;
|
||||
|
@ -60,7 +60,6 @@ int main(int argc, char** argv)
|
|||
// for (int i = 1; i < 10; ++i) {
|
||||
// print_states_from(i);
|
||||
// }
|
||||
#endif // DEBUG
|
||||
|
||||
return 0;
|
||||
#endif // DEBUG
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue