Merge pull request #2303 from mozilla/simplify-decoder
Simplify decoder impl by making it object oriented, avoid pointers where possible
This commit is contained in:
commit
4c14c6b78b
@ -1,15 +1,20 @@
|
|||||||
# Description: Deepspeech native client library.
|
# Description: Deepspeech native client library.
|
||||||
|
|
||||||
load("@org_tensorflow//tensorflow:tensorflow.bzl",
|
load(
|
||||||
"tf_cc_shared_object", "if_cuda")
|
"@org_tensorflow//tensorflow:tensorflow.bzl",
|
||||||
|
"if_cuda",
|
||||||
load("@org_tensorflow//tensorflow/lite:build_def.bzl",
|
"tf_cc_shared_object",
|
||||||
"tflite_copts", "tflite_linkopts")
|
)
|
||||||
|
load(
|
||||||
|
"@org_tensorflow//tensorflow/lite:build_def.bzl",
|
||||||
|
"tflite_copts",
|
||||||
|
"tflite_linkopts",
|
||||||
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "tflite",
|
name = "tflite",
|
||||||
define_values = {
|
define_values = {
|
||||||
"runtime": "tflite"
|
"runtime": "tflite",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,75 +22,98 @@ genrule(
|
|||||||
name = "workspace_status",
|
name = "workspace_status",
|
||||||
outs = ["workspace_status.cc"],
|
outs = ["workspace_status.cc"],
|
||||||
cmd = "$(location :gen_workspace_status.sh) >$@",
|
cmd = "$(location :gen_workspace_status.sh) >$@",
|
||||||
tools = [":gen_workspace_status.sh"],
|
|
||||||
local = 1,
|
local = 1,
|
||||||
stamp = 1,
|
stamp = 1,
|
||||||
|
tools = [":gen_workspace_status.sh"],
|
||||||
)
|
)
|
||||||
|
|
||||||
KENLM_SOURCES = glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
|
KENLM_SOURCES = glob(
|
||||||
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
|
[
|
||||||
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"])
|
"kenlm/lm/*.cc",
|
||||||
|
"kenlm/util/*.cc",
|
||||||
KENLM_INCLUDES = [
|
"kenlm/util/double-conversion/*.cc",
|
||||||
"kenlm",
|
"kenlm/lm/*.hh",
|
||||||
]
|
"kenlm/util/*.hh",
|
||||||
|
"kenlm/util/double-conversion/*.h",
|
||||||
|
],
|
||||||
|
exclude = [
|
||||||
|
"kenlm/*/*test.cc",
|
||||||
|
"kenlm/*/*main.cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
OPENFST_SOURCES_PLATFORM = select({
|
OPENFST_SOURCES_PLATFORM = select({
|
||||||
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
"//tensorflow:windows": glob(["ctcdecode/third_party/openfst-1.6.9-win/src/lib/*.cc"]),
|
||||||
"//conditions:default": glob(["ctcdecode/third_party/openfst-1.6.7/src/lib/*.cc"]),
|
"//conditions:default": glob(["ctcdecode/third_party/openfst-1.6.7/src/lib/*.cc"]),
|
||||||
})
|
})
|
||||||
|
|
||||||
DECODER_SOURCES = glob([
|
|
||||||
"ctcdecode/*.h",
|
|
||||||
"ctcdecode/*.cpp",
|
|
||||||
], exclude=["ctcdecode/*_wrap.cpp"]) + OPENFST_SOURCES_PLATFORM + KENLM_SOURCES
|
|
||||||
|
|
||||||
OPENFST_INCLUDES_PLATFORM = select({
|
OPENFST_INCLUDES_PLATFORM = select({
|
||||||
"//tensorflow:windows": ["ctcdecode/third_party/openfst-1.6.9-win/src/include"],
|
"//tensorflow:windows": ["ctcdecode/third_party/openfst-1.6.9-win/src/include"],
|
||||||
"//conditions:default": ["ctcdecode/third_party/openfst-1.6.7/src/include"],
|
"//conditions:default": ["ctcdecode/third_party/openfst-1.6.7/src/include"],
|
||||||
})
|
})
|
||||||
|
|
||||||
DECODER_INCLUDES = [
|
|
||||||
".",
|
|
||||||
"ctcdecode/third_party/ThreadPool",
|
|
||||||
] + OPENFST_INCLUDES_PLATFORM + KENLM_INCLUDES
|
|
||||||
|
|
||||||
LINUX_LINKOPTS = [
|
LINUX_LINKOPTS = [
|
||||||
"-ldl",
|
"-ldl",
|
||||||
"-pthread",
|
"-pthread",
|
||||||
"-Wl,-Bsymbolic",
|
"-Wl,-Bsymbolic",
|
||||||
"-Wl,-Bsymbolic-functions",
|
"-Wl,-Bsymbolic-functions",
|
||||||
"-Wl,-export-dynamic"
|
"-Wl,-export-dynamic",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "decoder",
|
||||||
|
srcs = [
|
||||||
|
"ctcdecode/ctc_beam_search_decoder.cpp",
|
||||||
|
"ctcdecode/decoder_utils.cpp",
|
||||||
|
"ctcdecode/decoder_utils.h",
|
||||||
|
"ctcdecode/scorer.cpp",
|
||||||
|
"ctcdecode/path_trie.cpp",
|
||||||
|
"ctcdecode/path_trie.h",
|
||||||
|
] + KENLM_SOURCES + OPENFST_SOURCES_PLATFORM,
|
||||||
|
hdrs = [
|
||||||
|
"ctcdecode/ctc_beam_search_decoder.h",
|
||||||
|
"ctcdecode/scorer.h",
|
||||||
|
],
|
||||||
|
defines = ["KENLM_MAX_ORDER=6"],
|
||||||
|
includes = [
|
||||||
|
".",
|
||||||
|
"ctcdecode/third_party/ThreadPool",
|
||||||
|
"kenlm",
|
||||||
|
] + OPENFST_INCLUDES_PLATFORM,
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_shared_object(
|
tf_cc_shared_object(
|
||||||
name = "libdeepspeech.so",
|
name = "libdeepspeech.so",
|
||||||
srcs = ["deepspeech.cc",
|
srcs = [
|
||||||
"deepspeech.h",
|
"deepspeech.cc",
|
||||||
"alphabet.h",
|
"deepspeech.h",
|
||||||
"modelstate.h",
|
"alphabet.h",
|
||||||
"modelstate.cc",
|
"modelstate.h",
|
||||||
"workspace_status.h",
|
"modelstate.cc",
|
||||||
"workspace_status.cc"] +
|
"workspace_status.h",
|
||||||
DECODER_SOURCES +
|
"workspace_status.cc",
|
||||||
select({
|
] + select({
|
||||||
"//native_client:tflite": [
|
"//native_client:tflite": [
|
||||||
"tflitemodelstate.h",
|
"tflitemodelstate.h",
|
||||||
"tflitemodelstate.cc"
|
"tflitemodelstate.cc",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"tfmodelstate.h",
|
"tfmodelstate.h",
|
||||||
"tfmodelstate.cc"
|
"tfmodelstate.cc",
|
||||||
]}),
|
],
|
||||||
|
}),
|
||||||
copts = select({
|
copts = select({
|
||||||
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
# -fvisibility=hidden is not required on Windows, MSCV hides all declarations by default
|
||||||
"//tensorflow:windows": ["/w"],
|
"//tensorflow:windows": ["/w"],
|
||||||
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
# -Wno-sign-compare to silent a lot of warnings from tensorflow itself,
|
||||||
# which makes it harder to see our own warnings
|
# which makes it harder to see our own warnings
|
||||||
"//conditions:default": ["-Wno-sign-compare", "-fvisibility=hidden"],
|
"//conditions:default": [
|
||||||
|
"-Wno-sign-compare",
|
||||||
|
"-fvisibility=hidden",
|
||||||
|
],
|
||||||
}) + select({
|
}) + select({
|
||||||
"//native_client:tflite": [ "-DUSE_TFLITE" ],
|
"//native_client:tflite": ["-DUSE_TFLITE"],
|
||||||
"//conditions:default": [ "-UUSE_TFLITE" ]
|
"//conditions:default": ["-UUSE_TFLITE"],
|
||||||
}) + tflite_copts(),
|
}) + tflite_copts(),
|
||||||
linkopts = select({
|
linkopts = select({
|
||||||
"//tensorflow:macos": [],
|
"//tensorflow:macos": [],
|
||||||
@ -93,7 +121,7 @@ tf_cc_shared_object(
|
|||||||
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
"//tensorflow:rpi3": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||||
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
"//tensorflow:rpi3-armv8": LINUX_LINKOPTS + ["-l:libstdc++.a"],
|
||||||
"//tensorflow:windows": [],
|
"//tensorflow:windows": [],
|
||||||
"//conditions:default": []
|
"//conditions:default": [],
|
||||||
}) + tflite_linkopts(),
|
}) + tflite_linkopts(),
|
||||||
deps = select({
|
deps = select({
|
||||||
"//native_client:tflite": [
|
"//native_client:tflite": [
|
||||||
@ -107,66 +135,70 @@ tf_cc_shared_object(
|
|||||||
### => Trying to be more fine-grained
|
### => Trying to be more fine-grained
|
||||||
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
### Use bin/ops_in_graph.py to list all the ops used by a frozen graph.
|
||||||
### CPU only build, libdeepspeech.so file size reduced by ~50%
|
### CPU only build, libdeepspeech.so file size reduced by ~50%
|
||||||
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
"//tensorflow/core/kernels:spectrogram_op", # AudioSpectrogram
|
||||||
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
"//tensorflow/core/kernels:bias_op", # BiasAdd
|
||||||
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
"//tensorflow/contrib/rnn:lstm_ops_kernels", # BlockLSTM
|
||||||
"//tensorflow/core/kernels:cast_op", # Cast
|
"//tensorflow/core/kernels:cast_op", # Cast
|
||||||
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
"//tensorflow/core/kernels:concat_op", # ConcatV2
|
||||||
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
|
"//tensorflow/core/kernels:constant_op", # Const, Placeholder
|
||||||
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
|
"//tensorflow/core/kernels:shape_ops", # ExpandDims, Shape
|
||||||
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
"//tensorflow/core/kernels:gather_nd_op", # GatherNd
|
||||||
"//tensorflow/core/kernels:identity_op", # Identity
|
"//tensorflow/core/kernels:identity_op", # Identity
|
||||||
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
"//tensorflow/core/kernels:immutable_constant_op", # ImmutableConst (used in memmapped models)
|
||||||
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
|
"//tensorflow/core/kernels:deepspeech_cwise_ops", # Less, Minimum, Mul
|
||||||
"//tensorflow/core/kernels:matmul_op", # MatMul
|
"//tensorflow/core/kernels:matmul_op", # MatMul
|
||||||
"//tensorflow/core/kernels:reduction_ops", # Max
|
"//tensorflow/core/kernels:reduction_ops", # Max
|
||||||
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
"//tensorflow/core/kernels:mfcc_op", # Mfcc
|
||||||
"//tensorflow/core/kernels:no_op", # NoOp
|
"//tensorflow/core/kernels:no_op", # NoOp
|
||||||
"//tensorflow/core/kernels:pack_op", # Pack
|
"//tensorflow/core/kernels:pack_op", # Pack
|
||||||
"//tensorflow/core/kernels:sequence_ops", # Range
|
"//tensorflow/core/kernels:sequence_ops", # Range
|
||||||
"//tensorflow/core/kernels:relu_op", # Relu
|
"//tensorflow/core/kernels:relu_op", # Relu
|
||||||
"//tensorflow/core/kernels:reshape_op", # Reshape
|
"//tensorflow/core/kernels:reshape_op", # Reshape
|
||||||
"//tensorflow/core/kernels:softmax_op", # Softmax
|
"//tensorflow/core/kernels:softmax_op", # Softmax
|
||||||
"//tensorflow/core/kernels:tile_ops", # Tile
|
"//tensorflow/core/kernels:tile_ops", # Tile
|
||||||
"//tensorflow/core/kernels:transpose_op", # Transpose
|
"//tensorflow/core/kernels:transpose_op", # Transpose
|
||||||
# And we also need the op libs for these ops used in the model:
|
# And we also need the op libs for these ops used in the model:
|
||||||
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
|
"//tensorflow/core:audio_ops_op_lib", # AudioSpectrogram, Mfcc
|
||||||
"//tensorflow/contrib/rnn:lstm_ops_op_lib", # BlockLSTM
|
"//tensorflow/contrib/rnn:lstm_ops_op_lib", # BlockLSTM
|
||||||
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
|
"//tensorflow/core:math_ops_op_lib", # Cast, Less, Max, MatMul, Minimum, Range
|
||||||
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
|
"//tensorflow/core:array_ops_op_lib", # ConcatV2, Const, ExpandDims, Fill, GatherNd, Identity, Pack, Placeholder, Reshape, Tile, Transpose
|
||||||
"//tensorflow/core:no_op_op_lib", # NoOp
|
"//tensorflow/core:no_op_op_lib", # NoOp
|
||||||
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
|
"//tensorflow/core:nn_ops_op_lib", # Relu, Softmax, BiasAdd
|
||||||
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
|
# And op libs for these ops brought in by dependencies of dependencies to silence unknown OpKernel warnings:
|
||||||
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
|
"//tensorflow/core:dataset_ops_op_lib", # UnwrapDatasetVariant, WrapDatasetVariant
|
||||||
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
|
"//tensorflow/core:sendrecv_ops_op_lib", # _HostRecv, _HostSend, _Recv, _Send
|
||||||
],
|
],
|
||||||
}) + if_cuda([
|
}) + if_cuda([
|
||||||
"//tensorflow/core:core",
|
"//tensorflow/core:core",
|
||||||
]),
|
]) + [":decoder"],
|
||||||
includes = DECODER_INCLUDES,
|
|
||||||
defines = ["KENLM_MAX_ORDER=6"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "generate_trie",
|
name = "generate_trie",
|
||||||
srcs = [
|
srcs = [
|
||||||
"generate_trie.cpp",
|
"alphabet.h",
|
||||||
"alphabet.h",
|
"generate_trie.cpp",
|
||||||
] + DECODER_SOURCES,
|
],
|
||||||
includes = DECODER_INCLUDES,
|
|
||||||
copts = ["-std=c++11"],
|
copts = ["-std=c++11"],
|
||||||
linkopts = ["-lm", "-ldl", "-pthread"],
|
linkopts = [
|
||||||
defines = ["KENLM_MAX_ORDER=6"],
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
|
deps = [":decoder"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "trie_load",
|
name = "trie_load",
|
||||||
srcs = [
|
srcs = [
|
||||||
"trie_load.cc",
|
"alphabet.h",
|
||||||
"alphabet.h",
|
"trie_load.cc",
|
||||||
] + DECODER_SOURCES,
|
],
|
||||||
includes = DECODER_INCLUDES,
|
|
||||||
copts = ["-std=c++11"],
|
copts = ["-std=c++11"],
|
||||||
linkopts = ["-lm", "-ldl", "-pthread"],
|
linkopts = [
|
||||||
defines = ["KENLM_MAX_ORDER=6"],
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
|
deps = [":decoder"],
|
||||||
)
|
)
|
||||||
|
@ -15,8 +15,15 @@
|
|||||||
*/
|
*/
|
||||||
class Alphabet {
|
class Alphabet {
|
||||||
public:
|
public:
|
||||||
Alphabet(const char *config_file) {
|
Alphabet() = default;
|
||||||
|
Alphabet(const Alphabet&) = default;
|
||||||
|
Alphabet& operator=(const Alphabet&) = default;
|
||||||
|
|
||||||
|
int init(const char *config_file) {
|
||||||
std::ifstream in(config_file, std::ios::in);
|
std::ifstream in(config_file, std::ios::in);
|
||||||
|
if (!in) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
unsigned int label = 0;
|
unsigned int label = 0;
|
||||||
space_label_ = -2;
|
space_label_ = -2;
|
||||||
for (std::string line; std::getline(in, line);) {
|
for (std::string line; std::getline(in, line);) {
|
||||||
@ -35,6 +42,7 @@ public:
|
|||||||
}
|
}
|
||||||
size_ = label;
|
size_ = label;
|
||||||
in.close();
|
in.close();
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string& StringFromLabel(unsigned int label) const {
|
const std::string& StringFromLabel(unsigned int label) const {
|
||||||
|
@ -19,7 +19,10 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
def __init__(self, alpha, beta, model_path, trie_path, alphabet):
|
||||||
swigwrapper.Scorer.__init__(self, alpha, beta, model_path, trie_path, alphabet.config_file())
|
super(Scorer, self).__init__()
|
||||||
|
err = self.init(alpha, beta, model_path, trie_path, alphabet.config_file())
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError("Scorer initialization failed with error code {}".format(err), err)
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
|
@ -12,29 +12,29 @@
|
|||||||
#include "fst/fstlib.h"
|
#include "fst/fstlib.h"
|
||||||
#include "path_trie.h"
|
#include "path_trie.h"
|
||||||
|
|
||||||
DecoderState*
|
|
||||||
decoder_init(const Alphabet &alphabet,
|
|
||||||
int class_dim,
|
|
||||||
Scorer* ext_scorer)
|
|
||||||
{
|
|
||||||
// dimension check
|
|
||||||
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
|
|
||||||
"The shape of probs does not match with "
|
|
||||||
"the shape of the vocabulary");
|
|
||||||
|
|
||||||
|
int
|
||||||
|
DecoderState::init(const Alphabet& alphabet,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer)
|
||||||
|
{
|
||||||
// assign special ids
|
// assign special ids
|
||||||
DecoderState *state = new DecoderState;
|
abs_time_step_ = 0;
|
||||||
state->time_step = 0;
|
space_id_ = alphabet.GetSpaceLabel();
|
||||||
state->space_id = alphabet.GetSpaceLabel();
|
blank_id_ = alphabet.GetSize();
|
||||||
state->blank_id = alphabet.GetSize();
|
|
||||||
|
beam_size_ = beam_size;
|
||||||
|
cutoff_prob_ = cutoff_prob;
|
||||||
|
cutoff_top_n_ = cutoff_top_n;
|
||||||
|
ext_scorer_ = ext_scorer;
|
||||||
|
|
||||||
// init prefixes' root
|
// init prefixes' root
|
||||||
PathTrie *root = new PathTrie;
|
PathTrie *root = new PathTrie;
|
||||||
root->score = root->log_prob_b_prev = 0.0;
|
root->score = root->log_prob_b_prev = 0.0;
|
||||||
|
prefix_root_.reset(root);
|
||||||
state->prefix_root = root;
|
prefixes_.push_back(root);
|
||||||
|
|
||||||
state->prefixes.push_back(root);
|
|
||||||
|
|
||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
auto dict_ptr = ext_scorer->dictionary->Copy(true);
|
||||||
@ -43,51 +43,45 @@ decoder_init(const Alphabet &alphabet,
|
|||||||
root->set_matcher(matcher);
|
root->set_matcher(matcher);
|
||||||
}
|
}
|
||||||
|
|
||||||
return state;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
decoder_next(const double *probs,
|
DecoderState::next(const double *probs,
|
||||||
const Alphabet &alphabet,
|
int time_dim,
|
||||||
DecoderState *state,
|
int class_dim)
|
||||||
int time_dim,
|
|
||||||
int class_dim,
|
|
||||||
double cutoff_prob,
|
|
||||||
size_t cutoff_top_n,
|
|
||||||
size_t beam_size,
|
|
||||||
Scorer *ext_scorer)
|
|
||||||
{
|
{
|
||||||
// prefix search over time
|
// prefix search over time
|
||||||
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++state->time_step) {
|
for (size_t rel_time_step = 0; rel_time_step < time_dim; ++rel_time_step, ++abs_time_step_) {
|
||||||
auto *prob = &probs[rel_time_step*class_dim];
|
auto *prob = &probs[rel_time_step*class_dim];
|
||||||
|
|
||||||
float min_cutoff = -NUM_FLT_INF;
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
bool full_beam = false;
|
bool full_beam = false;
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer_ != nullptr) {
|
||||||
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
|
size_t num_prefixes = std::min(prefixes_.size(), beam_size_);
|
||||||
std::sort(
|
std::sort(
|
||||||
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
|
prefixes_.begin(), prefixes_.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
min_cutoff = state->prefixes[num_prefixes - 1]->score +
|
min_cutoff = prefixes_[num_prefixes - 1]->score +
|
||||||
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
|
std::log(prob[blank_id_]) - std::max(0.0, ext_scorer_->beta);
|
||||||
full_beam = (num_prefixes == beam_size);
|
full_beam = (num_prefixes == beam_size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||||
get_pruned_log_probs(prob, class_dim, cutoff_prob, cutoff_top_n);
|
get_pruned_log_probs(prob, class_dim, cutoff_prob_, cutoff_top_n_);
|
||||||
// loop over class dim
|
// loop over class dim
|
||||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||||
auto c = log_prob_idx[index].first;
|
auto c = log_prob_idx[index].first;
|
||||||
auto log_prob_c = log_prob_idx[index].second;
|
auto log_prob_c = log_prob_idx[index].second;
|
||||||
|
|
||||||
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
|
for (size_t i = 0; i < prefixes_.size() && i < beam_size_; ++i) {
|
||||||
auto prefix = state->prefixes[i];
|
auto prefix = prefixes_[i];
|
||||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// blank
|
// blank
|
||||||
if (c == state->blank_id) {
|
if (c == blank_id_) {
|
||||||
prefix->log_prob_b_cur =
|
prefix->log_prob_b_cur =
|
||||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||||
continue;
|
continue;
|
||||||
@ -100,7 +94,7 @@ decoder_next(const double *probs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// get new prefix
|
// get new prefix
|
||||||
auto prefix_new = prefix->get_path_trie(c, state->time_step, log_prob_c);
|
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
|
||||||
|
|
||||||
if (prefix_new != nullptr) {
|
if (prefix_new != nullptr) {
|
||||||
float log_p = -NUM_FLT_INF;
|
float log_p = -NUM_FLT_INF;
|
||||||
@ -113,11 +107,11 @@ decoder_next(const double *probs,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// language model scoring
|
// language model scoring
|
||||||
if (ext_scorer != nullptr &&
|
if (ext_scorer_ != nullptr &&
|
||||||
(c == state->space_id || ext_scorer->is_character_based())) {
|
(c == space_id_ || ext_scorer_->is_character_based())) {
|
||||||
PathTrie *prefix_to_score = nullptr;
|
PathTrie *prefix_to_score = nullptr;
|
||||||
// skip scoring the space
|
// skip scoring the space
|
||||||
if (ext_scorer->is_character_based()) {
|
if (ext_scorer_->is_character_based()) {
|
||||||
prefix_to_score = prefix_new;
|
prefix_to_score = prefix_new;
|
||||||
} else {
|
} else {
|
||||||
prefix_to_score = prefix;
|
prefix_to_score = prefix;
|
||||||
@ -125,10 +119,10 @@ decoder_next(const double *probs,
|
|||||||
|
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram;
|
std::vector<std::string> ngram;
|
||||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
ngram = ext_scorer_->make_ngram(prefix_to_score);
|
||||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||||
log_p += score;
|
log_p += score;
|
||||||
log_p += ext_scorer->beta;
|
log_p += ext_scorer_->beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix_new->log_prob_nb_cur =
|
prefix_new->log_prob_nb_cur =
|
||||||
@ -138,53 +132,50 @@ decoder_next(const double *probs,
|
|||||||
} // end of loop over alphabet
|
} // end of loop over alphabet
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
state->prefixes.clear();
|
prefixes_.clear();
|
||||||
state->prefix_root->iterate_to_vec(state->prefixes);
|
prefix_root_->iterate_to_vec(prefixes_);
|
||||||
|
|
||||||
// only preserve top beam_size prefixes
|
// only preserve top beam_size prefixes
|
||||||
if (state->prefixes.size() > beam_size) {
|
if (prefixes_.size() > beam_size_) {
|
||||||
std::nth_element(state->prefixes.begin(),
|
std::nth_element(prefixes_.begin(),
|
||||||
state->prefixes.begin() + beam_size,
|
prefixes_.begin() + beam_size_,
|
||||||
state->prefixes.end(),
|
prefixes_.end(),
|
||||||
prefix_compare);
|
prefix_compare);
|
||||||
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
|
for (size_t i = beam_size_; i < prefixes_.size(); ++i) {
|
||||||
state->prefixes[i]->remove();
|
prefixes_[i]->remove();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the elements from std::vector
|
// Remove the elements from std::vector
|
||||||
state->prefixes.resize(beam_size);
|
prefixes_.resize(beam_size_);
|
||||||
}
|
}
|
||||||
} // end of loop over time
|
} // end of loop over time
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Output>
|
std::vector<Output>
|
||||||
decoder_decode(DecoderState *state,
|
DecoderState::decode() const
|
||||||
const Alphabet &alphabet,
|
|
||||||
size_t beam_size,
|
|
||||||
Scorer* ext_scorer)
|
|
||||||
{
|
{
|
||||||
std::vector<PathTrie*> prefixes_copy = state->prefixes;
|
std::vector<PathTrie*> prefixes_copy = prefixes_;
|
||||||
std::unordered_map<const PathTrie*, float> scores;
|
std::unordered_map<const PathTrie*, float> scores;
|
||||||
for (PathTrie* prefix : prefixes_copy) {
|
for (PathTrie* prefix : prefixes_copy) {
|
||||||
scores[prefix] = prefix->score;
|
scores[prefix] = prefix->score;
|
||||||
}
|
}
|
||||||
|
|
||||||
// score the last word of each prefix that doesn't end with space
|
// score the last word of each prefix that doesn't end with space
|
||||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
if (ext_scorer_ != nullptr && !ext_scorer_->is_character_based()) {
|
||||||
for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) {
|
for (size_t i = 0; i < beam_size_ && i < prefixes_copy.size(); ++i) {
|
||||||
auto prefix = prefixes_copy[i];
|
auto prefix = prefixes_copy[i];
|
||||||
if (!prefix->is_empty() && prefix->character != state->space_id) {
|
if (!prefix->is_empty() && prefix->character != space_id_) {
|
||||||
float score = 0.0;
|
float score = 0.0;
|
||||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
std::vector<std::string> ngram = ext_scorer_->make_ngram(prefix);
|
||||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
score = ext_scorer_->get_log_cond_prob(ngram) * ext_scorer_->alpha;
|
||||||
score += ext_scorer->beta;
|
score += ext_scorer_->beta;
|
||||||
scores[prefix] += score;
|
scores[prefix] += score;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
using namespace std::placeholders;
|
using namespace std::placeholders;
|
||||||
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size);
|
size_t num_prefixes = std::min(prefixes_copy.size(), beam_size_);
|
||||||
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, std::bind(prefix_compare_external, _1, _2, scores));
|
||||||
|
|
||||||
//TODO: expose this as an API parameter
|
//TODO: expose this as an API parameter
|
||||||
@ -194,16 +185,16 @@ decoder_decode(DecoderState *state,
|
|||||||
// return order of decoding result. To delete when decoder gets stable.
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
for (size_t i = 0; i < top_paths && i < prefixes_copy.size(); ++i) {
|
||||||
double approx_ctc = scores[prefixes_copy[i]];
|
double approx_ctc = scores[prefixes_copy[i]];
|
||||||
if (ext_scorer != nullptr) {
|
if (ext_scorer_ != nullptr) {
|
||||||
std::vector<int> output;
|
std::vector<int> output;
|
||||||
std::vector<int> timesteps;
|
std::vector<int> timesteps;
|
||||||
prefixes_copy[i]->get_path_vec(output, timesteps);
|
prefixes_copy[i]->get_path_vec(output, timesteps);
|
||||||
auto prefix_length = output.size();
|
auto prefix_length = output.size();
|
||||||
auto words = ext_scorer->split_labels(output);
|
auto words = ext_scorer_->split_labels(output);
|
||||||
// remove word insert
|
// remove word insert
|
||||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
approx_ctc = approx_ctc - prefix_length * ext_scorer_->beta;
|
||||||
// remove language model weight:
|
// remove language model weight:
|
||||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
approx_ctc -= (ext_scorer_->get_sent_log_prob(words)) * ext_scorer_->alpha;
|
||||||
}
|
}
|
||||||
prefixes_copy[i]->approx_ctc = approx_ctc;
|
prefixes_copy[i]->approx_ctc = approx_ctc;
|
||||||
}
|
}
|
||||||
@ -221,13 +212,10 @@ std::vector<Output> ctc_beam_search_decoder(
|
|||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer)
|
Scorer *ext_scorer)
|
||||||
{
|
{
|
||||||
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
|
DecoderState state;
|
||||||
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
|
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
|
||||||
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
|
state.next(probs, time_dim, class_dim);
|
||||||
|
return state.decode();
|
||||||
delete state;
|
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
|
@ -7,66 +7,67 @@
|
|||||||
#include "scorer.h"
|
#include "scorer.h"
|
||||||
#include "output.h"
|
#include "output.h"
|
||||||
#include "alphabet.h"
|
#include "alphabet.h"
|
||||||
#include "decoderstate.h"
|
|
||||||
|
|
||||||
/* Initialize CTC beam search decoder
|
class DecoderState {
|
||||||
|
int abs_time_step_;
|
||||||
|
int space_id_;
|
||||||
|
int blank_id_;
|
||||||
|
size_t beam_size_;
|
||||||
|
double cutoff_prob_;
|
||||||
|
size_t cutoff_top_n_;
|
||||||
|
|
||||||
* Parameters:
|
Scorer* ext_scorer_; // weak
|
||||||
* alphabet: The alphabet.
|
std::vector<PathTrie*> prefixes_;
|
||||||
* class_dim: Alphabet length (plus 1 for space character).
|
std::unique_ptr<PathTrie> prefix_root_;
|
||||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
|
||||||
* n-gram language model scoring and word insertion term.
|
|
||||||
* Default null, decoding the input sample without scorer.
|
|
||||||
* Return:
|
|
||||||
* A struct containing prefixes and state variables.
|
|
||||||
*/
|
|
||||||
DecoderState* decoder_init(const Alphabet &alphabet,
|
|
||||||
int class_dim,
|
|
||||||
Scorer *ext_scorer);
|
|
||||||
|
|
||||||
/* Send data to the decoder
|
public:
|
||||||
|
DecoderState() = default;
|
||||||
|
~DecoderState() = default;
|
||||||
|
|
||||||
* Parameters:
|
// Disallow copying
|
||||||
* probs: 2-D vector where each element is a vector of probabilities
|
DecoderState(const DecoderState&) = delete;
|
||||||
* over alphabet of one time step.
|
DecoderState& operator=(DecoderState&) = delete;
|
||||||
* alphabet: The alphabet.
|
|
||||||
* state: The state structure previously obtained from decoder_init().
|
|
||||||
* time_dim: Number of timesteps.
|
|
||||||
* class_dim: Alphabet length (plus 1 for space character).
|
|
||||||
* cutoff_prob: Cutoff probability for pruning.
|
|
||||||
* cutoff_top_n: Cutoff number for pruning.
|
|
||||||
* beam_size: The width of beam search.
|
|
||||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
|
||||||
* n-gram language model scoring and word insertion term.
|
|
||||||
* Default null, decoding the input sample without scorer.
|
|
||||||
*/
|
|
||||||
void decoder_next(const double *probs,
|
|
||||||
const Alphabet &alphabet,
|
|
||||||
DecoderState *state,
|
|
||||||
int time_dim,
|
|
||||||
int class_dim,
|
|
||||||
double cutoff_prob,
|
|
||||||
size_t cutoff_top_n,
|
|
||||||
size_t beam_size,
|
|
||||||
Scorer *ext_scorer);
|
|
||||||
|
|
||||||
/* Get transcription for the data you sent via decoder_next()
|
/* Initialize CTC beam search decoder
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* alphabet: The alphabet.
|
||||||
|
* beam_size: The width of beam search.
|
||||||
|
* cutoff_prob: Cutoff probability for pruning.
|
||||||
|
* cutoff_top_n: Cutoff number for pruning.
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* Return:
|
||||||
|
* Zero on success, non-zero on failure.
|
||||||
|
*/
|
||||||
|
int init(const Alphabet& alphabet,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer);
|
||||||
|
|
||||||
|
/* Send data to the decoder
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* probs: 2-D vector where each element is a vector of probabilities
|
||||||
|
* over alphabet of one time step.
|
||||||
|
* time_dim: Number of timesteps.
|
||||||
|
* class_dim: Number of classes (alphabet length + 1 for space character).
|
||||||
|
*/
|
||||||
|
void next(const double *probs,
|
||||||
|
int time_dim,
|
||||||
|
int class_dim);
|
||||||
|
|
||||||
|
/* Get transcription from current decoder state
|
||||||
|
*
|
||||||
|
* Return:
|
||||||
|
* A vector where each element is a pair of score and decoding result,
|
||||||
|
* in descending order.
|
||||||
|
*/
|
||||||
|
std::vector<Output> decode() const;
|
||||||
|
};
|
||||||
|
|
||||||
* Parameters:
|
|
||||||
* state: The state structure previously obtained from decoder_init().
|
|
||||||
* alphabet: The alphabet.
|
|
||||||
* beam_size: The width of beam search.
|
|
||||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
|
||||||
* n-gram language model scoring and word insertion term.
|
|
||||||
* Default null, decoding the input sample without scorer.
|
|
||||||
* Return:
|
|
||||||
* A vector where each element is a pair of score and decoding result,
|
|
||||||
* in descending order.
|
|
||||||
*/
|
|
||||||
std::vector<Output> decoder_decode(DecoderState *state,
|
|
||||||
const Alphabet &alphabet,
|
|
||||||
size_t beam_size,
|
|
||||||
Scorer* ext_scorer);
|
|
||||||
|
|
||||||
/* CTC Beam Search Decoder
|
/* CTC Beam Search Decoder
|
||||||
* Parameters:
|
* Parameters:
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
#ifndef DECODERSTATE_H_
|
|
||||||
#define DECODERSTATE_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
|
|
||||||
|
|
||||||
struct DecoderState {
|
|
||||||
int time_step;
|
|
||||||
int space_id;
|
|
||||||
int blank_id;
|
|
||||||
std::vector<PathTrie*> prefixes;
|
|
||||||
PathTrie *prefix_root;
|
|
||||||
|
|
||||||
~DecoderState() {
|
|
||||||
if (prefix_root != nullptr) {
|
|
||||||
delete prefix_root;
|
|
||||||
}
|
|
||||||
prefix_root = nullptr;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#endif // DECODERSTATE_H_
|
|
@ -29,19 +29,38 @@ using namespace lm::ngram;
|
|||||||
static const int32_t MAGIC = 'TRIE';
|
static const int32_t MAGIC = 'TRIE';
|
||||||
static const int32_t FILE_VERSION = 4;
|
static const int32_t FILE_VERSION = 4;
|
||||||
|
|
||||||
Scorer::Scorer(double alpha,
|
int
|
||||||
double beta,
|
Scorer::init(double alpha,
|
||||||
const std::string& lm_path,
|
double beta,
|
||||||
const std::string& trie_path,
|
const std::string& lm_path,
|
||||||
const Alphabet& alphabet)
|
const std::string& trie_path,
|
||||||
: dictionary()
|
const Alphabet& alphabet)
|
||||||
, language_model_()
|
|
||||||
, is_character_based_(true)
|
|
||||||
, max_order_(0)
|
|
||||||
, alphabet_(alphabet)
|
|
||||||
{
|
{
|
||||||
reset_params(alpha, beta);
|
reset_params(alpha, beta);
|
||||||
|
alphabet_ = alphabet;
|
||||||
|
setup(lm_path, trie_path);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
Scorer::init(double alpha,
|
||||||
|
double beta,
|
||||||
|
const std::string& lm_path,
|
||||||
|
const std::string& trie_path,
|
||||||
|
const std::string& alphabet_config_path)
|
||||||
|
{
|
||||||
|
reset_params(alpha, beta);
|
||||||
|
int err = alphabet_.init(alphabet_config_path.c_str());
|
||||||
|
if (err != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
setup(lm_path, trie_path);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
||||||
|
{
|
||||||
|
// (Re-)Initialize character map
|
||||||
char_map_.clear();
|
char_map_.clear();
|
||||||
|
|
||||||
SPACE_ID_ = alphabet_.GetSpaceLabel();
|
SPACE_ID_ = alphabet_.GetSpaceLabel();
|
||||||
@ -53,24 +72,6 @@ Scorer::Scorer(double alpha,
|
|||||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
setup(lm_path, trie_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
Scorer::Scorer(double alpha,
|
|
||||||
double beta,
|
|
||||||
const std::string& lm_path,
|
|
||||||
const std::string& trie_path,
|
|
||||||
const std::string& alphabet_config_path)
|
|
||||||
: Scorer(alpha, beta, lm_path, trie_path, Alphabet(alphabet_config_path.c_str()))
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
Scorer::~Scorer()
|
|
||||||
{
|
|
||||||
}
|
|
||||||
|
|
||||||
void Scorer::setup(const std::string& lm_path, const std::string& trie_path)
|
|
||||||
{
|
|
||||||
// load language model
|
// load language model
|
||||||
const char* filename = lm_path.c_str();
|
const char* filename = lm_path.c_str();
|
||||||
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
VALID_CHECK_EQ(access(filename, R_OK), 0, "Invalid language model path");
|
||||||
|
@ -43,17 +43,24 @@ class Scorer {
|
|||||||
using FstType = PathTrie::FstType;
|
using FstType = PathTrie::FstType;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Scorer(double alpha,
|
Scorer() = default;
|
||||||
double beta,
|
~Scorer() = default;
|
||||||
const std::string &lm_path,
|
|
||||||
const std::string &trie_path,
|
// disallow copying
|
||||||
const Alphabet &alphabet);
|
Scorer(const Scorer&) = delete;
|
||||||
Scorer(double alpha,
|
Scorer& operator=(const Scorer&) = delete;
|
||||||
double beta,
|
|
||||||
const std::string &lm_path,
|
int init(double alpha,
|
||||||
const std::string &trie_path,
|
double beta,
|
||||||
const std::string &alphabet_config_path);
|
const std::string &lm_path,
|
||||||
~Scorer();
|
const std::string &trie_path,
|
||||||
|
const Alphabet &alphabet);
|
||||||
|
|
||||||
|
int init(double alpha,
|
||||||
|
double beta,
|
||||||
|
const std::string &lm_path,
|
||||||
|
const std::string &trie_path,
|
||||||
|
const std::string &alphabet_config_path);
|
||||||
|
|
||||||
double get_log_cond_prob(const std::vector<std::string> &words);
|
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||||
|
|
||||||
@ -79,9 +86,9 @@ public:
|
|||||||
void save_dictionary(const std::string &path);
|
void save_dictionary(const std::string &path);
|
||||||
|
|
||||||
// language model weight
|
// language model weight
|
||||||
double alpha;
|
double alpha = 0.;
|
||||||
// word insertion weight
|
// word insertion weight
|
||||||
double beta;
|
double beta = 0.;
|
||||||
|
|
||||||
// pointer to the dictionary of FST
|
// pointer to the dictionary of FST
|
||||||
std::unique_ptr<FstType> dictionary;
|
std::unique_ptr<FstType> dictionary;
|
||||||
@ -100,8 +107,8 @@ protected:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<lm::base::Model> language_model_;
|
std::unique_ptr<lm::base::Model> language_model_;
|
||||||
bool is_character_based_;
|
bool is_character_based_ = true;
|
||||||
size_t max_order_;
|
size_t max_order_ = 0;
|
||||||
|
|
||||||
int SPACE_ID_;
|
int SPACE_ID_;
|
||||||
Alphabet alphabet_;
|
Alphabet alphabet_;
|
||||||
|
@ -19,42 +19,44 @@ import_array();
|
|||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_dim, int time_dim, int class_dim)};
|
||||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
||||||
|
|
||||||
// Convert char* to Alphabet
|
// Add overloads converting char* to Alphabet
|
||||||
%rename (ctc_beam_search_decoder) mod_decoder;
|
|
||||||
%inline %{
|
%inline %{
|
||||||
std::vector<Output>
|
std::vector<Output>
|
||||||
mod_decoder(const double *probs,
|
ctc_beam_search_decoder(const double *probs,
|
||||||
int time_dim,
|
int time_dim,
|
||||||
int class_dim,
|
int class_dim,
|
||||||
char* alphabet_config_path,
|
char* alphabet_config_path,
|
||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer)
|
Scorer *ext_scorer)
|
||||||
{
|
{
|
||||||
Alphabet a(alphabet_config_path);
|
Alphabet a;
|
||||||
|
if (a.init(alphabet_config_path)) {
|
||||||
|
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
|
||||||
|
}
|
||||||
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
|
return ctc_beam_search_decoder(probs, time_dim, class_dim, a, beam_size,
|
||||||
cutoff_prob, cutoff_top_n, ext_scorer);
|
cutoff_prob, cutoff_top_n, ext_scorer);
|
||||||
}
|
}
|
||||||
%}
|
|
||||||
|
|
||||||
%rename (ctc_beam_search_decoder_batch) mod_decoder_batch;
|
|
||||||
%inline %{
|
|
||||||
std::vector<std::vector<Output>>
|
std::vector<std::vector<Output>>
|
||||||
mod_decoder_batch(const double *probs,
|
ctc_beam_search_decoder_batch(const double *probs,
|
||||||
int batch_dim,
|
int batch_dim,
|
||||||
int time_dim,
|
int time_dim,
|
||||||
int class_dim,
|
int class_dim,
|
||||||
const int *seq_lengths,
|
const int *seq_lengths,
|
||||||
int seq_lengths_size,
|
int seq_lengths_size,
|
||||||
char* alphabet_config_path,
|
char* alphabet_config_path,
|
||||||
size_t beam_size,
|
size_t beam_size,
|
||||||
size_t num_processes,
|
size_t num_processes,
|
||||||
double cutoff_prob,
|
double cutoff_prob,
|
||||||
size_t cutoff_top_n,
|
size_t cutoff_top_n,
|
||||||
Scorer *ext_scorer)
|
Scorer *ext_scorer)
|
||||||
{
|
{
|
||||||
Alphabet a(alphabet_config_path);
|
Alphabet a;
|
||||||
|
if (a.init(alphabet_config_path)) {
|
||||||
|
std::cerr << "Error initializing alphabet from file: \"" << alphabet_config_path << "\"\n";
|
||||||
|
}
|
||||||
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
|
return ctc_beam_search_decoder_batch(probs, batch_dim, time_dim, class_dim,
|
||||||
seq_lengths, seq_lengths_size, a, beam_size,
|
seq_lengths, seq_lengths_size, a, beam_size,
|
||||||
num_processes, cutoff_prob, cutoff_top_n,
|
num_processes, cutoff_prob, cutoff_top_n,
|
||||||
|
@ -71,7 +71,7 @@ struct StreamingState {
|
|||||||
vector<float> previous_state_h_;
|
vector<float> previous_state_h_;
|
||||||
|
|
||||||
ModelState* model_;
|
ModelState* model_;
|
||||||
std::unique_ptr<DecoderState> decoder_state_;
|
DecoderState decoder_state_;
|
||||||
|
|
||||||
StreamingState();
|
StreamingState();
|
||||||
~StreamingState();
|
~StreamingState();
|
||||||
@ -133,21 +133,21 @@ StreamingState::feedAudioContent(const short* buffer,
|
|||||||
char*
|
char*
|
||||||
StreamingState::intermediateDecode()
|
StreamingState::intermediateDecode()
|
||||||
{
|
{
|
||||||
return model_->decode(decoder_state_.get());
|
return model_->decode(decoder_state_);
|
||||||
}
|
}
|
||||||
|
|
||||||
char*
|
char*
|
||||||
StreamingState::finishStream()
|
StreamingState::finishStream()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model_->decode(decoder_state_.get());
|
return model_->decode(decoder_state_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
StreamingState::finishStreamWithMetadata()
|
StreamingState::finishStreamWithMetadata()
|
||||||
{
|
{
|
||||||
finalizeStream();
|
finalizeStream();
|
||||||
return model_->decode_metadata(decoder_state_.get());
|
return model_->decode_metadata(decoder_state_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -244,23 +244,15 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
|
|||||||
previous_state_c_,
|
previous_state_c_,
|
||||||
previous_state_h_);
|
previous_state_h_);
|
||||||
|
|
||||||
const int cutoff_top_n = 40;
|
const size_t num_classes = model_->alphabet_.GetSize() + 1; // +1 for blank
|
||||||
const double cutoff_prob = 1.0;
|
|
||||||
const size_t num_classes = model_->alphabet_->GetSize() + 1; // +1 for blank
|
|
||||||
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
const int n_frames = logits.size() / (ModelState::BATCH_SIZE * num_classes);
|
||||||
|
|
||||||
// Convert logits to double
|
// Convert logits to double
|
||||||
vector<double> inputs(logits.begin(), logits.end());
|
vector<double> inputs(logits.begin(), logits.end());
|
||||||
|
|
||||||
decoder_next(inputs.data(),
|
decoder_state_.next(inputs.data(),
|
||||||
*model_->alphabet_,
|
n_frames,
|
||||||
decoder_state_.get(),
|
num_classes);
|
||||||
n_frames,
|
|
||||||
num_classes,
|
|
||||||
cutoff_prob,
|
|
||||||
cutoff_top_n,
|
|
||||||
model_->beam_width_,
|
|
||||||
model_->scorer_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
@ -316,15 +308,15 @@ DS_EnableDecoderWithLM(ModelState* aCtx,
|
|||||||
float aLMAlpha,
|
float aLMAlpha,
|
||||||
float aLMBeta)
|
float aLMBeta)
|
||||||
{
|
{
|
||||||
try {
|
aCtx->scorer_.reset(new Scorer());
|
||||||
aCtx->scorer_ = new Scorer(aLMAlpha, aLMBeta,
|
int err = aCtx->scorer_->init(aLMAlpha, aLMBeta,
|
||||||
aLMPath ? aLMPath : "",
|
aLMPath ? aLMPath : "",
|
||||||
aTriePath ? aTriePath : "",
|
aTriePath ? aTriePath : "",
|
||||||
*aCtx->alphabet_);
|
aCtx->alphabet_);
|
||||||
return DS_ERR_OK;
|
if (err != 0) {
|
||||||
} catch (...) {
|
|
||||||
return DS_ERR_INVALID_LM;
|
return DS_ERR_INVALID_LM;
|
||||||
}
|
}
|
||||||
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
@ -340,8 +332,6 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
return DS_ERR_FAIL_CREATE_STREAM;
|
return DS_ERR_FAIL_CREATE_STREAM;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_classes = aCtx->alphabet_->GetSize() + 1; // +1 for blank
|
|
||||||
|
|
||||||
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
ctx->audio_buffer_.reserve(aCtx->audio_win_len_);
|
||||||
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
ctx->mfcc_buffer_.reserve(aCtx->mfcc_feats_per_timestep_);
|
||||||
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
ctx->mfcc_buffer_.resize(aCtx->n_features_*aCtx->n_context_, 0.f);
|
||||||
@ -350,7 +340,14 @@ DS_SetupStream(ModelState* aCtx,
|
|||||||
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
|
ctx->previous_state_h_.resize(aCtx->state_size_, 0.f);
|
||||||
ctx->model_ = aCtx;
|
ctx->model_ = aCtx;
|
||||||
|
|
||||||
ctx->decoder_state_.reset(decoder_init(*aCtx->alphabet_, num_classes, aCtx->scorer_));
|
const int cutoff_top_n = 40;
|
||||||
|
const double cutoff_prob = 1.0;
|
||||||
|
|
||||||
|
ctx->decoder_state_.init(aCtx->alphabet_,
|
||||||
|
aCtx->beam_width_,
|
||||||
|
cutoff_prob,
|
||||||
|
cutoff_top_n,
|
||||||
|
aCtx->scorer_.get());
|
||||||
|
|
||||||
*retval = ctx.release();
|
*retval = ctx.release();
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
|
@ -8,8 +8,16 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
|
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* trie_path) {
|
||||||
Alphabet alphabet(alphabet_path);
|
Alphabet alphabet;
|
||||||
Scorer scorer(0.0, 0.0, kenlm_path, "", alphabet);
|
int err = alphabet.init(alphabet_path);
|
||||||
|
if (err != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
Scorer scorer;
|
||||||
|
err = scorer.init(0.0, 0.0, kenlm_path, "", alphabet);
|
||||||
|
if (err != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
scorer.save_dictionary(trie_path);
|
scorer.save_dictionary(trie_path);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -7,9 +7,7 @@
|
|||||||
using std::vector;
|
using std::vector;
|
||||||
|
|
||||||
ModelState::ModelState()
|
ModelState::ModelState()
|
||||||
: alphabet_(nullptr)
|
: beam_width_(-1)
|
||||||
, scorer_(nullptr)
|
|
||||||
, beam_width_(-1)
|
|
||||||
, n_steps_(-1)
|
, n_steps_(-1)
|
||||||
, n_context_(-1)
|
, n_context_(-1)
|
||||||
, n_features_(-1)
|
, n_features_(-1)
|
||||||
@ -23,8 +21,6 @@ ModelState::ModelState()
|
|||||||
|
|
||||||
ModelState::~ModelState()
|
ModelState::~ModelState()
|
||||||
{
|
{
|
||||||
delete scorer_;
|
|
||||||
delete alphabet_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int
|
int
|
||||||
@ -36,29 +32,24 @@ ModelState::init(const char* model_path,
|
|||||||
{
|
{
|
||||||
n_features_ = n_features;
|
n_features_ = n_features;
|
||||||
n_context_ = n_context;
|
n_context_ = n_context;
|
||||||
alphabet_ = new Alphabet(alphabet_path);
|
if (alphabet_.init(alphabet_path)) {
|
||||||
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
|
}
|
||||||
beam_width_ = beam_width;
|
beam_width_ = beam_width;
|
||||||
return DS_ERR_OK;
|
return DS_ERR_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<Output>
|
|
||||||
ModelState::decode_raw(DecoderState* state)
|
|
||||||
{
|
|
||||||
vector<Output> out = decoder_decode(state, *alphabet_, beam_width_, scorer_);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
char*
|
char*
|
||||||
ModelState::decode(DecoderState* state)
|
ModelState::decode(const DecoderState& state)
|
||||||
{
|
{
|
||||||
vector<Output> out = decode_raw(state);
|
vector<Output> out = state.decode();
|
||||||
return strdup(alphabet_->LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
ModelState::decode_metadata(DecoderState* state)
|
ModelState::decode_metadata(const DecoderState& state)
|
||||||
{
|
{
|
||||||
vector<Output> out = decode_raw(state);
|
vector<Output> out = state.decode();
|
||||||
|
|
||||||
std::unique_ptr<Metadata> metadata(new Metadata());
|
std::unique_ptr<Metadata> metadata(new Metadata());
|
||||||
metadata->num_items = out[0].tokens.size();
|
metadata->num_items = out[0].tokens.size();
|
||||||
@ -68,7 +59,7 @@ ModelState::decode_metadata(DecoderState* state)
|
|||||||
|
|
||||||
// Loop through each character
|
// Loop through each character
|
||||||
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
for (int i = 0; i < out[0].tokens.size(); ++i) {
|
||||||
items[i].character = strdup(alphabet_->StringFromLabel(out[0].tokens[i]).c_str());
|
items[i].character = strdup(alphabet_.StringFromLabel(out[0].tokens[i]).c_str());
|
||||||
items[i].timestep = out[0].timesteps[i];
|
items[i].timestep = out[0].timesteps[i];
|
||||||
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
items[i].start_time = out[0].timesteps[i] * ((float)audio_win_step_ / sample_rate_);
|
||||||
|
|
||||||
|
@ -8,7 +8,8 @@
|
|||||||
|
|
||||||
#include "ctcdecode/scorer.h"
|
#include "ctcdecode/scorer.h"
|
||||||
#include "ctcdecode/output.h"
|
#include "ctcdecode/output.h"
|
||||||
#include "ctcdecode/decoderstate.h"
|
|
||||||
|
class DecoderState;
|
||||||
|
|
||||||
struct ModelState {
|
struct ModelState {
|
||||||
//TODO: infer batch size from model/use dynamic batch size
|
//TODO: infer batch size from model/use dynamic batch size
|
||||||
@ -18,8 +19,8 @@ struct ModelState {
|
|||||||
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
static constexpr unsigned int DEFAULT_WINDOW_LENGTH = DEFAULT_SAMPLE_RATE * 0.032;
|
||||||
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
static constexpr unsigned int DEFAULT_WINDOW_STEP = DEFAULT_SAMPLE_RATE * 0.02;
|
||||||
|
|
||||||
Alphabet* alphabet_;
|
Alphabet alphabet_;
|
||||||
Scorer* scorer_;
|
std::unique_ptr<Scorer> scorer_;
|
||||||
unsigned int beam_width_;
|
unsigned int beam_width_;
|
||||||
unsigned int n_steps_;
|
unsigned int n_steps_;
|
||||||
unsigned int n_context_;
|
unsigned int n_context_;
|
||||||
@ -59,16 +60,6 @@ struct ModelState {
|
|||||||
std::vector<float>& state_c_output,
|
std::vector<float>& state_c_output,
|
||||||
std::vector<float>& state_h_output) = 0;
|
std::vector<float>& state_h_output) = 0;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
|
||||||
* CTC decoder with KenLM enabled
|
|
||||||
*
|
|
||||||
* @param state Decoder state to use when decoding.
|
|
||||||
*
|
|
||||||
* @return Vector of Output structs directly from the CTC decoder for additional processing.
|
|
||||||
*/
|
|
||||||
virtual std::vector<Output> decode_raw(DecoderState* state);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Perform decoding of the logits, using basic CTC decoder or
|
* @brief Perform decoding of the logits, using basic CTC decoder or
|
||||||
* CTC decoder with KenLM enabled
|
* CTC decoder with KenLM enabled
|
||||||
@ -77,7 +68,7 @@ struct ModelState {
|
|||||||
*
|
*
|
||||||
* @return String representing the decoded text.
|
* @return String representing the decoded text.
|
||||||
*/
|
*/
|
||||||
virtual char* decode(DecoderState* state);
|
virtual char* decode(const DecoderState& state);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return character-level metadata including letter timings.
|
* @brief Return character-level metadata including letter timings.
|
||||||
@ -87,7 +78,7 @@ struct ModelState {
|
|||||||
* @return Metadata struct containing MetadataItem structs for each character.
|
* @return Metadata struct containing MetadataItem structs for each character.
|
||||||
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
* The user is responsible for freeing Metadata by calling DS_FreeMetadata().
|
||||||
*/
|
*/
|
||||||
virtual Metadata* decode_metadata(DecoderState* state);
|
virtual Metadata* decode_metadata(const DecoderState& state);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MODELSTATE_H
|
#endif // MODELSTATE_H
|
||||||
|
@ -151,9 +151,9 @@ TFLiteModelState::init(const char* model_path,
|
|||||||
|
|
||||||
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
TfLiteIntArray* dims_logits = interpreter_->tensor(logits_idx_)->dims;
|
||||||
const int final_dim_size = dims_logits->data[1] - 1;
|
const int final_dim_size = dims_logits->data[1] - 1;
|
||||||
if (final_dim_size != alphabet_->GetSize()) {
|
if (final_dim_size != alphabet_.GetSize()) {
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
<< "has size " << alphabet_->GetSize()
|
<< "has size " << alphabet_.GetSize()
|
||||||
<< ", but model has " << final_dim_size
|
<< ", but model has " << final_dim_size
|
||||||
<< " classes in its output. Make sure you're passing an alphabet "
|
<< " classes in its output. Make sure you're passing an alphabet "
|
||||||
<< "file with the same size as the one used for training."
|
<< "file with the same size as the one used for training."
|
||||||
@ -208,7 +208,7 @@ TFLiteModelState::infer(const vector<float>& mfcc,
|
|||||||
vector<float>& state_c_output,
|
vector<float>& state_c_output,
|
||||||
vector<float>& state_h_output)
|
vector<float>& state_h_output)
|
||||||
{
|
{
|
||||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
// Feeding input_node
|
// Feeding input_node
|
||||||
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
copy_vector_to_tensor(mfcc, input_node_idx_, n_frames*mfcc_feats_per_timestep_);
|
||||||
|
@ -108,9 +108,9 @@ TFModelState::init(const char* model_path,
|
|||||||
}
|
}
|
||||||
|
|
||||||
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
int final_dim_size = logits_shape.vec<int>()(2) - 1;
|
||||||
if (final_dim_size != alphabet_->GetSize()) {
|
if (final_dim_size != alphabet_.GetSize()) {
|
||||||
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
std::cerr << "Error: Alphabet size does not match loaded model: alphabet "
|
||||||
<< "has size " << alphabet_->GetSize()
|
<< "has size " << alphabet_.GetSize()
|
||||||
<< ", but model has " << final_dim_size
|
<< ", but model has " << final_dim_size
|
||||||
<< " classes in its output. Make sure you're passing an alphabet "
|
<< " classes in its output. Make sure you're passing an alphabet "
|
||||||
<< "file with the same size as the one used for training."
|
<< "file with the same size as the one used for training."
|
||||||
@ -173,7 +173,7 @@ TFModelState::infer(const std::vector<float>& mfcc,
|
|||||||
vector<float>& state_c_output,
|
vector<float>& state_c_output,
|
||||||
vector<float>& state_h_output)
|
vector<float>& state_h_output)
|
||||||
{
|
{
|
||||||
const size_t num_classes = alphabet_->GetSize() + 1; // +1 for blank
|
const size_t num_classes = alphabet_.GetSize() + 1; // +1 for blank
|
||||||
|
|
||||||
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
Tensor input = tensor_from_vector(mfcc, TensorShape({BATCH_SIZE, n_steps_, 2*n_context_+1, n_features_}));
|
||||||
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
Tensor previous_state_c_t = tensor_from_vector(previous_state_c, TensorShape({BATCH_SIZE, (long long)state_size_}));
|
||||||
|
@ -16,8 +16,11 @@ int main(int argc, char** argv)
|
|||||||
|
|
||||||
printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
|
printf("Loading trie(%s) and alphabet(%s)\n", trie_path, alphabet_path);
|
||||||
|
|
||||||
Alphabet alphabet(alphabet_path);
|
Alphabet alphabet;
|
||||||
Scorer scorer(0.0, 0.0, kenlm_path, trie_path, alphabet);
|
int err = alphabet.init(alphabet_path);
|
||||||
|
if (err != 0) {
|
||||||
return 0;
|
return err;
|
||||||
|
}
|
||||||
|
Scorer scorer;
|
||||||
|
return scorer.init(0.0, 0.0, kenlm_path, trie_path, alphabet);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user