Write a CTC beam search decoder TF op that scores beams with our LM
This commit is contained in:
		
							parent
							
								
									af71da0d4d
								
							
						
					
					
						commit
						fc91e3d7b8
					
				| @ -34,3 +34,35 @@ cc_library( | ||||
|     copts = [] + if_linux_x86_64(["-mno-fma", "-mno-avx", "-mno-avx2"]), | ||||
|     nocopts = "(-fstack-protector|-fno-omit-frame-pointer)", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "ctc_decoder_with_kenlm", | ||||
|     srcs = ["beam_search.cc", | ||||
|             "alphabet.h", | ||||
|             "trie_node.h"] + | ||||
|            glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", | ||||
|                  "kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], | ||||
|                 exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), | ||||
|     includes = ["kenlm"], | ||||
|     defines = ["KENLM_MAX_ORDER=6"], | ||||
|     deps = ["//tensorflow/core:core", | ||||
|             "//tensorflow/core/util/ctc", | ||||
|             "//third_party/eigen3", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_binary( | ||||
|     name = "generate_trie", | ||||
|     srcs = [ | ||||
|         "generate_trie.cpp", | ||||
|         "trie_node.h", | ||||
|         "alphabet.h", | ||||
|     ] + glob(["kenlm/lm/*.cc", "kenlm/util/*.cc", "kenlm/util/double-conversion/*.cc", | ||||
|                  "kenlm/lm/*.hh", "kenlm/util/*.hh", "kenlm/util/double-conversion/*.h"], | ||||
|                 exclude = ["kenlm/*/*test.cc", "kenlm/*/*main.cc"]), | ||||
|     includes = ["kenlm"], | ||||
|     copts = ['-std=c++11'], | ||||
|     linkopts = ['-lm'], | ||||
|     defines = ["KENLM_MAX_ORDER=6"], | ||||
| ) | ||||
|  | ||||
| @ -54,6 +54,11 @@ public: | ||||
|     return size_; | ||||
|   } | ||||
| 
 | ||||
|   bool IsSpace(unsigned int label) const { | ||||
|     const std::string& str = StringFromLabel(label); | ||||
|     return str.size() == 1 && str[0] == ' '; | ||||
|   } | ||||
| 
 | ||||
| private: | ||||
|   size_t size_; | ||||
|   std::unordered_map<unsigned int, std::string> label_to_str_; | ||||
|  | ||||
							
								
								
									
										546
									
								
								native_client/beam_search.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										546
									
								
								native_client/beam_search.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,546 @@ | ||||
| /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // This test illustrates how to make use of the CTCBeamSearchDecoder using a
 | ||||
| // custom BeamScorer and BeamState based on a dictionary with a few artificial
 | ||||
| // words.
 | ||||
| #include "tensorflow/core/util/ctc/ctc_beam_search.h" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cmath> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "tensorflow/core/framework/op.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/shape_inference.h" | ||||
| #include "tensorflow/core/kernels/bounds_check.h" | ||||
| 
 | ||||
| #include "kenlm/lm/model.hh" | ||||
| 
 | ||||
| #include "alphabet.h" | ||||
| #include "trie_node.h" | ||||
| 
 | ||||
| namespace tf = tensorflow; | ||||
| using tf::shape_inference::DimensionHandle; | ||||
| using tf::shape_inference::InferenceContext; | ||||
| using tf::shape_inference::ShapeHandle; | ||||
| 
 | ||||
| REGISTER_OP("CTCBeamSearchDecoderWithLM") | ||||
|     .Input("inputs: float") | ||||
|     .Input("sequence_length: int32") | ||||
|     .Attr("model_path: string") | ||||
|     .Attr("trie_path: string") | ||||
|     .Attr("alphabet_path: string") | ||||
|     .Attr("beam_width: int >= 1 = 100") | ||||
|     .Attr("top_paths: int >= 1 = 1") | ||||
|     .Attr("merge_repeated: bool = true") | ||||
|     .Output("decoded_indices: top_paths * int64") | ||||
|     .Output("decoded_values: top_paths * int64") | ||||
|     .Output("decoded_shape: top_paths * int64") | ||||
|     .Output("log_probability: float") | ||||
|     .SetShapeFn([](InferenceContext *c) { | ||||
|       ShapeHandle inputs; | ||||
|       ShapeHandle sequence_length; | ||||
| 
 | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &inputs)); | ||||
|       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); | ||||
| 
 | ||||
|       // Get batch size from inputs and sequence_length.
 | ||||
|       DimensionHandle batch_size; | ||||
|       TF_RETURN_IF_ERROR( | ||||
|           c->Merge(c->Dim(inputs, 1), c->Dim(sequence_length, 0), &batch_size)); | ||||
| 
 | ||||
