Added BeamRoot to manage instances of BeamEntry for CTCBeamSearchDecoder
to avoid recursive destructor call. PiperOrigin-RevId: 182625992
This commit is contained in:
		
							parent
							
								
									76d33554c7
								
							
						
					
					
						commit
						b3e464e0d6
					
				| @ -52,26 +52,25 @@ struct BeamProbability { | ||||
|   float label; | ||||
| }; | ||||
| 
 | ||||
| template <class CTCBeamState> | ||||
| class BeamRoot; | ||||
| 
 | ||||
| template <class CTCBeamState = EmptyBeamState> | ||||
| struct BeamEntry { | ||||
|   // Default constructor does not create a vector of children.
 | ||||
|   BeamEntry() : parent(nullptr), label(-1) {} | ||||
|   // Constructor giving parent, label, and number of children does
 | ||||
|   // create a vector of children.  The object pointed to by p
 | ||||
|   // cannot be copied and should not be moved, otherwise parent will
 | ||||
|   // become invalid.
 | ||||
|   BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {} | ||||
|   // BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
 | ||||
|   friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry( | ||||
|       BeamEntry<CTCBeamState>* p, int l); | ||||
|   inline bool Active() const { return newp.total != kLogZero; } | ||||
|   // Return the child at the given index, or construct a new one in-place if
 | ||||
|   // none was found.
 | ||||
|   BeamEntry& GetChild(int ind) { | ||||
|     auto entry = children.emplace(ind, nullptr); | ||||
|     auto& child_entry = entry.first->second; | ||||
|     // If this is a new child, populate the uniqe_ptr.
 | ||||
|     // If this is a new child, populate the BeamEntry<CTCBeamState>*.
 | ||||
|     if (entry.second) { | ||||
|       child_entry.reset(new BeamEntry(this, ind)); | ||||
|       child_entry = beam_root->AddEntry(this, ind); | ||||
|     } | ||||
|     return *(child_entry.get()); | ||||
|     return *child_entry; | ||||
|   } | ||||
|   std::vector<int> LabelSeq(bool merge_repeated) const { | ||||
|     std::vector<int> labels; | ||||
| @ -90,15 +89,45 @@ struct BeamEntry { | ||||
| 
 | ||||
|   BeamEntry<CTCBeamState>* parent; | ||||
|   int label; | ||||
|   gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children; | ||||
|   // All instances of child BeamEntry are owned by *beam_root.
 | ||||
|   gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children; | ||||
|   BeamProbability oldp; | ||||
|   BeamProbability newp; | ||||
|   CTCBeamState state; | ||||
| 
 | ||||
|  private: | ||||
|   // Constructor giving parent, label, and the beam_root.
 | ||||
|   // The object pointed to by p cannot be copied and should not be moved,
 | ||||
|   // otherwise parent will become invalid.
 | ||||
|   // This private constructor is only called through the factory method
 | ||||
|   // BeamRoot<CTCBeamState>::AddEntry().
 | ||||
|   BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root) | ||||
|       : parent(p), label(l), beam_root(beam_root) {} | ||||
|   BeamRoot<CTCBeamState>* beam_root; | ||||
|   TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry); | ||||
| }; | ||||
| 
 | ||||
| // This class owns all instances of BeamEntry.  This is used to avoid recursive
 | ||||
| // destructor call during destruction.
 | ||||
| template <class CTCBeamState = EmptyBeamState> | ||||
| class BeamRoot { | ||||
|  public: | ||||
|   BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); } | ||||
|   BeamRoot(const BeamRoot&) = delete; | ||||
|   BeamRoot& operator=(const BeamRoot&) = delete; | ||||
| 
 | ||||
|   BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) { | ||||
|     auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this); | ||||
|     beam_entries_.emplace_back(new_entry); | ||||
|     return new_entry; | ||||
|   } | ||||
|   BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; } | ||||
| 
 | ||||
|  private: | ||||
|   BeamEntry<CTCBeamState>* root_entry_ = nullptr; | ||||
|   std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_; | ||||
| }; | ||||
| 
 | ||||
| // BeamComparer is the default beam comparer provided in CTCBeamSearch.
 | ||||
| template <class CTCBeamState = EmptyBeamState> | ||||
| class BeamComparer { | ||||
|  | ||||
| @ -16,11 +16,15 @@ limitations under the License. | ||||
| #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ | ||||
| #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <cmath> | ||||
| #include <limits> | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "third_party/eigen3/Eigen/Core" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/lib/core/status.h" | ||||
| #include "tensorflow/core/lib/gtl/top_n.h" | ||||
| #include "tensorflow/core/platform/logging.h" | ||||
| #include "tensorflow/core/platform/macros.h" | ||||
| @ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { | ||||
|   //   P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
 | ||||
|   // but we calculate it recursively for speed purposes.
 | ||||
|   typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry; | ||||
|   typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot; | ||||
|   typedef ctc_beam_search::BeamProbability BeamProbability; | ||||
| 
 | ||||
|  public: | ||||
| @ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { | ||||
|   float label_selection_margin_ = -1;  // -1 means unlimited.
 | ||||
| 
 | ||||
|   gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_; | ||||
|   std::unique_ptr<BeamEntry> beam_root_; | ||||
|   std::unique_ptr<BeamRoot> beam_root_; | ||||
|   BaseBeamScorer<CTCBeamState>* beam_scorer_; | ||||
| 
 | ||||
|   TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder); | ||||
| @ -367,15 +372,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() { | ||||
| 
 | ||||
|   // This beam root, and all of its children, will be in memory until
 | ||||
|   // the next reset.
 | ||||
|   beam_root_.reset(new BeamEntry(nullptr, -1)); | ||||
|   beam_root_->newp.total = 0.0;  // ln(1)
 | ||||
|   beam_root_->newp.blank = 0.0;  // ln(1)
 | ||||
|   beam_root_.reset(new BeamRoot(nullptr, -1)); | ||||
|   beam_root_->RootEntry()->newp.total = 0.0;  // ln(1)
 | ||||
|   beam_root_->RootEntry()->newp.blank = 0.0;  // ln(1)
 | ||||
| 
 | ||||
|   // Add the root as the initial leaf.
 | ||||
|   leaves_.push(beam_root_.get()); | ||||
|   leaves_.push(beam_root_->RootEntry()); | ||||
| 
 | ||||
|   // Call initialize state on the root object.
 | ||||
|   beam_scorer_->InitializeState(&beam_root_->state); | ||||
|   beam_scorer_->InitializeState(&beam_root_->RootEntry()->state); | ||||
| } | ||||
| 
 | ||||
| template <typename CTCBeamState, typename CTCBeamComparer> | ||||
|  | ||||
| @ -16,6 +16,9 @@ limitations under the License. | ||||
| #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ | ||||
| #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ | ||||
| 
 | ||||
| #include <memory> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "third_party/eigen3/Eigen/Core" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/lib/core/status.h" | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user