Deduplicate Alphabet implementations, use C++ one everywhere

This commit is contained in:
Reuben Morais 2020-06-26 11:57:10 +02:00
parent f82c77392d
commit a84abf813c
22 changed files with 257 additions and 322 deletions

View File

@ -75,6 +75,7 @@ cc_library(
"ctcdecode/scorer.cpp",
"ctcdecode/path_trie.cpp",
"ctcdecode/path_trie.h",
"alphabet.cc",
] + OPENFST_SOURCES_PLATFORM,
hdrs = [
"ctcdecode/ctc_beam_search_decoder.h",
@ -86,13 +87,17 @@ cc_library(
".",
"ctcdecode/third_party/ThreadPool",
] + OPENFST_INCLUDES_PLATFORM,
deps = [":kenlm"]
deps = [":kenlm"],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
)
tf_cc_shared_object(
name = "libdeepspeech.so",
srcs = [
"alphabet.h",
"deepspeech.cc",
"deepspeech.h",
"deepspeech_errors.cc",
@ -203,6 +208,11 @@ cc_binary(
"@com_google_absl//absl/types:optional",
"@boost//:program_options",
],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
)
cc_binary(
@ -221,10 +231,5 @@ cc_binary(
"trie_load.cc",
],
copts = ["-std=c++11"],
linkopts = [
"-lm",
"-ldl",
"-pthread",
],
deps = [":decoder"],
)

154
native_client/alphabet.cc Normal file
View File

@ -0,0 +1,154 @@
#include "alphabet.h"
#include "ctcdecode/decoder_utils.h"
#include <fstream>
int
Alphabet::init(const char *config_file)
{
std::ifstream in(config_file, std::ios::in);
if (!in) {
return 1;
}
unsigned int label = 0;
space_label_ = -2;
for (std::string line; std::getline(in, line);) {
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
line = '#';
} else if (line[0] == '#') {
continue;
}
//TODO: we should probably do something more i18n-aware here
if (line == " ") {
space_label_ = label;
}
label_to_str_[label] = line;
str_to_label_[line] = label;
++label;
}
size_ = label;
in.close();
return 0;
}
std::string
Alphabet::Serialize()
{
// Serialization format is a sequence of (key, value) pairs, where key is
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
// encoded bytes with the label.
std::stringstream out;
// We start by writing the number of pairs in the buffer as uint16_t.
uint16_t size = size_;
out.write(reinterpret_cast<char*>(&size), sizeof(size));
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
uint16_t key = it->first;
string str = it->second;
uint16_t len = str.length();
// Then we write the key as uint16_t, followed by the length of the value
// as uint16_t, followed by `length` bytes (the value itself).
out.write(reinterpret_cast<char*>(&key), sizeof(key));
out.write(reinterpret_cast<char*>(&len), sizeof(len));
out.write(str.data(), len);
}
return out.str();
}
int
Alphabet::Deserialize(const char* buffer, const int buffer_size)
{
// See util/text.py for an explanation of the serialization format.
int offset = 0;
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t size = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
size_ = size;
for (int i = 0; i < size; ++i) {
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t label = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t val_len = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
if (buffer_size - offset < val_len) {
return 1;
}
std::string val(buffer+offset, val_len);
offset += val_len;
label_to_str_[label] = val;
str_to_label_[val] = label;
if (val == " ") {
space_label_ = label;
}
}
return 0;
}
std::string
Alphabet::DecodeSingle(unsigned int label) const
{
auto it = label_to_str_.find(label);
if (it != label_to_str_.end()) {
return it->second;
} else {
std::cerr << "Invalid label " << label << std::endl;
abort();
}
}
unsigned int
Alphabet::EncodeSingle(const std::string& string) const
{
auto it = str_to_label_.find(string);
if (it != str_to_label_.end()) {
return it->second;
} else {
std::cerr << "Invalid string " << string << std::endl;
abort();
}
}
std::string
Alphabet::Decode(const std::vector<unsigned int>& input) const
{
std::string word;
for (auto ind : input) {
word += DecodeSingle(ind);
}
return word;
}
std::string
Alphabet::Decode(const unsigned int* input, int length) const
{
std::string word;
for (int i = 0; i < length; ++i) {
word += DecodeSingle(input[i]);
}
return word;
}
std::vector<unsigned int>
Alphabet::Encode(const std::string& input) const
{
std::vector<unsigned int> result;
for (auto cp : split_into_codepoints(input)) {
result.push_back(EncodeSingle(cp));
}
return result;
}