|       tf::int32 top_paths; | ||||
|       TF_RETURN_IF_ERROR(c->GetAttr("top_paths", &top_paths)); | ||||
| 
 | ||||
|       // Outputs.
 | ||||
|       int out_idx = 0; | ||||
|       for (int i = 0; i < top_paths; ++i) {  // decoded_indices
 | ||||
|         c->set_output(out_idx++, c->Matrix(InferenceContext::kUnknownDim, 2)); | ||||
|       } | ||||
|       for (int i = 0; i < top_paths; ++i) {  // decoded_values
 | ||||
|         c->set_output(out_idx++, c->Vector(InferenceContext::kUnknownDim)); | ||||
|       } | ||||
|       ShapeHandle shape_v = c->Vector(2); | ||||
|       for (int i = 0; i < top_paths; ++i) {  // decoded_shape
 | ||||
|         c->set_output(out_idx++, shape_v); | ||||
|       } | ||||
|       c->set_output(out_idx++, c->Matrix(batch_size, top_paths)); | ||||
|       return tf::Status::OK(); | ||||
|     }) | ||||
|     .Doc(R"doc( | ||||
| Performs beam search decoding on the logits given in input. | ||||
| 
 | ||||
| A note about the attribute merge_repeated: For the beam search decoder, | ||||
| this means that if consecutive entries in a beam are the same, only | ||||
| the first of these is emitted.  That is, when the top path is "A B B B B", | ||||
| "A B" is returned if merge_repeated = True but "A B B B B" is | ||||
| returned if merge_repeated = False. | ||||
| 
 | ||||
| inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. | ||||
| sequence_length: A vector containing sequence lengths, size `(batch)`. | ||||
| beam_width: A scalar >= 0 (beam search beam width). | ||||
| top_paths: A scalar >= 0, <= beam_width (controls output size). | ||||
| merge_repeated: If true, merge repeated classes in output. | ||||
| decoded_indices: A list (length: top_paths) of indices matrices.  Matrix j, | ||||
|   size `(total_decoded_outputs[j] x 2)`, has indices of a | ||||
|   `SparseTensor<int64, 2>`.  The rows store: [batch, time]. | ||||
| decoded_values: A list (length: top_paths) of values vectors.  Vector j, | ||||
|   size `(length total_decoded_outputs[j])`, has the values of a | ||||
|   `SparseTensor<int64, 2>`.  The vector stores the decoded classes for beam j. | ||||
| decoded_shape: A list (length: top_paths) of shape vector.  Vector j, | ||||
|   size `(2)`, stores the shape of the decoded `SparseTensor[j]`. | ||||
|   Its values are: `[batch_size, max_decoded_length[j]]`. | ||||
| log_probability: A matrix, shaped: `(batch_size x top_paths)`.  The | ||||
|   sequence log-probabilities. | ||||
| )doc"); | ||||
| 
 | ||||
| struct KenLMBeamState { | ||||
|   float language_model_score; | ||||
|   float score; | ||||
|   float delta_score; | ||||
|   std::string incomplete_word; | ||||
|   TrieNode *incomplete_word_trie_node; | ||||
|   lm::ngram::ProbingModel::State model_state; | ||||
| }; | ||||
| 
 | ||||
