Added BeamRoot to manage instances of BeamEntry for CTCBeamSearchDecoder
to avoid recursive destructor call. PiperOrigin-RevId: 182625992
This commit is contained in:
parent
76d33554c7
commit
b3e464e0d6
@ -52,26 +52,25 @@ struct BeamProbability {
|
|||||||
float label;
|
float label;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <class CTCBeamState>
|
||||||
|
class BeamRoot;
|
||||||
|
|
||||||
template <class CTCBeamState = EmptyBeamState>
|
template <class CTCBeamState = EmptyBeamState>
|
||||||
struct BeamEntry {
|
struct BeamEntry {
|
||||||
// Default constructor does not create a vector of children.
|
// BeamRoot<CTCBeamState>::AddEntry() serves as the factory method.
|
||||||
BeamEntry() : parent(nullptr), label(-1) {}
|
friend BeamEntry<CTCBeamState>* BeamRoot<CTCBeamState>::AddEntry(
|
||||||
// Constructor giving parent, label, and number of children does
|
BeamEntry<CTCBeamState>* p, int l);
|
||||||
// create a vector of children. The object pointed to by p
|
|
||||||
// cannot be copied and should not be moved, otherwise parent will
|
|
||||||
// become invalid.
|
|
||||||
BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
|
|
||||||
inline bool Active() const { return newp.total != kLogZero; }
|
inline bool Active() const { return newp.total != kLogZero; }
|
||||||
// Return the child at the given index, or construct a new one in-place if
|
// Return the child at the given index, or construct a new one in-place if
|
||||||
// none was found.
|
// none was found.
|
||||||
BeamEntry& GetChild(int ind) {
|
BeamEntry& GetChild(int ind) {
|
||||||
auto entry = children.emplace(ind, nullptr);
|
auto entry = children.emplace(ind, nullptr);
|
||||||
auto& child_entry = entry.first->second;
|
auto& child_entry = entry.first->second;
|
||||||
// If this is a new child, populate the uniqe_ptr.
|
// If this is a new child, populate the BeamEntry<CTCBeamState>*.
|
||||||
if (entry.second) {
|
if (entry.second) {
|
||||||
child_entry.reset(new BeamEntry(this, ind));
|
child_entry = beam_root->AddEntry(this, ind);
|
||||||
}
|
}
|
||||||
return *(child_entry.get());
|
return *child_entry;
|
||||||
}
|
}
|
||||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||||
std::vector<int> labels;
|
std::vector<int> labels;
|
||||||
@ -90,15 +89,45 @@ struct BeamEntry {
|
|||||||
|
|
||||||
BeamEntry<CTCBeamState>* parent;
|
BeamEntry<CTCBeamState>* parent;
|
||||||
int label;
|
int label;
|
||||||
gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
|
// All instances of child BeamEntry are owned by *beam_root.
|
||||||
|
gtl::FlatMap<int, BeamEntry<CTCBeamState>*> children;
|
||||||
BeamProbability oldp;
|
BeamProbability oldp;
|
||||||
BeamProbability newp;
|
BeamProbability newp;
|
||||||
CTCBeamState state;
|
CTCBeamState state;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Constructor giving parent, label, and the beam_root.
|
||||||
|
// The object pointed to by p cannot be copied and should not be moved,
|
||||||
|
// otherwise parent will become invalid.
|
||||||
|
// This private constructor is only called through the factory method
|
||||||
|
// BeamRoot<CTCBeamState>::AddEntry().
|
||||||
|
BeamEntry(BeamEntry* p, int l, BeamRoot<CTCBeamState>* beam_root)
|
||||||
|
: parent(p), label(l), beam_root(beam_root) {}
|
||||||
|
BeamRoot<CTCBeamState>* beam_root;
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
|
TF_DISALLOW_COPY_AND_ASSIGN(BeamEntry);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This class owns all instances of BeamEntry. This is used to avoid recursive
|
||||||
|
// destructor call during destruction.
|
||||||
|
template <class CTCBeamState = EmptyBeamState>
|
||||||
|
class BeamRoot {
|
||||||
|
public:
|
||||||
|
BeamRoot(BeamEntry<CTCBeamState>* p, int l) { root_entry_ = AddEntry(p, l); }
|
||||||
|
BeamRoot(const BeamRoot&) = delete;
|
||||||
|
BeamRoot& operator=(const BeamRoot&) = delete;
|
||||||
|
|
||||||
|
BeamEntry<CTCBeamState>* AddEntry(BeamEntry<CTCBeamState>* p, int l) {
|
||||||
|
auto* new_entry = new BeamEntry<CTCBeamState>(p, l, this);
|
||||||
|
beam_entries_.emplace_back(new_entry);
|
||||||
|
return new_entry;
|
||||||
|
}
|
||||||
|
BeamEntry<CTCBeamState>* RootEntry() const { return root_entry_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
BeamEntry<CTCBeamState>* root_entry_ = nullptr;
|
||||||
|
std::vector<std::unique_ptr<BeamEntry<CTCBeamState>>> beam_entries_;
|
||||||
|
};
|
||||||
|
|
||||||
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
|
// BeamComparer is the default beam comparer provided in CTCBeamSearch.
|
||||||
template <class CTCBeamState = EmptyBeamState>
|
template <class CTCBeamState = EmptyBeamState>
|
||||||
class BeamComparer {
|
class BeamComparer {
|
||||||
|
|||||||
@ -16,11 +16,15 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/gtl/top_n.h"
|
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
@ -69,6 +73,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
|||||||
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
|
// P(l=abc? @ t=3) = P(a @ 0)*P(b @ 1)*P(c @ 2)*P(? @ 3)
|
||||||
// but we calculate it recursively for speed purposes.
|
// but we calculate it recursively for speed purposes.
|
||||||
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
|
typedef ctc_beam_search::BeamEntry<CTCBeamState> BeamEntry;
|
||||||
|
typedef ctc_beam_search::BeamRoot<CTCBeamState> BeamRoot;
|
||||||
typedef ctc_beam_search::BeamProbability BeamProbability;
|
typedef ctc_beam_search::BeamProbability BeamProbability;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -142,7 +147,7 @@ class CTCBeamSearchDecoder : public CTCDecoder {
|
|||||||
float label_selection_margin_ = -1; // -1 means unlimited.
|
float label_selection_margin_ = -1; // -1 means unlimited.
|
||||||
|
|
||||||
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
|
gtl::TopN<BeamEntry*, CTCBeamComparer> leaves_;
|
||||||
std::unique_ptr<BeamEntry> beam_root_;
|
std::unique_ptr<BeamRoot> beam_root_;
|
||||||
BaseBeamScorer<CTCBeamState>* beam_scorer_;
|
BaseBeamScorer<CTCBeamState>* beam_scorer_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
|
TF_DISALLOW_COPY_AND_ASSIGN(CTCBeamSearchDecoder);
|
||||||
@ -367,15 +372,15 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
|||||||
|
|
||||||
// This beam root, and all of its children, will be in memory until
|
// This beam root, and all of its children, will be in memory until
|
||||||
// the next reset.
|
// the next reset.
|
||||||
beam_root_.reset(new BeamEntry(nullptr, -1));
|
beam_root_.reset(new BeamRoot(nullptr, -1));
|
||||||
beam_root_->newp.total = 0.0; // ln(1)
|
beam_root_->RootEntry()->newp.total = 0.0; // ln(1)
|
||||||
beam_root_->newp.blank = 0.0; // ln(1)
|
beam_root_->RootEntry()->newp.blank = 0.0; // ln(1)
|
||||||
|
|
||||||
// Add the root as the initial leaf.
|
// Add the root as the initial leaf.
|
||||||
leaves_.push(beam_root_.get());
|
leaves_.push(beam_root_->RootEntry());
|
||||||
|
|
||||||
// Call initialize state on the root object.
|
// Call initialize state on the root object.
|
||||||
beam_scorer_->InitializeState(&beam_root_->state);
|
beam_scorer_->InitializeState(&beam_root_->RootEntry()->state);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||||
|
|||||||
@ -16,6 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user