Added BeamRoot to manage instances of BeamEntry for CTCBeamSearchDecoder
to avoid recursive destructor call. PiperOrigin-RevId: 182625992
This commit is contained in:
parent
76d33554c7
commit
b3e464e0d6
@ -52,26 +52,25 @@ struct BeamProbability {
|
||||
float label;
|
||||
};
|
||||
|
||||
template <class CTCBeamState>
|
||||
class BeamRoot;
|
||||
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
struct BeamEntry {
|
||||
// Default constructor does not create a vector of children.
|
||||
BeamEntry() : parent(nullptr), label(-1) {}
|
||||
// Constructor giving parent, label, and number of children does
|
||||
// create a vector of children. The object pointed to by p
|
||||
// cannot be copied and should not be moved, otherwise parent will
|
||||
// become invalid.
|
||||
BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
|
||||
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
|
||||
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
|
||||
BeamEntry<CTCBeamState>* p, int l);
|
||||
inline bool Active() const { return newp.total != kLogZero; }
|
||||
// Return the child at the given index, or construct a new one in-place if
|
||||
// none was found.
|
||||
BeamEntry& GetChild(int ind) {
|
||||
auto entry = children.emplace(ind, nullptr);
|
||||
auto& child_entry = entry.first->second;
|
||||
// If this is a new child, populate the uniqe_ptr.
|
||||
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
|
||||
if (entry.second) {
|
||||
child_entry.reset(new BeamEntry(this, ind));
|
||||
child_entry = beam_root->AddEntry(this, ind);
|
||||
}
|
||||
return *(child_entry.get());
|
||||
return *child_entry;
|
||||
}
|
||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||
std::vector<int> labels;
|
||||
@ -90,15 +89,45 @@ struct BeamEntry {
|
||||
|
||||
BeamEntry<CTCBeamState>* parent;
|
||||
int label;
|
||||
gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
|
||||
// All instances of child BeamEntry are owned by *beam_root.
|
||||
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
|
||||
BeamProbability oldp;
|
||||
BeamProbability newp;
|
||||
CTCBeamState state;
|
||||
|
||||
private:
|
||||
// Constructor giving parent, label, and the beam_root.
|
||||
// The object pointed to by p cannot be copied and should not be moved,
|
||||
// otherwise parent will become invalid.
|
||||
// This private constructor is only called through the factory method
|
||||
// BeamRoot<CTCBeamState>::AddEntry().
|
||||
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
|
||||
: parent(p), label(l), beam_root(beam_root) {}
|
||||
BeamRoot<CTCBeamState>* beam_root;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
|
||||
};
|
||||
|
||||
// This class owns all instances of BeamEntry. This is used to avoid recursive
|
||||
// destructor call during destruction.
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
class BeamRoot {
|
||||
public:
|
||||
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
|
||||
BeamRoot(const BeamRoot&) = delete;
|
||||
BeamRoot& operator=(const BeamRoot&) = delete;
|
||||
|
||||
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
|
||||
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
|
||||
beam_entries_.emplace_back(new_entry);
|
||||
return new_entry;
|
||||
}
|
||||
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
|
||||
|
||||
private:
|
||||
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
|
||||
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
|
||||
};
|
||||
|
||||
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
|
||||
template <class CTCBeamState = EmptyBeamState>
|
||||
class BeamComparer {
|
||||
|
||||
@ -16,11 +16,15 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
|
||||
// but we calculate it recursively for speed purposes.
|
||||
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
|
||||
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
|
||||
typedef ctc_beam_search::BeamProbability BeamProbability;
|
||||
|
||||
public:
|
||||
@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
||||
float label_selection_margin_ = -1; // -1 means unlimited.
|
||||
|
||||
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
|
||||
std::unique_ptr<BeamEntry> beam_root_;
|
||||
std::unique_ptr<BeamRoot> beam_root_;
|
||||
BaseBeamScorer<CTCBeamState>* beam_scorer_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
|
||||
@ -367,15 +372,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
||||
|
||||
// This beam root, and all of its children, will be in memory until
|
||||
// the next reset.
|
||||
beam_root_.reset(new BeamEntry(nullptr, -1));
|
||||
beam_root_->newp.total = 0.0; // ln(1)
|
||||
beam_root_->newp.blank = 0.0; // ln(1)
|
||||
beam_root_.reset(new BeamRoot(nullptr, -1));
|
||||
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
|
||||
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
|
||||
|
||||
// Add the root as the initial leaf.
|
||||
leaves_.push(beam_root_.get());
|
||||
leaves_.push(beam_root_->RootEntry());
|
||||
|
||||
// Call initialize state on the root object.
|
||||
beam_scorer_->InitializeState(&beam_root_->state);
|
||||
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
|
||||
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user