| class KenLMBeamScorer : public tf::ctc::BaseBeamScorer<KenLMBeamState> { | ||||
|  public: | ||||
|   typedef lm::ngram::ProbingModel Model; | ||||
| 
 | ||||
|   KenLMBeamScorer(const std::string &kenlm_path, const std::string &trie_path, const std::string &alphabet_path) | ||||
|     : lm_weight(1.0f) | ||||
|     , word_count_weight(-0.1f) | ||||
|     , valid_word_count_weight(1.0f) | ||||
|   { | ||||
|     lm::ngram::Config config; | ||||
|     config.load_method = util::POPULATE_OR_READ; | ||||
|     model = new Model(kenlm_path.c_str(), config); | ||||
| 
 | ||||
|     alphabet = new Alphabet(alphabet_path.c_str()); | ||||
| 
 | ||||
|     std::ifstream in; | ||||
|     in.open(trie_path, std::ios::in); | ||||
|     TrieNode::ReadFromStream(in, trieRoot, alphabet->GetSize()); | ||||
|     in.close(); | ||||
|   } | ||||
| 
 | ||||
|   virtual ~KenLMBeamScorer() { | ||||
|     delete model; | ||||
|     delete trieRoot; | ||||
|   } | ||||
| 
 | ||||
|   // State initialization.
 | ||||
|   void InitializeState(KenLMBeamState* root) const { | ||||
|     root->language_model_score = 0.0f; | ||||
|     root->score = 0.0f; | ||||
|     root->delta_score = 0.0f; | ||||
|     root->incomplete_word.clear(); | ||||
|     root->incomplete_word_trie_node = trieRoot; | ||||
|     root->model_state = model->BeginSentenceState(); | ||||
|   } | ||||
|   // ExpandState is called when expanding a beam to one of its children.
 | ||||
|   // Called at most once per child beam. In the simplest case, no state
 | ||||
|   // expansion is done.
 | ||||
|   void ExpandState(const KenLMBeamState& from_state, int from_label, | ||||
|                          KenLMBeamState* to_state, int to_label) const { | ||||
|     CopyState(from_state, to_state); | ||||
| 
 | ||||
|     if (!alphabet->IsSpace(to_label)) { | ||||
|       to_state->incomplete_word += alphabet->StringFromLabel(to_label); | ||||
|       TrieNode *trie_node = from_state.incomplete_word_trie_node; | ||||
| 
 | ||||
|       // TODO replace with OOV unigram prob?
 | ||||
|       // If we have no valid prefix we assume a very low log probability
 | ||||
|       float min_unigram_score = -10.0f; | ||||
|       // If prefix does exist
 | ||||
|       if (trie_node != nullptr) { | ||||
|         trie_node = trie_node->GetChildAt(to_label); | ||||
|         to_state->incomplete_word_trie_node = trie_node; | ||||
| 
 | ||||
|         if (trie_node != nullptr) { | ||||
|           min_unigram_score = trie_node->GetMinUnigramScore(); | ||||
|         } | ||||
|       } | ||||
|       // TODO try two options
 | ||||
|       // 1) unigram score added up to language model scare
 | ||||
|       // 2) langugage model score of (preceding_words + unigram_word)
 | ||||
|       to_state->score = min_unigram_score + to_state->language_model_score; | ||||
|       to_state->delta_score = to_state->score - from_state.score; | ||||
|     } else { | ||||
|       float lm_score_delta = ScoreIncompleteWord(from_state.model_state, | ||||
|                             to_state->incomplete_word, | ||||
|                             to_state->model_state); | ||||
|       // Give fixed word bonus
 | ||||
|       if (!IsOOV(to_state->incomplete_word)) { | ||||
|         to_state->language_model_score += valid_word_count_weight; | ||||
|       } | ||||
|       to_state->language_model_score += word_count_weight; | ||||
|       UpdateWithLMScore(to_state, lm_score_delta); | ||||
|       ResetIncompleteWord(to_state); | ||||
|     } | ||||
|   } | ||||
|   // ExpandStateEnd is called after decoding has finished. Its purpose is to
 | ||||
|   // allow a final scoring of the beam in its current state, before resorting
 | ||||
|   // and retrieving the TopN requested candidates. Called at most once per beam.
 | ||||
|   void ExpandStateEnd(KenLMBeamState* state) const { | ||||
|     float lm_score_delta = 0.0f; | ||||
|     Model::State out; | ||||
|     if (state->incomplete_word.size() > 0) { | ||||
|       lm_score_delta += ScoreIncompleteWord(state->model_state, | ||||
|                                             state->incomplete_word, | ||||
|                                             out); | ||||
|       ResetIncompleteWord(state); | ||||
|       state->model_state = out; | ||||
|     } | ||||
|     lm_score_delta += model->FullScore(state->model_state, | ||||
|                                       model->GetVocabulary().EndSentence(), | ||||
|                                       out).prob; | ||||
|     UpdateWithLMScore(state, lm_score_delta); | ||||
|   } | ||||
|   // GetStateExpansionScore should be an inexpensive method to retrieve the
 | ||||
|   // (cached) expansion score computed within ExpandState. The score is
 | ||||
|   // multiplied (log-addition) with the input score at the current step from
 | ||||
|   // the network.
 | ||||
|   //
 | ||||
|   // The score returned should be a log-probability. In the simplest case, as
 | ||||
|   // there's no state expansion logic, the expansion score is zero.
 | ||||
|   float GetStateExpansionScore(const KenLMBeamState& state, | ||||
|                                float previous_score) const { | ||||
|     return lm_weight * state.delta_score + previous_score; | ||||
|   } | ||||
|   // GetStateEndExpansionScore should be an inexpensive method to retrieve the
 | ||||
|   // (cached) expansion score computed within ExpandStateEnd. The score is
 | ||||
|   // multiplied (log-addition) with the final probability of the beam.
 | ||||
|   //
 | ||||
|   // The score returned should be a log-probability.
 | ||||
|   float GetStateEndExpansionScore(const KenLMBeamState& state) const { | ||||
|     return lm_weight * state.delta_score; | ||||
|   } | ||||
| 
 | ||||
|   void SetLMWeight(float lm_weight) { | ||||
|     this->lm_weight = lm_weight; | ||||
|   } | ||||
| 
 | ||||
|   void SetWordCountWeight(float word_count_weight) { | ||||
|     this->word_count_weight = word_count_weight; | ||||
|   } | ||||
| 
 | ||||
|   void SetValidWordCountWeight(float valid_word_count_weight) { | ||||
|     this->valid_word_count_weight = valid_word_count_weight; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   Model *model; | ||||
|   Alphabet *alphabet; | ||||
|   TrieNode *trieRoot; | ||||
|   float lm_weight; | ||||
|   float word_count_weight; | ||||
|   float valid_word_count_weight; | ||||
| 
 | ||||
|   void UpdateWithLMScore(KenLMBeamState *state, float lm_score_delta) const { | ||||
|     float previous_score = state->score; | ||||
|     state->language_model_score += lm_score_delta; | ||||
|     state->score = state->language_model_score; | ||||
|     state->delta_score = state->language_model_score - previous_score; | ||||
|   } | ||||
| 
 | ||||
|   void ResetIncompleteWord(KenLMBeamState *state) const { | ||||
|     state->incomplete_word.clear(); | ||||
|     state->incomplete_word_trie_node = trieRoot; | ||||
|   } | ||||
| 
 | ||||
|   bool IsOOV(const std::string& word) const { | ||||
|     auto &vocabulary = model->GetVocabulary(); | ||||
|     return vocabulary.Index(word) == vocabulary.NotFound(); | ||||
|   } | ||||
| 
 | ||||
|   float ScoreIncompleteWord(const Model::State& model_state, | ||||
|                             const std::string& word, | ||||
|                             Model::State& out) const { | ||||
|     lm::FullScoreReturn full_score_return; | ||||
|     lm::WordIndex vocab = model->GetVocabulary().Index(word); | ||||
|     full_score_return = model->FullScore(model_state, vocab, out); | ||||
|     return full_score_return.prob; | ||||
|   } | ||||
| 
 | ||||
|   void CopyState(const KenLMBeamState& from, KenLMBeamState* to) const { | ||||
|     to->language_model_score = from.language_model_score; | ||||
|     to->score = from.score; | ||||
|     to->delta_score = from.delta_score; | ||||
|     to->incomplete_word = from.incomplete_word; | ||||
|     to->incomplete_word_trie_node = from.incomplete_word_trie_node; | ||||
|     to->model_state = from.model_state; | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| class CTCDecodeHelper { | ||||
|  public: | ||||
|   CTCDecodeHelper() : top_paths_(1) {} | ||||
| 
 | ||||
|   inline int GetTopPaths() const { return top_paths_; } | ||||
|   void SetTopPaths(int tp) { top_paths_ = tp; } | ||||
| 
 | ||||
|   tf::Status ValidateInputsGenerateOutputs( | ||||
|       tf::OpKernelContext *ctx, const tf::Tensor **inputs, const tf::Tensor **seq_len, | ||||
|       std::string *model_path, std::string *trie_path, std::string *alphabet_path, | ||||
|       tf::Tensor **log_prob, tf::OpOutputList *decoded_indices, | ||||
|       tf::OpOutputList *decoded_values, tf::OpOutputList *decoded_shape) const { | ||||
|     tf::Status status = ctx->input("inputs", inputs); | ||||
|     if (!status.ok()) return status; | ||||
|     status = ctx->input("sequence_length", seq_len); | ||||
|     if (!status.ok()) return status; | ||||
| 
 | ||||
|     const tf::TensorShape &inputs_shape = (*inputs)->shape(); | ||||
| 
 | ||||
|     if (inputs_shape.dims() != 3) { | ||||
|       return tf::errors::InvalidArgument("inputs is not a 3-Tensor"); | ||||
|     } | ||||
| 
 | ||||
|     const tf::int64 max_time = inputs_shape.dim_size(0); | ||||
|     const tf::int64 batch_size = inputs_shape.dim_size(1); | ||||
| 
 | ||||
|     if (max_time == 0) { | ||||
|       return tf::errors::InvalidArgument("max_time is 0"); | ||||
|     } | ||||
|     if (!tf::TensorShapeUtils::IsVector((*seq_len)->shape())) { | ||||
|       return tf::errors::InvalidArgument("sequence_length is not a vector"); | ||||
|     } | ||||
| 
 | ||||
|     if (!(batch_size == (*seq_len)->dim_size(0))) { | ||||
|       return tf::errors::FailedPrecondition( | ||||
|           "len(sequence_length) != batch_size.  ", "len(sequence_length):  ", | ||||
|           (*seq_len)->dim_size(0), " batch_size: ", batch_size); | ||||
|     } | ||||
| 
 | ||||
|     auto seq_len_t = (*seq_len)->vec<tf::int32>(); | ||||
| 
 | ||||
|     for (int b = 0; b < batch_size; ++b) { | ||||
|       if (!(seq_len_t(b) <= max_time)) { | ||||
|         return tf::errors::FailedPrecondition("sequence_length(", b, ") <= ", | ||||
|                                           max_time); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     tf::Status s = ctx->allocate_output( | ||||
|         "log_probability", tf::TensorShape({batch_size, top_paths_}), log_prob); | ||||
|     if (!s.ok()) return s; | ||||
| 
 | ||||
|     s = ctx->output_list("decoded_indices", decoded_indices); | ||||
|     if (!s.ok()) return s; | ||||
|     s = ctx->output_list("decoded_values", decoded_values); | ||||
|     if (!s.ok()) return s; | ||||
|     s = ctx->output_list("decoded_shape", decoded_shape); | ||||
|     if (!s.ok()) return s; | ||||
| 
 | ||||
|     return tf::Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|   // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b".
 | ||||
|   tf::Status StoreAllDecodedSequences( | ||||
|       const std::vector<std::vector<std::vector<int>>> &sequences, | ||||
|       tf::OpOutputList *decoded_indices, tf::OpOutputList *decoded_values, | ||||
|       tf::OpOutputList *decoded_shape) const { | ||||
|     // Calculate the total number of entries for each path
 | ||||
|     const tf::int64 batch_size = sequences.size(); | ||||
|     std::vector<tf::int64> num_entries(top_paths_, 0); | ||||
| 
 | ||||
|     // Calculate num_entries per path
 | ||||
|     for (const auto &batch_s : sequences) { | ||||
|       CHECK_EQ(batch_s.size(), top_paths_); | ||||
|       for (int p = 0; p < top_paths_; ++p) { | ||||
|         num_entries[p] += batch_s[p].size(); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     for (int p = 0; p < top_paths_; ++p) { | ||||
|       tf::Tensor *p_indices = nullptr; | ||||
|       tf::Tensor *p_values = nullptr; | ||||
|       tf::Tensor *p_shape = nullptr; | ||||
| 
 | ||||
|       const tf::int64 p_num = num_entries[p]; | ||||
| 
 | ||||
|       tf::Status s = | ||||
|           decoded_indices->allocate(p, tf::TensorShape({p_num, 2}), &p_indices); | ||||
|       if (!s.ok()) return s; | ||||
|       s = decoded_values->allocate(p, tf::TensorShape({p_num}), &p_values); | ||||
|       if (!s.ok()) return s; | ||||
|       s = decoded_shape->allocate(p, tf::TensorShape({2}), &p_shape); | ||||
|       if (!s.ok()) return s; | ||||
| 
 | ||||
|       auto indices_t = p_indices->matrix<tf::int64>(); | ||||
|       auto values_t = p_values->vec<tf::int64>(); | ||||
|       auto shape_t = p_shape->vec<tf::int64>(); | ||||
| 
 | ||||
|       tf::int64 max_decoded = 0; | ||||
|       tf::int64 offset = 0; | ||||
| 
 | ||||
|       for (tf::int64 b = 0; b < batch_size; ++b) { | ||||
|         auto &p_batch = sequences[b][p]; | ||||
|         tf::int64 num_decoded = p_batch.size(); | ||||
|         max_decoded = std::max(max_decoded, num_decoded); | ||||
|         std::copy_n(p_batch.begin(), num_decoded, &values_t(offset)); | ||||
|         for (tf::int64 t = 0; t < num_decoded; ++t, ++offset) { | ||||
|           indices_t(offset, 0) = b; | ||||
|           indices_t(offset, 1) = t; | ||||
|         } | ||||
|       } | ||||
| 
 | ||||
|       shape_t(0) = batch_size; | ||||
|       shape_t(1) = max_decoded; | ||||
|     } | ||||
|     return tf::Status::OK(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   int top_paths_; | ||||
|   TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper); | ||||
| }; | ||||
| 
 | ||||
| // CTC beam search
 | ||||
| class CTCBeamSearchDecoderOp : public tf::OpKernel { | ||||
|  public: | ||||
|   explicit CTCBeamSearchDecoderOp(tf::OpKernelConstruction *ctx) | ||||
|     : tf::OpKernel(ctx) | ||||
|     , beam_scorer_(GetModelPath(ctx), | ||||
|                    GetTriePath(ctx), | ||||
|                    GetAlphabetPath(ctx)) | ||||
|   { | ||||
|     OP_REQUIRES_OK(ctx, ctx->GetAttr("merge_repeated", &merge_repeated_)); | ||||
|     OP_REQUIRES_OK(ctx, ctx->GetAttr("beam_width", &beam_width_)); | ||||
|     int top_paths; | ||||
|     OP_REQUIRES_OK(ctx, ctx->GetAttr("top_paths", &top_paths)); | ||||
|     decode_helper_.SetTopPaths(top_paths); | ||||
| 
 | ||||
|     // const tf::Tensor* model_tensor;
 | ||||
|     // tf::Status status = ctx->input("model_path", &model_tensor);
 | ||||
|     // if (!status.ok()) return status;
 | ||||
|     // auto model_vec = model_tensor->flat<std::string>();
 | ||||
|     // *model_path = model_vec(0);
 | ||||
| 
 | ||||
|     // const tf::Tensor* trie_tensor;
 | ||||
|     // status = ctx->input("trie_path", &trie_tensor);
 | ||||
|     // if (!status.ok()) return status;
 | ||||
|     // auto trie_vec = trie_tensor->flat<std::string>();
 | ||||
|     // *trie_path = model_vec(0);
 | ||||
| 
 | ||||
|     // const tf::Tensor* alphabet_tensor;
 | ||||
|     // status = ctx->input("alphabet_path", &alphabet_tensor);
 | ||||
|     // if (!status.ok()) return status;
 | ||||
|     // auto alphabet_vec = alphabet_tensor->flat<std::string>();
 | ||||
|     // *alphabet_path = alphabet_vec(0);
 | ||||
|   } | ||||
| 
 | ||||
|   std::string GetModelPath(tf::OpKernelConstruction *ctx) { | ||||
|     std::string model_path; | ||||
|     ctx->GetAttr("model_path", &model_path); | ||||
|     return model_path; | ||||
|   } | ||||
| 
 | ||||
|   std::string GetTriePath(tf::OpKernelConstruction *ctx) { | ||||
|     std::string trie_path; | ||||
|     ctx->GetAttr("trie_path", &trie_path); | ||||
|     return trie_path; | ||||
|   } | ||||
| 
 | ||||
|   std::string GetAlphabetPath(tf::OpKernelConstruction *ctx) { | ||||
|     std::string alphabet_path; | ||||
|     ctx->GetAttr("alphabet_path", &alphabet_path); | ||||
|     return alphabet_path; | ||||
|   } | ||||
| 
 | ||||
|   void Compute(tf::OpKernelContext *ctx) override { | ||||
|     const tf::Tensor *inputs; | ||||
|     const tf::Tensor *seq_len; | ||||
|     std::string model_path; | ||||
|     std::string trie_path; | ||||
|     std::string alphabet_path; | ||||
|     tf::Tensor *log_prob = nullptr; | ||||
|     tf::OpOutputList decoded_indices; | ||||
|     tf::OpOutputList decoded_values; | ||||
|     tf::OpOutputList decoded_shape; | ||||
|     OP_REQUIRES_OK(ctx, decode_helper_.ValidateInputsGenerateOutputs( | ||||
|                             ctx, &inputs, &seq_len, &model_path, &trie_path, | ||||
|                             &alphabet_path, &log_prob, &decoded_indices, | ||||
|                             &decoded_values, &decoded_shape)); | ||||
| 
 | ||||
|     auto inputs_t = inputs->tensor<float, 3>(); | ||||
|     auto seq_len_t = seq_len->vec<tf::int32>(); | ||||
|     auto log_prob_t = log_prob->matrix<float>(); | ||||
| 
 | ||||
|     const tf::TensorShape &inputs_shape = inputs->shape(); | ||||
| 
 | ||||
|     const tf::int64 max_time = inputs_shape.dim_size(0); | ||||
|     const tf::int64 batch_size = inputs_shape.dim_size(1); | ||||
|     const tf::int64 num_classes_raw = inputs_shape.dim_size(2); | ||||
|     OP_REQUIRES( | ||||
|         ctx, tf::FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()), | ||||
|         tf::errors::InvalidArgument("num_classes cannot exceed max int")); | ||||
|     const int num_classes = static_cast<const int>(num_classes_raw); | ||||
| 
 | ||||
|     log_prob_t.setZero(); | ||||
| 
 | ||||
|     std::vector<tf::TTypes<float>::UnalignedConstMatrix> input_list_t; | ||||
| 
 | ||||
|     for (std::size_t t = 0; t < max_time; ++t) { | ||||
|       input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, | ||||
|                                 batch_size, num_classes); | ||||
|     } | ||||
| 
 | ||||
|     tf::ctc::CTCBeamSearchDecoder<KenLMBeamState> beam_search(num_classes, beam_width_, | ||||
|                                             &beam_scorer_, 1 /* batch_size */, | ||||
|                                             merge_repeated_); | ||||
|     tf::Tensor input_chip(tf::DT_FLOAT, tf::TensorShape({num_classes})); | ||||
|     auto input_chip_t = input_chip.flat<float>(); | ||||
| 
 | ||||
|     std::vector<std::vector<std::vector<int>>> best_paths(batch_size); | ||||
|     std::vector<float> log_probs; | ||||
| 
 | ||||
|     // Assumption: the blank index is num_classes - 1
 | ||||
|     for (int b = 0; b < batch_size; ++b) { | ||||
|       auto &best_paths_b = best_paths[b]; | ||||
|       best_paths_b.resize(decode_helper_.GetTopPaths()); | ||||
|       for (int t = 0; t < seq_len_t(b); ++t) { | ||||
|         input_chip_t = input_list_t[t].chip(b, 0); | ||||
|         auto input_bi = | ||||
|             Eigen::Map<const Eigen::ArrayXf>(input_chip_t.data(), num_classes); | ||||
|         beam_search.Step(input_bi); | ||||
|       } | ||||
|       OP_REQUIRES_OK( | ||||
|           ctx, beam_search.TopPaths(decode_helper_.GetTopPaths(), &best_paths_b, | ||||
|                                     &log_probs, merge_repeated_)); | ||||
| 
 | ||||
|       beam_search.Reset(); | ||||
| 
 | ||||
|       for (int bp = 0; bp < decode_helper_.GetTopPaths(); ++bp) { | ||||
|         log_prob_t(b, bp) = log_probs[bp]; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     OP_REQUIRES_OK(ctx, decode_helper_.StoreAllDecodedSequences( | ||||
|                             best_paths, &decoded_indices, &decoded_values, | ||||
|                             &decoded_shape)); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   CTCDecodeHelper decode_helper_; | ||||
|   KenLMBeamScorer beam_scorer_; | ||||
|   bool merge_repeated_; | ||||
|   int beam_width_; | ||||
|   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp); | ||||
| }; | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoderWithLM").Device(tf::DEVICE_CPU), | ||||
|                         CTCBeamSearchDecoderOp); | ||||
							
								
								
									
										66
									
								
								native_client/generate_trie.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								native_client/generate_trie.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,66 @@ | ||||
| #include <algorithm> | ||||
| #include <iostream> | ||||
| #include <string> | ||||
| using namespace std; | ||||
| 
 | ||||
| #include "lm/model.hh" | ||||
| #include "trie_node.h" | ||||
| #include "alphabet.h" | ||||
| 
 | ||||
| typedef lm::ngram::ProbingModel Model; | ||||
| 
 | ||||
| lm::WordIndex GetWordIndex(const Model& model, const std::string& word) { | ||||
|   lm::WordIndex vocab; | ||||
|   vocab = model.GetVocabulary().Index(word); | ||||
|   return vocab; | ||||
| } | ||||
| 
 | ||||
| float ScoreWord(const Model& model, lm::WordIndex vocab) { | ||||
|   Model::State in_state = model.NullContextState(); | ||||
|   Model::State out; | ||||
|   lm::FullScoreReturn full_score_return; | ||||
|   full_score_return = model.FullScore(in_state, vocab, out); | ||||
|   return full_score_return.prob; | ||||
| } | ||||
| 
 | ||||
| int generate_trie(const char* alphabet_path, const char* kenlm_path, const char* vocab_path, const char* trie_path) { | ||||
|   Alphabet a(alphabet_path); | ||||
| 
 | ||||
|   lm::ngram::Config config; | ||||
|   config.load_method = util::POPULATE_OR_READ; | ||||
|   Model model(kenlm_path, config); | ||||
|   TrieNode root(a.GetSize()); | ||||
| 
 | ||||
|   std::ifstream ifs; | ||||
|   ifs.open(vocab_path, std::ifstream::in); | ||||
| 
 | ||||
|   if (!ifs.is_open()) { | ||||
|     std::cout << "unable to open vocabulary" << std::endl; | ||||
|     return -1; | ||||
|   } | ||||
| 
 | ||||
|   std::ofstream ofs; | ||||
|   ofs.open(trie_path); | ||||
| 
 | ||||
|   std::string word; | ||||
|   while (ifs >> word) { | ||||
|     for_each(word.begin(), word.end(), [](char& a) { a = tolower(a); }); | ||||
|     lm::WordIndex vocab = GetWordIndex(model, word); | ||||
|     float unigram_score = ScoreWord(model, vocab); | ||||
|     root.Insert(word.c_str(), [&a](char c) { | ||||
|                   return a.LabelFromString(string(1, c)); | ||||
|                 }, vocab, unigram_score); | ||||
|   } | ||||
| 
 | ||||
|   root.WriteToStream(ofs); | ||||
|   ifs.close(); | ||||
|   ofs.close(); | ||||
|   return 0; | ||||
| } | ||||
| 
 | ||||
| int main(void) { | ||||
|   return generate_trie("/Users/remorais/Development/DeepSpeech/data/alphabet.txt", | ||||
|                        "/Users/remorais/Development/DeepSpeech/data/lm/lm.binary", | ||||
|                        "/Users/remorais/Development/DeepSpeech/data/lm/vocab.txt", | ||||
|                        "/Users/remorais/Development/DeepSpeech/data/lm/trie"); | ||||
| } | ||||
							
								
								
									
										125
									
								
								native_client/trie_node.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								native_client/trie_node.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,125 @@ | ||||
| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef TRIE_NODE_H | ||||
| #define TRIE_NODE_H | ||||
| 
 | ||||
| #include "lm/model.hh" | ||||
| 
 | ||||
| #include <functional> | ||||
| #include <istream> | ||||
| #include <iostream> | ||||
| #include <limits> | ||||
| 
 | ||||
| class TrieNode { | ||||
| public: | ||||
|   TrieNode(int vocab_size) | ||||
|     : vocab_size(vocab_size) | ||||
|     , prefixCount(0) | ||||
|     , min_score_word(0) | ||||
|     , min_unigram_score(std::numeric_limits<float>::max()) | ||||
|     { | ||||
|       children = new TrieNode*[vocab_size](); | ||||
|     } | ||||
| 
 | ||||
|   ~TrieNode() { | ||||
|     for (int i = 0; i < vocab_size; i++) { | ||||
|       delete children[i]; | ||||
|     } | ||||
|     delete children; | ||||
|   } | ||||
| 
 | ||||
|   void WriteToStream(std::ostream& os) { | ||||
|     WriteNode(os); | ||||
|     for (int i = 0; i < vocab_size; i++) { | ||||
|       if (children[i] == nullptr) { | ||||
|         os << -1 << std::endl; | ||||
|       } else { | ||||
|         // Recursive call
 | ||||
|         children[i]->WriteToStream(os); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   static void ReadFromStream(std::istream& is, TrieNode* &obj, int vocab_size) { | ||||
|     int prefixCount; | ||||
|     is >> prefixCount; | ||||
| 
 | ||||
|     if (prefixCount == -1) { | ||||
|       // This is an undefined child
 | ||||
|       obj = nullptr; | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     obj = new TrieNode(vocab_size); | ||||
|     obj->ReadNode(is, prefixCount); | ||||
|     for (int i = 0; i < vocab_size; i++) { | ||||
|       // Recursive call
 | ||||
|       ReadFromStream(is, obj->children[i], vocab_size); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   void Insert(const char* word, std::function<int (char)> translator, | ||||
|               lm::WordIndex lm_word, float unigram_score) { | ||||
|     char wordCharacter = *word; | ||||
|     prefixCount++; | ||||
|     if (unigram_score < min_unigram_score) { | ||||
|       min_unigram_score = unigram_score; | ||||
|       min_score_word = lm_word; | ||||
|     } | ||||
|     if (wordCharacter != '\0') { | ||||
|       int vocabIndex = translator(wordCharacter); | ||||
|       TrieNode *child = children[vocabIndex]; | ||||
|       if (child == nullptr) | ||||
|         child = children[vocabIndex] = new TrieNode(vocab_size); | ||||
|       child->Insert(word + 1, translator, lm_word, unigram_score); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   int GetFrequency() { | ||||
|     return prefixCount; | ||||
|   } | ||||
| 
 | ||||
|   lm::WordIndex GetMinScoreWordIndex() { | ||||
|     return min_score_word; | ||||
|   } | ||||
| 
 | ||||
|   float GetMinUnigramScore() { | ||||
|     return min_unigram_score; | ||||
|   } | ||||
| 
 | ||||
|   TrieNode *GetChildAt(int vocabIndex) { | ||||
|     return children[vocabIndex]; | ||||
|   } | ||||
| 
 | ||||
| private: | ||||
|   int vocab_size; | ||||
|   int prefixCount; | ||||
|   lm::WordIndex min_score_word; | ||||
|   float min_unigram_score; | ||||
|   TrieNode **children; | ||||
| 
 | ||||
|   void WriteNode(std::ostream& os) const { | ||||
|     os << prefixCount << std::endl; | ||||
|     os << min_score_word << std::endl; | ||||
|     os << min_unigram_score << std::endl; | ||||
|   } | ||||
| 
 | ||||
|   void ReadNode(std::istream& is, int first_input) { | ||||
|     prefixCount = first_input; | ||||
|     is >> min_score_word; | ||||
|     is >> min_unigram_score; | ||||
|   } | ||||
| 
 | ||||
| }; | ||||
| 
 | ||||
| #endif //TRIE_NODE_H
 | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user