Change decoder API

This commit is contained in:
Reuben Morais 2020-01-20 08:28:24 +01:00
parent 16d5632d6f
commit ab08f5ee5a
7 changed files with 64 additions and 97 deletions

View File

@ -181,21 +181,6 @@ genrule(
cmd = "dsymutil $(location :libdeepspeech.so) -o $@" cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
) )
cc_binary(
name = "generate_trie",
srcs = [
"alphabet.h",
"generate_trie.cpp",
],
copts = ["-std=c++11"],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
deps = [":decoder"],
)
cc_binary( cc_binary(
name = "trie_load", name = "trie_load",
srcs = [ srcs = [

View File

@ -24,39 +24,29 @@
#include "decoder_utils.h" #include "decoder_utils.h"
using namespace lm::ngram;
static const int32_t MAGIC = 'TRIE'; static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 6; static const int32_t FILE_VERSION = 6;
int int
Scorer::init(double alpha, Scorer::init(const std::string& lm_path,
double beta,
const std::string& lm_path,
const std::string& trie_path,
const Alphabet& alphabet) const Alphabet& alphabet)
{ {
reset_params(alpha, beta);
alphabet_ = alphabet; alphabet_ = alphabet;
setup_char_map(); setup_char_map();
load_lm(lm_path, trie_path); load_lm(lm_path);
return 0; return 0;
} }
int int
Scorer::init(double alpha, Scorer::init(const std::string& lm_path,
double beta,
const std::string& lm_path,
const std::string& trie_path,
const std::string& alphabet_config_path) const std::string& alphabet_config_path)
{ {
reset_params(alpha, beta);
int err = alphabet_.init(alphabet_config_path.c_str()); int err = alphabet_.init(alphabet_config_path.c_str());
if (err != 0) { if (err != 0) {
return err; return err;
} }
setup_char_map(); setup_char_map();
load_lm(lm_path, trie_path); load_lm(lm_path);
return 0; return 0;
} }
@ -82,7 +72,7 @@ void Scorer::setup_char_map()
} }
} }
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path) void Scorer::load_lm(const std::string& lm_path)
{ {
// load language model // load language model
const char* filename = lm_path.c_str(); const char* filename = lm_path.c_str();

View File

@ -50,16 +50,10 @@ public:
Scorer(const Scorer&) = delete; Scorer(const Scorer&) = delete;
Scorer& operator=(const Scorer&) = delete; Scorer& operator=(const Scorer&) = delete;
int init(double alpha, int init(const std::string &lm_path,
double beta,
const std::string &lm_path,
const std::string &trie_path,
const Alphabet &alphabet); const Alphabet &alphabet);
int init(double alpha, int init(const std::string &lm_path,
double beta,
const std::string &lm_path,
const std::string &trie_path,
const std::string &alphabet_config_path); const std::string &alphabet_config_path);
double get_log_cond_prob(const std::vector<std::string> &words, double get_log_cond_prob(const std::vector<std::string> &words,
@ -104,7 +98,7 @@ public:
void fill_dictionary(const std::vector<std::string> &vocabulary); void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path // load language model from given path
void load_lm(const std::string &lm_path, const std::string &trie_path); void load_lm(const std::string &lm_path);
// language model weight // language model weight
double alpha = 0.; double alpha = 0.;

View File

@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
} }
int int
DS_EnableDecoderWithLM(ModelState* aCtx, DS_EnableExternalScorer(ModelState* aCtx,
const char* aLMPath, const char* aScorerPath)
const char* aTriePath,
float aLMAlpha,
float aLMBeta)
{ {
aCtx->scorer_.reset(new Scorer()); aCtx->scorer_.reset(new Scorer());
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta, int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
aLMPath ? aLMPath : "",
aTriePath ? aTriePath : "",
aCtx->alphabet_);
if (err != 0) { if (err != 0) {
return DS_ERR_INVALID_LM; return DS_ERR_INVALID_LM;
} }
return DS_ERR_OK; return DS_ERR_OK;
} }
int
DS_DisableExternalScorer(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->scorer_.reset(nullptr);
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta)
{
if (aCtx->scorer_) {
aCtx->scorer_->reset_params(aAlpha, aBeta);
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int int
DS_CreateStream(ModelState* aCtx, DS_CreateStream(ModelState* aCtx,
StreamingState** retval) StreamingState** retval)

View File

@ -61,6 +61,7 @@ enum DeepSpeech_Error_Codes
DS_ERR_INVALID_SHAPE = 0x2001, DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002, DS_ERR_INVALID_LM = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003, DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
// Runtime failures // Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000, DS_ERR_FAIL_INIT_MMAP = 0x3000,
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
void DS_FreeModel(ModelState* ctx); void DS_FreeModel(ModelState* ctx);
/** /**
* @brief Enable decoding using beam scoring with a KenLM language model. * @brief Enable decoding using an external scorer.
* *
* @param aCtx The ModelState pointer for the model being changed. * @param aCtx The ModelState pointer for the model being changed.
* @param aLMPath The path to the language model binary file. * @param aScorerPath The path to the external scorer file.
* @param aTriePath The path to the trie file build from the same vocabu-
* lary as the language model binary.
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
weight.
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
weight.
* *
* @return Zero on success, non-zero on failure (invalid arguments). * @return Zero on success, non-zero on failure (invalid arguments).
*/ */
DEEPSPEECH_EXPORT DEEPSPEECH_EXPORT
int DS_EnableDecoderWithLM(ModelState* aCtx, int DS_EnableExternalScorer(ModelState* aCtx,
const char* aLMPath, const char* aScorerPath);
const char* aTriePath,
float aLMAlpha, /**
float aLMBeta); * @brief Disable decoding using an external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_DisableExternalScorer(ModelState* aCtx);
/**
* @brief Set hyperparameters alpha and beta of a KenLM external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta);
/** /**
* @brief Use the DeepSpeech model to perform Speech-To-Text. * @brief Use the DeepSpeech model to perform Speech-To-Text.

View File

@ -1,32 +0,0 @@
#include <algorithm>
#include <iostream>
#include <string>
#include "ctcdecode/scorer.h"
#include "alphabet.h"
using namespace std;
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
Alphabet alphabet;
int err = alphabet.init(alphabet_path);
if (err != 0) {
return err;
}
Scorer scorer;
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
if (err != 0) {
return err;
}
scorer.save_dictionary(trie_path);
return 0;
}
int main(int argc, char** argv) {
if (argc != 4) {
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
return -1;
}
return generate_trie(argv[1], argv[2], argv[3]);
}

View File

@ -27,9 +27,9 @@ int main(int argc, char** argv)
return err; return err;
} }
Scorer scorer; Scorer scorer;
err = scorer.init(kenlm_path, alphabet);
#ifndef DEBUG #ifndef DEBUG
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet); return err;
#else #else
// Print some info about the FST // Print some info about the FST
using FstType = fst::ConstFst<fst::StdArc>; using FstType = fst::ConstFst<fst::StdArc>;
@ -60,7 +60,6 @@ int main(int argc, char** argv)
// for (int i = 1; i < 10; ++i) { // for (int i = 1; i < 10; ++i) {
// print_states_from(i); // print_states_from(i);
// } // }
#endif // DEBUG
return 0; return 0;
#endif // DEBUG
} }