Change decoder API
This commit is contained in:
parent
16d5632d6f
commit
ab08f5ee5a
@ -181,21 +181,6 @@ genrule(
|
|||||||
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "generate_trie",
|
|
||||||
srcs = [
|
|
||||||
"alphabet.h",
|
|
||||||
"generate_trie.cpp",
|
|
||||||
],
|
|
||||||
copts = ["-std=c++11"],
|
|
||||||
linkopts = [
|
|
||||||
"-lm",
|
|
||||||
"-ldl",
|
|
||||||
"-pthread",
|
|
||||||
],
|
|
||||||
deps = [":decoder"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "trie_load",
|
name = "trie_load",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -24,39 +24,29 @@
|
|||||||
|
|
||||||
#include "decoder_utils.h"
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
using namespace lm::ngram;
|
|
||||||
|
|
||||||
static const int32_t MAGIC = 'TRIE';
|
static const int32_t MAGIC = 'TRIE';
|
||||||
static const int32_t FILE_VERSION = 6;
|
static const int32_t FILE_VERSION = 6;
|
||||||
|
|
||||||
int
|
int
|
||||||
Scorer::init(double alpha,
|
Scorer::init(const std::string& lm_path,
|
||||||
double beta,
|
|
||||||
const std::string& lm_path,
|
|
||||||
const std::string& trie_path,
|
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
reset_params(alpha, beta);
|
|
||||||
alphabet_ = alphabet;
|
alphabet_ = alphabet;
|
||||||
setup_char_map();
|
setup_char_map();
|
||||||
load_lm(lm_path, trie_path);
|
load_lm(lm_path);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
Scorer::init(double alpha,
|
Scorer::init(const std::string& lm_path,
|
||||||
double beta,
|
|
||||||
const std::string& lm_path,
|
|
||||||
const std::string& trie_path,
|
|
||||||
const std::string& alphabet_config_path)
|
const std::string& alphabet_config_path)
|
||||||
{
|
{
|
||||||
reset_params(alpha, beta);
|
|
||||||
int err = alphabet_.init(alphabet_config_path.c_str());
|
int err = alphabet_.init(alphabet_config_path.c_str());
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
setup_char_map();
|
setup_char_map();
|
||||||
load_lm(lm_path, trie_path);
|
load_lm(lm_path);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,7 +72,7 @@ void Scorer::setup_char_map()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
|
void Scorer::load_lm(const std::string& lm_path)
|
||||||
{
|
{
|
||||||
// load language model
|
// load language model
|
||||||
const char* filename = lm_path.c_str();
|
const char* filename = lm_path.c_str();
|
||||||
|
@ -50,16 +50,10 @@ public:
|
|||||||
Scorer(const Scorer&) = delete;
|
Scorer(const Scorer&) = delete;
|
||||||
Scorer& operator=(const Scorer&) = delete;
|
Scorer& operator=(const Scorer&) = delete;
|
||||||
|
|
||||||
int init(double alpha,
|
int init(const std::string &lm_path,
|
||||||
double beta,
|
|
||||||
const std::string &lm_path,
|
|
||||||
const std::string &trie_path,
|
|
||||||
const Alphabet &alphabet);
|
const Alphabet &alphabet);
|
||||||
|
|
||||||
int init(double alpha,
|
int init(const std::string &lm_path,
|
||||||
double beta,
|
|
||||||
const std::string &lm_path,
|
|
||||||
const std::string &trie_path,
|
|
||||||
const std::string &alphabet_config_path);
|
const std::string &alphabet_config_path);
|
||||||
|
|
||||||
double get_log_cond_prob(const std::vector<std::string> &words,
|
double get_log_cond_prob(const std::vector<std::string> &words,
|
||||||
@ -104,7 +98,7 @@ public:
|
|||||||
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
void fill_dictionary(const std::vector<std::string> &vocabulary);
|
||||||
|
|
||||||
// load language model from given path
|
// load language model from given path
|
||||||
void load_lm(const std::string &lm_path, const std::string &trie_path);
|
void load_lm(const std::string &lm_path);
|
||||||
|
|
||||||
// language model weight
|
// language model weight
|
||||||
double alpha = 0.;
|
double alpha = 0.;
|
||||||
|
@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
|
|||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
DS_EnableDecoderWithLM(ModelState* aCtx,
|
DS_EnableExternalScorer(ModelState* aCtx,
|
||||||
const char* aLMPath,
|
const char* aScorerPath)
|
||||||
const char* aTriePath,
|
|
||||||
float aLMAlpha,
|
|
||||||
float aLMBeta)
|
|
||||||
{
|
{
|
||||||
aCtx->scorer_.reset(new Scorer());
|
aCtx->scorer_.reset(new Scorer());
|
||||||
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta,
|
int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
|
||||||
aLMPath ? aLMPath : "",
|
|
||||||
aTriePath ? aTriePath : "",
|
|
||||||
aCtx->alphabet_);
|
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return DS_ERR_INVALID_LM;
|
return DS_ERR_INVALID_LM;
|
||||||
}
|
}
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
DS_DisableExternalScorer(ModelState* aCtx)
|
||||||
|
{
|
||||||
|
if (aCtx->scorer_) {
|
||||||
|
aCtx->scorer_.reset(nullptr);
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
return DS_ERR_SCORER_NOT_ENABLED;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||||
|
float aAlpha,
|
||||||
|
float aBeta)
|
||||||
|
{
|
||||||
|
if (aCtx->scorer_) {
|
||||||
|
aCtx->scorer_->reset_params(aAlpha, aBeta);
|
||||||
|
return DS_ERR_OK;
|
||||||
|
}
|
||||||
|
return DS_ERR_SCORER_NOT_ENABLED;
|
||||||
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
DS_CreateStream(ModelState* aCtx,
|
DS_CreateStream(ModelState* aCtx,
|
||||||
StreamingState** retval)
|
StreamingState** retval)
|
||||||
|
@ -61,6 +61,7 @@ enum DeepSpeech_Error_Codes
|
|||||||
DS_ERR_INVALID_SHAPE = 0x2001,
|
DS_ERR_INVALID_SHAPE = 0x2001,
|
||||||
DS_ERR_INVALID_LM = 0x2002,
|
DS_ERR_INVALID_LM = 0x2002,
|
||||||
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
|
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
|
||||||
|
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
|
||||||
|
|
||||||
// Runtime failures
|
// Runtime failures
|
||||||
DS_ERR_FAIL_INIT_MMAP = 0x3000,
|
DS_ERR_FAIL_INIT_MMAP = 0x3000,
|
||||||
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
|
|||||||
void DS_FreeModel(ModelState* ctx);
|
void DS_FreeModel(ModelState* ctx);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Enable decoding using beam scoring with a KenLM language model.
|
* @brief Enable decoding using an external scorer.
|
||||||
*
|
*
|
||||||
* @param aCtx The ModelState pointer for the model being changed.
|
* @param aCtx The ModelState pointer for the model being changed.
|
||||||
* @param aLMPath The path to the language model binary file.
|
* @param aScorerPath The path to the external scorer file.
|
||||||
* @param aTriePath The path to the trie file build from the same vocabu-
|
|
||||||
* lary as the language model binary.
|
|
||||||
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
|
|
||||||
weight.
|
|
||||||
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
|
|
||||||
weight.
|
|
||||||
*
|
*
|
||||||
* @return Zero on success, non-zero on failure (invalid arguments).
|
* @return Zero on success, non-zero on failure (invalid arguments).
|
||||||
*/
|
*/
|
||||||
DEEPSPEECH_EXPORT
|
DEEPSPEECH_EXPORT
|
||||||
int DS_EnableDecoderWithLM(ModelState* aCtx,
|
int DS_EnableExternalScorer(ModelState* aCtx,
|
||||||
const char* aLMPath,
|
const char* aScorerPath);
|
||||||
const char* aTriePath,
|
|
||||||
float aLMAlpha,
|
/**
|
||||||
float aLMBeta);
|
* @brief Disable decoding using an external scorer.
|
||||||
|
*
|
||||||
|
* @param aCtx The ModelState pointer for the model being changed.
|
||||||
|
*
|
||||||
|
* @return Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
int DS_DisableExternalScorer(ModelState* aCtx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set hyperparameters alpha and beta of a KenLM external scorer.
|
||||||
|
*
|
||||||
|
* @param aCtx The ModelState pointer for the model being changed.
|
||||||
|
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
|
||||||
|
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
|
||||||
|
*
|
||||||
|
* @return Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
DEEPSPEECH_EXPORT
|
||||||
|
int DS_SetScorerAlphaBeta(ModelState* aCtx,
|
||||||
|
float aAlpha,
|
||||||
|
float aBeta);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Use the DeepSpeech model to perform Speech-To-Text.
|
* @brief Use the DeepSpeech model to perform Speech-To-Text.
|
||||||
|
@ -1,32 +0,0 @@
|
|||||||
#include <algorithm>
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "ctcdecode/scorer.h"
|
|
||||||
#include "alphabet.h"
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
|
|
||||||
Alphabet alphabet;
|
|
||||||
int err = alphabet.init(alphabet_path);
|
|
||||||
if (err != 0) {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
Scorer scorer;
|
|
||||||
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
|
|
||||||
if (err != 0) {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
scorer.save_dictionary(trie_path);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
if (argc != 4) {
|
|
||||||
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
return generate_trie(argv[1], argv[2], argv[3]);
|
|
||||||
}
|
|
@ -27,9 +27,9 @@ int main(int argc, char** argv)
|
|||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
Scorer scorer;
|
Scorer scorer;
|
||||||
|
err = scorer.init(kenlm_path, alphabet);
|
||||||
#ifndef DEBUG
|
#ifndef DEBUG
|
||||||
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet);
|
return err;
|
||||||
#else
|
#else
|
||||||
// Print some info about the FST
|
// Print some info about the FST
|
||||||
using FstType = fst::ConstFst<fst::StdArc>;
|
using FstType = fst::ConstFst<fst::StdArc>;
|
||||||
@ -60,7 +60,6 @@ int main(int argc, char** argv)
|
|||||||
// for (int i = 1; i < 10; ++i) {
|
// for (int i = 1; i < 10; ++i) {
|
||||||
// print_states_from(i);
|
// print_states_from(i);
|
||||||
// }
|
// }
|
||||||
#endif // DEBUG
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
#endif // DEBUG
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user