Use a map instead of a vector of Children() in the BeamEntry.
The assumption is that since the entries are sparse (they are all populated, but most are never Active()), using the map will save memory and make iterating over the Children() more efficient. PiperOrigin-RevId: 168548814
This commit is contained in:
parent
0d5ab82cec
commit
bece65c6f3
@ -17,9 +17,11 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/eigen3/Eigen/Core"
|
#include "third_party/eigen3/Eigen/Core"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -58,30 +60,18 @@ struct BeamEntry {
|
|||||||
// create a vector of children. The object pointed to by p
|
// create a vector of children. The object pointed to by p
|
||||||
// cannot be copied and should not be moved, otherwise parent will
|
// cannot be copied and should not be moved, otherwise parent will
|
||||||
// become invalid.
|
// become invalid.
|
||||||
BeamEntry(BeamEntry* p, int l, int L, int t) : parent(p), label(l) {
|
BeamEntry(BeamEntry* p, int l) : parent(p), label(l) {}
|
||||||
PopulateChildren(L);
|
|
||||||
}
|
|
||||||
inline bool Active() const { return newp.total != kLogZero; }
|
inline bool Active() const { return newp.total != kLogZero; }
|
||||||
inline bool HasChildren() const { return !children.empty(); }
|
// Return the child at the given index, or construct a new one in-place if
|
||||||
void PopulateChildren(int L) {
|
// none was found.
|
||||||
CHECK(!HasChildren());
|
BeamEntry& GetChild(int ind) {
|
||||||
children = std::vector<BeamEntry>(L);
|
auto entry = children.emplace(ind, nullptr);
|
||||||
int ci = 0;
|
auto& child_entry = entry.first->second;
|
||||||
for (auto& c : children) {
|
// If this is a new child, populate the uniqe_ptr.
|
||||||
// The current object cannot be copied, and should not be moved.
|
if (entry.second) {
|
||||||
// Otherwise the child's parent will become invalid.
|
child_entry.reset(new BeamEntry(this, ind));
|
||||||
c.parent = this;
|
|
||||||
c.label = ci;
|
|
||||||
++ci;
|
|
||||||
}
|
}
|
||||||
}
|
return *(child_entry.get());
|
||||||
inline std::vector<BeamEntry>* Children() {
|
|
||||||
CHECK(HasChildren());
|
|
||||||
return &children;
|
|
||||||
}
|
|
||||||
inline const std::vector<BeamEntry>* Children() const {
|
|
||||||
CHECK(HasChildren());
|
|
||||||
return &children;
|
|
||||||
}
|
}
|
||||||
std::vector<int> LabelSeq(bool merge_repeated) const {
|
std::vector<int> LabelSeq(bool merge_repeated) const {
|
||||||
std::vector<int> labels;
|
std::vector<int> labels;
|
||||||
@ -100,7 +90,7 @@ struct BeamEntry {
|
|||||||
|
|
||||||
BeamEntry<CTCBeamState>* parent;
|
BeamEntry<CTCBeamState>* parent;
|
||||||
int label;
|
int label;
|
||||||
std::vector<BeamEntry<CTCBeamState>> children;
|
gtl::FlatMap<int, std::unique_ptr<BeamEntry<CTCBeamState>>> children;
|
||||||
BeamProbability oldp;
|
BeamProbability oldp;
|
||||||
BeamProbability newp;
|
BeamProbability newp;
|
||||||
CTCBeamState state;
|
CTCBeamState state;
|
||||||
|
@ -285,17 +285,14 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!b->HasChildren()) {
|
for (int ind = 0; ind < num_classes_ - 1; ind++) {
|
||||||
b->PopulateChildren(num_classes_ - 1);
|
// Perform label selection: if input for this label looks very
|
||||||
}
|
// unpromising, never evaluate it with a scorer.
|
||||||
|
if (input(ind) < label_selection_input_min) {
|
||||||
for (BeamEntry& c : *b->Children()) {
|
continue;
|
||||||
|
}
|
||||||
|
BeamEntry& c = b->GetChild(ind);
|
||||||
if (!c.Active()) {
|
if (!c.Active()) {
|
||||||
// Perform label selection: if input for this label looks very
|
|
||||||
// unpromising, never evaluate it with a scorer.
|
|
||||||
if (input(c.label) < label_selection_input_min) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Pblank(l=abcd @ t=6) = 0
|
// Pblank(l=abcd @ t=6) = 0
|
||||||
c.newp.blank = kLogZero;
|
c.newp.blank = kLogZero;
|
||||||
// If new child label is identical to beam label:
|
// If new child label is identical to beam label:
|
||||||
@ -320,13 +317,13 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
|
|||||||
}
|
}
|
||||||
leaves_.push(&c);
|
leaves_.push(&c);
|
||||||
} else {
|
} else {
|
||||||
// Deactivate child (signal it's not in the beam)
|
// Deactivate child.
|
||||||
c.oldp.Reset();
|
c.oldp.Reset();
|
||||||
c.newp.Reset();
|
c.newp.Reset();
|
||||||
}
|
}
|
||||||
} // if (!c.Active()) ...
|
}
|
||||||
} // for (BeamEntry& c in children...
|
}
|
||||||
} // for (BeamEntry* b...
|
} // for (BeamEntry* b...
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename CTCBeamState, typename CTCBeamComparer>
|
template <typename CTCBeamState, typename CTCBeamComparer>
|
||||||
@ -335,7 +332,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Reset() {
|
|||||||
|
|
||||||
// This beam root, and all of its children, will be in memory until
|
// This beam root, and all of its children, will be in memory until
|
||||||
// the next reset.
|
// the next reset.
|
||||||
beam_root_.reset(new BeamEntry(nullptr, -1, num_classes_ - 1, -1));
|
beam_root_.reset(new BeamEntry(nullptr, -1));
|
||||||
beam_root_->newp.total = 0.0; // ln(1)
|
beam_root_->newp.total = 0.0; // ln(1)
|
||||||
beam_root_->newp.blank = 0.0; // ln(1)
|
beam_root_->newp.blank = 0.0; // ln(1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user