Use a map instead of a vector of Children() in the BeamEntry.
The assumption is that since the entries are sparse (they are all populated, but most are never Active()), using the map will save memory and make iterating over the Children() more efficient. PiperOrigin-RevId: 168548814
This commit is contained in:
parent
0d5ab82cec
commit
bece65c6f3
@ -17,9 +17,11 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -58,30 +60,18 @@ struct BeamEntry {
|
||||
// create a vector of children. The object pointed to by p
|
||||
// cannot be copied and should not be moved, otherwise parent will
|
||||
// become invalid.
|
||||
BeamEntry(BeamEntry* p, int l, int L, int t) : parent(p), label(l) {
|
||||
PopulateChildren(L);
|
||||
}
|
||||
BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
|
||||
inline bool Active() const { return newp.total != kLogZero; }
|
||||
inline bool HasChildren() const { return !children.empty(); }
|
||||
void PopulateChildren(int L) {
|
||||
CHECK(!HasChildren());
|
||||
children = std::vector<BeamEntry>(L);
|
||||
int ci = 0;
|
||||
for (auto& c : children) {
|
||||
// The current object cannot be copied, and should not be moved.
|
||||
// Otherwise the child's parent will become invalid.
|
||||
c.parent = this;
|
||||
c.label = ci;
|
||||
++ci;
|
||||
// Return the child at the given index, or construct a new one in-place if
|
||||
// none was found.
|
||||
BeamEntry& GetChild(int ind) {
|
||||
auto entry = children.emplace(ind, nullptr);
|
||||
auto& child_entry = entry.first->second;
|
||||
// If this is a new child, populate the uniqe_ptr.
|
||||
if (entry.second) {
|
||||
child_entry.reset(new BeamEntry(this, ind));
|
||||
}
|
||||
}
|
||||
inline std::vector<BeamEntry>* Children() {
|
||||
CHECK(HasChildren());
|
||||
return &children;
|
||||
}
|
||||
inline const std::vector<BeamEntry>* Children() const {
|
||||
CHECK(HasChildren());
|
||||
return &children;
|
||||
return *(child_entry.get());
|
||||
}
|
||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||
std::vector<int> labels;
|
||||
@ -100,7 +90,7 @@ struct BeamEntry {
|
||||
|
||||
BeamEntry<CTCBeamState>* parent;
|
||||
int label;
|
||||
std::vector<BeamEntry<CTCBeamState>> children;
|
||||
gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
|
||||
BeamProbability oldp;
|
||||
BeamProbability newp;
|
||||
CTCBeamState state;
|
||||
|
@ -285,17 +285,14 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!b->HasChildren()) {
|
||||
b->PopulateChildren(num_classes_ - 1);
|
||||
}
|
||||
|
||||
for (BeamEntry& c : *b->Children()) {
|
||||
for (int ind = 0; ind < num_classes_ - 1; ind++) {
|
||||
// Perform label selection: if input for this label looks very
|
||||
// unpromising, never evaluate it with a scorer.
|
||||
if (input(ind) < label_selection_input_min) {
|
||||
continue;
|
||||
}
|
||||
BeamEntry& c = b->GetChild(ind);
|
||||
if (!c.Active()) {
|
||||
// Perform label selection: if input for this label looks very
|
||||
// unpromising, never evaluate it with a scorer.
|
||||
if (input(c.label) < label_selection_input_min) {
|
||||
continue;
|
||||
}
|
||||
// Pblank(l=abcd @ t=6) = 0
|
||||
c.newp.blank = kLogZero;
|
||||
// If new child label is identical to beam label:
|
||||
@ -320,13 +317,13 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
||||
}
|
||||
leaves_.push(&c);
|
||||
} else {
|
||||
// Deactivate child (signal it's not in the beam)
|
||||
// Deactivate child.
|
||||
c.oldp.Reset();
|
||||
c.newp.Reset();
|
||||
}
|
||||
} // if (!c.Active()) ...
|
||||
} // for (BeamEntry& c in children...
|
||||
} // for (BeamEntry* b...
|
||||
}
|
||||
}
|
||||
} // for (BeamEntry* b...
|
||||
}
|
||||
|
||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||
@ -335,7 +332,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
||||
|
||||
// This beam root, and all of its children, will be in memory until
|
||||
// the next reset.
|
||||
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1));
|
||||
beam_root_.reset(new BeamEntry(nullptr, -1));
|
||||
beam_root_->newp.total = 0.0; // ln(1)
|
||||
beam_root_->newp.blank = 0.0; // ln(1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user