From b3e464e0d658c7039b621ef86b1c725ad7f81eac Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jan 2018 20:52:19 -0800 Subject: [PATCH] Added BeamRoot to manage instances of BeamEntry for CTCBeamSearchDecoder to avoid recursive destructor call. PiperOrigin-RevId: 182625992 --- tensorflow/core/util/ctc/ctc_beam_entry.h | 51 +++++++++++++++++----- tensorflow/core/util/ctc/ctc_beam_search.h | 17 +++++--- tensorflow/core/util/ctc/ctc_decoder.h | 3 ++ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index d30ab3f4dad..53087821d7b 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -52,26 +52,25 @@ struct BeamProbability { float label; }; +template +class BeamRoot; + template struct BeamEntry { - // Default constructor does not create a vector of children. - BeamEntry() : parent(nullptr), label(-1) {} - // Constructor giving parent, label, and number of children does - // create a vector of children. The object pointed to by p - // cannot be copied and should not be moved, otherwise parent will - // become invalid. - BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {} + // BeamRoot::AddEntry() serves as the factory method. + friend BeamEntry* BeamRoot::AddEntry( + BeamEntry* p, int l); inline bool Active() const { return newp.total != kLogZero; } // Return the child at the given index, or construct a new one in-place if // none was found. BeamEntry& GetChild(int ind) { auto entry = children.emplace(ind, nullptr); auto& child_entry = entry.first->second; - // If this is a new child, populate the uniqe_ptr. + // If this is a new child, populate the BeamEntry*. if (entry.second) { - child_entry.reset(new BeamEntry(this, ind)); + child_entry = beam_root->AddEntry(this, ind); } - return *(child_entry.get()); + return *child_entry; } std::vector LabelSeq(bool merge_repeated) const { std::vector labels; @@ -90,15 +89,45 @@ struct BeamEntry { BeamEntry* parent; int label; - gtl::FlatMap>> children; + // All instances of child BeamEntry are owned by *beam_root. + gtl::FlatMap*> children; BeamProbability oldp; BeamProbability newp; CTCBeamState state; private: + // Constructor giving parent, label, and the beam_root. + // The object pointed to by p cannot be copied and should not be moved, + // otherwise parent will become invalid. + // This private constructor is only called through the factory method + // BeamRoot::AddEntry(). + BeamEntry(BeamEntry* p, int l, BeamRoot* beam_root) + : parent(p), label(l), beam_root(beam_root) {} + BeamRoot* beam_root; TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); }; +// This class owns all instances of BeamEntry. This is used to avoid recursive +// destructor call during destruction. +template +class BeamRoot { + public: + BeamRoot(BeamEntry* p, int l) { root_entry_ = AddEntry(p, l); } + BeamRoot(const BeamRoot&) = delete; + BeamRoot& operator=(const BeamRoot&) = delete; + + BeamEntry* AddEntry(BeamEntry* p, int l) { + auto* new_entry = new BeamEntry(p, l, this); + beam_entries_.emplace_back(new_entry); + return new_entry; + } + BeamEntry* RootEntry() const { return root_entry_; } + + private: + BeamEntry* root_entry_ = nullptr; + std::vector>> beam_entries_; +}; + // BeamComparer is the default beam comparer provided in CTCBeamSearch. template class BeamComparer { diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 372f25a1434..709c65fc965 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -16,11 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ +#include #include +#include #include +#include #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { // P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3) // but we calculate it recursively for speed purposes. typedef ctc_beam_search::BeamEntry BeamEntry; + typedef ctc_beam_search::BeamRoot BeamRoot; typedef ctc_beam_search::BeamProbability BeamProbability; public: @@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { float label_selection_margin_ = -1; // -1 means unlimited. gtl::TopN leaves_; - std::unique_ptr beam_root_; + std::unique_ptr beam_root_; BaseBeamScorer* beam_scorer_; TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder); @@ -367,15 +372,15 @@ void CTCBeamSearchDecoder::Reset() { // This beam root, and all of its children, will be in memory until // the next reset. - beam_root_.reset(new BeamEntry(nullptr, -1)); - beam_root_->newp.total = 0.0; // ln(1) - beam_root_->newp.blank = 0.0; // ln(1) + beam_root_.reset(new BeamRoot(nullptr, -1)); + beam_root_->RootEntry()->newp.total = 0.0; // ln(1) + beam_root_->RootEntry()->newp.blank = 0.0; // ln(1) // Add the root as the initial leaf. - leaves_.push(beam_root_.get()); + leaves_.push(beam_root_->RootEntry()); // Call initialize state on the root object. - beam_scorer_->InitializeState(&beam_root_->state); + beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); } template diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 5b28aeb70ad..b8bab69053f 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ +#include +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h"