From bece65c6f3605f00a72e4163e7e6d5ccda10cd81 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Sep 2017 09:19:44 -0700 Subject: [PATCH] Use a map instead of a vector of Children() in the BeamEntry. The assumption is that since the entries are sparse (they are all populated, but most are never Active()), using the map will save memory and make iterating over the Children() more efficient. PiperOrigin-RevId: 168548814 --- tensorflow/core/util/ctc/ctc_beam_entry.h | 36 ++++++++-------------- tensorflow/core/util/ctc/ctc_beam_search.h | 27 ++++++++-------- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index fc1f0be2740..d30ab3f4dad 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -17,9 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ #include +#include #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -58,30 +60,18 @@ struct BeamEntry { // create a vector of children. The object pointed to by p // cannot be copied and should not be moved, otherwise parent will // become invalid. - BeamEntry(BeamEntry* p, int l, int L, int t) : parent(p), label(l) { - PopulateChildren(L); - } + BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {} inline bool Active() const { return newp.total != kLogZero; } - inline bool HasChildren() const { return !children.empty(); } - void PopulateChildren(int L) { - CHECK(!HasChildren()); - children = std::vector(L); - int ci = 0; - for (auto& c : children) { - // The current object cannot be copied, and should not be moved. - // Otherwise the child's parent will become invalid. - c.parent = this; - c.label = ci; - ++ci; + // Return the child at the given index, or construct a new one in-place if + // none was found. + BeamEntry& GetChild(int ind) { + auto entry = children.emplace(ind, nullptr); + auto& child_entry = entry.first->second; + // If this is a new child, populate the uniqe_ptr. + if (entry.second) { + child_entry.reset(new BeamEntry(this, ind)); } - } - inline std::vector* Children() { - CHECK(HasChildren()); - return &children; - } - inline const std::vector* Children() const { - CHECK(HasChildren()); - return &children; + return *(child_entry.get()); } std::vector LabelSeq(bool merge_repeated) const { std::vector labels; @@ -100,7 +90,7 @@ struct BeamEntry { BeamEntry* parent; int label; - std::vector> children; + gtl::FlatMap>> children; BeamProbability oldp; BeamProbability newp; CTCBeamState state; diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index d1059806d79..f1773bcd95f 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -285,17 +285,14 @@ void CTCBeamSearchDecoder::Step( continue; } - if (!b->HasChildren()) { - b->PopulateChildren(num_classes_ - 1); - } - - for (BeamEntry& c : *b->Children()) { + for (int ind = 0; ind < num_classes_ - 1; ind++) { + // Perform label selection: if input for this label looks very + // unpromising, never evaluate it with a scorer. + if (input(ind) < label_selection_input_min) { + continue; + } + BeamEntry& c = b->GetChild(ind); if (!c.Active()) { - // Perform label selection: if input for this label looks very - // unpromising, never evaluate it with a scorer. - if (input(c.label) < label_selection_input_min) { - continue; - } // Pblank(l=abcd @ t=6) = 0 c.newp.blank = kLogZero; // If new child label is identical to beam label: @@ -320,13 +317,13 @@ void CTCBeamSearchDecoder::Step( } leaves_.push(&c); } else { - // Deactivate child (signal it's not in the beam) + // Deactivate child. c.oldp.Reset(); c.newp.Reset(); } - } // if (!c.Active()) ... - } // for (BeamEntry& c in children... - } // for (BeamEntry* b... + } + } + } // for (BeamEntry* b... } template @@ -335,7 +332,7 @@ void CTCBeamSearchDecoder::Reset() { // This beam root, and all of its children, will be in memory until // the next reset. - beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1)); + beam_root_.reset(new BeamEntry(nullptr, -1)); beam_root_->newp.total = 0.0; // ln(1) beam_root_->newp.blank = 0.0; // ln(1)