diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index fc1f0be2740..d30ab3f4dad 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -17,9 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ #include +#include #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -58,30 +60,18 @@ struct BeamEntry { // create a vector of children. The object pointed to by p // cannot be copied and should not be moved, otherwise parent will // become invalid. - BeamEntry(BeamEntry* p, int l, int L, int t) : parent(p), label(l) { - PopulateChildren(L); - } + BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {} inline bool Active() const { return newp.total != kLogZero; } - inline bool HasChildren() const { return !children.empty(); } - void PopulateChildren(int L) { - CHECK(!HasChildren()); - children = std::vector(L); - int ci = 0; - for (auto& c : children) { - // The current object cannot be copied, and should not be moved. - // Otherwise the child's parent will become invalid. - c.parent = this; - c.label = ci; - ++ci; + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the uniqe_ptr. + if (entry.second) { + child_entry.reset(new BeamEntry(this, ind)); } - } - inline std::vector* Children() { - CHECK(HasChildren()); - return &children; - } - inline const std::vector* Children() const { - CHECK(HasChildren()); - return &children; + return *(child_entry.get()); } std::vector LabelSeq(bool merge_repeated) const { std::vector labels; @@ -100,7 +90,7 @@ struct BeamEntry { BeamEntry* parent; int label; - std::vector> children; + gtl::FlatMap>> children; BeamProbability oldp; BeamProbability newp; CTCBeamState state; diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index d1059806d79..f1773bcd95f 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -285,17 +285,14 @@ void CTCBeamSearchDecoder::Step( continue; } - if (!b->HasChildren()) { - b->PopulateChildren(num_classes_ - 1); - } - - for (BeamEntry& c : *b->Children()) { + for (int ind = 0; ind < num_classes_ - 1; ind++) { + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (input(ind) < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(ind); if (!c.Active()) { - // Perform label selection: if input for this label looks very - // unpromising, never evaluate it with a scorer. - if (input(c.label) < label_selection_input_min) { - continue; - } // Pblank(l=abcd @ t=6) = 0 c.newp.blank = kLogZero; // If new child label is identical to beam label: @@ -320,13 +317,13 @@ void CTCBeamSearchDecoder::Step( } leaves_.push(&c); } else { - // Deactivate child (signal it's not in the beam) + // Deactivate child. c.oldp.Reset(); c.newp.Reset(); } - } // if (!c.Active()) ... - } // for (BeamEntry& c in children... - } // for (BeamEntry* b... + } + } + } // for (BeamEntry* b... } template @@ -335,7 +332,7 @@ void CTCBeamSearchDecoder::Reset() { // This beam root, and all of its children, will be in memory until // the next reset. - beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1)); + beam_root_.reset(new BeamEntry(nullptr, -1)); beam_root_->newp.total = 0.0; // ln(1) beam_root_->newp.blank = 0.0; // ln(1)