Merge pull request #3279 from godefv/decoder_timesteps

The CTC decoder timesteps now corresponds to the timesteps of the most probable CTC path, instead of the earliest timesteps of all possible paths.
This commit is contained in:
Reuben Morais 2020-09-17 20:31:05 +02:00 committed by GitHub
commit cc62aa2eb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 277 additions and 31 deletions

View File

@ -100,6 +100,7 @@ cc_library(
includes = [
".",
"ctcdecode/third_party/ThreadPool",
"ctcdecode/third_party/object_pool",
] + OPENFST_INCLUDES_PLATFORM,
deps = [":kenlm"],
linkopts = [

View File

@ -26,7 +26,8 @@ INCLUDES = [
'..',
'../kenlm',
OPENFST_DIR + '/src/include',
'third_party/ThreadPool'
'third_party/ThreadPool',
'third_party/object_pool'
]
KENLM_FILES = (glob.glob('../kenlm/util/*.cc')

View File

@ -35,6 +35,7 @@ DecoderState::init(const Alphabet& alphabet,
PathTrie *root = new PathTrie;
root->score = root->log_prob_b_prev = 0.0;
prefix_root_.reset(root);
prefix_root_->timesteps = &timestep_tree_root_;
prefixes_.push_back(root);
if (ext_scorer && (bool)(ext_scorer_->dictionary)) {
@ -96,24 +97,46 @@ DecoderState::next(const double *probs,
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
break;
}
if (prefix->score == -NUM_FLT_INF) {
continue;
}
assert(prefix->timesteps != nullptr);
// blank
if (c == blank_id_) {
// compute probability of current path
float log_p = log_prob_c + prefix->score;
// combine current path with previous ones with the same prefix
// the blank label comes last, so we can compare log_prob_nb_cur with log_p
if (prefix->log_prob_nb_cur < log_p) {
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_b_cur =
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
log_sum_exp(prefix->log_prob_b_cur, log_p);
continue;
}
// repeated character
if (c == prefix->character) {
// compute probability of current path
float log_p = log_prob_c + prefix->log_prob_nb_prev;
// combine current path with previous ones with the same prefix
if (prefix->log_prob_nb_cur < log_p) {
// keep current timesteps
prefix->previous_timesteps = nullptr;
}
prefix->log_prob_nb_cur = log_sum_exp(
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
prefix->log_prob_nb_cur, log_p);
}
// get new prefix
auto prefix_new = prefix->get_path_trie(c, abs_time_step_, log_prob_c);
auto prefix_new = prefix->get_path_trie(c, log_prob_c);
if (prefix_new != nullptr) {
// compute probability of current path
float log_p = -NUM_FLT_INF;
if (c == prefix->character &&
@ -144,6 +167,13 @@ DecoderState::next(const double *probs,
}
}
// combine current path with previous ones with the same prefix
if (prefix_new->log_prob_nb_cur < log_p) {
// record data needed to update timesteps
// the actual update will be done if nothing better is found
prefix_new->previous_timesteps = prefix->timesteps;
prefix_new->new_timestep = abs_time_step_;
}
prefix_new->log_prob_nb_cur =
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
}
@ -207,7 +237,9 @@ DecoderState::decode(size_t num_results) const
for (size_t i = 0; i < num_returned; ++i) {
Output output;
prefixes_copy[i]->get_path_vec(output.tokens, output.timesteps);
prefixes_copy[i]->get_path_vec(output.tokens);
output.timesteps = get_history(prefixes_copy[i]->timesteps, &timestep_tree_root_);
assert(output.tokens.size() == output.timesteps.size());
output.confidence = scores[prefixes_copy[i]];
outputs.push_back(output);
}

View File

@ -21,6 +21,7 @@ class DecoderState {
std::shared_ptr<Scorer> ext_scorer_;
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0};
public:
DecoderState() = default;

View File

@ -18,7 +18,6 @@ PathTrie::PathTrie() {
ROOT_ = -1;
character = ROOT_;
timestep = 0;
exists_ = true;
parent = nullptr;
@ -35,7 +34,7 @@ PathTrie::~PathTrie() {
}
}
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
PathTrie* PathTrie::get_path_trie(unsigned int new_char, float cur_log_prob_c, bool reset) {
auto child = children_.begin();
for (; child != children_.end(); ++child) {
if (child->first == new_char) {
@ -67,7 +66,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->dictionary_ = dictionary_;
new_path->has_dictionary_ = true;
@ -93,7 +91,6 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
} else {
PathTrie* new_path = new PathTrie;
new_path->character = new_char;
new_path->timestep = new_timestep;
new_path->parent = this;
new_path->log_prob_c = cur_log_prob_c;
children_.push_back(std::make_pair(new_char, new_path));
@ -102,20 +99,18 @@ PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timest
}
}
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
void PathTrie::get_path_vec(std::vector<unsigned int>& output) {
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
parent->get_path_vec(output, timesteps);
parent->get_path_vec(output);
}
if (character != ROOT_) {
output.push_back(character);
timesteps.push_back(timestep);
}
}
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
@ -125,10 +120,9 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
stop = parent->get_prev_grapheme(output, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);
return stop;
}
@ -147,7 +141,6 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
}
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
@ -157,14 +150,18 @@ PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
stop = parent->get_prev_word(output, timesteps, alphabet);
stop = parent->get_prev_word(output, alphabet);
}
output.push_back(character);
timesteps.push_back(timestep);
return stop;
}
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
// previous_timesteps might point to ancestors' timesteps
// therefore, children must be uptaded first
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
if (exists_) {
log_prob_b_prev = log_prob_b_cur;
log_prob_nb_prev = log_prob_nb_cur;
@ -173,11 +170,23 @@ void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
log_prob_nb_cur = -NUM_FLT_INF;
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
if (previous_timesteps != nullptr) {
timesteps = nullptr;
for (auto const& child : previous_timesteps->children) {
if (child->data == new_timestep) {
timesteps = child.get();
break;
}
}
if (timesteps == nullptr) {
timesteps = add_child(previous_timesteps, new_timestep);
}
}
previous_timesteps = nullptr;
output.push_back(this);
}
for (auto child : children_) {
child.second->iterate_to_vec(output);
}
}
void PathTrie::remove() {
@ -229,8 +238,8 @@ void PathTrie::print(const Alphabet& a) {
}
}
printf("\ntimesteps:\t ");
for (PathTrie* el : chain) {
printf("%d ", el->timestep);
for (unsigned int timestep : get_history(timesteps)) {
printf("%d ", timestep);
}
printf("\n");
printf("transcript:\t %s\n", tr.c_str());

View File

@ -9,6 +9,34 @@
#include "fst/fstlib.h"
#include "alphabet.h"
#include "object_pool.h"
/* Tree structure with parent and children information
* It is used to store the timesteps data for the PathTrie below
*/
template<class DataT>
struct TreeNode {
TreeNode<DataT>* parent;
std::vector<std::unique_ptr< TreeNode<DataT>, godefv::object_pool_deleter_t<TreeNode<DataT>> >> children;
DataT data;
TreeNode(TreeNode<DataT>* parent_, DataT const& data_): parent{parent_}, data{data_} {}
};
/* Creates a new TreeNode<NodeDataT> with given data as a child to the given node.
* Returns a pointer to the created node. This pointer remains valid as long as the child is not destroyed.
*/
template<class NodeDataT, class ChildDataT>
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* tree_node, ChildDataT&& data);
/* Returns the sequence of tree node's data from the given root (exclusive) to the given tree_node (inclusive).
* By default (if no root is provided), the full sequence from the root of the tree is returned.
*/
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root = nullptr);
using TimestepTreeNode = TreeNode<unsigned int>;
/* Trie tree for prefix storing and manipulating, with a dictionary in
* finite-state transducer for spelling correction.
@ -21,14 +49,13 @@ public:
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
PathTrie* get_path_trie(unsigned int new_char, float log_prob_c, bool reset = true);
// get the prefix data in correct time order from root to current node
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
void get_path_vec(std::vector<unsigned int>& output);
// get the prefix data in correct time order from beginning of last grapheme to current node
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
@ -36,7 +63,6 @@ public:
// get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);
// update log probs
@ -65,7 +91,12 @@ public:
float score;
float approx_ctc;
unsigned int character;
unsigned int timestep;
TimestepTreeNode* timesteps = nullptr;
// timestep temporary storage for each decoding step.
TimestepTreeNode* previous_timesteps = nullptr;
unsigned int new_timestep;
PathTrie* parent;
private:
@ -81,4 +112,28 @@ private:
std::shared_ptr<fst::SortedMatcher<FstType>> matcher_;
};
// TreeNode implementation
template<class NodeDataT, class ChildDataT>
TreeNode<NodeDataT>* add_child(TreeNode<NodeDataT>* tree_node, ChildDataT&& data) {
static thread_local godefv::object_pool_t<TreeNode<NodeDataT>> tree_node_pool;
tree_node->children.push_back(tree_node_pool.make_unique(tree_node, std::forward<ChildDataT>(data)));
return tree_node->children.back().get();
}
template<class DataT>
void get_history_helper(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root, std::vector<DataT>* output) {
if (tree_node == root) return;
assert(tree_node != nullptr);
assert(tree_node->parent != tree_node);
get_history_helper(tree_node->parent, root, output);
output->push_back(tree_node->data);
}
template<class DataT>
std::vector<DataT> get_history(TreeNode<DataT> const* tree_node, TreeNode<DataT> const* root) {
std::vector<DataT> output;
get_history_helper(tree_node, root, &output);
return output;
}
#endif // PATH_TRIE_H

View File

@ -293,12 +293,11 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
}
std::vector<unsigned int> prefix_vec;
std::vector<unsigned int> prefix_steps;
if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
new_node = current_node->get_prev_grapheme(prefix_vec, alphabet_);
} else {
new_node = current_node->get_prev_word(prefix_vec, prefix_steps, alphabet_);
new_node = current_node->get_prev_word(prefix_vec, alphabet_);
}
current_node = new_node->parent;

View File

@ -0,0 +1 @@
This code was imported from https://github.com/godefv/memory on September 17th 2020, commit 5ff1af8ee09ced04990b4863b2c02a8d07f4356a. It's licensed under "CC0 1.0 Universal" license.

View File

@ -0,0 +1,108 @@
#ifndef GODEFV_MEMORY_OBJECT_POOL_H
#define GODEFV_MEMORY_OBJECT_POOL_H
#include "unique_ptr.h"
#include <memory>
#include <vector>
#include <array>
namespace godefv{
// Forward declaration
template<class Object, template<class T> class Allocator = std::allocator, std::size_t ChunkSize = 1024>
class object_pool_t;
//! Custom deleter to recycle the deleted pointers of the object_pool_t.
template<class Object, template<class T> class Allocator = std::allocator, std::size_t ChunkSize = 1024>
struct object_pool_deleter_t{
private:
object_pool_t<Object, Allocator, ChunkSize>* object_pool_ptr;
public:
explicit object_pool_deleter_t(decltype(object_pool_ptr) input_object_pool_ptr) :
object_pool_ptr(input_object_pool_ptr)
{}
void operator()(Object* object_ptr)
{
object_pool_ptr->delete_object(object_ptr);
}
};
//! Allocates instances of Object efficiently (constant time and log((maximum number of Objects used at the same time)/ChunkSize) calls to malloc in the whole lifetime of the object pool).
//! When an instance returned by the object pool is destroyed, its allocated memory is recycled by the object pool. Defragmenting the object pool to free memory is not possible.
template<class Object, template<class T> class Allocator, std::size_t ChunkSize>
class object_pool_t{
//! An object slot is an uninitialized memory space of the same size as Object.
//! It is initially "free". It can then be "used" to construct an Object in place and the pointer to it is returned by the object pool. When the pointer is destroyed, the object slot is "recycled" and can be used again but it is not "free" anymore because "free" object slots are contiguous in memory.
using object_slot_t=std::array<char, sizeof(Object)>;
//! To minimize calls to malloc, the object slots are allocated in chunks.
//! For example, if ChunkSize=8, a chunk may look like this : |used|recycled|used|used|recycled|free|free|free|. In this example, if more than 5 new Object are now asked from the object pool, at least one new chunk of 8 object slots will be allocated.
using chunk_t=std::array<object_slot_t, ChunkSize>;
Allocator<chunk_t> chunk_allocator; //!< This allocator can be used to have aligned memory if required.
std::vector<unique_ptr_t<chunk_t, decltype(chunk_allocator)>> memory_chunks;
//! Recycled object slots are tracked using a stack of pointers to them. When an object slot is recycled, a pointer to it is pushed in constant time. When a new object is constructed, a recycled object slot can be found and poped in constant time.
std::vector<object_slot_t*> recycled_object_slots;
object_slot_t* free_object_slots_begin;
object_slot_t* free_object_slots_end;
//! When a pointer provided by the ObjectPool is deleted, its memory is converted to an object slot to be recycled.
void delete_object(Object* object_ptr){
object_ptr->~Object();
recycled_object_slots.push_back(reinterpret_cast<object_slot_t*>(object_ptr));
}
friend object_pool_deleter_t<Object, Allocator, ChunkSize>;
public:
using object_t = Object;
using deleter_t = object_pool_deleter_t<Object, Allocator, ChunkSize>;
using object_unique_ptr_t = std::unique_ptr<object_t, deleter_t>; //!< The type returned by the object pool.
object_pool_t(Allocator<chunk_t> const& allocator = Allocator<chunk_t>{}) :
chunk_allocator{ allocator },
free_object_slots_begin{ free_object_slots_end } // At the begining, set the 2 iterators at the same value to simulate a full pool.
{}
//! Returns a unique pointer to an object_t using an unused object slot from the object pool.
template<class... Args> object_unique_ptr_t make_unique(Args&&... vars){
auto construct_object_unique_ptr=[&](object_slot_t* object_slot){
return object_unique_ptr_t{ new (reinterpret_cast<object_t*>(object_slot)) object_t{ std::forward<Args>(vars)... } , deleter_t{ this } };
};
// If a recycled object slot is available, use it.
if (!recycled_object_slots.empty())
{
auto object_slot = recycled_object_slots.back();
recycled_object_slots.pop_back();
return construct_object_unique_ptr(object_slot);
}
// If the pool is full: add a new chunk.
if (free_object_slots_begin == free_object_slots_end)
{
memory_chunks.emplace_back(chunk_allocator);
auto& new_chunk = memory_chunks.back();
free_object_slots_begin=new_chunk->data();
free_object_slots_end =free_object_slots_begin+new_chunk->size();
}
// We know that there is now at least one free object slot, use it.
return construct_object_unique_ptr(free_object_slots_begin++);
}
//! Returns the total number of object slots (free, recycled, or used).
std::size_t capacity() const{
return memory_chunks.size()*ChunkSize;
}
//! Returns the number of currently used object slots.
std::size_t size() const{
return capacity() - static_cast<std::size_t>(free_object_slots_end-free_object_slots_begin) - recycled_object_slots.size();
}
};
} /* namespace godefv */
#endif /* GODEFV_MEMORY_OBJECT_POOL_H */

