Write a CTC beam search decoder TF op that scores beams with our LM
This commit is contained in:
parent
af71da0d4d
commit
fc91e3d7b8
@ -34,3 +34,35 @@ cc_library(
|
||||
copts = [] + if_linux_x86_64(["-mno-fma", "-mno-avx", "-mno-avx2"]),
|
||||
nocopts = "(-fstack-protector|-fno-omit-frame-pointer)",
|
||||
)
|
||||
|
||||
|
||||
cc_library(
|
||||
name = "ctc_decoder_with_kenlm",
|
||||
srcs = ["beam_search.cc",
|
||||
"alphabet.h",
|
||||
"trie_node.h"] +
|
||||
glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
|
||||
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]),
|
||||
includes = ["kenlm"],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
deps = ["//tensorflow/core:core",
|
||||
"//tensorflow/core/util/ctc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "generate_trie",
|
||||
srcs = [
|
||||
"generate_trie.cpp",
|
||||
"trie_node.h",
|
||||
"alphabet.h",
|
||||
] + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc",
|
||||
"kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"],
|
||||
exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]),
|
||||
includes = ["kenlm"],
|
||||
copts = ['-std=c++11'],
|
||||
linkopts = ['-lm'],
|
||||
defines = ["KENLM_MAX_ORDER=6"],
|
||||
)
|
||||
|
||||
@ -54,6 +54,11 @@ public:
|
||||
return size_;
|
||||
}
|
||||
|
||||
bool IsSpace(unsigned int label) const {
|
||||
const std::string& str = StringFromLabel(label);
|
||||
return str.size() == 1 && str[0] == ' ';
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||
|
||||
546
native_client/beam_search.cc
Normal file
546
native_client/beam_search.cc
Normal file
@ -0,0 +1,546 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This test illustrates how to make use of the CTCBeamSearchDecoder using a
|
||||
// custom BeamScorer and BeamState based on a dictionary with a few artificial
|
||||
// words.
|
||||
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
|
||||
#include "kenlm/lm/model.hh"
|
||||
|
||||
#include "alphabet.h"
|
||||
#include "trie_node.h"
|
||||
|
||||
namespace tf = tensorflow;
|
||||
using tf::shape_inference::DimensionHandle;
|
||||
using tf::shape_inference::InferenceContext;
|
||||
using tf::shape_inference::ShapeHandle;
|
||||
|
||||
REGISTER_OP("CTCBeamSearchDecoderWithLM")
|
||||
.Input("inputs: float")
|
||||
.Input("sequence_length: int32")
|
||||
.Attr("model_path: string")
|
||||
.Attr("trie_path: string")
|
||||
.Attr("alphabet_path: string")
|
||||
.Attr("beam_width: int >= 1 = 100")
|
||||
.Attr("top_paths: int >= 1 = 1")
|
||||
.Attr("merge_repeated: bool = true")
|
||||
.Output("decoded_indices: top_paths * int64")
|
||||
.Output("decoded_values: top_paths * int64")
|
||||
.Output("decoded_shape: top_paths * int64")
|
||||
.Output("log_probability: float")
|
||||
.SetShapeFn([](InferenceContext *c) {
|
||||
ShapeHandle inputs;
|
||||
ShapeHandle sequence_length;
|
||||
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length));
|
||||
|
||||
// Get batch size from inputs and sequence_length.
|
||||
DimensionHandle batch_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size));
|
||||
|
||||
tf::int32 top_paths;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths));
|
||||
|
||||
// Outputs.
|
||||
int out_idx = 0;
|
||||
for (int i = 0; i < top_paths; ++i) { // decoded_indices
|
||||
c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2));
|
||||
}
|
||||
for (int i = 0; i < top_paths; ++i) { // decoded_values
|
||||
c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim));
|
||||
}
|
||||
ShapeHandle shape_v = c->Vector(2);
|
||||
for (int i = 0; i < top_paths; ++i) { // decoded_shape
|
||||
c->set_output(out_idx++, shape_v);
|
||||
}
|
||||
c->set_output(out_idx++, c->Matrix(batch_size, top_paths));
|
||||
return tf::Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Performs beam search decoding on the logits given in input.
|
||||
|
||||
A note about the attribute merge_repeated: For the beam search decoder,
|
||||
this means that if consecutive entries in a beam are the same, only
|
||||
the first of these is emitted. That is, when the top path is "A B B B B",
|
||||
"A B" is returned if merge_repeated = True but "A B B B B" is
|
||||
returned if merge_repeated = False.
|
||||
|
||||
inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits.
|
||||
sequence_length: A vector containing sequence lengths, size `(batch)`.
|
||||
beam_width: A scalar >= 0 (beam search beam width).
|
||||
top_paths: A scalar >= 0, <= beam_width (controls output size).
|
||||
merge_repeated: If true, merge repeated classes in output.
|
||||
decoded_indices: A list (length: top_paths) of indices matrices. Matrix j,
|
||||
size `(total_decoded_outputs[j] x 2)`, has indices of a
|
||||
`SparseTensor<int64, 2>`. The rows store: [batch, time].
|
||||
decoded_values: A list (length: top_paths) of values vectors. Vector j,
|
||||
size `(length total_decoded_outputs[j])`, has the values of a
|
||||
`SparseTensor<int64, 2>`. The vector stores the decoded classes for beam j.
|
||||
decoded_shape: A list (length: top_paths) of shape vector. Vector j,
|
||||
size `(2)`, stores the shape of the decoded `SparseTensor[j]`.
|
||||
Its values are: `[batch_size, max_decoded_length[j]]`.
|
||||
log_probability: A matrix, shaped: `(batch_size x top_paths)`. The
|
||||
sequence log-probabilities.
|
||||
)doc");
|
||||
|
||||
struct KenLMBeamState {
|
||||
float language_model_score;
|
||||
float score;
|
||||
float delta_score;
|
||||
std::string incomplete_word;
|
||||
TrieNode *incomplete_word_trie_node;
|
||||
lm::ngram::ProbingModel::State model_state;
|
||||
};
|
||||
|
||||
class KenLMBeamScorer : public tf::ctc::BaseBeamScorer<KenLMBeamState> {
|
||||
public:
|
||||
typedef lm::ngram::ProbingModel Model;
|
||||
|
||||
KenLMBeamScorer(const std::string &kenlm_path, const std::string &trie_path, const std::string &alphabet_path)
|
||||
: lm_weight(1.0f)
|
||||
, word_count_weight(-0.1f)
|
||||
, valid_word_count_weight(1.0f)
|
||||
{
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::POPULATE_OR_READ;
|
||||
model = new Model(kenlm_path.c_str(), config);
|
||||
|
||||
alphabet = new Alphabet(alphabet_path.c_str());
|
||||
|
||||
std::ifstream in;
|
||||
in.open(trie_path, std::ios::in);
|
||||
TrieNode::ReadFromStream(in, trieRoot, alphabet->GetSize());
|
||||
in.close();
|
||||
}
|
||||
|
||||
virtual ~KenLMBeamScorer() {
|
||||
delete model;
|
||||
delete trieRoot;
|
||||
}
|
||||
|
||||
// State initialization.
|
||||
void InitializeState(KenLMBeamState* root) const {
|
||||
root->language_model_score = 0.0f;
|
||||
root->score = 0.0f;
|
||||
root->delta_score = 0.0f;
|
||||
root->incomplete_word.clear();
|
||||
root->incomplete_word_trie_node = trieRoot;
|
||||
root->model_state = model->BeginSentenceState();
|
||||
}
|
||||
// ExpandState is called when expanding a beam to one of its children.
|
||||
// Called at most once per child beam. In the simplest case, no state
|
||||
// expansion is done.
|
||||
void ExpandState(const KenLMBeamState& from_state, int from_label,
|
||||
KenLMBeamState* to_state, int to_label) const {
|
||||
CopyState(from_state, to_state);
|
||||
|
||||
if (!alphabet->IsSpace(to_label)) {
|
||||
to_state->incomplete_word += alphabet->StringFromLabel(to_label);
|
||||
TrieNode *trie_node = from_state.incomplete_word_trie_node;
|
||||
|
||||
// TODO replace with OOV unigram prob?
|
||||
// If we have no valid prefix we assume a very low log probability
|
||||
float min_unigram_score = -10.0f;
|
||||
// If prefix does exist
|
||||
if (trie_node != nullptr) {
|
||||
trie_node = trie_node->GetChildAt(to_label);
|
||||
to_state->incomplete_word_trie_node = trie_node;
|
||||
|
||||
if (trie_node != nullptr) {
|
||||
min_unigram_score = trie_node->GetMinUnigramScore();
|
||||
}
|
||||
}
|
||||
// TODO try two options
|
||||
// 1) unigram score added up to language model scare
|
||||
// 2) langugage model score of (preceding_words + unigram_word)
|
||||
to_state->score = min_unigram_score + to_state->language_model_score;
|
||||
to_state->delta_score = to_state->score - from_state.score;
|
||||
} else {
|
||||
float lm_score_delta = ScoreIncompleteWord(from_state.model_state,
|
||||
to_state->incomplete_word,
|
||||
to_state->model_state);
|
||||
// Give fixed word bonus
|
||||
if (!IsOOV(to_state->incomplete_word)) {
|
||||
to_state->language_model_score += valid_word_count_weight;
|
||||
}
|
||||
to_state->language_model_score += word_count_weight;
|
||||
UpdateWithLMScore(to_state, lm_score_delta);
|
||||
ResetIncompleteWord(to_state);
|
||||
}
|
||||
}
|
||||
// ExpandStateEnd is called after decoding has finished. Its purpose is to
|
||||
// allow a final scoring of the beam in its current state, before resorting
|
||||
// and retrieving the TopN requested candidates. Called at most once per beam.
|
||||
void ExpandStateEnd(KenLMBeamState* state) const {
|
||||
float lm_score_delta = 0.0f;
|
||||
Model::State out;
|
||||
if (state->incomplete_word.size() > 0) {
|
||||
lm_score_delta += ScoreIncompleteWord(state->model_state,
|
||||
state->incomplete_word,
|
||||
out);
|
||||
ResetIncompleteWord(state);
|
||||
state->model_state = out;
|
||||
}
|
||||
lm_score_delta += model->FullScore(state->model_state,
|
||||
model->GetVocabulary().EndSentence(),
|
||||
out).prob;
|
||||
UpdateWithLMScore(state, lm_score_delta);
|
||||
}
|
||||
// GetStateExpansionScore should be an inexpensive method to retrieve the
|
||||
// (cached) expansion score computed within ExpandState. The score is
|
||||
// multiplied (log-addition) with the input score at the current step from
|
||||
// the network.
|
||||
//
|
||||
// The score returned should be a log-probability. In the simplest case, as
|
||||
// there's no state expansion logic, the expansion score is zero.
|
||||
float GetStateExpansionScore(const KenLMBeamState& state,
|
||||
float previous_score) const {
|
||||
return lm_weight * state.delta_score + previous_score;
|
||||
}
|
||||
// GetStateEndExpansionScore should be an inexpensive method to retrieve the
|
||||
// (cached) expansion score computed within ExpandStateEnd. The score is
|
||||
// multiplied (log-addition) with the final probability of the beam.
|
||||
//
|
||||
// The score returned should be a log-probability.
|
||||
float GetStateEndExpansionScore(const KenLMBeamState& state) const {
|
||||
return lm_weight * state.delta_score;
|
||||
}
|
||||
|
||||
void SetLMWeight(float lm_weight) {
|
||||
this->lm_weight = lm_weight;
|
||||
}
|
||||
|
||||
void SetWordCountWeight(float word_count_weight) {
|
||||
this->word_count_weight = word_count_weight;
|
||||
}
|
||||
|
||||
void SetValidWordCountWeight(float valid_word_count_weight) {
|
||||
this->valid_word_count_weight = valid_word_count_weight;
|
||||
}
|
||||
|
||||
private:
|
||||
Model *model;
|
||||
Alphabet *alphabet;
|
||||
TrieNode *trieRoot;
|
||||
float lm_weight;
|
||||
float word_count_weight;
|
||||
float valid_word_count_weight;
|
||||
|
||||
void UpdateWithLMScore(KenLMBeamState *state, float lm_score_delta) const {
|
||||
float previous_score = state->score;
|
||||
state->language_model_score += lm_score_delta;
|
||||
state->score = state->language_model_score;
|
||||
state->delta_score = state->language_model_score - previous_score;
|
||||
}
|
||||
|
||||
void ResetIncompleteWord(KenLMBeamState *state) const {
|
||||
state->incomplete_word.clear();
|
||||
state->incomplete_word_trie_node = trieRoot;
|
||||
}
|
||||
|
||||
bool IsOOV(const std::string& word) const {
|
||||
auto &vocabulary = model->GetVocabulary();
|
||||
return vocabulary.Index(word) == vocabulary.NotFound();
|
||||
}
|
||||
|
||||
float ScoreIncompleteWord(const Model::State& model_state,
|
||||
const std::string& word,
|
||||
Model::State& out) const {
|
||||
lm::FullScoreReturn full_score_return;
|
||||
lm::WordIndex vocab = model->GetVocabulary().Index(word);
|
||||
full_score_return = model->FullScore(model_state, vocab, out);
|
||||
return full_score_return.prob;
|
||||
}
|
||||
|
||||
void CopyState(const KenLMBeamState& from, KenLMBeamState* to) const {
|
||||
to->language_model_score = from.language_model_score;
|
||||
to->score = from.score;
|
||||
to->delta_score = from.delta_score;
|
||||
to->incomplete_word = from.incomplete_word;
|
||||
to->incomplete_word_trie_node = from.incomplete_word_trie_node;
|
||||
to->model_state = from.model_state;
|
||||
}
|
||||
};
|
||||
|
||||
class CTCDecodeHelper {
|
||||
public:
|
||||
CTCDecodeHelper() : top_paths_(1) {}
|
||||
|
||||
inline int GetTopPaths() const { return top_paths_; }
|
||||
void SetTopPaths(int tp) { top_paths_ = tp; }
|
||||
|
||||
tf::Status ValidateInputsGenerateOutputs(
|
||||
tf::OpKernelContext *ctx, const tf::Tensor **inputs, const tf::Tensor **seq_len,
|
||||
std::string *model_path, std::string *trie_path, std::string *alphabet_path,
|
||||
tf::Tensor **log_prob, tf::OpOutputList *decoded_indices,
|
||||
tf::OpOutputList *decoded_values, tf::OpOutputList *decoded_shape) const {
|
||||
tf::Status status = ctx->input("inputs", inputs);
|
||||
if (!status.ok()) return status;
|
||||
status = ctx->input("sequence_length", seq_len);
|
||||
if (!status.ok()) return status;
|
||||
|
||||
const tf::TensorShape &inputs_shape = (*inputs)->shape();
|
||||
|
||||
if (inputs_shape.dims() != 3) {
|
||||
return tf::errors::InvalidArgument("inputs is not a 3-Tensor");
|
||||
}
|
||||
|
||||
const tf::int64 max_time = inputs_shape.dim_size(0);
|
||||
const tf::int64 batch_size = inputs_shape.dim_size(1);
|
||||
|
||||
if (max_time == 0) {
|
||||
return tf::errors::InvalidArgument("max_time is 0");
|
||||
}
|
||||
if (!tf::TensorShapeUtils::IsVector((*seq_len)->shape())) {
|
||||
return tf::errors::InvalidArgument("sequence_length is not a vector");
|
||||
}
|
||||
|
||||
if (!(batch_size == (*seq_len)->dim_size(0))) {
|
||||
return tf::errors::FailedPrecondition(
|
||||
"len(sequence_length) != batch_size. ", "len(sequence_length): ",
|
||||
(*seq_len)->dim_size(0), " batch_size: ", batch_size);
|
||||
}
|
||||
|
||||
auto seq_len_t = (*seq_len)->vec<tf::int32>();
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
if (!(seq_len_t(b) <= max_time)) {
|
||||
return tf::errors::FailedPrecondition("sequence_length(", b, ") <= ",
|
||||
max_time);
|
||||
}
|
||||
}
|
||||
|
||||
tf::Status s = ctx->allocate_output(
|
||||
"log_probability", tf::TensorShape({batch_size, top_paths_}), log_prob);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
s = ctx->output_list("decoded_indices", decoded_indices);
|
||||
if (!s.ok()) return s;
|
||||
s = ctx->output_list("decoded_values", decoded_values);
|
||||
if (!s.ok()) return s;
|
||||
s = ctx->output_list("decoded_shape", decoded_shape);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
return tf::Status::OK();
|
||||
}
|
||||
|
||||
// sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
|
||||
tf::Status StoreAllDecodedSequences(
|
||||
const std::vector<std::vector<std::vector<int>>> &sequences,
|
||||
tf::OpOutputList *decoded_indices, tf::OpOutputList *decoded_values,
|
||||
tf::OpOutputList *decoded_shape) const {
|
||||
// Calculate the total number of entries for each path
|
||||
const tf::int64 batch_size = sequences.size();
|
||||
std::vector<tf::int64> num_entries(top_paths_, 0);
|
||||
|
||||
// Calculate num_entries per path
|
||||
for (const auto &batch_s : sequences) {
|
||||
CHECK_EQ(batch_s.size(), top_paths_);
|
||||
for (int p = 0; p < top_paths_; ++p) {
|
||||
num_entries[p] += batch_s[p].size();
|
||||
}
|
||||
}
|
||||
|
||||
for (int p = 0; p < top_paths_; ++p) {
|
||||
tf::Tensor *p_indices = nullptr;
|
||||
tf::Tensor *p_values = nullptr;
|
||||
tf::Tensor *p_shape = nullptr;
|
||||
|
||||
const tf::int64 p_num = num_entries[p];
|
||||
|
||||
tf::Status s =
|
||||
decoded_indices->allocate(p, tf::TensorShape({p_num, 2}), &p_indices);
|
||||
if (!s.ok()) return s;
|
||||
s = decoded_values->allocate(p, tf::TensorShape({p_num}), &p_values);
|
||||
if (!s.ok()) return s;
|
||||
s = decoded_shape->allocate(p, tf::TensorShape({2}), &p_shape);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
auto indices_t = p_indices->matrix<tf::int64>();
|
||||
auto values_t = p_values->vec<tf::int64>();
|
||||
auto shape_t = p_shape->vec<tf::int64>();
|
||||
|
||||
tf::int64 max_decoded = 0;
|
||||
tf::int64 offset = 0;
|
||||
|
||||
for (tf::int64 b = 0; b < batch_size; ++b) {
|
||||
auto &p_batch = sequences[b][p];
|
||||
tf::int64 num_decoded = p_batch.size();
|
||||
max_decoded = std::max(max_decoded, num_decoded);
|
||||
std::copy_n(p_batch.begin(), num_decoded, &values_t(offset));
|
||||
for (tf::int64 t = 0; t < num_decoded; ++t, ++offset) {
|
||||
indices_t(offset, 0) = b;
|
||||
indices_t(offset, 1) = t;
|
||||
}
|
||||
}
|
||||
|
||||
shape_t(0) = batch_size;
|
||||
shape_t(1) = max_decoded;
|
||||
}
|
||||
return tf::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
int top_paths_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper);
|
||||
};
|
||||
|
||||
// CTC beam search
|
||||
class CTCBeamSearchDecoderOp : public tf::OpKernel {
|
||||
public:
|
||||
explicit CTCBeamSearchDecoderOp(tf::OpKernelConstruction *ctx)
|
||||
: tf::OpKernel(ctx)
|
||||
, beam_scorer_(GetModelPath(ctx),
|
||||
GetTriePath(ctx),
|
||||
GetAlphabetPath(ctx))
|
||||
{
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_));
|
||||
int top_paths;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths));
|
||||
decode_helper_.SetTopPaths(top_paths);
|
||||
|
||||
// const tf::Tensor* model_tensor;
|
||||
// tf::Status status = ctx->input("model_path", &model_tensor);
|
||||
// if (!status.ok()) return status;
|
||||
// auto model_vec = model_tensor->flat<std::string>();
|
||||
// *model_path = model_vec(0);
|
||||
|
||||
// const tf::Tensor* trie_tensor;
|
||||
// status = ctx->input("trie_path", &trie_tensor);
|
||||
// if (!status.ok()) return status;
|
||||
// auto trie_vec = trie_tensor->flat<std::string>();
|
||||
// *trie_path = model_vec(0);
|
||||
|
||||
// const tf::Tensor* alphabet_tensor;
|
||||
// status = ctx->input("alphabet_path", &alphabet_tensor);
|
||||
// if (!status.ok()) return status;
|
||||
// auto alphabet_vec = alphabet_tensor->flat<std::string>();
|
||||
// *alphabet_path = alphabet_vec(0);
|
||||
}
|
||||
|
||||
std::string GetModelPath(tf::OpKernelConstruction *ctx) {
|
||||
std::string model_path;
|
||||
ctx->GetAttr("model_path", &model_path);
|
||||
return model_path;
|
||||
}
|
||||
|
||||
std::string GetTriePath(tf::OpKernelConstruction *ctx) {
|
||||
std::string trie_path;
|
||||
ctx->GetAttr("trie_path", &trie_path);
|
||||
return trie_path;
|
||||
}
|
||||
|
||||
std::string GetAlphabetPath(tf::OpKernelConstruction *ctx) {
|
||||
std::string alphabet_path;
|
||||
ctx->GetAttr("alphabet_path", &alphabet_path);
|
||||
return alphabet_path;
|
||||
}
|
||||
|
||||
void Compute(tf::OpKernelContext *ctx) override {
|
||||
const tf::Tensor *inputs;
|
||||
const tf::Tensor *seq_len;
|
||||
std::string model_path;
|
||||
std::string trie_path;
|
||||
std::string alphabet_path;
|
||||
tf::Tensor *log_prob = nullptr;
|
||||
tf::OpOutputList decoded_indices;
|
||||
tf::OpOutputList decoded_values;
|
||||
tf::OpOutputList decoded_shape;
|
||||
OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs(
|
||||
ctx, &inputs, &seq_len, &model_path, &trie_path,
|
||||
&alphabet_path, &log_prob, &decoded_indices,
|
||||
&decoded_values, &decoded_shape));
|
||||
|
||||
auto inputs_t = inputs->tensor<float, 3>();
|
||||
auto seq_len_t = seq_len->vec<tf::int32>();
|
||||
auto log_prob_t = log_prob->matrix<float>();
|
||||
|
||||
const tf::TensorShape &inputs_shape = inputs->shape();
|
||||
|
||||
const tf::int64 max_time = inputs_shape.dim_size(0);
|
||||
const tf::int64 batch_size = inputs_shape.dim_size(1);
|
||||
const tf::int64 num_classes_raw = inputs_shape.dim_size(2);
|
||||
OP_REQUIRES(
|
||||
ctx, tf::FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
|
||||
tf::errors::InvalidArgument("num_classes cannot exceed max int"));
|
||||
const int num_classes = static_cast<const int>(num_classes_raw);
|
||||
|
||||
log_prob_t.setZero();
|
||||
|
||||
std::vector<tf::TTypes<float>::UnalignedConstMatrix> input_list_t;
|
||||
|
||||
for (std::size_t t = 0; t < max_time; ++t) {
|
||||
input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes,
|
||||
batch_size, num_classes);
|
||||
}
|
||||
|
||||
tf::ctc::CTCBeamSearchDecoder<KenLMBeamState> beam_search(num_classes, beam_width_,
|
||||
&beam_scorer_, 1 /* batch_size */,
|
||||
merge_repeated_);
|
||||
tf::Tensor input_chip(tf::DT_FLOAT, tf::TensorShape({num_classes}));
|
||||
auto input_chip_t = input_chip.flat<float>();
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> best_paths(batch_size);
|
||||
std::vector<float> log_probs;
|
||||
|
||||
// Assumption: the blank index is num_classes - 1
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
auto &best_paths_b = best_paths[b];
|
||||
best_paths_b.resize(decode_helper_.GetTopPaths());
|
||||
for (int t = 0; t < seq_len_t(b); ++t) {
|
||||
input_chip_t = input_list_t[t].chip(b, 0);
|
||||
auto input_bi =
|
||||
Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes);
|
||||
beam_search.Step(input_bi);
|
||||
}
|
||||
OP_REQUIRES_OK(
|
||||
ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b,
|
||||
&log_probs, merge_repeated_));
|
||||
|
||||
beam_search.Reset();
|
||||
|
||||
for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) {
|
||||
log_prob_t(b, bp) = log_probs[bp];
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences(
|
||||
best_paths, &decoded_indices, &decoded_values,
|
||||
&decoded_shape));
|
||||
}
|
||||
|
||||
private:
|
||||
CTCDecodeHelper decode_helper_;
|
||||
KenLMBeamScorer beam_scorer_;
|
||||
bool merge_repeated_;
|
||||
int beam_width_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp);
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithLM").Device(tf::DEVICE_CPU),
|
||||
CTCBeamSearchDecoderOp);
|
||||
66
native_client/generate_trie.cpp
Normal file
66
native_client/generate_trie.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
using namespace std;
|
||||
|
||||
#include "lm/model.hh"
|
||||
#include "trie_node.h"
|
||||
#include "alphabet.h"
|
||||
|
||||
typedef lm::ngram::ProbingModel Model;
|
||||
|
||||
lm::WordIndex GetWordIndex(const Model& model, const std::string& word) {
|
||||
lm::WordIndex vocab;
|
||||
vocab = model.GetVocabulary().Index(word);
|
||||
return vocab;
|
||||
}
|
||||
|
||||
float ScoreWord(const Model& model, lm::WordIndex vocab) {
|
||||
Model::State in_state = model.NullContextState();
|
||||
Model::State out;
|
||||
lm::FullScoreReturn full_score_return;
|
||||
full_score_return = model.FullScore(in_state, vocab, out);
|
||||
return full_score_return.prob;
|
||||
}
|
||||
|
||||
int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* vocab_path, const char* trie_path) {
|
||||
Alphabet a(alphabet_path);
|
||||
|
||||
lm::ngram::Config config;
|
||||
config.load_method = util::POPULATE_OR_READ;
|
||||
Model model(kenlm_path, config);
|
||||
TrieNode root(a.GetSize());
|
||||
|
||||
std::ifstream ifs;
|
||||
ifs.open(vocab_path, std::ifstream::in);
|
||||
|
||||
if (!ifs.is_open()) {
|
||||
std::cout << "unable to open vocabulary" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::ofstream ofs;
|
||||
ofs.open(trie_path);
|
||||
|
||||
std::string word;
|
||||
while (ifs >> word) {
|
||||
for_each(word.begin(), word.end(), [](char& a) { a = tolower(a); });
|
||||
lm::WordIndex vocab = GetWordIndex(model, word);
|
||||
float unigram_score = ScoreWord(model, vocab);
|
||||
root.Insert(word.c_str(), [&a](char c) {
|
||||
return a.LabelFromString(string(1, c));
|
||||
}, vocab, unigram_score);
|
||||
}
|
||||
|
||||
root.WriteToStream(ofs);
|
||||
ifs.close();
|
||||
ofs.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
return generate_trie("/Users/remorais/Development/DeepSpeech/data/alphabet.txt",
|
||||
"/Users/remorais/Development/DeepSpeech/data/lm/lm.binary",
|
||||
"/Users/remorais/Development/DeepSpeech/data/lm/vocab.txt",
|
||||
"/Users/remorais/Development/DeepSpeech/data/lm/trie");
|
||||
}
|
||||
125
native_client/trie_node.h
Normal file
125
native_client/trie_node.h
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TRIE_NODE_H
|
||||
#define TRIE_NODE_H
|
||||
|
||||
#include "lm/model.hh"
|
||||
|
||||
#include <functional>
|
||||
#include <istream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
class TrieNode {
|
||||
public:
|
||||
TrieNode(int vocab_size)
|
||||
: vocab_size(vocab_size)
|
||||
, prefixCount(0)
|
||||
, min_score_word(0)
|
||||
, min_unigram_score(std::numeric_limits<float>::max())
|
||||
{
|
||||
children = new TrieNode*[vocab_size]();
|
||||
}
|
||||
|
||||
~TrieNode() {
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
delete children[i];
|
||||
}
|
||||
delete children;
|
||||
}
|
||||
|
||||
void WriteToStream(std::ostream& os) {
|
||||
WriteNode(os);
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
if (children[i] == nullptr) {
|
||||
os << -1 << std::endl;
|
||||
} else {
|
||||
// Recursive call
|
||||
children[i]->WriteToStream(os);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ReadFromStream(std::istream& is, TrieNode* &obj, int vocab_size) {
|
||||
int prefixCount;
|
||||
is >> prefixCount;
|
||||
|
||||
if (prefixCount == -1) {
|
||||
// This is an undefined child
|
||||
obj = nullptr;
|
||||
return;
|
||||
}
|
||||
|
||||
obj = new TrieNode(vocab_size);
|
||||
obj->ReadNode(is, prefixCount);
|
||||
for (int i = 0; i < vocab_size; i++) {
|
||||
// Recursive call
|
||||
ReadFromStream(is, obj->children[i], vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
void Insert(const char* word, std::function<int (char)> translator,
|
||||
lm::WordIndex lm_word, float unigram_score) {
|
||||
char wordCharacter = *word;
|
||||
prefixCount++;
|
||||
if (unigram_score < min_unigram_score) {
|
||||
min_unigram_score = unigram_score;
|
||||
min_score_word = lm_word;
|
||||
}
|
||||
if (wordCharacter != '\0') {
|
||||
int vocabIndex = translator(wordCharacter);
|
||||
TrieNode *child = children[vocabIndex];
|
||||
if (child == nullptr)
|
||||
child = children[vocabIndex] = new TrieNode(vocab_size);
|
||||
child->Insert(word + 1, translator, lm_word, unigram_score);
|
||||
}
|
||||
}
|
||||
|
||||
int GetFrequency() {
|
||||
return prefixCount;
|
||||
}
|
||||
|
||||
lm::WordIndex GetMinScoreWordIndex() {
|
||||
return min_score_word;
|
||||
}
|
||||
|
||||
float GetMinUnigramScore() {
|
||||
return min_unigram_score;
|
||||
}
|
||||
|
||||
TrieNode *GetChildAt(int vocabIndex) {
|
||||
return children[vocabIndex];
|
||||
}
|
||||
|
||||
private:
|
||||
int vocab_size;
|
||||
int prefixCount;
|
||||
lm::WordIndex min_score_word;
|
||||
float min_unigram_score;
|
||||
TrieNode **children;
|
||||
|
||||
void WriteNode(std::ostream& os) const {
|
||||
os << prefixCount << std::endl;
|
||||
os << min_score_word << std::endl;
|
||||
os << min_unigram_score << std::endl;
|
||||
}
|
||||
|
||||
void ReadNode(std::istream& is, int first_input) {
|
||||
prefixCount = first_input;
|
||||
is >> min_score_word;
|
||||
is >> min_unigram_score;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
#endif //TRIE_NODE_H
|
||||
Loading…
x
Reference in New Issue
Block a user