diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index 8cadeac68d7..2d187fb54a6 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -33,11 +34,11 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -inline float RowMax(const TTypes::UnalignedConstMatrix& m, int r, - int* c) { +template +inline T RowMax(const typename TTypes::UnalignedConstMatrix& m, int r, int* c) { *c = 0; CHECK_LT(0, m.dimension(1)); - float p = m(r, 0); + auto p = m(r, 0); for (int i = 1; i < m.dimension(1); ++i) { if (m(r, i) > p) { p = m(r, i); @@ -170,6 +171,7 @@ class CTCDecodeHelper { TF_DISALLOW_COPY_AND_ASSIGN(CTCDecodeHelper); }; +template class CTCGreedyDecoderOp : public OpKernel { public: explicit CTCGreedyDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -189,7 +191,7 @@ class CTCGreedyDecoderOp : public OpKernel { const TensorShape& inputs_shape = inputs->shape(); - std::vector::UnalignedConstMatrix> input_list_t; + std::vector::UnalignedConstMatrix> input_list_t; const int64 max_time = inputs_shape.dim_size(0); const int64 batch_size = inputs_shape.dim_size(1); const int64 num_classes_raw = inputs_shape.dim_size(2); @@ -198,14 +200,14 @@ class CTCGreedyDecoderOp : public OpKernel { errors::InvalidArgument("num_classes cannot exceed max int")); const int num_classes = static_cast(num_classes_raw); - auto inputs_t = inputs->tensor(); + auto inputs_t = inputs->tensor(); for (std::size_t t = 0; t < max_time; ++t) { input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, batch_size, num_classes); } auto seq_len_t = seq_len->vec(); - auto log_prob_t = log_prob->matrix(); + auto log_prob_t = log_prob->matrix(); log_prob_t.setZero(); @@ -221,7 +223,7 @@ class CTCGreedyDecoderOp : public OpKernel { int prev_indices = -1; for (int t = 0; t < seq_len_t(b); ++t) { int max_class_indices; - log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); + log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); if (max_class_indices != blank_index && !(merge_repeated_ && max_class_indices == prev_indices)) { sequence.push_back(max_class_indices); @@ -250,10 +252,18 @@ class CTCGreedyDecoderOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(CTCGreedyDecoderOp); }; -REGISTER_KERNEL_BUILDER(Name("CTCGreedyDecoder").Device(DEVICE_CPU), - CTCGreedyDecoderOp); +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CTCGreedyDecoder").Device(DEVICE_CPU).TypeConstraint("T"), \ + CTCGreedyDecoderOp); + +REGISTER_CPU(float); +REGISTER_CPU(double); + +#undef REGISTER_CPU // CTC beam search +template class CTCBeamSearchDecoderOp : public OpKernel { public: explicit CTCBeamSearchDecoderOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -275,9 +285,9 @@ class CTCBeamSearchDecoderOp : public OpKernel { ctx, &inputs, &seq_len, &log_prob, &decoded_indices, &decoded_values, &decoded_shape)); - auto inputs_t = inputs->tensor(); + auto inputs_t = inputs->tensor(); auto seq_len_t = seq_len->vec(); - auto log_prob_t = log_prob->matrix(); + auto log_prob_t = log_prob->matrix(); const TensorShape& inputs_shape = inputs->shape(); @@ -291,21 +301,21 @@ class CTCBeamSearchDecoderOp : public OpKernel { log_prob_t.setZero(); - std::vector::UnalignedConstMatrix> input_list_t; + std::vector::UnalignedConstMatrix> input_list_t; for (std::size_t t = 0; t < max_time; ++t) { input_list_t.emplace_back(inputs_t.data() + t * batch_size * num_classes, batch_size, num_classes); } - ctc::CTCBeamSearchDecoder<> beam_search(num_classes, beam_width_, - &beam_scorer_, 1 /* batch_size */, - merge_repeated_); - Tensor input_chip(DT_FLOAT, TensorShape({num_classes})); - auto input_chip_t = input_chip.flat(); + ctc::CTCBeamSearchDecoder beam_search(num_classes, beam_width_, + &beam_scorer_, 1 /* batch_size */, + merge_repeated_); + Tensor input_chip(DataTypeToEnum::v(), TensorShape({num_classes})); + auto input_chip_t = input_chip.flat(); std::vector > > best_paths(batch_size); - std::vector log_probs; + std::vector log_probs; // Assumption: the blank index is num_classes - 1 for (int b = 0; b < batch_size; ++b) { @@ -313,8 +323,8 @@ class CTCBeamSearchDecoderOp : public OpKernel { best_paths_b.resize(decode_helper_.GetTopPaths()); for (int t = 0; t < seq_len_t(b); ++t) { input_chip_t = input_list_t[t].chip(b, 0); - auto input_bi = - Eigen::Map(input_chip_t.data(), num_classes); + auto input_bi = Eigen::Map> + (input_chip_t.data(), num_classes); beam_search.Step(input_bi); } OP_REQUIRES_OK( @@ -335,13 +345,20 @@ class CTCBeamSearchDecoderOp : public OpKernel { private: CTCDecodeHelper decode_helper_; - ctc::CTCBeamSearchDecoder<>::DefaultBeamScorer beam_scorer_; + typename ctc::CTCBeamSearchDecoder::DefaultBeamScorer beam_scorer_; bool merge_repeated_; int beam_width_; - TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp); + TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoderOp); }; -REGISTER_KERNEL_BUILDER(Name("CTCBeamSearchDecoder").Device(DEVICE_CPU), - CTCBeamSearchDecoderOp); +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CTCBeamSearchDecoder").Device(DEVICE_CPU).TypeConstraint("T"), \ + CTCBeamSearchDecoderOp); + +REGISTER_CPU(float); +REGISTER_CPU(double); + +#undef REGISTER_CPU } // end namespace tensorflow diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index aa68e105add..96a449da954 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/ctc/ctc_loss_calculator.h" @@ -26,14 +27,15 @@ limitations under the License. namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; +// typedef Eigen::ThreadPoolDevice CPUDevice; +template class CTCLossOp : public OpKernel { - typedef Eigen::Map > InputMap; typedef Eigen::Map< - Eigen::Matrix > + Eigen::Matrix > OutputMap; public: @@ -110,7 +112,7 @@ class CTCLossOp : public OpKernel { errors::InvalidArgument("label SparseTensor is not valid: ", labels_sp_valid.error_message())); - ctc::CTCLossCalculator::LabelSequences labels_t(batch_size); + typename ctc::CTCLossCalculator::LabelSequences labels_t(batch_size); for (const auto& g : labels_sp.group({0})) { // iterate by batch const int64 batch_indices = g.group()[0]; OP_REQUIRES(ctx, FastBoundsCheck(batch_indices, batch_size), @@ -137,13 +139,13 @@ class CTCLossOp : public OpKernel { Tensor* loss = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss)); - auto loss_t = loss->vec(); + auto loss_t = loss->vec(); Tensor* gradient; OP_REQUIRES_OK(ctx, ctx->allocate_output("gradient", inputs_shape, &gradient)); - auto gradient_t = gradient->tensor(); - auto inputs_t = inputs->tensor(); + auto gradient_t = gradient->tensor(); + auto inputs_t = inputs->tensor(); std::vector gradient_list_t; std::vector input_list_t; @@ -158,7 +160,7 @@ class CTCLossOp : public OpKernel { gradient_t.setZero(); // Assumption: the blank index is num_classes - 1 - ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0); + ctc::CTCLossCalculator ctc_loss_calculator(num_classes - 1, 0); DeviceBase::CpuWorkerThreads workers = *ctx->device()->tensorflow_cpu_worker_threads(); OP_REQUIRES_OK(ctx, ctc_loss_calculator.CalculateLoss( @@ -173,9 +175,17 @@ class CTCLossOp : public OpKernel { bool ctc_merge_repeated_; bool ignore_longer_outputs_than_inputs_; - TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp); + TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOp); }; -REGISTER_KERNEL_BUILDER(Name("CTCLoss").Device(DEVICE_CPU), CTCLossOp); +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("CTCLoss").Device(DEVICE_CPU).TypeConstraint("T"), \ + CTCLossOp); +REGISTER_CPU(float); +REGISTER_CPU(double); + +#undef REGISTER_CPU + } // end namespace tensorflow diff --git a/tensorflow/core/ops/ctc_ops.cc b/tensorflow/core/ops/ctc_ops.cc index f2322c730bc..5d2811e4120 100644 --- a/tensorflow/core/ops/ctc_ops.cc +++ b/tensorflow/core/ops/ctc_ops.cc @@ -25,15 +25,16 @@ using shape_inference::ShapeHandle; // CTC is Connectionist Temporal Classification. See util/ctc/ for details. REGISTER_OP("CTCLoss") - .Input("inputs: float") + .Input("inputs: T") .Input("labels_indices: int64") .Input("labels_values: int32") .Input("sequence_length: int32") .Attr("preprocess_collapse_repeated: bool = false") .Attr("ctc_merge_repeated: bool = true") .Attr("ignore_longer_outputs_than_inputs: bool = false") - .Output("loss: float") - .Output("gradient: float") + .Output("loss: T") + .Output("gradient: T") + .Attr("T: {float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle inputs; ShapeHandle labels_indices; @@ -62,13 +63,14 @@ REGISTER_OP("CTCLoss") }); REGISTER_OP("CTCGreedyDecoder") - .Input("inputs: float") + .Input("inputs: T") .Input("sequence_length: int32") .Attr("merge_repeated: bool = false") .Output("decoded_indices: int64") .Output("decoded_values: int64") .Output("decoded_shape: int64") - .Output("log_probability: float") + .Output("log_probability: T") + .Attr("T: {float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle inputs; ShapeHandle sequence_length; @@ -90,7 +92,7 @@ REGISTER_OP("CTCGreedyDecoder") }); REGISTER_OP("CTCBeamSearchDecoder") - .Input("inputs: float") + .Input("inputs: T") .Input("sequence_length: int32") .Attr("beam_width: int >= 1") .Attr("top_paths: int >= 1") @@ -98,7 +100,8 @@ REGISTER_OP("CTCBeamSearchDecoder") .Output("decoded_indices: top_paths * int64") .Output("decoded_values: top_paths * int64") .Output("decoded_shape: top_paths * int64") - .Output("log_probability: float") + .Output("log_probability: T") + .Attr("T: {float, double}") .SetShapeFn([](InferenceContext* c) { ShapeHandle inputs; ShapeHandle sequence_length; diff --git a/tensorflow/core/util/ctc/BUILD b/tensorflow/core/util/ctc/BUILD index aa00a210f79..45c6618b80e 100644 --- a/tensorflow/core/util/ctc/BUILD +++ b/tensorflow/core/util/ctc/BUILD @@ -59,7 +59,8 @@ tf_cc_tests( name = "ctc_beam_search_test", size = "small", srcs = [ - "ctc_beam_search_test.cc", + "ctc_beam_search_test_float.cc", + "ctc_beam_search_test_double.cc", ], deps = [ ":ctc_beam_search_lib", diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 7382b8e6849..1da7fe78bfb 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -41,30 +41,33 @@ namespace ctc_beam_search { struct EmptyBeamState {}; +template struct BeamProbability { - BeamProbability() : total(kLogZero), blank(kLogZero), label(kLogZero) {} + BeamProbability() : total(kLogZero::val), + blank(kLogZero::val), + label(kLogZero::val) {} void Reset() { - total = kLogZero; - blank = kLogZero; - label = kLogZero; + total = kLogZero::val; + blank = kLogZero::val; + label = kLogZero::val; } - float total; - float blank; - float label; + T total; + T blank; + T label; }; -template +template class BeamRoot; -template +template struct BeamEntry { // BeamRoot::AddEntry() serves as the factory method. - friend BeamEntry* BeamRoot::AddEntry( - BeamEntry* p, int l); - inline bool Active() const { return newp.total != kLogZero; } + friend BeamEntry* + BeamRoot::AddEntry(BeamEntry* p, int l); + inline bool Active() const { return newp.total != kLogZero::val; } // Return the child at the given index, or construct a new one in-place if // none was found. - BeamEntry& GetChild(int ind) { + BeamEntry& GetChild(int ind) { auto entry = children.emplace(ind, nullptr); auto& child_entry = entry.first->second; // If this is a new child, populate the BeamEntry*. @@ -76,7 +79,7 @@ struct BeamEntry { std::vector LabelSeq(bool merge_repeated) const { std::vector labels; int prev_label = -1; - const BeamEntry* c = this; + const BeamEntry* c = this; while (c->parent != nullptr) { // Checking c->parent to skip root leaf. if (!merge_repeated || c->label != prev_label) { labels.push_back(c->label); @@ -88,12 +91,12 @@ struct BeamEntry { return labels; } - BeamEntry* parent; + BeamEntry* parent; int label; // All instances of child BeamEntry are owned by *beam_root. - gtl::FlatMap*> children; - BeamProbability oldp; - BeamProbability newp; + gtl::FlatMap*> children; + BeamProbability oldp; + BeamProbability newp; CTCBeamState state; private: @@ -102,40 +105,40 @@ struct BeamEntry { // otherwise parent will become invalid. // This private constructor is only called through the factory method // BeamRoot::AddEntry(). - BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) : parent(p), label(l), beam_root(beam_root) {} - BeamRoot* beam_root; + BeamRoot* beam_root; TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); }; // This class owns all instances of BeamEntry. This is used to avoid recursive // destructor call during destruction. -template +template class BeamRoot { public: - BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } BeamRoot(const BeamRoot&) = delete; BeamRoot& operator=(const BeamRoot&) = delete; - BeamEntry* AddEntry(BeamEntry* p, int l) { - auto* new_entry = new BeamEntry(p, l, this); + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); beam_entries_.emplace_back(new_entry); return new_entry; } - BeamEntry* RootEntry() const { return root_entry_; } + BeamEntry* RootEntry() const { return root_entry_; } private: - BeamEntry* root_entry_ = nullptr; - std::vector>> beam_entries_; + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; }; // BeamComparer is the default beam comparer provided in CTCBeamSearch. -template +template class BeamComparer { public: virtual ~BeamComparer() {} - virtual bool inline operator()(const BeamEntry* a, - const BeamEntry* b) const { + virtual bool inline operator()(const BeamEntry* a, + const BeamEntry* b) const { return a->newp.total > b->newp.total; } }; diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index fc63dfb0fd2..1e50f667e88 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -34,7 +34,7 @@ namespace ctc { // be subclassed and provided as an argument to CTCBeamSearchDecoder, if complex // scoring is required. Its main purpose is to provide a thin layer for // integrating language model scoring easily. -template +template class BaseBeamScorer { public: virtual ~BaseBeamScorer() {} @@ -56,8 +56,8 @@ class BaseBeamScorer { // // The score returned should be a log-probability. In the simplest case, as // there's no state expansion logic, the expansion score is zero. - virtual float GetStateExpansionScore(const CTCBeamState& state, - float previous_score) const { + virtual T GetStateExpansionScore(const CTCBeamState& state, + T previous_score) const { return previous_score; } // GetStateEndExpansionScore should be an inexpensive method to retrieve the @@ -65,8 +65,8 @@ class BaseBeamScorer { // multiplied (log-addition) with the final probability of the beam. // // The score returned should be a log-probability. - virtual float GetStateEndExpansionScore(const CTCBeamState& state) const { - return 0; + virtual T GetStateEndExpansionScore(const CTCBeamState& state) const { + return T(0); } }; diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index f2022d486c7..5e6038aba1c 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -38,10 +38,11 @@ limitations under the License. namespace tensorflow { namespace ctc { -template > -class CTCBeamSearchDecoder : public CTCDecoder { + ctc_beam_search::BeamComparer> +class CTCBeamSearchDecoder : public CTCDecoder { // Beam Search // // Example (GravesTh Fig. 7.5): @@ -73,12 +74,12 @@ class CTCBeamSearchDecoder : public CTCDecoder { // starts at 0). This special case can be calculated as: // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) // but we calculate it recursively for speed purposes. - typedef ctc_beam_search::BeamEntry BeamEntry; - typedef ctc_beam_search::BeamRoot BeamRoot; - typedef ctc_beam_search::BeamProbability BeamProbability; + typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; + typedef ctc_beam_search::BeamProbability BeamProbability; public: - typedef BaseBeamScorer DefaultBeamScorer; + typedef BaseBeamScorer DefaultBeamScorer; // The beam search decoder is constructed specifying the beam_width (number of // candidates to keep at each decoding timestep) and a beam scorer (used for @@ -87,9 +88,10 @@ class CTCBeamSearchDecoder : public CTCDecoder { // implementation, CTCBeamSearchDecoder<>::DefaultBeamScorer, generates the // standard beam search. CTCBeamSearchDecoder(int num_classes, int beam_width, - BaseBeamScorer* scorer, int batch_size = 1, + BaseBeamScorer* scorer, + int batch_size = 1, bool merge_repeated = false) - : CTCDecoder(num_classes, batch_size, merge_repeated), + : CTCDecoder(num_classes, batch_size, merge_repeated), beam_width_(beam_width), leaves_(beam_width), beam_scorer_(CHECK_NOTNULL(scorer)) { @@ -99,27 +101,27 @@ class CTCBeamSearchDecoder : public CTCDecoder { ~CTCBeamSearchDecoder() override {} // Run the hibernating beam search algorithm on the given input. - Status Decode(const CTCDecoder::SequenceLength& seq_len, - const std::vector& input, - std::vector* output, - CTCDecoder::ScoreOutput* scores) override; + Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override; // Calculate the next step of the beam search and update the internal state. template void Step(const Vector& log_input_t); template - float GetTopK(const int K, const Vector& input, - std::vector* top_k_logits, - std::vector* top_k_indices); + T GetTopK(const int K, const Vector& input, + std::vector* top_k_logits, + std::vector* top_k_indices); // Retrieve the beam scorer instance used during decoding. - BaseBeamScorer* GetBeamScorer() const { return beam_scorer_; } + BaseBeamScorer* GetBeamScorer() const { return beam_scorer_; } // Set label selection parameters for faster decoding. // See comments for label_selection_size_ and label_selection_margin_. void SetLabelSelectionParameters(int label_selection_size, - float label_selection_margin) { + T label_selection_margin) { label_selection_size_ = label_selection_size; label_selection_margin_ = label_selection_margin; } @@ -129,7 +131,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { // Extract the top n paths at current time step Status TopPaths(int n, std::vector>* paths, - std::vector* log_probs, bool merge_repeated) const; + std::vector* log_probs, bool merge_repeated) const; private: int beam_width_; @@ -145,37 +147,38 @@ class CTCBeamSearchDecoder : public CTCDecoder { // Default is to do no label selection. // For more detail: https://research.google.com/pubs/pub44823.html int label_selection_size_ = 0; // zero means unlimited - float label_selection_margin_ = -1; // -1 means unlimited. + T label_selection_margin_ = -1; // -1 means unlimited. gtl::TopN leaves_; std::unique_ptr beam_root_; - BaseBeamScorer* beam_scorer_; + BaseBeamScorer* beam_scorer_; TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder); }; -template -Status CTCBeamSearchDecoder::Decode( - const CTCDecoder::SequenceLength& seq_len, - const std::vector& input, - std::vector* output, ScoreOutput* scores) { +template +Status CTCBeamSearchDecoder::Decode( + const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) { // Storage for top paths. std::vector> beams; - std::vector beam_log_probabilities; + std::vector beam_log_probabilities; int top_n = output->size(); if (std::any_of(output->begin(), output->end(), - [this](const CTCDecoder::Output& output) -> bool { + [this](const typename CTCDecoder::Output& output) -> bool { return output.size() < this->batch_size_; })) { return errors::InvalidArgument( "output needs to be of size at least (top_n, batch_size)."); } - if (scores->rows() < batch_size_ || scores->cols() < top_n) { + if (scores->rows() < this->batch_size_ || scores->cols() < top_n) { return errors::InvalidArgument( "scores needs to be of size at least (batch_size, top_n)."); } - for (int b = 0; b < batch_size_; ++b) { + for (int b = 0; b < this->batch_size_; ++b) { int seq_len_b = seq_len[b]; Reset(); @@ -196,7 +199,7 @@ Status CTCBeamSearchDecoder::Decode( } Status status = - TopPaths(top_n, &beams, &beam_log_probabilities, merge_repeated_); + TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_); if (!status.ok()) { return status; } @@ -213,20 +216,20 @@ Status CTCBeamSearchDecoder::Decode( return Status::OK(); } -template +template template -float CTCBeamSearchDecoder::GetTopK( - const int K, const Vector& input, std::vector* top_k_logits, +T CTCBeamSearchDecoder::GetTopK( + const int K, const Vector& input, std::vector* top_k_logits, std::vector* top_k_indices) { // Find Top K choices, complexity nk in worst case. The array input is read // just once. - CHECK_EQ(num_classes_, input.size()); + CHECK_EQ(this->num_classes_, input.size()); top_k_logits->clear(); top_k_indices->clear(); top_k_logits->resize(K, -INFINITY); top_k_indices->resize(K, -1); - for (int j = 0; j < num_classes_ - 1; ++j) { - const float logit = input(j); + for (int j = 0; j < this->num_classes_ - 1; ++j) { + const T logit = input(j); if (logit > (*top_k_logits)[K - 1]) { int k = K - 1; while (k > 0 && logit > (*top_k_logits)[k - 1]) { @@ -239,43 +242,43 @@ float CTCBeamSearchDecoder::GetTopK( } } // Return max value which is in 0th index or blank character logit - return std::max((*top_k_logits)[0], input(num_classes_ - 1)); + return std::max((*top_k_logits)[0], input(this->num_classes_ - 1)); } -template +template template -void CTCBeamSearchDecoder::Step( +void CTCBeamSearchDecoder::Step( const Vector& raw_input) { - std::vector top_k_logits; + std::vector top_k_logits; std::vector top_k_indices; const bool top_k = (label_selection_size_ > 0 && label_selection_size_ < raw_input.size()); // Number of character classes to consider in each step. - const int max_classes = top_k ? label_selection_size_ : (num_classes_ - 1); + const int max_classes = top_k ? label_selection_size_ : (this->num_classes_ - 1); // Get max coefficient and remove it from raw_input later. - float max_coeff; + T max_coeff; if (top_k) { max_coeff = GetTopK(label_selection_size_, raw_input, &top_k_logits, &top_k_indices); } else { max_coeff = raw_input.maxCoeff(); } - // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). - float logsumexp = 0.0; + T logsumexp = T(0.0); for (int j = 0; j < raw_input.size(); ++j) { logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); } logsumexp = Eigen::numext::log(logsumexp); // Final normalization offset to get correct log probabilities. - float norm_offset = max_coeff + logsumexp; + T norm_offset = max_coeff + logsumexp; - const float label_selection_input_min = + + const T label_selection_input_min = (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) - : -std::numeric_limits::infinity(); + : -std::numeric_limits::infinity(); // Extract the beams sorted in decreasing new probability - CHECK_EQ(num_classes_, raw_input.size()); + CHECK_EQ(this->num_classes_, raw_input.size()); std::unique_ptr> branches(leaves_.Extract()); leaves_.Reset(); @@ -294,8 +297,8 @@ void CTCBeamSearchDecoder::Step( // else: // Plabel(l=abc @ t=6) = (Plabel(l=abc @ t=5) // + P(l=ab @ t=5)) - float previous = (b->label == b->parent->label) ? b->parent->oldp.blank - : b->parent->oldp.total; + T previous = (b->label == b->parent->label) ? b->parent->oldp.blank + : b->parent->oldp.total; b->newp.label = LogSumExp(b->newp.label, beam_scorer_->GetStateExpansionScore(b->state, previous)); @@ -304,7 +307,7 @@ void CTCBeamSearchDecoder::Step( b->newp.label += raw_input(b->label) - norm_offset; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset; + b->newp.blank = b->oldp.total + raw_input(this->blank_index_) - norm_offset; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -325,7 +328,7 @@ void CTCBeamSearchDecoder::Step( // isn't full, or the lowest probability entry in the beam has a // lower probability than the leaf. auto is_candidate = [this](const BeamProbability& prob) { - return (prob.total > kLogZero && + return (prob.total > kLogZero::val && (leaves_.size() < beam_width_ || prob.total > leaves_.peek_bottom()->newp.total)); }; @@ -336,7 +339,7 @@ void CTCBeamSearchDecoder::Step( for (int ind = 0; ind < max_classes; ind++) { const int label = top_k ? top_k_indices[ind] : ind; - const float logit = top_k ? top_k_logits[ind] : raw_input(ind); + const T logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. // We may compare logits instead of log probabilities, @@ -347,13 +350,13 @@ void CTCBeamSearchDecoder::Step( BeamEntry& c = b->GetChild(label); if (!c.Active()) { // Pblank(l=abcd @ t=6) = 0 - c.newp.blank = kLogZero; + c.newp.blank = kLogZero::val; // If new child label is identical to beam label: // Plabel(l=abcc @ t=6) = Pblank(l=abc @ t=5) * P(c @ 6) // Otherwise: // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); - float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; + T previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; c.newp.label = logit - norm_offset + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) @@ -379,15 +382,15 @@ void CTCBeamSearchDecoder::Step( } // for (BeamEntry* b... } -template -void CTCBeamSearchDecoder::Reset() { +template +void CTCBeamSearchDecoder::Reset() { leaves_.Reset(); // This beam root, and all of its children, will be in memory until // the next reset. beam_root_.reset(new BeamRoot(nullptr, -1)); - beam_root_->RootEntry()->newp.total = 0.0; // ln(1) - beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) + beam_root_->RootEntry()->newp.total = T(0.0); // ln(1) + beam_root_->RootEntry()->newp.blank = T(0.0); // ln(1) // Add the root as the initial leaf. leaves_.push(beam_root_->RootEntry()); @@ -396,9 +399,9 @@ void CTCBeamSearchDecoder::Reset() { beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); } -template -Status CTCBeamSearchDecoder::TopPaths( - int n, std::vector>* paths, std::vector* log_probs, +template +Status CTCBeamSearchDecoder::TopPaths( + int n, std::vector>* paths, std::vector* log_probs, bool merge_repeated) const { CHECK_NOTNULL(paths)->clear(); CHECK_NOTNULL(log_probs)->clear(); diff --git a/tensorflow/core/util/ctc/ctc_beam_search_test_double.cc b/tensorflow/core/util/ctc/ctc_beam_search_test_double.cc new file mode 100644 index 00000000000..24a3202127d --- /dev/null +++ b/tensorflow/core/util/ctc/ctc_beam_search_test_double.cc @@ -0,0 +1,346 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test illustrates how to make use of the CTCBeamSearchDecoder using a +// custom BeamScorer and BeamState based on a dictionary with a few artificial +// words. +#include "tensorflow/core/util/ctc/ctc_beam_search.h" + +#include +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +typedef std::vector>> TestData; +typedef tensorflow::ctc::CTCBeamSearchDecoder CTCBeamSearchDecoder +typedef tensorflow::ctc::CTCDecoder CTCDecoder; + +// The HistoryBeamState is used to keep track of the current candidate and +// caches the expansion score (needed by the scorer). +struct HistoryBeamState { + double score; + std::vector labels; +}; + +// DictionaryBeamScorer essentially favors candidates that can still become +// dictionary words. As soon as a beam candidate is not a dictionary word or +// a prefix of a dictionary word it gets a low probability at each step. +// +// The dictionary itself is hard-coded a static const variable of the class. +class DictionaryBeamScorer + : public tensorflow::ctc::BaseBeamScorer { + public: + void InitializeState(HistoryBeamState* root) const override { + root->score = 0; + } + + void ExpandState(const HistoryBeamState& from_state, int from_label, + HistoryBeamState* to_state, int to_label) const override { + // Keep track of the current complete candidate by storing the labels along + // the expansion path in the beam state. + to_state->labels.push_back(to_label); + SetStateScoreAccordingToDict(to_state); + } + + void ExpandStateEnd(HistoryBeamState* state) const override { + SetStateScoreAccordingToDict(state); + } + + double GetStateExpansionScore(const HistoryBeamState& state, + double previous_score) const override { + return previous_score + state.score; + } + + double GetStateEndExpansionScore( + const HistoryBeamState& state) const override { + return state.score; + } + + // Simple dictionary used when scoring the beams to check if they are prefixes + // of dictionary words (see SetStateScoreAccordingToDict below). + static const std::vector> dictionary_; + + private: + void SetStateScoreAccordingToDict(HistoryBeamState* state) const; +}; + +const std::vector> DictionaryBeamScorer::dictionary_ = { + {3}, {3, 1}}; + +void DictionaryBeamScorer::SetStateScoreAccordingToDict( + HistoryBeamState* state) const { + // Check if the beam can still be a dictionary word (e.g. prefix of one). + const std::vector& candidate = state->labels; + for (int w = 0; w < dictionary_.size(); ++w) { + const std::vector& word = dictionary_[w]; + // If the length of the current beam is already larger, skip. + if (candidate.size() > word.size()) { + continue; + } + if (std::equal(word.begin(), word.begin() + candidate.size(), + candidate.begin())) { + state->score = std::log(1.0); + return; + } + } + // At this point, the candidate certainly can't be in the dictionary. + state->score = std::log(0.01); +} + +TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { + const int batch_size = 1; + const int timesteps = 5; + const int top_paths = 3; + const int num_classes = 6; + + // Plain decoder using hibernating beam search algorithm. + CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer; + CTCBeamSearchDecoder<> decoder(num_classes, 10 * top_paths, &default_scorer); + + // Dictionary decoder, allowing only two dictionary words : {3}, {3, 1}. + DictionaryBeamScorer dictionary_scorer; + CTCBeamSearchDecoder dictionary_decoder( + num_classes, top_paths, &dictionary_scorer); + + // Raw data containers (arrays of floats64, ints, etc.). + int sequence_lengths[batch_size] = {timesteps}; + double input_data_mat[timesteps][batch_size][num_classes] = { + {{0, 0.6, 0, 0.4, 0, 0}}, + {{0, 0.5, 0, 0.5, 0, 0}}, + {{0, 0.4, 0, 0.6, 0, 0}}, + {{0, 0.4, 0, 0.6, 0, 0}}, + {{0, 0.4, 0, 0.6, 0, 0}}}; + + // The CTCDecoder works with log-probs. + for (int t = 0; t < timesteps; ++t) { + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < num_classes; ++c) { + input_data_mat[t][b][c] = std::log(input_data_mat[t][b][c]); + } + } + } + + // Plain output, without any additional scoring. + std::vector expected_output = { + {{1, 3}, {1, 3, 1}, {3, 1, 3}}, + }; + + // Dictionary outputs: preference for dictionary candidates. The + // second-candidate is there, despite it not being a dictionary word, due to + // stronger probability in the input to the decoder. + std::vector expected_dict_output = { + {{3}, {1, 3}, {3, 1}}, + }; + + // Convert data containers to the format accepted by the decoder, simply + // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, + // using Eigen::Map. + Eigen::Map seq_len(&sequence_lengths[0], batch_size); + std::vector> inputs; + inputs.reserve(timesteps); + for (int t = 0; t < timesteps; ++t) { + inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); + } + + // Prepare containers for output and scores. + std::vector outputs(top_paths); + for (CTCDecoder::Output& output : outputs) { + output.resize(batch_size); + } + double score[batch_size][top_paths] = {{0.0}}; + Eigen::Map scores(&score[0][0], batch_size, top_paths); + + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_output[0][path]); + } + + // Prepare dictionary outputs. + std::vector dict_outputs(top_paths); + for (CTCDecoder::Output& output : dict_outputs) { + output.resize(batch_size); + } + EXPECT_TRUE( + dictionary_decoder.Decode(seq_len, inputs, &dict_outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(dict_outputs[path][0], expected_dict_output[0][path]); + } +} + +TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { + const int batch_size = 1; + const int timesteps = 1; + const int top_paths = 3; + const int num_classes = 6; + + // Plain decoder using hibernating beam search algorithm. + CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer; + CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer); + + // Raw data containers (arrays of floats64, ints, etc.). + int sequence_lengths[batch_size] = {timesteps}; + double input_data_mat[timesteps][batch_size][num_classes] = { + {{0.4, 0.3, 0, 0, 0, 0.5}}}; + + // Convert data containers to the format accepted by the decoder, simply + // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, + // using Eigen::Map. + Eigen::Map seq_len(&sequence_lengths[0], batch_size); + std::vector> inputs; + inputs.reserve(timesteps); + for (int t = 0; t < timesteps; ++t) { + inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); + } + + // Prepare containers for output and scores. + std::vector outputs(top_paths); + for (CTCDecoder::Output& output : outputs) { + output.resize(batch_size); + } + double score[batch_size][top_paths] = {{0.0}}; + Eigen::Map scores(&score[0][0], batch_size, top_paths); + + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + // Make sure all scores are finite. + for (int path = 0; path < top_paths; ++path) { + LOG(INFO) << "path " << path; + EXPECT_FALSE(std::isinf(score[0][path])); + } +} + +// A beam decoder to test label selection. It simply models N labels with +// rapidly dropping off log-probability. + +typedef int LabelState; // The state is simply the final label. + +class RapidlyDroppingLabelScorer + : public tensorflow::ctc::BaseBeamScorer { + public: + void InitializeState(LabelState* root) const override {} + + void ExpandState(const LabelState& from_state, int from_label, + LabelState* to_state, int to_label) const override { + *to_state = to_label; + } + + void ExpandStateEnd(LabelState* state) const override {} + + double GetStateExpansionScore(const LabelState& state, + double previous_score) const override { + // Drop off rapidly for later labels. + const double kRapidly = 100; + return previous_score - kRapidly * state; + } + + double GetStateEndExpansionScore(const LabelState& state) const override { + return 0; + } +}; + +TEST(CtcBeamSearch, LabelSelection) { + const int batch_size = 1; + const int timesteps = 3; + const int top_paths = 5; + const int num_classes = 6; + + // Decoder which drops off log-probabilities for labels 0 >> 1 >> 2 >> 3. + RapidlyDroppingLabelScorer scorer; + CTCBeamSearchDecoder decoder(num_classes, top_paths, &scorer); + + // Raw data containers (arrays of floats64, ints, etc.). + int sequence_lengths[batch_size] = {timesteps}; + // Log probabilities, slightly preferring later labels, this decision + // should be overridden by the scorer which strongly prefers earlier labels. + // The last one is empty label, and for simplicity we give it an extremely + // high cost to ignore it. We also use the first label to break up the + // repeated label sequence. + double input_data_mat[timesteps][batch_size][num_classes] = { + {{-1e6, 1, 2, 3, 4, -1e6}}, + {{1e6, 0, 0, 0, 0, -1e6}}, // force label 0 to break up repeated + {{-1e6, 1.1, 2.2, 3.3, 4.4, -1e6}}, + }; + + // Expected output without label selection + std::vector expected_default_output = { + {{1, 0, 1}, {1, 0, 2}, {2, 0, 1}, {1, 0, 3}, {2, 0, 2}}, + }; + + // Expected output with label selection limiting to 2 items + // this is suboptimal because only labels 3 and 4 were allowed to be seen. + std::vector expected_output_size2 = { + {{3, 0, 3}, {3, 0, 4}, {4, 0, 3}, {4, 0, 4}, {3}}, + }; + + // Expected output with label width of 2.0. This would permit three labels at + // the first timestep, but only two at the last. + std::vector expected_output_width2 = { + {{2, 0, 3}, {2, 0, 4}, {3, 0, 3}, {3, 0, 4}, {4, 0, 3}}, + }; + + // Convert data containers to the format accepted by the decoder, simply + // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, + // using Eigen::Map. + Eigen::Map seq_len(&sequence_lengths[0], batch_size); + std::vector> inputs; + inputs.reserve(timesteps); + for (int t = 0; t < timesteps; ++t) { + inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); + } + + // Prepare containers for output and scores. + std::vector outputs(top_paths); + for (CTCDecoder::Output& output : outputs) { + output.resize(batch_size); + } + double score[batch_size][top_paths] = {{0.0}}; + Eigen::Map scores(&score[0][0], batch_size, top_paths); + + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_default_output[0][path]); + } + + // Try label selection size 2 + decoder.SetLabelSelectionParameters(2, -1); + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]); + } + + // Try label selection width 2.0 + decoder.SetLabelSelectionParameters(0, 2.0); + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_output_width2[0][path]); + } + + // Try both size 2 and width 2.0: the former is more constraining, so + // it's equivalent to that. + decoder.SetLabelSelectionParameters(2, 2.0); + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_output_size2[0][path]); + } + + // Size 4 and width > 3.3 are equivalent to no label selection + decoder.SetLabelSelectionParameters(4, 3.3001); + EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); + for (int path = 0; path < top_paths; ++path) { + EXPECT_EQ(outputs[path][0], expected_default_output[0][path]); + } +} + +} // namespace diff --git a/tensorflow/core/util/ctc/ctc_beam_search_test.cc b/tensorflow/core/util/ctc/ctc_beam_search_test_float.cc similarity index 94% rename from tensorflow/core/util/ctc/ctc_beam_search_test.cc rename to tensorflow/core/util/ctc/ctc_beam_search_test_float.cc index b2d5ef56adf..bbc67e93f09 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search_test.cc +++ b/tensorflow/core/util/ctc/ctc_beam_search_test_float.cc @@ -25,8 +25,8 @@ limitations under the License. namespace { typedef std::vector>> TestData; -using tensorflow::ctc::CTCBeamSearchDecoder; -using tensorflow::ctc::CTCDecoder; +typedef tensorflow::ctc::CTCBeamSearchDecoder CTCBeamSearchDecoder; +using tensorflow::ctc::CTCDecoder CTCDecoder; // The HistoryBeamState is used to keep track of the current candidate and // caches the expansion score (needed by the scorer). @@ -115,7 +115,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { CTCBeamSearchDecoder dictionary_decoder( num_classes, top_paths, &dictionary_scorer); - // Raw data containers (arrays of floats, ints, etc.). + // Raw data containers (arrays of floats64, ints, etc.). int sequence_lengths[batch_size] = {timesteps}; float input_data_mat[timesteps][batch_size][num_classes] = { {{0, 0.6, 0, 0.4, 0, 0}}, @@ -149,7 +149,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); - std::vector> inputs; + std::vector> inputs; inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); @@ -161,7 +161,7 @@ TEST(CtcBeamSearch, DecodingWithAndWithoutDictionary) { output.resize(batch_size); } float score[batch_size][top_paths] = {{0.0}}; - Eigen::Map scores(&score[0][0], batch_size, top_paths); + Eigen::Map scores(&score[0][0], batch_size, top_paths); EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); for (int path = 0; path < top_paths; ++path) { @@ -190,7 +190,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { CTCBeamSearchDecoder<>::DefaultBeamScorer default_scorer; CTCBeamSearchDecoder<> decoder(num_classes, top_paths, &default_scorer); - // Raw data containers (arrays of floats, ints, etc.). + // Raw data containers (arrays of floats64, ints, etc.). int sequence_lengths[batch_size] = {timesteps}; float input_data_mat[timesteps][batch_size][num_classes] = { {{0.4, 0.3, 0, 0, 0, 0.5}}}; @@ -199,7 +199,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); - std::vector> inputs; + std::vector> inputs; inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); @@ -211,7 +211,7 @@ TEST(CtcBeamSearch, AllBeamElementsHaveFiniteScores) { output.resize(batch_size); } float score[batch_size][top_paths] = {{0.0}}; - Eigen::Map scores(&score[0][0], batch_size, top_paths); + Eigen::Map scores(&score[0][0], batch_size, top_paths); EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); // Make sure all scores are finite. @@ -260,7 +260,7 @@ TEST(CtcBeamSearch, LabelSelection) { RapidlyDroppingLabelScorer scorer; CTCBeamSearchDecoder decoder(num_classes, top_paths, &scorer); - // Raw data containers (arrays of floats, ints, etc.). + // Raw data containers (arrays of floats64, ints, etc.). int sequence_lengths[batch_size] = {timesteps}; // Log probabilities, slightly preferring later labels, this decision // should be overridden by the scorer which strongly prefers earlier labels. @@ -294,7 +294,7 @@ TEST(CtcBeamSearch, LabelSelection) { // mapping the memory from the container to an Eigen::ArrayXi,::MatrixXf, // using Eigen::Map. Eigen::Map seq_len(&sequence_lengths[0], batch_size); - std::vector> inputs; + std::vector> inputs; inputs.reserve(timesteps); for (int t = 0; t < timesteps; ++t) { inputs.emplace_back(&input_data_mat[t][0][0], batch_size, num_classes); @@ -306,7 +306,7 @@ TEST(CtcBeamSearch, LabelSelection) { output.resize(batch_size); } float score[batch_size][top_paths] = {{0.0}}; - Eigen::Map scores(&score[0][0], batch_size, top_paths); + Eigen::Map scores(&score[0][0], batch_size, top_paths); EXPECT_TRUE(decoder.Decode(seq_len, inputs, &outputs, &scores).ok()); for (int path = 0; path < top_paths; ++path) { diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index f5c9e4bb596..3bec5ee5695 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -33,12 +33,13 @@ namespace ctc { // The two types of decoding available are: // - greedy path, through the CTCGreedyDecoder // - beam search, through the CTCBeamSearchDecoder +template class CTCDecoder { public: typedef Eigen::Map SequenceLength; - typedef Eigen::Map Input; + typedef Eigen::Map> Input; typedef std::vector> Output; - typedef Eigen::Map ScoreOutput; + typedef Eigen::Map> ScoreOutput; CTCDecoder(int num_classes, int batch_size, bool merge_repeated) : num_classes_(num_classes), @@ -69,25 +70,27 @@ class CTCDecoder { // CTCGreedyDecoder is an implementation of the simple best path decoding // algorithm, selecting at each timestep the most likely class at each timestep. -class CTCGreedyDecoder : public CTCDecoder { +template +class CTCGreedyDecoder : public CTCDecoder { public: + typedef CTCDecoder Decoder; CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) - : CTCDecoder(num_classes, batch_size, merge_repeated) {} + : CTCDecoder(num_classes, batch_size, merge_repeated) {} - Status Decode(const CTCDecoder::SequenceLength& seq_len, - const std::vector& input, - std::vector* output, - CTCDecoder::ScoreOutput* scores) override { - if (output->empty() || (*output)[0].size() < batch_size_) { + Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override { + if (output->empty() || (*output)[0].size() < Decoder::batch_size_) { return errors::InvalidArgument( "output needs to be of size at least (1, batch_size)."); } - if (scores->rows() < batch_size_ || scores->cols() == 0) { + if (scores->rows() < Decoder::batch_size_ || scores->cols() == 0) { return errors::InvalidArgument( "scores needs to be of size at least (batch_size, 1)."); } // For each batch entry, identify the transitions - for (int b = 0; b < batch_size_; ++b) { + for (int b = 0; b < Decoder::batch_size_; ++b) { int seq_len_b = seq_len[b]; // Only writing to beam 0 std::vector& output_b = (*output)[0][b]; @@ -98,8 +101,8 @@ class CTCGreedyDecoder : public CTCDecoder { auto row = input[t].row(b); int max_class_ix; (*scores)(b, 0) += -row.maxCoeff(&max_class_ix); - if (max_class_ix != blank_index_ && - !(merge_repeated_ && max_class_ix == prev_class_ix)) { + if (max_class_ix != Decoder::blank_index_ && + !(Decoder::merge_repeated_ && max_class_ix == prev_class_ix)) { output_b.push_back(max_class_ix); } prev_class_ix = max_class_ix; diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.cc b/tensorflow/core/util/ctc/ctc_loss_calculator.cc index a0ac5eec4bc..8e7d3d48447 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.cc +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.cc @@ -19,168 +19,6 @@ limitations under the License. namespace tensorflow { namespace ctc { -// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3. -// Starting with t = 0 instead of t = 1 used in the text. -// Based on Kanishka's CTC. -void CTCLossCalculator::CalculateForwardVariables( - const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, - Matrix* log_alpha) const { - // Number of cols is the number of time steps = number of cols in target - // after the output delay. - log_alpha->setConstant(kLogZero); - - int U = l_prime.size(); - int T = log_alpha->cols(); - - CHECK_EQ(U, log_alpha->rows()); - - // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6. - log_alpha->coeffRef(0, 0) = std::log(y(blank_index_, output_delay_)); - // Below, l_prime[1] == labels[0] - auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_; - log_alpha->coeffRef(1, 0) = std::log(y(label_0, output_delay_)); - - for (int t = 1; t < T; ++t) { - // If there is not enough time to output the remaining labels or - // some labels have been skipped, then let log_alpha(u, t) continue to - // be kLogZero. - for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); - ++u) { - // Begin (GravesTh) Eq 7.9 - // Add in the u, t - 1 term. - float sum_log_alpha = kLogZero; - if (ctc_merge_repeated || l_prime[u] == blank_index_) { - sum_log_alpha = log_alpha->coeff(u, t - 1); - } - - // Add in the u - 1, t - 1 term. - if (u > 0) { - sum_log_alpha = - LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1)); - } - - // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2). - if (u > 1) { - const bool matching_labels_merge = - ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]); - if (l_prime[u] != blank_index_ && !matching_labels_merge) { - sum_log_alpha = - LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1)); - } - } - // Multiply the summed alphas with the activation log probability. - log_alpha->coeffRef(u, t) = - std::log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; - } // End (GravesTh) Eq 7.9. - } -} - -// Calculates the beta(t, u) as described in (GravesTh) Section 7.3. -void CTCLossCalculator::CalculateBackwardVariables( - const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, - Matrix* log_beta) const { - // Number of cols is the number of time steps = number of cols in target. - // Matrix log_beta = - // Matrix::Constant(l_prime.size(), y.cols() - output_delay_, - // kLogZero); - log_beta->setConstant(kLogZero); - int T = log_beta->cols(); - int U = l_prime.size(); - CHECK_EQ(U, log_beta->rows()); - - // Initial beta values in (GravesTh) Eq 7.13: log of probability 1. - for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0; - - for (int t = T - 1 - 1; t >= 0; --t) { - // If there is not enough time to output the remaining labels or - // some labels have been skipped, then let log_beta(u, t) continue to - // be kLogZero. - for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); - ++u) { - // Begin (GravesTh) Eq 7.15 - // Add in the u, t + 1 term. - if (ctc_merge_repeated || l_prime[u] == blank_index_) { - log_beta->coeffRef(u, t) = - LogSumExp(log_beta->coeff(u, t), - log_beta->coeff(u, t + 1) + - std::log(y(l_prime[u], output_delay_ + t + 1))); - } - - // Add in the u + 1, t + 1 term. - if (u + 1 < U) { - log_beta->coeffRef(u, t) = - LogSumExp(log_beta->coeff(u, t), - log_beta->coeff(u + 1, t + 1) + - std::log(y(l_prime[u + 1], output_delay_ + t + 1))); - } - - // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2). - if (u + 2 < U) { - const bool matching_labels_merge = - ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]); - if (l_prime[u] != blank_index_ && !matching_labels_merge) { - // Add in u + 2 term. - log_beta->coeffRef(u, t) = - LogSumExp(log_beta->coeff(u, t), - log_beta->coeff(u + 2, t + 1) + - std::log(y(l_prime[u + 2], output_delay_ + t + 1))); - } - } // End (GravesTh) Eq. 7.15 - } - } -} - -// Using (GravesTh) Eq 7.26 & 7.34. -void CTCLossCalculator::CalculateGradient(const std::vector& l_prime, - const Matrix& y, - const Matrix& log_alpha, - const Matrix& log_beta, - float log_p_z_x, Matrix* dy) const { - // Only working with the leftmost part of dy for this batch element. - auto dy_b = dy->leftCols(y.cols()); - - // It is possible that no valid path is found if the activations for the - // targets are zero. - if (log_p_z_x == kLogZero) { - LOG(WARNING) << "No valid path found."; - dy_b = y; - return; - } - - int L = y.rows(); - int T = y.cols(); - int U = l_prime.size(); - - for (int t = 0; t < T - output_delay_; ++t) { - Array prob_sum(L); - prob_sum.setConstant(kLogZero); - - for (int u = 0; u < U; ++u) { - int l = l_prime[u]; - prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t)); - } - - for (int l = 0; l < L; ++l) { - // Negative term in (GravesTh) Eq 7.28. - float negative_term = expf(prob_sum[l] - log_p_z_x); - - dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term; - } - } -} - -void CTCLossCalculator::GetLPrimeIndices(const std::vector& l, - std::vector* l_prime) const { - // Assumption is that l_prime is empty. - l_prime->reserve(2 * l.size() + 1); - - for (auto label : l) { - l_prime->push_back(blank_index_); - l_prime->push_back(label); - } - // Add final blank to l'. - l_prime->push_back(blank_index_); -} } // namespace ctc } // namespace tensorflow diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index 5f4c4cd8a08..56312821894 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -30,6 +30,7 @@ limitations under the License. namespace tensorflow { namespace ctc { +template class CTCLossCalculator { // Connectionist Temporal Classification Loss // @@ -50,10 +51,14 @@ class CTCLossCalculator { // Neural Networks" (PhD Thesis), Technische Universit¨at M¨unchen. public: typedef std::vector> LabelSequences; - typedef Eigen::MatrixXf Matrix; - typedef Eigen::ArrayXf Array; - typedef Eigen::Map InputMap; - typedef Eigen::Map OutputMap; + using Matrix = Eigen::Matrix; + // typedef Eigen::MatrixXd Matrix; + using Array = Eigen::Array; + // typedef Eigen::ArrayXd Array; + using InputMap = Eigen::Map; + // typedef Eigen::Map InputMap; + using OutputMap = Eigen::Map; + // typedef Eigen::Map OutputMap; CTCLossCalculator(int blank_index, int output_delay) : blank_index_(blank_index), output_delay_(output_delay) {} @@ -79,7 +84,7 @@ class CTCLossCalculator { void CalculateGradient(const std::vector& l_prime, const Matrix& y, const Matrix& log_alpha, const Matrix& log_beta, - float log_p_z_x, Matrix* dy) const; + T log_p_z_x, Matrix* dy) const; void GetLPrimeIndices(const std::vector& l, std::vector* l_prime) const; @@ -103,9 +108,10 @@ class CTCLossCalculator { const int output_delay_; }; +template template -Status CTCLossCalculator::CalculateLoss( +Status CTCLossCalculator::CalculateLoss( const VectorIn& seq_len, const LabelSequences& labels, const std::vector& inputs, bool preprocess_collapse_repeated, bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, @@ -205,11 +211,11 @@ Status CTCLossCalculator::CalculateLoss( // Convert label from DistBelief // y, prob are in num_classes x seq_len(b) // Output activations. - Eigen::ArrayXf y_b_col; + Array y_b_col; for (int t = 0; t < seq_len(b); t++) { - // Calculate the softmax of y_b. Use double precision + // Calculate the softmax of y_b. Use original precision // arithmetic for the sum. - float max_coeff = inputs[t].row(b).maxCoeff(); + T max_coeff = inputs[t].row(b).maxCoeff(); y_b_col = (inputs[t].row(b).array() - max_coeff).exp(); y_b.col(t) = y_b_col / y_b_col.sum(); } @@ -222,7 +228,7 @@ Status CTCLossCalculator::CalculateLoss( // The loss is computed as the log(p(z|x)) between the target and // prediction. Do lazy evaluation of log_prob here. - float log_p_z_x = kLogZero; + T log_p_z_x = kLogZero::val; for (int u = 0; u < l_prime.size(); ++u) { // (GravesTh) Eq 7.26, sum over all paths for t = 0. log_p_z_x = LogSumExp(log_p_z_x, log_alpha_b(u, 0) + log_beta_b(u, 0)); @@ -253,19 +259,19 @@ Status CTCLossCalculator::CalculateLoss( // fwd,bwd: T * 2 * (2*L + 1) * (Cost(LogSumExp) + Cost(Log)) + // grad: T * ((2L + 1) * Cost(LogSumExp) + L * (Cost(Expf) + Cost(Add)). const int64 cost_exp = Eigen::internal::functor_traits< - Eigen::internal::scalar_exp_op>::Cost; + Eigen::internal::scalar_exp_op>::Cost; const int64 cost_log = Eigen::internal::functor_traits< - Eigen::internal::scalar_log_op>::Cost; + Eigen::internal::scalar_log_op>::Cost; const int64 cost_log_sum_exp = - Eigen::TensorOpCost::AddCost() + cost_exp + cost_log; + Eigen::TensorOpCost::AddCost() + cost_exp + cost_log; const int64 cost = max_seq_len * num_classes * - (cost_exp + Eigen::TensorOpCost::DivCost()) + + (cost_exp + Eigen::TensorOpCost::DivCost()) + max_seq_len * 2 * (2 * num_classes + 1) * (cost_log_sum_exp + cost_log) + max_seq_len * ((2 * num_classes + 1) * cost_log_sum_exp + - num_classes * (cost_exp + Eigen::TensorOpCost::AddCost())); + num_classes * (cost_exp + Eigen::TensorOpCost::AddCost())); Shard(workers->num_threads, workers->workers, batch_size, cost, ComputeLossAndGradients); } else { @@ -274,8 +280,9 @@ Status CTCLossCalculator::CalculateLoss( return Status::OK(); } +template template -Status CTCLossCalculator::PopulateLPrimes( +Status CTCLossCalculator::PopulateLPrimes( bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs, int batch_size, int num_classes, const Vector& seq_len, const LabelSequences& labels, size_t* max_u_prime, @@ -357,6 +364,174 @@ Status CTCLossCalculator::PopulateLPrimes( return Status::OK(); } + +// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3. +// Starting with t = 0 instead of t = 1 used in the text. +// Based on Kanishka's CTC. +template +void CTCLossCalculator::CalculateForwardVariables( + const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, + Matrix* log_alpha) const { + // Number of cols is the number of time steps = number of cols in target + // after the output delay. + log_alpha->setConstant(kLogZero::val); + + int U = l_prime.size(); + int T = log_alpha->cols(); + + CHECK_EQ(U, log_alpha->rows()); + + // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6. + log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_)); + // Below, l_prime[1] == labels[0] + auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_; + log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_)); + + for (int t = 1; t < T; ++t) { + // If there is not enough time to output the remaining labels or + // some labels have been skipped, then let log_alpha(u, t) continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); + ++u) { + // Begin (GravesTh) Eq 7.9 + // Add in the u, t - 1 term. + auto sum_log_alpha = kLogZero::val; + if (ctc_merge_repeated || l_prime[u] == blank_index_) { + sum_log_alpha = log_alpha->coeff(u, t - 1); + } + + // Add in the u - 1, t - 1 term. + if (u > 0) { + sum_log_alpha = + LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1)); + } + + // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2). + if (u > 1) { + const bool matching_labels_merge = + ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]); + if (l_prime[u] != blank_index_ && !matching_labels_merge) { + sum_log_alpha = + LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1)); + } + } + // Multiply the summed alphas with the activation log probability. + log_alpha->coeffRef(u, t) = + log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; + } // End (GravesTh) Eq 7.9. + } +} + +// Calculates the beta(t, u) as described in (GravesTh) Section 7.3. +template +void CTCLossCalculator::CalculateBackwardVariables( + const std::vector& l_prime, const Matrix& y, bool ctc_merge_repeated, + Matrix* log_beta) const { + // Number of cols is the number of time steps = number of cols in target. + // Matrix log_beta = + // Matrix::Constant(l_prime.size(), y.cols() - output_delay_, + // kLogZero); + log_beta->setConstant(kLogZero::val); + int T = log_beta->cols(); + int U = l_prime.size(); + CHECK_EQ(U, log_beta->rows()); + + // Initial beta values in (GravesTh) Eq 7.13: log of probability 1. + for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0; + + for (int t = T - 1 - 1; t >= 0; --t) { + // If there is not enough time to output the remaining labels or + // some labels have been skipped, then let log_beta(u, t) continue to + // be kLogZero. + for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); + ++u) { + // Begin (GravesTh) Eq 7.15 + // Add in the u, t + 1 term. + if (ctc_merge_repeated || l_prime[u] == blank_index_) { + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u, t + 1) + + log(y(l_prime[u], output_delay_ + t + 1))); + } + + // Add in the u + 1, t + 1 term. + if (u + 1 < U) { + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u + 1, t + 1) + + log(y(l_prime[u + 1], output_delay_ + t + 1))); + } + + // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2). + if (u + 2 < U) { + const bool matching_labels_merge = + ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]); + if (l_prime[u] != blank_index_ && !matching_labels_merge) { + // Add in u + 2 term. + log_beta->coeffRef(u, t) = + LogSumExp(log_beta->coeff(u, t), + log_beta->coeff(u + 2, t + 1) + + log(y(l_prime[u + 2], output_delay_ + t + 1))); + } + } // End (GravesTh) Eq. 7.15 + } + } +} + +// Using (GravesTh) Eq 7.26 & 7.34. +template +void CTCLossCalculator::CalculateGradient(const std::vector& l_prime, + const Matrix& y, + const Matrix& log_alpha, + const Matrix& log_beta, + TT log_p_z_x, Matrix* dy) const { + // Only working with the leftmost part of dy for this batch element. + auto dy_b = dy->leftCols(y.cols()); + + // It is possible that no valid path is found if the activations for the + // targets are zero. + if (log_p_z_x == kLogZero::val) { + LOG(WARNING) << "No valid path found."; + dy_b = y; + return; + } + + int L = y.rows(); + int T = y.cols(); + int U = l_prime.size(); + + for (int t = 0; t < T - output_delay_; ++t) { + Array prob_sum(L); + prob_sum.setConstant(kLogZero::val); + + for (int u = 0; u < U; ++u) { + int l = l_prime[u]; + prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t)); + } + + for (int l = 0; l < L; ++l) { + // Negative term in (GravesTh) Eq 7.28. + auto negative_term = expf(prob_sum[l] - log_p_z_x); + + dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term; + } + } +} + +template +void CTCLossCalculator::GetLPrimeIndices(const std::vector& l, + std::vector* l_prime) const { + // Assumption is that l_prime is empty. + l_prime->reserve(2 * l.size() + 1); + + for (auto label : l) { + l_prime->push_back(blank_index_); + l_prime->push_back(label); + } + // Add final blank to l'. + l_prime->push_back(blank_index_); +} + } // namespace ctc } // namespace tensorflow diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index df0de926d9a..514378e5bb3 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -23,18 +23,24 @@ limitations under the License. namespace tensorflow { namespace ctc { -const float kLogZero = -std::numeric_limits::infinity(); +template +struct kLogZero +{ + static constexpr T val = -std::numeric_limits::infinity(); +}; // Add logarithmic probabilities using: // ln(a + b) = ln(a) + ln(1 + exp(ln(b) - ln(a))) // The two inputs are assumed to be log probabilities. // (GravesTh) Eq. 7.18 -inline float LogSumExp(float log_prob_1, float log_prob_2) { +template +inline T LogSumExp(T log_prob_1, T log_prob_2) { + //const T kLogZero = -std::numeric_limits::infinity(); // Always have 'b' be the smaller number to avoid the exponential from // blowing up. - if (log_prob_1 == kLogZero) { + if (log_prob_1 == kLogZero::val) { return log_prob_2; - } else if (log_prob_2 == kLogZero) { + } else if (log_prob_2 == kLogZero::val) { return log_prob_1; } else { return (log_prob_1 > log_prob_2)