View File

@ -0,0 +1,39 @@
#ifndef GODEFV_MEMORY_ALLOCATED_UNIQUE_PTR_H
#define GODEFV_MEMORY_ALLOCATED_UNIQUE_PTR_H
#include <memory>
namespace godefv{
//! A deleter to deallocate memory which have been allocated by the given allocator.
template<class Allocator>
struct allocator_deleter_t
{
allocator_deleter_t(Allocator const& allocator) :
mAllocator{ allocator }
{}
void operator()(typename Allocator::value_type* ptr)
{
mAllocator.deallocate(ptr, 1);
}
private:
Allocator mAllocator;
};
//! A smart pointer like std::unique_ptr but templated on an allocator instead of a deleter.
//! The deleter is deduced from the given allocator.
template<class T, class Allocator = std::allocator<T>>
struct unique_ptr_t : public std::unique_ptr<T, allocator_deleter_t<Allocator>>
{
using base_t = std::unique_ptr<T, allocator_deleter_t<Allocator>>;
unique_ptr_t(Allocator allocator = Allocator{}) :
base_t{ allocator.allocate(1), allocator_deleter_t<Allocator>{ allocator } }
{}
};
} // namespace godefv
#endif // GODEFV_MEMORY_ALLOCATED_UNIQUE_PTR_H