Deduplicate Alphabet implementations, use C++ one everywhere
This commit is contained in:
parent
f82c77392d
commit
a84abf813c
@ -75,6 +75,7 @@ cc_library(
|
||||
"ctcdecode/scorer.cpp",
|
||||
"ctcdecode/path_trie.cpp",
|
||||
"ctcdecode/path_trie.h",
|
||||
"alphabet.cc",
|
||||
] + OPENFST_SOURCES_PLATFORM,
|
||||
hdrs = [
|
||||
"ctcdecode/ctc_beam_search_decoder.h",
|
||||
@ -86,13 +87,17 @@ cc_library(
|
||||
".",
|
||||
"ctcdecode/third_party/ThreadPool",
|
||||
] + OPENFST_INCLUDES_PLATFORM,
|
||||
deps = [":kenlm"]
|
||||
deps = [":kenlm"],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_shared_object(
|
||||
name = "libdeepspeech.so",
|
||||
srcs = [
|
||||
"alphabet.h",
|
||||
"deepspeech.cc",
|
||||
"deepspeech.h",
|
||||
"deepspeech_errors.cc",
|
||||
@ -203,6 +208,11 @@ cc_binary(
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@boost//:program_options",
|
||||
],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
@ -221,10 +231,5 @@ cc_binary(
|
||||
"trie_load.cc",
|
||||
],
|
||||
copts = ["-std=c++11"],
|
||||
linkopts = [
|
||||
"-lm",
|
||||
"-ldl",
|
||||
"-pthread",
|
||||
],
|
||||
deps = [":decoder"],
|
||||
)
|
||||
|
154
native_client/alphabet.cc
Normal file
154
native_client/alphabet.cc
Normal file
@ -0,0 +1,154 @@
|
||||
#include "alphabet.h"
|
||||
#include "ctcdecode/decoder_utils.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
int
|
||||
Alphabet::init(const char *config_file)
|
||||
{
|
||||
std::ifstream in(config_file, std::ios::in);
|
||||
if (!in) {
|
||||
return 1;
|
||||
}
|
||||
unsigned int label = 0;
|
||||
space_label_ = -2;
|
||||
for (std::string line; std::getline(in, line);) {
|
||||
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||
line = '#';
|
||||
} else if (line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
//TODO: we should probably do something more i18n-aware here
|
||||
if (line == " ") {
|
||||
space_label_ = label;
|
||||
}
|
||||
label_to_str_[label] = line;
|
||||
str_to_label_[line] = label;
|
||||
++label;
|
||||
}
|
||||
size_ = label;
|
||||
in.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::Serialize()
|
||||
{
|
||||
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||
// encoded bytes with the label.
|
||||
std::stringstream out;
|
||||
|
||||
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||
uint16_t size = size_;
|
||||
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||
|
||||
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||
uint16_t key = it->first;
|
||||
string str = it->second;
|
||||
uint16_t len = str.length();
|
||||
// Then we write the key as uint16_t, followed by the length of the value
|
||||
// as uint16_t, followed by `length` bytes (the value itself).
|
||||
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||
out.write(str.data(), len);
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
int
|
||||
Alphabet::Deserialize(const char* buffer, const int buffer_size)
|
||||
{
|
||||
// See util/text.py for an explanation of the serialization format.
|
||||
int offset = 0;
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t size = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
size_ = size;
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t label = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
|
||||
if (buffer_size - offset < val_len) {
|
||||
return 1;
|
||||
}
|
||||
std::string val(buffer+offset, val_len);
|
||||
offset += val_len;
|
||||
|
||||
label_to_str_[label] = val;
|
||||
str_to_label_[val] = label;
|
||||
|
||||
if (val == " ") {
|
||||
space_label_ = label;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::DecodeSingle(unsigned int label) const
|
||||
{
|
||||
auto it = label_to_str_.find(label);
|
||||
if (it != label_to_str_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
std::cerr << "Invalid label " << label << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int
|
||||
Alphabet::EncodeSingle(const std::string& string) const
|
||||
{
|
||||
auto it = str_to_label_.find(string);
|
||||
if (it != str_to_label_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
std::cerr << "Invalid string " << string << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::Decode(const std::vector<unsigned int>& input) const
|
||||
{
|
||||
std::string word;
|
||||
for (auto ind : input) {
|
||||
word += DecodeSingle(ind);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
std::string
|
||||
Alphabet::Decode(const unsigned int* input, int length) const
|
||||
{
|
||||
std::string word;
|
||||
for (int i = 0; i < length; ++i) {
|
||||
word += DecodeSingle(input[i]);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<unsigned int>
|
||||
Alphabet::Encode(const std::string& input) const
|
||||
{
|
||||
std::vector<unsigned int> result;
|
||||
for (auto cp : split_into_codepoints(input)) {
|
||||
result.push_back(EncodeSingle(cp));
|
||||
}
|
||||
return result;
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
#ifndef ALPHABET_H
|
||||
#define ALPHABET_H
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
@ -18,116 +15,15 @@ public:
|
||||
Alphabet() = default;
|
||||
Alphabet(const Alphabet&) = default;
|
||||
Alphabet& operator=(const Alphabet&) = default;
|
||||
virtual ~Alphabet() = default;
|
||||
|
||||
virtual int init(const char *config_file) {
|
||||
std::ifstream in(config_file, std::ios::in);
|
||||
if (!in) {
|
||||
return 1;
|
||||
}
|
||||
unsigned int label = 0;
|
||||
space_label_ = -2;
|
||||
for (std::string line; std::getline(in, line);) {
|
||||
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||
line = '#';
|
||||
} else if (line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
//TODO: we should probably do something more i18n-aware here
|
||||
if (line == " ") {
|
||||
space_label_ = label;
|
||||
}
|
||||
label_to_str_[label] = line;
|
||||
str_to_label_[line] = label;
|
||||
++label;
|
||||
}
|
||||
size_ = label;
|
||||
in.close();
|
||||
return 0;
|
||||
}
|
||||
virtual int init(const char *config_file);
|
||||
|
||||
std::string serialize() {
|
||||
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||
// encoded bytes with the label.
|
||||
std::stringstream out;
|
||||
// Serialize alphabet into a binary buffer.
|
||||
std::string Serialize();
|
||||
|
||||
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||
uint16_t size = size_;
|
||||
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||
|
||||
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||
uint16_t key = it->first;
|
||||
string str = it->second;
|
||||
uint16_t len = str.length();
|
||||
// Then we write the key as uint16_t, followed by the length of the value
|
||||
// as uint16_t, followed by `length` bytes (the value itself).
|
||||
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||
out.write(str.data(), len);
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
int deserialize(const char* buffer, const int buffer_size) {
|
||||
// See util/text.py for an explanation of the serialization format.
|
||||
int offset = 0;
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t size = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
size_ = size;
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t label = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
|
||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||
return 1;
|
||||
}
|
||||
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
||||
offset += sizeof(uint16_t);
|
||||
|
||||
if (buffer_size - offset < val_len) {
|
||||
return 1;
|
||||
}
|
||||
std::string val(buffer+offset, val_len);
|
||||
offset += val_len;
|
||||
|
||||
label_to_str_[label] = val;
|
||||
str_to_label_[val] = label;
|
||||
|
||||
if (val == " ") {
|
||||
space_label_ = label;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
const std::string& StringFromLabel(unsigned int label) const {
|
||||
auto it = label_to_str_.find(label);
|
||||
if (it != label_to_str_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
std::cerr << "Invalid label " << label << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int LabelFromString(const std::string& string) const {
|
||||
auto it = str_to_label_.find(string);
|
||||
if (it != str_to_label_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
std::cerr << "Invalid string " << string << std::endl;
|
||||
abort();
|
||||
}
|
||||
}
|
||||
// Deserialize alphabet from a binary buffer.
|
||||
int Deserialize(const char* buffer, const int buffer_size);
|
||||
|
||||
size_t GetSize() const {
|
||||
return size_;
|
||||
@ -141,14 +37,22 @@ public:
|
||||
return space_label_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string LabelsToString(const std::vector<T>& input) const {
|
||||
std::string word;
|
||||
for (auto ind : input) {
|
||||
word += StringFromLabel(ind);
|
||||
}
|
||||
return word;
|
||||
}
|
||||
// Decode a single label into a string.
|
||||
std::string DecodeSingle(unsigned int label) const;
|
||||
|
||||
// Encode a single character/output class into a label.
|
||||
unsigned int EncodeSingle(const std::string& string) const;
|
||||
|
||||
// Decode a sequence of labels into a string.
|
||||
std::string Decode(const std::vector<unsigned int>& input) const;
|
||||
|
||||
// We provide a C-style overload for accepting NumPy arrays as input, since
|
||||
// the NumPy library does not have built-in typemaps for std::vector<T>.
|
||||
std::string Decode(const unsigned int* input, int length) const;
|
||||
|
||||
// Encode a sequence of character/output classes into a sequence of labels.
|
||||
// Characters are assumed to always take a single Unicode codepoint.
|
||||
std::vector<unsigned int> Encode(const std::string& input) const;
|
||||
|
||||
protected:
|
||||
size_t size_;
|
||||
@ -163,14 +67,16 @@ public:
|
||||
UTF8Alphabet() {
|
||||
size_ = 255;
|
||||
space_label_ = ' ' - 1;
|
||||
for (int i = 0; i < size_; ++i) {
|
||||
for (size_t i = 0; i < size_; ++i) {
|
||||
std::string val(1, i+1);
|
||||
label_to_str_[i] = val;
|
||||
str_to_label_[val] = i;
|
||||
}
|
||||
}
|
||||
|
||||
int init(const char*) override {}
|
||||
int init(const char*) override {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from . import swigwrapper # pylint: disable=import-self
|
||||
from .swigwrapper import Alphabet
|
||||
from .swigwrapper import UTF8Alphabet
|
||||
|
||||
__version__ = swigwrapper.__version__
|
||||
|
||||
@ -30,24 +30,20 @@ class Scorer(swigwrapper.Scorer):
|
||||
assert beta is not None, 'beta parameter is required'
|
||||
assert scorer_path, 'scorer_path parameter is required'
|
||||
|
||||
serialized = alphabet.serialize()
|
||||
native_alphabet = swigwrapper.Alphabet()
|
||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||
err = self.init(scorer_path, alphabet)
|
||||
if err != 0:
|
||||
raise ValueError('Error when deserializing alphabet.')
|
||||
|
||||
err = self.init(scorer_path.encode('utf-8'),
|
||||
native_alphabet)
|
||||
if err != 0:
|
||||
raise ValueError('Scorer initialization failed with error code {}'.format(err))
|
||||
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
|
||||
|
||||
self.reset_params(alpha, beta)
|
||||
|
||||
def load_lm(self, lm_path):
|
||||
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
||||
|
||||
def save_dictionary(self, save_path, *args, **kwargs):
|
||||
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
||||
class Alphabet(swigwrapper.Alphabet):
|
||||
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||
def __init__(self, config_path):
|
||||
super(Alphabet, self).__init__()
|
||||
err = self.init(config_path)
|
||||
if err != 0:
|
||||
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
@ -79,15 +75,10 @@ def ctc_beam_search_decoder(probs_seq,
|
||||
results, in descending order of the confidence.
|
||||
:rtype: list
|
||||
"""
|
||||
serialized = alphabet.serialize()
|
||||
native_alphabet = swigwrapper.Alphabet()
|
||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||
if err != 0:
|
||||
raise ValueError("Error when deserializing alphabet.")
|
||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||
probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||
scorer)
|
||||
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
||||
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
return beam_results
|
||||
|
||||
|
||||
@ -126,14 +117,9 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
||||
results, in descending order of the confidence.
|
||||
:rtype: list
|
||||
"""
|
||||
serialized = alphabet.serialize()
|
||||
native_alphabet = swigwrapper.Alphabet()
|
||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
||||
if err != 0:
|
||||
raise ValueError("Error when deserializing alphabet.")
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||
batch_beam_results = [
|
||||
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
||||
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||
for beam_results in batch_beam_results
|
||||
]
|
||||
return batch_beam_results
|
||||
|
@ -46,7 +46,8 @@ CTC_DECODER_FILES = [
|
||||
'scorer.cpp',
|
||||
'path_trie.cpp',
|
||||
'decoder_utils.cpp',
|
||||
'workspace_status.cc'
|
||||
'workspace_status.cc',
|
||||
'../alphabet.cc',
|
||||
]
|
||||
|
||||
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
|
||||
|
@ -119,7 +119,7 @@ bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::un
|
||||
}
|
||||
}
|
||||
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
if (dictionary->NumStates() == 0) {
|
||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||
@ -144,7 +144,7 @@ bool add_word_to_dictionary(
|
||||
fst::StdVectorFst *dictionary) {
|
||||
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
|
||||
|
||||
std::vector<int> int_word;
|
||||
std::vector<unsigned int> int_word;
|
||||
|
||||
for (auto &c : characters) {
|
||||
auto int_c = char_map.find(c);
|
||||
|
@ -86,7 +86,7 @@ std::vector<std::string> split_into_codepoints(const std::string &str);
|
||||
std::vector<std::string> split_into_bytes(const std::string &str);
|
||||
|
||||
// Add a word in index to the dicionary of fst
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||
fst::StdVectorFst *dictionary);
|
||||
|
||||
// Return whether a byte is a code point boundary (not a continuation byte).
|
||||
|
@ -8,8 +8,8 @@
|
||||
*/
|
||||
struct Output {
|
||||
double confidence;
|
||||
std::vector<int> tokens;
|
||||
std::vector<int> timesteps;
|
||||
std::vector<unsigned int> tokens;
|
||||
std::vector<unsigned int> timesteps;
|
||||
};
|
||||
|
||||
#endif // OUTPUT_H_
|
||||
|
@ -35,7 +35,7 @@ PathTrie::~PathTrie() {
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
|
||||
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (; child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
@ -102,7 +102,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
|
||||
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
if (parent != nullptr) {
|
||||
@ -114,8 +114,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
PathTrie* stop = this;
|
||||
@ -124,7 +124,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||
}
|
||||
// Recursive call: recurse back until stop condition, then append data in
|
||||
// correct order as we walk back down the stack in the lines below.
|
||||
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
||||
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
||||
}
|
||||
output.push_back(character);
|
||||
@ -135,7 +135,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
||||
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
||||
if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||
*first_byte = (unsigned char)character + 1;
|
||||
return 1;
|
||||
}
|
||||
@ -146,8 +146,8 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||
return 0;
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet)
|
||||
{
|
||||
PathTrie* stop = this;
|
||||
@ -225,7 +225,7 @@ void PathTrie::print(const Alphabet& a) {
|
||||
for (PathTrie* el : chain) {
|
||||
printf("%X ", (unsigned char)(el->character));
|
||||
if (el->character != ROOT_) {
|
||||
tr.append(a.StringFromLabel(el->character));
|
||||
tr.append(a.DecodeSingle(el->character));
|
||||
}
|
||||
}
|
||||
printf("\ntimesteps:\t ");
|
||||
|
@ -21,22 +21,22 @@ public:
|
||||
~PathTrie();
|
||||
|
||||
// get new prefix after appending new char
|
||||
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true);
|
||||
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
|
||||
|
||||
// get the prefix data in correct time order from root to current node
|
||||
void get_path_vec(std::vector<int>& output, std::vector<int>& timesteps);
|
||||
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
|
||||
|
||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet);
|
||||
|
||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
||||
|
||||
// get the prefix data in correct time order from beginning of last word to current node
|
||||
PathTrie* get_prev_word(std::vector<int>& output,
|
||||
std::vector<int>& timesteps,
|
||||
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
||||
std::vector<unsigned int>& timesteps,
|
||||
const Alphabet& alphabet);
|
||||
|
||||
// update log probs
|
||||
@ -64,8 +64,8 @@ public:
|
||||
float log_prob_c;
|
||||
float score;
|
||||
float approx_ctc;
|
||||
int character;
|
||||
int timestep;
|
||||
unsigned int character;
|
||||
unsigned int timestep;
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
@ -73,7 +73,7 @@ private:
|
||||
bool exists_;
|
||||
bool has_dictionary_;
|
||||
|
||||
std::vector<std::pair<int, PathTrie*>> children_;
|
||||
std::vector<std::pair<unsigned int, PathTrie*>> children_;
|
||||
|
||||
// pointer to dictionary of FST
|
||||
std::shared_ptr<FstType> dictionary_;
|
||||
|
@ -65,7 +65,7 @@ void Scorer::setup_char_map()
|
||||
// The initial state of FST is state 0, hence the index of chars in
|
||||
// the FST should start from 1 to avoid the conflict with the initial
|
||||
// state, otherwise wrong decoding results would be given.
|
||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
||||
char_map_[alphabet_.DecodeSingle(i)] = i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -314,11 +314,11 @@ void Scorer::reset_params(float alpha, float beta)
|
||||
this->beta = beta;
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
|
||||
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<unsigned int>& labels)
|
||||
{
|
||||
if (labels.empty()) return {};
|
||||
|
||||
std::string s = alphabet_.LabelsToString(labels);
|
||||
std::string s = alphabet_.Decode(labels);
|
||||
std::vector<std::string> words;
|
||||
if (is_utf8_mode_) {
|
||||
words = split_into_codepoints(s);
|
||||
@ -339,8 +339,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
break;
|
||||
}
|
||||
|
||||
std::vector<int> prefix_vec;
|
||||
std::vector<int> prefix_steps;
|
||||
std::vector<unsigned int> prefix_vec;
|
||||
std::vector<unsigned int> prefix_steps;
|
||||
|
||||
if (is_utf8_mode_) {
|
||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
||||
@ -350,7 +350,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
||||
current_node = new_node->parent;
|
||||
|
||||
// reconstruct word
|
||||
std::string word = alphabet_.LabelsToString(prefix_vec);
|
||||
std::string word = alphabet_.Decode(prefix_vec);
|
||||
ngram.push_back(word);
|
||||
}
|
||||
std::reverse(ngram.begin(), ngram.end());
|
||||
|
@ -73,7 +73,7 @@ public:
|
||||
|
||||
// trransform the labels in index to the vector of words (word based lm) or
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<unsigned int> &labels);
|
||||
|
||||
void set_alphabet(const Alphabet& alphabet);
|
||||
|
||||
|
@ -3,7 +3,6 @@
|
||||
%{
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
#define SWIG_FILE_WITH_INIT
|
||||
#define SWIG_PYTHON_STRICT_BYTE_CHAR
|
||||
#include "workspace_status.h"
|
||||
%}
|
||||
|
||||
@ -19,6 +18,9 @@ import_array();
|
||||
|
||||
namespace std {
|
||||
%template(StringVector) vector<string>;
|
||||
%template(UnsignedIntVector) vector<unsigned int>;
|
||||
%template(OutputVector) vector<Output>;
|
||||
%template(OutputVectorVector) vector<vector<Output>>;
|
||||
}
|
||||
|
||||
%shared_ptr(Scorer);
|
||||
@ -27,6 +29,7 @@ namespace std {
|
||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
||||
%apply (unsigned int* IN_ARRAY1, int DIM1) {(const unsigned int *input, int length)};
|
||||
|
||||
%ignore Scorer::dictionary;
|
||||
|
||||
@ -38,10 +41,6 @@ namespace std {
|
||||
%constant const char* __version__ = ds_version();
|
||||
%constant const char* __git_version__ = ds_git_version();
|
||||
|
||||
%template(IntVector) std::vector<int>;
|
||||
%template(OutputVector) std::vector<Output>;
|
||||
%template(OutputVectorVector) std::vector<std::vector<Output>>;
|
||||
|
||||
// Import only the error code enum definitions from deepspeech.h
|
||||
// We can't just do |%ignore "";| here because it affects this file globally (even
|
||||
// files %include'd above). That causes SWIG to lose destructor information and
|
||||
|
@ -33,7 +33,7 @@ char*
|
||||
ModelState::decode(const DecoderState& state) const
|
||||
{
|
||||
vector<Output> out = state.decode();
|
||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
||||
return strdup(alphabet_.Decode(out[0].tokens).c_str());
|
||||
}
|
||||
|
||||
Metadata*
|
||||
@ -50,7 +50,7 @@ ModelState::decode_metadata(const DecoderState& state,
|
||||
|
||||
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
||||
TokenMetadata token {
|
||||
strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text
|
||||
strdup(alphabet_.DecodeSingle(out[i].tokens[j]).c_str()), // text
|
||||
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
|
||||
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
|
||||
};
|
||||
|
@ -206,7 +206,7 @@ TFLiteModelState::init(const char* model_path)
|
||||
beam_width_ = (unsigned int)(*beam_width);
|
||||
|
||||
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
||||
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||
err = alphabet_.Deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||
if (err != 0) {
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
@ -119,7 +119,7 @@ TFModelState::init(const char* model_path)
|
||||
beam_width_ = (unsigned int)(beam_width);
|
||||
|
||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
||||
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||
if (err != 0) {
|
||||
return DS_ERR_INVALID_ALPHABET;
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import os
|
||||
|
||||
from deepspeech_training.util.text import Alphabet
|
||||
from ds_ctcdecoder import Alphabet
|
||||
|
||||
class TestAlphabetParsing(unittest.TestCase):
|
||||
|
||||
@ -11,12 +11,12 @@ class TestAlphabetParsing(unittest.TestCase):
|
||||
label_id = -1
|
||||
for expected_label, expected_label_id in expected:
|
||||
try:
|
||||
label_id = alphabet.encode(expected_label)
|
||||
label_id = alphabet.Encode(expected_label)
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(label_id, [expected_label_id])
|
||||
try:
|
||||
label = alphabet.decode([expected_label_id])
|
||||
label = alphabet.Decode([expected_label_id])
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(label, expected_label)
|
||||
|
@ -40,7 +40,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
||||
for i, index in enumerate(indices):
|
||||
results[index[0]].append(values[i])
|
||||
# List of strings
|
||||
return [alphabet.decode(res) for res in results]
|
||||
return [alphabet.Decode(res) for res in results]
|
||||
|
||||
|
||||
def evaluate(test_csvs, create_model):
|
||||
|
@ -771,7 +771,7 @@ def export():
|
||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
|
||||
|
||||
if FLAGS.export_language:
|
||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||
|
@ -6,14 +6,15 @@ import tensorflow.compat.v1 as tfv1
|
||||
|
||||
from attrdict import AttrDict
|
||||
from xdg import BaseDirectory as xdg
|
||||
from ds_ctcdecoder import Alphabet, UTF8Alphabet
|
||||
|
||||
from .flags import FLAGS
|
||||
from .gpu import get_available_gpus
|
||||
from .logging import log_error, log_warn
|
||||
from .text import Alphabet, UTF8Alphabet
|
||||
from .helpers import parse_file_size
|
||||
from .augmentations import parse_augmentations
|
||||
|
||||
|
||||
class ConfigSingleton:
|
||||
_config = None
|
||||
|
||||
@ -115,7 +116,7 @@ def initialize_globals():
|
||||
c.n_hidden_3 = c.n_cell_dim
|
||||
|
||||
# Units in the sixth layer = number of characters in the target language plus one
|
||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
||||
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||
|
||||
# Size of audio window in samples
|
||||
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
||||
|
@ -52,12 +52,10 @@ def check_ctcdecoder_version():
|
||||
sys.exit(1)
|
||||
raise e
|
||||
|
||||
decoder_version_s = decoder_version.decode()
|
||||
|
||||
rv = semver.compare(ds_version_s, decoder_version_s)
|
||||
rv = semver.compare(ds_version_s, decoder_version)
|
||||
if rv != 0:
|
||||
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
||||
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
|
||||
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version))
|
||||
sys.exit(1)
|
||||
|
||||
return rv
|
||||
|
@ -3,121 +3,6 @@ from __future__ import absolute_import, division, print_function
|
||||
import numpy as np
|
||||
import struct
|
||||
|
||||
from six.moves import range
|
||||
|
||||
class Alphabet(object):
|
||||
def __init__(self, config_file):
|
||||
self._config_file = config_file
|
||||
self._label_to_str = {}
|
||||
self._str_to_label = {}
|
||||
self._size = 0
|
||||
if config_file:
|
||||
with open(config_file, 'r', encoding='utf-8') as fin:
|
||||
for line in fin:
|
||||
if line[0:2] == '\\#':
|
||||
line = '#\n'
|
||||
elif line[0] == '#':
|
||||
continue
|
||||
self._label_to_str[self._size] = line[:-1] # remove the line ending
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def _string_from_label(self, label):
|
||||
return self._label_to_str[label]
|
||||
|
||||
def _label_from_string(self, string):
|
||||
try:
|
||||
return self._str_to_label[string]
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in \'{}\'! Use ' \
|
||||
'util/check_characters.py to see what characters are in your [train,dev,test].csv transcripts, and ' \
|
||||
'then add all these to \'{}\'.'.format(string, self._config_file, self._config_file)
|
||||
).with_traceback(e.__traceback__)
|
||||
|
||||
def has_char(self, char):
|
||||
return char in self._str_to_label
|
||||
|
||||
def encode(self, string):
|
||||
res = []
|
||||
for char in string:
|
||||
res.append(self._label_from_string(char))
|
||||
return res
|
||||
|
||||
def decode(self, labels):
|
||||
res = ''
|
||||
for label in labels:
|
||||
res += self._string_from_label(label)
|
||||
return res
|
||||
|
||||
def serialize(self):
|
||||
# Serialization format is a sequence of (key, value) pairs, where key is
|
||||
# a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||
# encoded bytes with the label.
|
||||
res = bytearray()
|
||||
|
||||
# We start by writing the number of pairs in the buffer as uint16_t.
|
||||
res += struct.pack('<H', self._size)
|
||||
for key, value in self._label_to_str.items():
|
||||
value = value.encode('utf-8')
|
||||
# struct.pack only takes fixed length strings/buffers, so we have to
|
||||
# construct the correct format string with the length of the encoded
|
||||
# label.
|
||||
res += struct.pack('<HH{}s'.format(len(value)), key, len(value), value)
|
||||
return bytes(res)
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def config_file(self):
|
||||
return self._config_file
|
||||
|
||||
|
||||
class UTF8Alphabet(object):
|
||||
@staticmethod
|
||||
def _string_from_label(_):
|
||||
assert False
|
||||
|
||||
@staticmethod
|
||||
def _label_from_string(_):
|
||||
assert False
|
||||
|
||||
@staticmethod
|
||||
def encode(string):
|
||||
# 0 never happens in the data, so we can shift values by one, use 255 for
|
||||
# the CTC blank, and keep the alphabet size = 256
|
||||
return np.frombuffer(string.encode('utf-8'), np.uint8).astype(np.int32) - 1
|
||||
|
||||
@staticmethod
|
||||
def decode(labels):
|
||||
# And here we need to shift back up
|
||||
return bytes(np.asarray(labels, np.uint8) + 1).decode('utf-8', errors='replace')
|
||||
|
||||
@staticmethod
|
||||
def size():
|
||||
return 255
|
||||
|
||||
@staticmethod
|
||||
def serialize():
|
||||
res = bytearray()
|
||||
res += struct.pack('<h', 255)
|
||||
for i in range(255):
|
||||
# Note that we also shift back up in the mapping constructed here
|
||||
# so that the native client sees the correct byte values when decoding.
|
||||
res += struct.pack('<hh1s', i, 1, bytes([i+1]))
|
||||
return bytes(res)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(buf):
|
||||
size = struct.unpack('<I', buf)[0]
|
||||
assert size == 255
|
||||
return UTF8Alphabet()
|
||||
|
||||
@staticmethod
|
||||
def config_file():
|
||||
return ''
|
||||
|
||||
|
||||
def text_to_char_array(transcript, alphabet, context=''):
|
||||
r"""
|
||||
Given a transcript string, map characters to
|
||||
@ -125,7 +10,7 @@ def text_to_char_array(transcript, alphabet, context=''):
|
||||
Use a string in `context` for adding text to raised exceptions.
|
||||
"""
|
||||
try:
|
||||
transcript = alphabet.encode(transcript)
|
||||
transcript = alphabet.Encode(transcript)
|
||||
if len(transcript) == 0:
|
||||
raise ValueError('While processing {}: Found an empty transcript! '
|
||||
'You must include a transcript for all training data.'
|
||||
|
Loading…
Reference in New Issue
Block a user