From fc91e3d7b8a61b7b18e0abee4b21289396a0b723 Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Fri, 1 Sep 2017 11:45:35 +0200 Subject: [PATCH] Write a CTC beam search decoder TF op that scores beams with our LM --- native_client/BUILD | 32 ++ native_client/alphabet.h | 5 + native_client/beam_search.cc | 546 ++++++++++++++++++++++++++++++++ native_client/generate_trie.cpp | 66 ++++ native_client/trie_node.h | 125 ++++++++ 5 files changed, 774 insertions(+) create mode 100644 native_client/beam_search.cc create mode 100644 native_client/generate_trie.cpp create mode 100644 native_client/trie_node.h diff --git a/native_client/BUILD b/native_client/BUILD index 7bb40bf7..9988cef9 100644 --- a/native_client/BUILD +++ b/native_client/BUILD @@ -34,3 +34,35 @@ cc_library( copts = [] + if_linux_x86_64(["-mno-fma", "-mno-avx", "-mno-avx2"]), nocopts = "(-fstack-protector|-fno-omit-frame-pointer)", ) + + +cc_library( + name = "ctc_decoder_with_kenlm", + srcs = ["beam_search.cc", + "alphabet.h", + "trie_node.h"] + + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", + "kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], + exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), + includes = ["kenlm"], + defines = ["KENLM_MAX_ORDER=6"], + deps = ["//tensorflow/core:core", + "//tensorflow/core/util/ctc", + "//third_party/eigen3", + ], +) + +cc_binary( + name = "generate_trie", + srcs = [ + "generate_trie.cpp", + "trie_node.h", + "alphabet.h", + ] + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", + "kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], + exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), + includes = ["kenlm"], + copts = ['-std=c++11'], + linkopts = ['-lm'], + defines = ["KENLM_MAX_ORDER=6"], +) diff --git a/native_client/alphabet.h b/native_client/alphabet.h index 75115e9c..4534e437 100644 --- a/native_client/alphabet.h +++ b/native_client/alphabet.h @@ -54,6 +54,11 @@ public: return size_; } + bool IsSpace(unsigned int label) const { + const std::string& str = StringFromLabel(label); + return str.size() == 1 && str[0] == ' '; + } + private: size_t size_; std::unordered_map label_to_str_; diff --git a/native_client/beam_search.cc b/native_client/beam_search.cc new file mode 100644 index 00000000..9e5c930d --- /dev/null +++ b/native_client/beam_search.cc @@ -0,0 +1,546 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test illustrates how to make use of the CTCBeamSearchDecoder using a +// custom BeamScorer and BeamState based on a dictionary with a few artificial +// words. +#include "tensorflow/core/util/ctc/ctc_beam_search.h" + +#include +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/kernels/bounds_check.h" + +#include "kenlm/lm/model.hh" + +#include "alphabet.h" +#include "trie_node.h" + +namespace tf = tensorflow; +using tf::shape_inference::DimensionHandle; +using tf::shape_inference::InferenceContext; +using tf::shape_inference::ShapeHandle; + +REGISTER_OP("CTCBeamSearchDecoderWithLM") + .Input("inputs: float") + .Input("sequence_length: int32") + .Attr("model_path: string") + .Attr("trie_path: string") + .Attr("alphabet_path: string") + .Attr("beam_width: int >= 1 = 100") + .Attr("top_paths: int >= 1 = 1") + .Attr("merge_repeated: bool = true") + .Output("decoded_indices: top_paths * int64") + .Output("decoded_values: top_paths * int64") + .Output("decoded_shape: top_paths * int64") + .Output("log_probability: float") + .SetShapeFn([](InferenceContext *c) { + ShapeHandle inputs; + ShapeHandle sequence_length; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); + + // Get batch size from inputs and sequence_length. + DimensionHandle batch_size; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); + + tf::int32 top_paths; + TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths)); + + // Outputs. + int out_idx = 0; + for (int i = 0; i < top_paths; ++i) { // decoded_indices + c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2)); + } + for (int i = 0; i < top_paths; ++i) { // decoded_values + c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim)); + } + ShapeHandle shape_v = c->Vector(2); + for (int i = 0; i < top_paths; ++i) { // decoded_shape + c->set_output(out_idx++, shape_v); + } + c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); + return tf::Status::OK(); + }) + .Doc(R"doc( +Performs beam search decoding on the logits given in input. + +A note about the attribute merge_repeated: For the beam search decoder, +this means that if consecutive entries in a beam are the same, only +the first of these is emitted. That is, when the top path is "A B B B B", +"A B" is returned if merge_repeated = True but "A B B B B" is +returned if merge_repeated = False. + +inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +sequence_length: A vector containing sequence lengths, size `(batch)`. +beam_width: A scalar >= 0 (beam search beam width). +top_paths: A scalar >= 0, <= beam_width (controls output size). +merge_repeated: If true, merge repeated classes in output. +decoded_indices: A list (length: top_paths) of indices matrices. Matrix j, + size `(total_decoded_outputs[j] x 2)`, has indices of a + `SparseTensor`. The rows store: [batch, time]. +decoded_values: A list (length: top_paths) of values vectors. Vector j, + size `(length total_decoded_outputs[j])`, has the values of a + `SparseTensor`. The vector stores the decoded classes for beam j. +decoded_shape: A list (length: top_paths) of shape vector. Vector j, + size `(2)`, stores the shape of the decoded `SparseTensor[j]`. + Its values are: `[batch_size, max_decoded_length[j]]`. +log_probability: A matrix, shaped: `(batch_size x top_paths)`. The + sequence log-probabilities. +)doc"); + +struct KenLMBeamState { + float language_model_score; + float score; + float delta_score; + std::string incomplete_word; + TrieNode *incomplete_word_trie_node; + lm::ngram::ProbingModel::State model_state; +}; + +class KenLMBeamScorer : public tf::ctc::BaseBeamScorer { + public: + typedef lm::ngram::ProbingModel Model; + + KenLMBeamScorer(const std::string &kenlm_path, const std::string &trie_path, const std::string &alphabet_path) + : lm_weight(1.0f) + , word_count_weight(-0.1f) + , valid_word_count_weight(1.0f) + { + lm::ngram::Config config; + config.load_method = util::POPULATE_OR_READ; + model = new Model(kenlm_path.c_str(), config); + + alphabet = new Alphabet(alphabet_path.c_str()); + + std::ifstream in; + in.open(trie_path, std::ios::in); + TrieNode::ReadFromStream(in, trieRoot, alphabet->GetSize()); + in.close(); + } + + virtual ~KenLMBeamScorer() { + delete model; + delete trieRoot; + } + + // State initialization. + void InitializeState(KenLMBeamState* root) const { + root->language_model_score = 0.0f; + root->score = 0.0f; + root->delta_score = 0.0f; + root->incomplete_word.clear(); + root->incomplete_word_trie_node = trieRoot; + root->model_state = model->BeginSentenceState(); + } + // ExpandState is called when expanding a beam to one of its children. + // Called at most once per child beam. In the simplest case, no state + // expansion is done. + void ExpandState(const KenLMBeamState& from_state, int from_label, + KenLMBeamState* to_state, int to_label) const { + CopyState(from_state, to_state); + + if (!alphabet->IsSpace(to_label)) { + to_state->incomplete_word += alphabet->StringFromLabel(to_label); + TrieNode *trie_node = from_state.incomplete_word_trie_node; + + // TODO replace with OOV unigram prob? + // If we have no valid prefix we assume a very low log probability + float min_unigram_score = -10.0f; + // If prefix does exist + if (trie_node != nullptr) { + trie_node = trie_node->GetChildAt(to_label); + to_state->incomplete_word_trie_node = trie_node; + + if (trie_node != nullptr) { + min_unigram_score = trie_node->GetMinUnigramScore(); + } + } + // TODO try two options + // 1) unigram score added up to language model scare + // 2) langugage model score of (preceding_words + unigram_word) + to_state->score = min_unigram_score + to_state->language_model_score; + to_state->delta_score = to_state->score - from_state.score; + } else { + float lm_score_delta = ScoreIncompleteWord(from_state.model_state, + to_state->incomplete_word, + to_state->model_state); + // Give fixed word bonus + if (!IsOOV(to_state->incomplete_word)) { + to_state->language_model_score += valid_word_count_weight; + } + to_state->language_model_score += word_count_weight; + UpdateWithLMScore(to_state, lm_score_delta); + ResetIncompleteWord(to_state); + } + } + // ExpandStateEnd is called after decoding has finished. Its purpose is to + // allow a final scoring of the beam in its current state, before resorting + // and retrieving the TopN requested candidates. Called at most once per beam. + void ExpandStateEnd(KenLMBeamState* state) const { + float lm_score_delta = 0.0f; + Model::State out; + if (state->incomplete_word.size() > 0) { + lm_score_delta += ScoreIncompleteWord(state->model_state, + state->incomplete_word, + out); + ResetIncompleteWord(state); + state->model_state = out; + } + lm_score_delta += model->FullScore(state->model_state, + model->GetVocabulary().EndSentence(), + out).prob; + UpdateWithLMScore(state, lm_score_delta); + } + // GetStateExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandState. The score is + // multiplied (log-addition) with the input score at the current step from + // the network. + // + // The score returned should be a log-probability. In the simplest case, as + // there's no state expansion logic, the expansion score is zero. + float GetStateExpansionScore(const KenLMBeamState& state, + float previous_score) const { + return lm_weight * state.delta_score + previous_score; + } + // GetStateEndExpansionScore should be an inexpensive method to retrieve the + // (cached) expansion score computed within ExpandStateEnd. The score is + // multiplied (log-addition) with the final probability of the beam. + // + // The score returned should be a log-probability. + float GetStateEndExpansionScore(const KenLMBeamState& state) const { + return lm_weight * state.delta_score; + } + + void SetLMWeight(float lm_weight) { + this->lm_weight = lm_weight; + } + + void SetWordCountWeight(float word_count_weight) { + this->word_count_weight = word_count_weight; + } + + void SetValidWordCountWeight(float valid_word_count_weight) { + this->valid_word_count_weight = valid_word_count_weight; + } + + private: + Model *model; + Alphabet *alphabet; + TrieNode *trieRoot; + float lm_weight; + float word_count_weight; + float valid_word_count_weight; + + void UpdateWithLMScore(KenLMBeamState *state, float lm_score_delta) const { + float previous_score = state->score; + state->language_model_score += lm_score_delta; + state->score = state->language_model_score; + state->delta_score = state->language_model_score - previous_score; + } + + void ResetIncompleteWord(KenLMBeamState *state) const { + state->incomplete_word.clear(); + state->incomplete_word_trie_node = trieRoot; + } + + bool IsOOV(const std::string& word) const { + auto &vocabulary = model->GetVocabulary(); + return vocabulary.Index(word) == vocabulary.NotFound(); + } + + float ScoreIncompleteWord(const Model::State& model_state, + const std::string& word, + Model::State& out) const { + lm::FullScoreReturn full_score_return; + lm::WordIndex vocab = model->GetVocabulary().Index(word); + full_score_return = model->FullScore(model_state, vocab, out); + return full_score_return.prob; + } + + void CopyState(const KenLMBeamState& from, KenLMBeamState* to) const { + to->language_model_score = from.language_model_score; + to->score = from.score; + to->delta_score = from.delta_score; + to->incomplete_word = from.incomplete_word; + to->incomplete_word_trie_node = from.incomplete_word_trie_node; + to->model_state = from.model_state; + } +}; + +class CTCDecodeHelper { + public: + CTCDecodeHelper() : top_paths_(1) {} + + inline int GetTopPaths() const { return top_paths_; } + void SetTopPaths(int tp) { top_paths_ = tp; } + + tf::Status ValidateInputsGenerateOutputs( + tf::OpKernelContext *ctx, const tf::Tensor **inputs, const tf::Tensor **seq_len, + std::string *model_path, std::string *trie_path, std::string *alphabet_path, + tf::Tensor **log_prob, tf::OpOutputList *decoded_indices, + tf::OpOutputList *decoded_values, tf::OpOutputList *decoded_shape) const { + tf::Status status = ctx->input("inputs", inputs); + if (!status.ok()) return status; + status = ctx->input("sequence_length", seq_len); + if (!status.ok()) return status; + + const tf::TensorShape &inputs_shape = (*inputs)->shape(); + + if (inputs_shape.dims() != 3) { + return tf::errors::InvalidArgument("inputs is not a 3-Tensor"); + } + + const tf::int64 max_time = inputs_shape.dim_size(0); + const tf::int64 batch_size = inputs_shape.dim_size(1); + + if (max_time == 0) { + return tf::errors::InvalidArgument("max_time is 0"); + } + if (!tf::TensorShapeUtils::IsVector((*seq_len)->shape())) { + return tf::errors::InvalidArgument("sequence_length is not a vector"); + } + + if (!(batch_size == (*seq_len)->dim_size(0))) { + return tf::errors::FailedPrecondition( + "len(sequence_length) != batch_size. ", "len(sequence_length): ", + (*seq_len)->dim_size(0), " batch_size: ", batch_size); + } + + auto seq_len_t = (*seq_len)->vec(); + + for (int b = 0; b < batch_size; ++b) { + if (!(seq_len_t(b) <= max_time)) { + return tf::errors::FailedPrecondition("sequence_length(", b, ") <= ", + max_time); + } + } + + tf::Status s = ctx->allocate_output( + "log_probability", tf::TensorShape({batch_size, top_paths_}), log_prob); + if (!s.ok()) return s; + + s = ctx->output_list("decoded_indices", decoded_indices); + if (!s.ok()) return s; + s = ctx->output_list("decoded_values", decoded_values); + if (!s.ok()) return s; + s = ctx->output_list("decoded_shape", decoded_shape); + if (!s.ok()) return s; + + return tf::Status::OK(); + } + + // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". + tf::Status StoreAllDecodedSequences( + const std::vector>> &sequences, + tf::OpOutputList *decoded_indices, tf::OpOutputList *decoded_values, + tf::OpOutputList *decoded_shape) const { + // Calculate the total number of entries for each path + const tf::int64 batch_size = sequences.size(); + std::vector num_entries(top_paths_, 0); + + // Calculate num_entries per path + for (const auto &batch_s : sequences) { + CHECK_EQ(batch_s.size(), top_paths_); + for (int p = 0; p < top_paths_; ++p) { + num_entries[p] += batch_s[p].size(); + } + } + + for (int p = 0; p < top_paths_; ++p) { + tf::Tensor *p_indices = nullptr; + tf::Tensor *p_values = nullptr; + tf::Tensor *p_shape = nullptr; + + const tf::int64 p_num = num_entries[p]; + + tf::Status s = + decoded_indices->allocate(p, tf::TensorShape({p_num, 2}), &p_indices); + if (!s.ok()) return s; + s = decoded_values->allocate(p, tf::TensorShape({p_num}), &p_values); + if (!s.ok()) return s; + s = decoded_shape->allocate(p, tf::TensorShape({2}), &p_shape); + if (!s.ok()) return s; + + auto indices_t = p_indices->matrix(); + auto values_t = p_values->vec(); + auto shape_t = p_shape->vec(); + + tf::int64 max_decoded = 0; + tf::int64 offset = 0; + + for (tf::int64 b = 0; b < batch_size; ++b) { + auto &p_batch = sequences[b][p]; + tf::int64 num_decoded = p_batch.size(); + max_decoded = std::max(max_decoded, num_decoded); + std::copy_n(p_batch.begin(), num_decoded, &values_t(offset)); + for (tf::int64 t = 0; t < num_decoded; ++t, ++offset) { + indices_t(offset, 0) = b; + indices_t(offset, 1) = t; + } + } + + shape_t(0) = batch_size; + shape_t(1) = max_decoded; + } + return tf::Status::OK(); + } + + private: + int top_paths_; + TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper); +}; + +// CTC beam search +class CTCBeamSearchDecoderOp : public tf::OpKernel { + public: + explicit CTCBeamSearchDecoderOp(tf::OpKernelConstruction *ctx) + : tf::OpKernel(ctx) + , beam_scorer_(GetModelPath(ctx), + GetTriePath(ctx), + GetAlphabetPath(ctx)) + { + OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_)); + int top_paths; + OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); + decode_helper_.SetTopPaths(top_paths); + + // const tf::Tensor* model_tensor; + // tf::Status status = ctx->input("model_path", &model_tensor); + // if (!status.ok()) return status; + // auto model_vec = model_tensor->flat(); + // *model_path = model_vec(0); + + // const tf::Tensor* trie_tensor; + // status = ctx->input("trie_path", &trie_tensor); + // if (!status.ok()) return status; + // auto trie_vec = trie_tensor->flat(); + // *trie_path = model_vec(0); + + // const tf::Tensor* alphabet_tensor; + // status = ctx->input("alphabet_path", &alphabet_tensor); + // if (!status.ok()) return status; + // auto alphabet_vec = alphabet_tensor->flat(); + // *alphabet_path = alphabet_vec(0); + } + + std::string GetModelPath(tf::OpKernelConstruction *ctx) { + std::string model_path; + ctx->GetAttr("model_path", &model_path); + return model_path; + } + + std::string GetTriePath(tf::OpKernelConstruction *ctx) { + std::string trie_path; + ctx->GetAttr("trie_path", &trie_path); + return trie_path; + } + + std::string GetAlphabetPath(tf::OpKernelConstruction *ctx) { + std::string alphabet_path; + ctx->GetAttr("alphabet_path", &alphabet_path); + return alphabet_path; + } + + void Compute(tf::OpKernelContext *ctx) override { + const tf::Tensor *inputs; + const tf::Tensor *seq_len; + std::string model_path; + std::string trie_path; + std::string alphabet_path; + tf::Tensor *log_prob = nullptr; + tf::OpOutputList decoded_indices; + tf::OpOutputList decoded_values; + tf::OpOutputList decoded_shape; + OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( + ctx, &inputs, &seq_len, &model_path, &trie_path, + &alphabet_path, &log_prob, &decoded_indices, + &decoded_values, &decoded_shape)); + + auto inputs_t = inputs->tensor(); + auto seq_len_t = seq_len->vec(); + auto log_prob_t = log_prob->matrix(); + + const tf::TensorShape &inputs_shape = inputs->shape(); + + const tf::int64 max_time = inputs_shape.dim_size(0); + const tf::int64 batch_size = inputs_shape.dim_size(1); + const tf::int64 num_classes_raw = inputs_shape.dim_size(2); + OP_REQUIRES( + ctx, tf::FastBoundsCheck(num_classes_raw, std::numeric_limits::max()), + tf::errors::InvalidArgument("num_classes cannot exceed max int")); + const int num_classes = static_cast(num_classes_raw); + + log_prob_t.setZero(); + + std::vector::UnalignedConstMatrix> input_list_t; + + for (std::size_t t = 0; t < max_time; ++t) { + input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, + batch_size, num_classes); + } + + tf::ctc::CTCBeamSearchDecoder beam_search(num_classes, beam_width_, + &beam_scorer_, 1 /* batch_size */, + merge_repeated_); + tf::Tensor input_chip(tf::DT_FLOAT, tf::TensorShape({num_classes})); + auto input_chip_t = input_chip.flat(); + + std::vector>> best_paths(batch_size); + std::vector log_probs; + + // Assumption: the blank index is num_classes - 1 + for (int b = 0; b < batch_size; ++b) { + auto &best_paths_b = best_paths[b]; + best_paths_b.resize(decode_helper_.GetTopPaths()); + for (int t = 0; t < seq_len_t(b); ++t) { + input_chip_t = input_list_t[t].chip(b, 0); + auto input_bi = + Eigen::Map(input_chip_t.data(), num_classes); + beam_search.Step(input_bi); + } + OP_REQUIRES_OK( + ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b, + &log_probs, merge_repeated_)); + + beam_search.Reset(); + + for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) { + log_prob_t(b, bp) = log_probs[bp]; + } + } + + OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences( + best_paths, &decoded_indices, &decoded_values, + &decoded_shape)); + } + + private: + CTCDecodeHelper decode_helper_; + KenLMBeamScorer beam_scorer_; + bool merge_repeated_; + int beam_width_; + TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp); +}; + +REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithLM").Device(tf::DEVICE_CPU), + CTCBeamSearchDecoderOp); diff --git a/native_client/generate_trie.cpp b/native_client/generate_trie.cpp new file mode 100644 index 00000000..ce3b9e65 --- /dev/null +++ b/native_client/generate_trie.cpp @@ -0,0 +1,66 @@ +#include +#include +#include +using namespace std; + +#include "lm/model.hh" +#include "trie_node.h" +#include "alphabet.h" + +typedef lm::ngram::ProbingModel Model; + +lm::WordIndex GetWordIndex(const Model& model, const std::string& word) { + lm::WordIndex vocab; + vocab = model.GetVocabulary().Index(word); + return vocab; +} + +float ScoreWord(const Model& model, lm::WordIndex vocab) { + Model::State in_state = model.NullContextState(); + Model::State out; + lm::FullScoreReturn full_score_return; + full_score_return = model.FullScore(in_state, vocab, out); + return full_score_return.prob; +} + +int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* vocab_path, const char* trie_path) { + Alphabet a(alphabet_path); + + lm::ngram::Config config; + config.load_method = util::POPULATE_OR_READ; + Model model(kenlm_path, config); + TrieNode root(a.GetSize()); + + std::ifstream ifs; + ifs.open(vocab_path, std::ifstream::in); + + if (!ifs.is_open()) { + std::cout << "unable to open vocabulary" << std::endl; + return -1; + } + + std::ofstream ofs; + ofs.open(trie_path); + + std::string word; + while (ifs >> word) { + for_each(word.begin(), word.end(), [](char& a) { a = tolower(a); }); + lm::WordIndex vocab = GetWordIndex(model, word); + float unigram_score = ScoreWord(model, vocab); + root.Insert(word.c_str(), [&a](char c) { + return a.LabelFromString(string(1, c)); + }, vocab, unigram_score); + } + + root.WriteToStream(ofs); + ifs.close(); + ofs.close(); + return 0; +} + +int main(void) { + return generate_trie("/Users/remorais/Development/DeepSpeech/data/alphabet.txt", + "/Users/remorais/Development/DeepSpeech/data/lm/lm.binary", + "/Users/remorais/Development/DeepSpeech/data/lm/vocab.txt", + "/Users/remorais/Development/DeepSpeech/data/lm/trie"); +} diff --git a/native_client/trie_node.h b/native_client/trie_node.h new file mode 100644 index 00000000..988a0f2e --- /dev/null +++ b/native_client/trie_node.h @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TRIE_NODE_H +#define TRIE_NODE_H + +#include "lm/model.hh" + +#include +#include +#include +#include + +class TrieNode { +public: + TrieNode(int vocab_size) + : vocab_size(vocab_size) + , prefixCount(0) + , min_score_word(0) + , min_unigram_score(std::numeric_limits::max()) + { + children = new TrieNode*[vocab_size](); + } + + ~TrieNode() { + for (int i = 0; i < vocab_size; i++) { + delete children[i]; + } + delete children; + } + + void WriteToStream(std::ostream& os) { + WriteNode(os); + for (int i = 0; i < vocab_size; i++) { + if (children[i] == nullptr) { + os << -1 << std::endl; + } else { + // Recursive call + children[i]->WriteToStream(os); + } + } + } + + static void ReadFromStream(std::istream& is, TrieNode* &obj, int vocab_size) { + int prefixCount; + is >> prefixCount; + + if (prefixCount == -1) { + // This is an undefined child + obj = nullptr; + return; + } + + obj = new TrieNode(vocab_size); + obj->ReadNode(is, prefixCount); + for (int i = 0; i < vocab_size; i++) { + // Recursive call + ReadFromStream(is, obj->children[i], vocab_size); + } + } + + void Insert(const char* word, std::function translator, + lm::WordIndex lm_word, float unigram_score) { + char wordCharacter = *word; + prefixCount++; + if (unigram_score < min_unigram_score) { + min_unigram_score = unigram_score; + min_score_word = lm_word; + } + if (wordCharacter != '\0') { + int vocabIndex = translator(wordCharacter); + TrieNode *child = children[vocabIndex]; + if (child == nullptr) + child = children[vocabIndex] = new TrieNode(vocab_size); + child->Insert(word + 1, translator, lm_word, unigram_score); + } + } + + int GetFrequency() { + return prefixCount; + } + + lm::WordIndex GetMinScoreWordIndex() { + return min_score_word; + } + + float GetMinUnigramScore() { + return min_unigram_score; + } + + TrieNode *GetChildAt(int vocabIndex) { + return children[vocabIndex]; + } + +private: + int vocab_size; + int prefixCount; + lm::WordIndex min_score_word; + float min_unigram_score; + TrieNode **children; + + void WriteNode(std::ostream& os) const { + os << prefixCount << std::endl; + os << min_score_word << std::endl; + os << min_unigram_score << std::endl; + } + + void ReadNode(std::istream& is, int first_input) { + prefixCount = first_input; + is >> min_score_word; + is >> min_unigram_score; + } + +}; + +#endif //TRIE_NODE_H