Added BeamRoot to manage instances of BeamEntry for CTCBeamSearchDecoder

to avoid recursive destructor call.

PiperOrigin-RevId: 182625992
This commit is contained in:
A. Unique TensorFlower 2018-01-19 20:52:19 -08:00 committed by TensorFlower Gardener
parent 76d33554c7
commit b3e464e0d6
3 changed files with 54 additions and 17 deletions

View File

@ -52,26 +52,25 @@ struct BeamProbability {
float label;
};
template <class CTCBeamState>
class BeamRoot;
template <class CTCBeamState = EmptyBeamState>
struct BeamEntry {
// Default constructor does not create a vector of children.
BeamEntry() : parent(nullptr), label(-1) {}
// Constructor giving parent, label, and number of children does
// create a vector of children. The object pointed to by p
// cannot be copied and should not be moved, otherwise parent will
// become invalid.
BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
BeamEntry<CTCBeamState>* p, int l);
inline bool Active() const { return newp.total != kLogZero; }
// Return the child at the given index, or construct a new one in-place if
// none was found.
BeamEntry& GetChild(int ind) {
auto entry = children.emplace(ind, nullptr);
auto& child_entry = entry.first->second;
// If this is a new child, populate the uniqe_ptr.
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
if (entry.second) {
child_entry.reset(new BeamEntry(this, ind));
child_entry = beam_root->AddEntry(this, ind);
}
return *(child_entry.get());
return *child_entry;
}
std::vector<int> LabelSeq(bool merge_repeated) const {
std::vector<int> labels;
@ -90,15 +89,45 @@ struct BeamEntry {
BeamEntry<CTCBeamState>* parent;
int label;
gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
// All instances of child BeamEntry are owned by *beam_root.
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
BeamProbability oldp;
BeamProbability newp;
CTCBeamState state;
private:
// Constructor giving parent, label, and the beam_root.
// The object pointed to by p cannot be copied and should not be moved,
// otherwise parent will become invalid.
// This private constructor is only called through the factory method
// BeamRoot<CTCBeamState>::AddEntry().
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
: parent(p), label(l), beam_root(beam_root) {}
BeamRoot<CTCBeamState>* beam_root;
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
};
// This class owns all instances of BeamEntry. This is used to avoid recursive
// destructor call during destruction.
template <class CTCBeamState = EmptyBeamState>
class BeamRoot {
public:
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
BeamRoot(const BeamRoot&) = delete;
BeamRoot& operator=(const BeamRoot&) = delete;
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
beam_entries_.emplace_back(new_entry);
return new_entry;
}
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
private:
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
};
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
template <class CTCBeamState = EmptyBeamState>
class BeamComparer {

View File

@ -16,11 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#include <algorithm>
#include <cmath>
#include <limits>
#include <memory>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
// but we calculate it recursively for speed purposes.
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
typedef ctc_beam_search::BeamProbability BeamProbability;
public:
@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
float label_selection_margin_ = -1; // -1 means unlimited.
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
std::unique_ptr<BeamEntry> beam_root_;
std::unique_ptr<BeamRoot> beam_root_;
BaseBeamScorer<CTCBeamState>* beam_scorer_;
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
@ -367,15 +372,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
// This beam root, and all of its children, will be in memory until
// the next reset.
beam_root_.reset(new BeamEntry(nullptr, -1));
beam_root_->newp.total = 0.0; // ln(1)
beam_root_->newp.blank = 0.0; // ln(1)
beam_root_.reset(new BeamRoot(nullptr, -1));
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
// Add the root as the initial leaf.
leaves_.push(beam_root_.get());
leaves_.push(beam_root_->RootEntry());
// Call initialize state on the root object.
beam_scorer_->InitializeState(&beam_root_->state);
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
}
template <typename CTCBeamState, typename CTCBeamComparer>

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#include <memory>
#include <vector>
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"