View File

@ -1,9 +1,6 @@
#ifndef ALPHABET_H
#define ALPHABET_H
#include <cassert>
#include <fstream>
#include <iostream>
#include <string>
#include <unordered_map>
#include <vector>
@ -18,116 +15,15 @@ public:
Alphabet() = default;
Alphabet(const Alphabet&) = default;
Alphabet& operator=(const Alphabet&) = default;
virtual ~Alphabet() = default;
virtual int init(const char *config_file) {
std::ifstream in(config_file, std::ios::in);
if (!in) {
return 1;
}
unsigned int label = 0;
space_label_ = -2;
for (std::string line; std::getline(in, line);) {
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
line = '#';
} else if (line[0] == '#') {
continue;
}
//TODO: we should probably do something more i18n-aware here
if (line == " ") {
space_label_ = label;
}
label_to_str_[label] = line;
str_to_label_[line] = label;
++label;
}
size_ = label;
in.close();
return 0;
}
virtual int init(const char *config_file);
std::string serialize() {
// Serialization format is a sequence of (key, value) pairs, where key is
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
// encoded bytes with the label.
std::stringstream out;
// Serialize alphabet into a binary buffer.
std::string Serialize();
// We start by writing the number of pairs in the buffer as uint16_t.
uint16_t size = size_;
out.write(reinterpret_cast<char*>(&size), sizeof(size));
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
uint16_t key = it->first;
string str = it->second;
uint16_t len = str.length();
// Then we write the key as uint16_t, followed by the length of the value
// as uint16_t, followed by `length` bytes (the value itself).
out.write(reinterpret_cast<char*>(&key), sizeof(key));
out.write(reinterpret_cast<char*>(&len), sizeof(len));
out.write(str.data(), len);
}
return out.str();
}
int deserialize(const char* buffer, const int buffer_size) {
// See util/text.py for an explanation of the serialization format.
int offset = 0;
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t size = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
size_ = size;
for (int i = 0; i < size; ++i) {
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t label = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
if (buffer_size - offset < sizeof(uint16_t)) {
return 1;
}
uint16_t val_len = *(uint16_t*)(buffer + offset);
offset += sizeof(uint16_t);
if (buffer_size - offset < val_len) {
return 1;
}
std::string val(buffer+offset, val_len);
offset += val_len;
label_to_str_[label] = val;
str_to_label_[val] = label;
if (val == " ") {
space_label_ = label;
}
}
return 0;
}
const std::string& StringFromLabel(unsigned int label) const {
auto it = label_to_str_.find(label);
if (it != label_to_str_.end()) {
return it->second;
} else {
std::cerr << "Invalid label " << label << std::endl;
abort();
}
}
unsigned int LabelFromString(const std::string& string) const {
auto it = str_to_label_.find(string);
if (it != str_to_label_.end()) {
return it->second;
} else {
std::cerr << "Invalid string " << string << std::endl;
abort();
}
}
// Deserialize alphabet from a binary buffer.
int Deserialize(const char* buffer, const int buffer_size);
size_t GetSize() const {
return size_;
@ -141,14 +37,22 @@ public:
return space_label_;
}
template <typename T>
std::string LabelsToString(const std::vector<T>& input) const {
std::string word;
for (auto ind : input) {
word += StringFromLabel(ind);
}
return word;
}
// Decode a single label into a string.
std::string DecodeSingle(unsigned int label) const;
// Encode a single character/output class into a label.
unsigned int EncodeSingle(const std::string& string) const;
// Decode a sequence of labels into a string.
std::string Decode(const std::vector<unsigned int>& input) const;
// We provide a C-style overload for accepting NumPy arrays as input, since
// the NumPy library does not have built-in typemaps for std::vector<T>.
std::string Decode(const unsigned int* input, int length) const;
// Encode a sequence of character/output classes into a sequence of labels.
// Characters are assumed to always take a single Unicode codepoint.
std::vector<unsigned int> Encode(const std::string& input) const;
protected:
size_t size_;
@ -163,14 +67,16 @@ public:
UTF8Alphabet() {
size_ = 255;
space_label_ = ' ' - 1;
for (int i = 0; i < size_; ++i) {
for (size_t i = 0; i < size_; ++i) {
std::string val(1, i+1);
label_to_str_[i] = val;
str_to_label_[val] = i;
}
}
int init(const char*) override {}
int init(const char*) override {
return 0;
}
};

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import, division, print_function
from . import swigwrapper # pylint: disable=import-self
from .swigwrapper import Alphabet
from .swigwrapper import UTF8Alphabet
__version__ = swigwrapper.__version__
@ -30,24 +30,20 @@ class Scorer(swigwrapper.Scorer):
assert beta is not None, 'beta parameter is required'
assert scorer_path, 'scorer_path parameter is required'
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
err = self.init(scorer_path, alphabet)
if err != 0:
raise ValueError('Error when deserializing alphabet.')
err = self.init(scorer_path.encode('utf-8'),
native_alphabet)
if err != 0:
raise ValueError('Scorer initialization failed with error code {}'.format(err))
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
self.reset_params(alpha, beta)
def load_lm(self, lm_path):
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
def save_dictionary(self, save_path, *args, **kwargs):
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
class Alphabet(swigwrapper.Alphabet):
"""Convenience wrapper for Alphabet which calls init in the constructor"""
def __init__(self, config_path):
super(Alphabet, self).__init__()
err = self.init(config_path)
if err != 0:
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
def ctc_beam_search_decoder(probs_seq,
@ -79,15 +75,10 @@ def ctc_beam_search_decoder(probs_seq,
results, in descending order of the confidence.
:rtype: list
"""
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer)
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
return beam_results
@ -126,14 +117,9 @@ def ctc_beam_search_decoder_batch(probs_seq,
results, in descending order of the confidence.
:rtype: list
"""
serialized = alphabet.serialize()
native_alphabet = swigwrapper.Alphabet()
err = native_alphabet.deserialize(serialized, len(serialized))
if err != 0:
raise ValueError("Error when deserializing alphabet.")
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
batch_beam_results = [
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results
]
return batch_beam_results

View File

@ -46,7 +46,8 @@ CTC_DECODER_FILES = [
'scorer.cpp',
'path_trie.cpp',
'decoder_utils.cpp',
'workspace_status.cc'
'workspace_status.cc',
'../alphabet.cc',
]
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):

View File

@ -119,7 +119,7 @@ bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::un
}
}
void add_word_to_fst(const std::vector<int> &word,
void add_word_to_fst(const std::vector<unsigned int> &word,
fst::StdVectorFst *dictionary) {
if (dictionary->NumStates() == 0) {
fst::StdVectorFst::StateId start = dictionary->AddState();
@ -144,7 +144,7 @@ bool add_word_to_dictionary(
fst::StdVectorFst *dictionary) {
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
std::vector<int> int_word;
std::vector<unsigned int> int_word;
for (auto &c : characters) {
auto int_c = char_map.find(c);

View File

@ -86,7 +86,7 @@ std::vector<std::string> split_into_codepoints(const std::string &str);
std::vector<std::string> split_into_bytes(const std::string &str);
// Add a word in index to the dicionary of fst
void add_word_to_fst(const std::vector<int> &word,
void add_word_to_fst(const std::vector<unsigned int> &word,
fst::StdVectorFst *dictionary);
// Return whether a byte is a code point boundary (not a continuation byte).

View File

@ -8,8 +8,8 @@
*/
struct Output {
double confidence;
std::vector<int> tokens;
std::vector<int> timesteps;
std::vector<unsigned int> tokens;
std::vector<unsigned int> timesteps;
};
#endif // OUTPUT_H_

View File

@ -35,7 +35,7 @@ PathTrie::~PathTrie() {
}
}
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
auto child = children_.begin();
for (; child != children_.end(); ++child) {
if (child->first == new_char) {
@ -102,7 +102,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
}
}
void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (parent != nullptr) {
@ -114,8 +114,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
}
}
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
std::vector<int>& timesteps,
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
@ -124,7 +124,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
}
// Recursive call: recurse back until stop condition, then append data in
// correct order as we walk back down the stack in the lines below.
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
}
output.push_back(character);
@ -135,7 +135,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
const Alphabet& alphabet)
{
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
*first_byte = (unsigned char)character + 1;
return 1;
}
@ -146,8 +146,8 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
return 0;
}
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet)
{
PathTrie* stop = this;
@ -225,7 +225,7 @@ void PathTrie::print(const Alphabet& a) {
for (PathTrie* el : chain) {
printf("%X ", (unsigned char)(el->character));
if (el->character != ROOT_) {
tr.append(a.StringFromLabel(el->character));
tr.append(a.DecodeSingle(el->character));
}
}
printf("\ntimesteps:\t ");

View File

@ -21,22 +21,22 @@ public:
~PathTrie();
// get new prefix after appending new char
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true);
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
// get the prefix data in correct time order from root to current node
void get_path_vec(std::vector<int>& output, std::vector<int>& timesteps);
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
// get the prefix data in correct time order from beginning of last grapheme to current node
PathTrie* get_prev_grapheme(std::vector<int>& output,
std::vector<int>& timesteps,
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
// get the prefix data in correct time order from beginning of last word to current node
PathTrie* get_prev_word(std::vector<int>& output,
std::vector<int>& timesteps,
PathTrie* get_prev_word(std::vector<unsigned int>& output,
std::vector<unsigned int>& timesteps,
const Alphabet& alphabet);
// update log probs
@ -64,8 +64,8 @@ public:
float log_prob_c;
float score;
float approx_ctc;
int character;
int timestep;
unsigned int character;
unsigned int timestep;
PathTrie* parent;
private:
@ -73,7 +73,7 @@ private:
bool exists_;
bool has_dictionary_;
std::vector<std::pair<int, PathTrie*>> children_;
std::vector<std::pair<unsigned int, PathTrie*>> children_;
// pointer to dictionary of FST
std::shared_ptr<FstType> dictionary_;

View File

@ -65,7 +65,7 @@ void Scorer::setup_char_map()
// The initial state of FST is state 0, hence the index of chars in
// the FST should start from 1 to avoid the conflict with the initial
// state, otherwise wrong decoding results would be given.
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
char_map_[alphabet_.DecodeSingle(i)] = i + 1;
}
}
@ -314,11 +314,11 @@ void Scorer::reset_params(float alpha, float beta)
this->beta = beta;
}
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<unsigned int>& labels)
{
if (labels.empty()) return {};
std::string s = alphabet_.LabelsToString(labels);
std::string s = alphabet_.Decode(labels);
std::vector<std::string> words;
if (is_utf8_mode_) {
words = split_into_codepoints(s);
@ -339,8 +339,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
break;
}
std::vector<int> prefix_vec;
std::vector<int> prefix_steps;
std::vector<unsigned int> prefix_vec;
std::vector<unsigned int> prefix_steps;
if (is_utf8_mode_) {
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
@ -350,7 +350,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
current_node = new_node->parent;
// reconstruct word
std::string word = alphabet_.LabelsToString(prefix_vec);
std::string word = alphabet_.Decode(prefix_vec);
ngram.push_back(word);
}
std::reverse(ngram.begin(), ngram.end());

View File

@ -73,7 +73,7 @@ public:
// trransform the labels in index to the vector of words (word based lm) or
// the vector of characters (character based lm)
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
std::vector<std::string> split_labels_into_scored_units(const std::vector<unsigned int> &labels);
void set_alphabet(const Alphabet& alphabet);

View File

@ -3,7 +3,6 @@
%{
#include "ctc_beam_search_decoder.h"
#define SWIG_FILE_WITH_INIT
#define SWIG_PYTHON_STRICT_BYTE_CHAR
#include "workspace_status.h"
%}
@ -19,6 +18,9 @@ import_array();
namespace std {
%template(StringVector) vector<string>;
%template(UnsignedIntVector) vector<unsigned int>;
%template(OutputVector) vector<Output>;
%template(OutputVectorVector) vector<vector<Output>>;
}
%shared_ptr(Scorer);
@ -27,6 +29,7 @@ namespace std {
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
%apply (unsigned int* IN_ARRAY1, int DIM1) {(const unsigned int *input, int length)};
%ignore Scorer::dictionary;
@ -38,10 +41,6 @@ namespace std {
%constant const char* __version__ = ds_version();
%constant const char* __git_version__ = ds_git_version();
%template(IntVector) std::vector<int>;
%template(OutputVector) std::vector<Output>;
%template(OutputVectorVector) std::vector<std::vector<Output>>;
// Import only the error code enum definitions from deepspeech.h
// We can't just do |%ignore "";| here because it affects this file globally (even
// files %include'd above). That causes SWIG to lose destructor information and

View File

@ -33,7 +33,7 @@ char*
ModelState::decode(const DecoderState& state) const
{
vector<Output> out = state.decode();
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
return strdup(alphabet_.Decode(out[0].tokens).c_str());
}
Metadata*
@ -50,7 +50,7 @@ ModelState::decode_metadata(const DecoderState& state,
for (int j = 0; j < out[i].tokens.size(); ++j) {
TokenMetadata token {
strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text
strdup(alphabet_.DecodeSingle(out[i].tokens[j]).c_str()), // text
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
};

View File

@ -206,7 +206,7 @@ TFLiteModelState::init(const char* model_path)
beam_width_ = (unsigned int)(*beam_width);
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
err = alphabet_.Deserialize(serialized_alphabet.str, serialized_alphabet.len);
if (err != 0) {
return DS_ERR_INVALID_ALPHABET;
}

View File

@ -119,7 +119,7 @@ TFModelState::init(const char* model_path)
beam_width_ = (unsigned int)(beam_width);
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
if (err != 0) {
return DS_ERR_INVALID_ALPHABET;
}

View File

@ -1,7 +1,7 @@
import unittest
import os
from deepspeech_training.util.text import Alphabet
from ds_ctcdecoder import Alphabet
class TestAlphabetParsing(unittest.TestCase):
@ -11,12 +11,12 @@ class TestAlphabetParsing(unittest.TestCase):
label_id = -1
for expected_label, expected_label_id in expected:
try:
label_id = alphabet.encode(expected_label)
label_id = alphabet.Encode(expected_label)
except KeyError:
pass
self.assertEqual(label_id, [expected_label_id])
try:
label = alphabet.decode([expected_label_id])
label = alphabet.Decode([expected_label_id])
except KeyError:
pass
self.assertEqual(label, expected_label)

View File

@ -40,7 +40,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
for i, index in enumerate(indices):
results[index[0]].append(values[i])
# List of strings
return [alphabet.decode(res) for res in results]
return [alphabet.Decode(res) for res in results]
def evaluate(test_csvs, create_model):

View File

@ -771,7 +771,7 @@ def export():
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
if FLAGS.export_language:
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')

View File

@ -6,14 +6,15 @@ import tensorflow.compat.v1 as tfv1
from attrdict import AttrDict
from xdg import BaseDirectory as xdg
from ds_ctcdecoder import Alphabet, UTF8Alphabet
from .flags import FLAGS
from .gpu import get_available_gpus
from .logging import log_error, log_warn
from .text import Alphabet, UTF8Alphabet
from .helpers import parse_file_size
from .augmentations import parse_augmentations
class ConfigSingleton:
_config = None
@ -115,7 +116,7 @@ def initialize_globals():
c.n_hidden_3 = c.n_cell_dim
# Units in the sixth layer = number of characters in the target language plus one
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
# Size of audio window in samples
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:

View File

@ -52,12 +52,10 @@ def check_ctcdecoder_version():
sys.exit(1)
raise e
decoder_version_s = decoder_version.decode()
rv = semver.compare(ds_version_s, decoder_version_s)
rv = semver.compare(ds_version_s, decoder_version)
if rv != 0:
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version))
sys.exit(1)
return rv

View File

@ -3,121 +3,6 @@ from __future__ import absolute_import, division, print_function
import numpy as np
import struct
from six.moves import range
class Alphabet(object):
def __init__(self, config_file):
self._config_file = config_file
self._label_to_str = {}
self._str_to_label = {}
self._size = 0
if config_file:
with open(config_file, 'r', encoding='utf-8') as fin:
for line in fin:
if line[0:2] == '\\#':
line = '#\n'
elif line[0] == '#':
continue
self._label_to_str[self._size] = line[:-1] # remove the line ending
self._str_to_label[line[:-1]] = self._size
self._size += 1
def _string_from_label(self, label):
return self._label_to_str[label]
def _label_from_string(self, string):
try:
return self._str_to_label[string]
except KeyError as e:
raise KeyError(
'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in \'{}\'! Use ' \
'util/check_characters.py to see what characters are in your [train,dev,test].csv transcripts, and ' \
'then add all these to \'{}\'.'.format(string, self._config_file, self._config_file)
).with_traceback(e.__traceback__)
def has_char(self, char):
return char in self._str_to_label
def encode(self, string):
res = []
for char in string:
res.append(self._label_from_string(char))
return res
def decode(self, labels):
res = ''
for label in labels:
res += self._string_from_label(label)
return res
def serialize(self):
# Serialization format is a sequence of (key, value) pairs, where key is
# a uint16_t and value is a uint16_t length followed by `length` UTF-8
# encoded bytes with the label.
res = bytearray()
# We start by writing the number of pairs in the buffer as uint16_t.
res += struct.pack('<H', self._size)
for key, value in self._label_to_str.items():
value = value.encode('utf-8')
# struct.pack only takes fixed length strings/buffers, so we have to
# construct the correct format string with the length of the encoded
# label.
res += struct.pack('<HH{}s'.format(len(value)), key, len(value), value)
return bytes(res)
def size(self):
return self._size
def config_file(self):
return self._config_file
class UTF8Alphabet(object):
@staticmethod
def _string_from_label(_):
assert False
@staticmethod
def _label_from_string(_):
assert False
@staticmethod
def encode(string):
# 0 never happens in the data, so we can shift values by one, use 255 for
# the CTC blank, and keep the alphabet size = 256
return np.frombuffer(string.encode('utf-8'), np.uint8).astype(np.int32) - 1
@staticmethod
def decode(labels):
# And here we need to shift back up
return bytes(np.asarray(labels, np.uint8) + 1).decode('utf-8', errors='replace')
@staticmethod
def size():
return 255
@staticmethod
def serialize():
res = bytearray()
res += struct.pack('<h', 255)
for i in range(255):
# Note that we also shift back up in the mapping constructed here
# so that the native client sees the correct byte values when decoding.
res += struct.pack('<hh1s', i, 1, bytes([i+1]))
return bytes(res)
@staticmethod
def deserialize(buf):
size = struct.unpack('<I', buf)[0]
assert size == 255
return UTF8Alphabet()
@staticmethod
def config_file():
return ''
def text_to_char_array(transcript, alphabet, context=''):
r"""
Given a transcript string, map characters to
@ -125,7 +10,7 @@ def text_to_char_array(transcript, alphabet, context=''):
Use a string in `context` for adding text to raised exceptions.
"""
try:
transcript = alphabet.encode(transcript)
transcript = alphabet.Encode(transcript)
if len(transcript) == 0:
raise ValueError('While processing {}: Found an empty transcript! '
'You must include a transcript for all training data.'