Change decoder API

This commit is contained in:
Reuben Morais 2020-01-20 08:28:24 +01:00
parent 16d5632d6f
commit ab08f5ee5a
7 changed files with 64 additions and 97 deletions

View File

@ -181,21 +181,6 @@ genrule(
cmd = "dsymutil $(location :libdeepspeech.so) -o $@"
)
cc_binary(
name = "generate_trie",
srcs = [
"alphabet.h",
"generate_trie.cpp",
],
copts = ["-std=c++11"],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
deps = [":decoder"],
)
cc_binary(
name = "trie_load",
srcs = [

View File

@ -24,39 +24,29 @@
#include "decoder_utils.h"
using namespace lm::ngram;
static const int32_t MAGIC = 'TRIE';
static const int32_t FILE_VERSION = 6;
int
Scorer::init(double alpha,
double beta,
const std::string& lm_path,
const std::string& trie_path,
Scorer::init(const std::string& lm_path,
const Alphabet& alphabet)
{
reset_params(alpha, beta);
alphabet_ = alphabet;
setup_char_map();
load_lm(lm_path, trie_path);
load_lm(lm_path);
return 0;
}
int
Scorer::init(double alpha,
double beta,
const std::string& lm_path,
const std::string& trie_path,
Scorer::init(const std::string& lm_path,
const std::string& alphabet_config_path)
{
reset_params(alpha, beta);
int err = alphabet_.init(alphabet_config_path.c_str());
if (err != 0) {
return err;
}
setup_char_map();
load_lm(lm_path, trie_path);
load_lm(lm_path);
return 0;
}
@ -82,7 +72,7 @@ void Scorer::setup_char_map()
}
}
void Scorer::load_lm(const std::string& lm_path, const std::string& trie_path)
void Scorer::load_lm(const std::string& lm_path)
{
// load language model
const char* filename = lm_path.c_str();

View File

@ -50,16 +50,10 @@ public:
Scorer(const Scorer&) = delete;
Scorer& operator=(const Scorer&) = delete;
int init(double alpha,
double beta,
const std::string &lm_path,
const std::string &trie_path,
int init(const std::string &lm_path,
const Alphabet &alphabet);
int init(double alpha,
double beta,
const std::string &lm_path,
const std::string &trie_path,
int init(const std::string &lm_path,
const std::string &alphabet_config_path);
double get_log_cond_prob(const std::vector<std::string> &words,
@ -104,7 +98,7 @@ public:
void fill_dictionary(const std::vector<std::string> &vocabulary);
// load language model from given path
void load_lm(const std::string &lm_path, const std::string &trie_path);
void load_lm(const std::string &lm_path);
// language model weight
double alpha = 0.;

View File

@ -304,23 +304,38 @@ DS_FreeModel(ModelState* ctx)
}
int
DS_EnableDecoderWithLM(ModelState* aCtx,
const char* aLMPath,
const char* aTriePath,
float aLMAlpha,
float aLMBeta)
DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath)
{
aCtx->scorer_.reset(new Scorer());
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta,
aLMPath ? aLMPath : "",
aTriePath ? aTriePath : "",
aCtx->alphabet_);
int err = aCtx->scorer_->init(aScorerPath, aCtx->alphabet_);
if (err != 0) {
return DS_ERR_INVALID_LM;
}
return DS_ERR_OK;
}
int
DS_DisableExternalScorer(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->scorer_.reset(nullptr);
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta)
{
if (aCtx->scorer_) {
aCtx->scorer_->reset_params(aAlpha, aBeta);
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}
int
DS_CreateStream(ModelState* aCtx,
StreamingState** retval)

View File

@ -61,6 +61,7 @@ enum DeepSpeech_Error_Codes
DS_ERR_INVALID_SHAPE = 0x2001,
DS_ERR_INVALID_LM = 0x2002,
DS_ERR_MODEL_INCOMPATIBLE = 0x2003,
DS_ERR_SCORER_NOT_ENABLED = 0x2004,
// Runtime failures
DS_ERR_FAIL_INIT_MMAP = 0x3000,
@ -106,25 +107,40 @@ DEEPSPEECH_EXPORT
void DS_FreeModel(ModelState* ctx);
/**
* @brief Enable decoding using beam scoring with a KenLM language model.
* @brief Enable decoding using an external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param aLMPath The path to the language model binary file.
* @param aTriePath The path to the trie file build from the same vocabu-
* lary as the language model binary.
* @param aLMAlpha The alpha hyperparameter of the CTC decoder. Language Model
weight.
* @param aLMBeta The beta hyperparameter of the CTC decoder. Word insertion
weight.
* @param aScorerPath The path to the external scorer file.
*
* @return Zero on success, non-zero on failure (invalid arguments).
*/
DEEPSPEECH_EXPORT
int DS_EnableDecoderWithLM(ModelState* aCtx,
const char* aLMPath,
const char* aTriePath,
float aLMAlpha,
float aLMBeta);
int DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath);
/**
* @brief Disable decoding using an external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_DisableExternalScorer(ModelState* aCtx);
/**
* @brief Set hyperparameters alpha and beta of a KenLM external scorer.
*
* @param aCtx The ModelState pointer for the model being changed.
* @param aAlpha The alpha hyperparameter of the decoder. Language model weight.
* @param aLMBeta The beta hyperparameter of the decoder. Word insertion weight.
*
* @return Zero on success, non-zero on failure.
*/
DEEPSPEECH_EXPORT
int DS_SetScorerAlphaBeta(ModelState* aCtx,
float aAlpha,
float aBeta);
/**
* @brief Use the DeepSpeech model to perform Speech-To-Text.

View File

@ -1,32 +0,0 @@
#include <algorithm>
#include <iostream>
#include <string>
#include "ctcdecode/scorer.h"
#include "alphabet.h"
using namespace std;
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
Alphabet alphabet;
int err = alphabet.init(alphabet_path);
if (err != 0) {
return err;
}
Scorer scorer;
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
if (err != 0) {
return err;
}
scorer.save_dictionary(trie_path);
return 0;
}
int main(int argc, char** argv) {
if (argc != 4) {
std::cerr << "Usage: " << argv[0] << " <alphabet> <lm_model> <trie_path>" << std::endl;
return -1;
}
return generate_trie(argv[1], argv[2], argv[3]);
}

View File

@ -27,9 +27,9 @@ int main(int argc, char** argv)
return err;
}
Scorer scorer;
err = scorer.init(kenlm_path, alphabet);
#ifndef DEBUG
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet);
return err;
#else
// Print some info about the FST
using FstType = fst::ConstFst<fst::StdArc>;
@ -60,7 +60,6 @@ int main(int argc, char** argv)
// for (int i = 1; i < 10; ++i) {
// print_states_from(i);
// }
#endif // DEBUG
return 0;
#endif // DEBUG
}