Deduplicate Alphabet implementations, use C++ one everywhere
This commit is contained in:
parent
f82c77392d
commit
a84abf813c
@ -75,6 +75,7 @@ cc_library(
|
|||||||
"ctcdecode/scorer.cpp",
|
"ctcdecode/scorer.cpp",
|
||||||
"ctcdecode/path_trie.cpp",
|
"ctcdecode/path_trie.cpp",
|
||||||
"ctcdecode/path_trie.h",
|
"ctcdecode/path_trie.h",
|
||||||
|
"alphabet.cc",
|
||||||
] + OPENFST_SOURCES_PLATFORM,
|
] + OPENFST_SOURCES_PLATFORM,
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"ctcdecode/ctc_beam_search_decoder.h",
|
"ctcdecode/ctc_beam_search_decoder.h",
|
||||||
@ -86,13 +87,17 @@ cc_library(
|
|||||||
".",
|
".",
|
||||||
"ctcdecode/third_party/ThreadPool",
|
"ctcdecode/third_party/ThreadPool",
|
||||||
] + OPENFST_INCLUDES_PLATFORM,
|
] + OPENFST_INCLUDES_PLATFORM,
|
||||||
deps = [":kenlm"]
|
deps = [":kenlm"],
|
||||||
|
linkopts = [
|
||||||
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_cc_shared_object(
|
tf_cc_shared_object(
|
||||||
name = "libdeepspeech.so",
|
name = "libdeepspeech.so",
|
||||||
srcs = [
|
srcs = [
|
||||||
"alphabet.h",
|
|
||||||
"deepspeech.cc",
|
"deepspeech.cc",
|
||||||
"deepspeech.h",
|
"deepspeech.h",
|
||||||
"deepspeech_errors.cc",
|
"deepspeech_errors.cc",
|
||||||
@ -203,6 +208,11 @@ cc_binary(
|
|||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"@boost//:program_options",
|
"@boost//:program_options",
|
||||||
],
|
],
|
||||||
|
linkopts = [
|
||||||
|
"-lm",
|
||||||
|
"-ldl",
|
||||||
|
"-pthread",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
@ -221,10 +231,5 @@ cc_binary(
|
|||||||
"trie_load.cc",
|
"trie_load.cc",
|
||||||
],
|
],
|
||||||
copts = ["-std=c++11"],
|
copts = ["-std=c++11"],
|
||||||
linkopts = [
|
|
||||||
"-lm",
|
|
||||||
"-ldl",
|
|
||||||
"-pthread",
|
|
||||||
],
|
|
||||||
deps = [":decoder"],
|
deps = [":decoder"],
|
||||||
)
|
)
|
||||||
|
154
native_client/alphabet.cc
Normal file
154
native_client/alphabet.cc
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
#include "alphabet.h"
|
||||||
|
#include "ctcdecode/decoder_utils.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
int
|
||||||
|
Alphabet::init(const char *config_file)
|
||||||
|
{
|
||||||
|
std::ifstream in(config_file, std::ios::in);
|
||||||
|
if (!in) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
unsigned int label = 0;
|
||||||
|
space_label_ = -2;
|
||||||
|
for (std::string line; std::getline(in, line);) {
|
||||||
|
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||||
|
line = '#';
|
||||||
|
} else if (line[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
//TODO: we should probably do something more i18n-aware here
|
||||||
|
if (line == " ") {
|
||||||
|
space_label_ = label;
|
||||||
|
}
|
||||||
|
label_to_str_[label] = line;
|
||||||
|
str_to_label_[line] = label;
|
||||||
|
++label;
|
||||||
|
}
|
||||||
|
size_ = label;
|
||||||
|
in.close();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Serialize()
|
||||||
|
{
|
||||||
|
// Serialization format is a sequence of (key, value) pairs, where key is
|
||||||
|
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
||||||
|
// encoded bytes with the label.
|
||||||
|
std::stringstream out;
|
||||||
|
|
||||||
|
// We start by writing the number of pairs in the buffer as uint16_t.
|
||||||
|
uint16_t size = size_;
|
||||||
|
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
||||||
|
|
||||||
|
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
||||||
|
uint16_t key = it->first;
|
||||||
|
string str = it->second;
|
||||||
|
uint16_t len = str.length();
|
||||||
|
// Then we write the key as uint16_t, followed by the length of the value
|
||||||
|
// as uint16_t, followed by `length` bytes (the value itself).
|
||||||
|
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
||||||
|
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
||||||
|
out.write(str.data(), len);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
Alphabet::Deserialize(const char* buffer, const int buffer_size)
|
||||||
|
{
|
||||||
|
// See util/text.py for an explanation of the serialization format.
|
||||||
|
int offset = 0;
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t size = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
size_ = size;
|
||||||
|
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t label = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < sizeof(uint16_t)) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
||||||
|
offset += sizeof(uint16_t);
|
||||||
|
|
||||||
|
if (buffer_size - offset < val_len) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
std::string val(buffer+offset, val_len);
|
||||||
|
offset += val_len;
|
||||||
|
|
||||||
|
label_to_str_[label] = val;
|
||||||
|
str_to_label_[val] = label;
|
||||||
|
|
||||||
|
if (val == " ") {
|
||||||
|
space_label_ = label;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::DecodeSingle(unsigned int label) const
|
||||||
|
{
|
||||||
|
auto it = label_to_str_.find(label);
|
||||||
|
if (it != label_to_str_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Invalid label " << label << std::endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned int
|
||||||
|
Alphabet::EncodeSingle(const std::string& string) const
|
||||||
|
{
|
||||||
|
auto it = str_to_label_.find(string);
|
||||||
|
if (it != str_to_label_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
std::cerr << "Invalid string " << string << std::endl;
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Decode(const std::vector<unsigned int>& input) const
|
||||||
|
{
|
||||||
|
std::string word;
|
||||||
|
for (auto ind : input) {
|
||||||
|
word += DecodeSingle(ind);
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string
|
||||||
|
Alphabet::Decode(const unsigned int* input, int length) const
|
||||||
|
{
|
||||||
|
std::string word;
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
word += DecodeSingle(input[i]);
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<unsigned int>
|
||||||
|
Alphabet::Encode(const std::string& input) const
|
||||||
|
{
|
||||||
|
std::vector<unsigned int> result;
|
||||||
|
for (auto cp : split_into_codepoints(input)) {
|
||||||
|
result.push_back(EncodeSingle(cp));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
@ -1,9 +1,6 @@
|
|||||||
#ifndef ALPHABET_H
|
#ifndef ALPHABET_H
|
||||||
#define ALPHABET_H
|
#define ALPHABET_H
|
||||||
|
|
||||||
#include <cassert>
|
|
||||||
#include <fstream>
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -18,116 +15,15 @@ public:
|
|||||||
Alphabet() = default;
|
Alphabet() = default;
|
||||||
Alphabet(const Alphabet&) = default;
|
Alphabet(const Alphabet&) = default;
|
||||||
Alphabet& operator=(const Alphabet&) = default;
|
Alphabet& operator=(const Alphabet&) = default;
|
||||||
|
virtual ~Alphabet() = default;
|
||||||
|
|
||||||
virtual int init(const char *config_file) {
|
virtual int init(const char *config_file);
|
||||||
std::ifstream in(config_file, std::ios::in);
|
|
||||||
if (!in) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
unsigned int label = 0;
|
|
||||||
space_label_ = -2;
|
|
||||||
for (std::string line; std::getline(in, line);) {
|
|
||||||
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
|
||||||
line = '#';
|
|
||||||
} else if (line[0] == '#') {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
//TODO: we should probably do something more i18n-aware here
|
|
||||||
if (line == " ") {
|
|
||||||
space_label_ = label;
|
|
||||||
}
|
|
||||||
label_to_str_[label] = line;
|
|
||||||
str_to_label_[line] = label;
|
|
||||||
++label;
|
|
||||||
}
|
|
||||||
size_ = label;
|
|
||||||
in.close();
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string serialize() {
|
// Serialize alphabet into a binary buffer.
|
||||||
// Serialization format is a sequence of (key, value) pairs, where key is
|
std::string Serialize();
|
||||||
// a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
|
||||||
// encoded bytes with the label.
|
|
||||||
std::stringstream out;
|
|
||||||
|
|
||||||
// We start by writing the number of pairs in the buffer as uint16_t.
|
// Deserialize alphabet from a binary buffer.
|
||||||
uint16_t size = size_;
|
int Deserialize(const char* buffer, const int buffer_size);
|
||||||
out.write(reinterpret_cast<char*>(&size), sizeof(size));
|
|
||||||
|
|
||||||
for (auto it = label_to_str_.begin(); it != label_to_str_.end(); ++it) {
|
|
||||||
uint16_t key = it->first;
|
|
||||||
string str = it->second;
|
|
||||||
uint16_t len = str.length();
|
|
||||||
// Then we write the key as uint16_t, followed by the length of the value
|
|
||||||
// as uint16_t, followed by `length` bytes (the value itself).
|
|
||||||
out.write(reinterpret_cast<char*>(&key), sizeof(key));
|
|
||||||
out.write(reinterpret_cast<char*>(&len), sizeof(len));
|
|
||||||
out.write(str.data(), len);
|
|
||||||
}
|
|
||||||
|
|
||||||
return out.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
int deserialize(const char* buffer, const int buffer_size) {
|
|
||||||
// See util/text.py for an explanation of the serialization format.
|
|
||||||
int offset = 0;
|
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t size = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
size_ = size;
|
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t label = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
|
|
||||||
if (buffer_size - offset < sizeof(uint16_t)) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
uint16_t val_len = *(uint16_t*)(buffer + offset);
|
|
||||||
offset += sizeof(uint16_t);
|
|
||||||
|
|
||||||
if (buffer_size - offset < val_len) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
std::string val(buffer+offset, val_len);
|
|
||||||
offset += val_len;
|
|
||||||
|
|
||||||
label_to_str_[label] = val;
|
|
||||||
str_to_label_[val] = label;
|
|
||||||
|
|
||||||
if (val == " ") {
|
|
||||||
space_label_ = label;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string& StringFromLabel(unsigned int label) const {
|
|
||||||
auto it = label_to_str_.find(label);
|
|
||||||
if (it != label_to_str_.end()) {
|
|
||||||
return it->second;
|
|
||||||
} else {
|
|
||||||
std::cerr << "Invalid label " << label << std::endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
unsigned int LabelFromString(const std::string& string) const {
|
|
||||||
auto it = str_to_label_.find(string);
|
|
||||||
if (it != str_to_label_.end()) {
|
|
||||||
return it->second;
|
|
||||||
} else {
|
|
||||||
std::cerr << "Invalid string " << string << std::endl;
|
|
||||||
abort();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t GetSize() const {
|
size_t GetSize() const {
|
||||||
return size_;
|
return size_;
|
||||||
@ -141,14 +37,22 @@ public:
|
|||||||
return space_label_;
|
return space_label_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
// Decode a single label into a string.
|
||||||
std::string LabelsToString(const std::vector<T>& input) const {
|
std::string DecodeSingle(unsigned int label) const;
|
||||||
std::string word;
|
|
||||||
for (auto ind : input) {
|
// Encode a single character/output class into a label.
|
||||||
word += StringFromLabel(ind);
|
unsigned int EncodeSingle(const std::string& string) const;
|
||||||
}
|
|
||||||
return word;
|
// Decode a sequence of labels into a string.
|
||||||
}
|
std::string Decode(const std::vector<unsigned int>& input) const;
|
||||||
|
|
||||||
|
// We provide a C-style overload for accepting NumPy arrays as input, since
|
||||||
|
// the NumPy library does not have built-in typemaps for std::vector<T>.
|
||||||
|
std::string Decode(const unsigned int* input, int length) const;
|
||||||
|
|
||||||
|
// Encode a sequence of character/output classes into a sequence of labels.
|
||||||
|
// Characters are assumed to always take a single Unicode codepoint.
|
||||||
|
std::vector<unsigned int> Encode(const std::string& input) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
size_t size_;
|
size_t size_;
|
||||||
@ -163,14 +67,16 @@ public:
|
|||||||
UTF8Alphabet() {
|
UTF8Alphabet() {
|
||||||
size_ = 255;
|
size_ = 255;
|
||||||
space_label_ = ' ' - 1;
|
space_label_ = ' ' - 1;
|
||||||
for (int i = 0; i < size_; ++i) {
|
for (size_t i = 0; i < size_; ++i) {
|
||||||
std::string val(1, i+1);
|
std::string val(1, i+1);
|
||||||
label_to_str_[i] = val;
|
label_to_str_[i] = val;
|
||||||
str_to_label_[val] = i;
|
str_to_label_[val] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int init(const char*) override {}
|
int init(const char*) override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import absolute_import, division, print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from . import swigwrapper # pylint: disable=import-self
|
from . import swigwrapper # pylint: disable=import-self
|
||||||
from .swigwrapper import Alphabet
|
from .swigwrapper import UTF8Alphabet
|
||||||
|
|
||||||
__version__ = swigwrapper.__version__
|
__version__ = swigwrapper.__version__
|
||||||
|
|
||||||
@ -30,24 +30,20 @@ class Scorer(swigwrapper.Scorer):
|
|||||||
assert beta is not None, 'beta parameter is required'
|
assert beta is not None, 'beta parameter is required'
|
||||||
assert scorer_path, 'scorer_path parameter is required'
|
assert scorer_path, 'scorer_path parameter is required'
|
||||||
|
|
||||||
serialized = alphabet.serialize()
|
err = self.init(scorer_path, alphabet)
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
if err != 0:
|
||||||
raise ValueError('Error when deserializing alphabet.')
|
raise ValueError('Scorer initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
err = self.init(scorer_path.encode('utf-8'),
|
|
||||||
native_alphabet)
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError('Scorer initialization failed with error code {}'.format(err))
|
|
||||||
|
|
||||||
self.reset_params(alpha, beta)
|
self.reset_params(alpha, beta)
|
||||||
|
|
||||||
def load_lm(self, lm_path):
|
|
||||||
return super(Scorer, self).load_lm(lm_path.encode('utf-8'))
|
|
||||||
|
|
||||||
def save_dictionary(self, save_path, *args, **kwargs):
|
class Alphabet(swigwrapper.Alphabet):
|
||||||
return super(Scorer, self).save_dictionary(save_path.encode('utf-8'), *args, **kwargs)
|
"""Convenience wrapper for Alphabet which calls init in the constructor"""
|
||||||
|
def __init__(self, config_path):
|
||||||
|
super(Alphabet, self).__init__()
|
||||||
|
err = self.init(config_path)
|
||||||
|
if err != 0:
|
||||||
|
raise ValueError('Alphabet initialization failed with error code 0x{:X}'.format(err))
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
@ -79,15 +75,10 @@ def ctc_beam_search_decoder(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
serialized = alphabet.serialize()
|
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError("Error when deserializing alphabet.")
|
|
||||||
beam_results = swigwrapper.ctc_beam_search_decoder(
|
beam_results = swigwrapper.ctc_beam_search_decoder(
|
||||||
probs_seq, native_alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
|
||||||
scorer)
|
scorer)
|
||||||
beam_results = [(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
return beam_results
|
return beam_results
|
||||||
|
|
||||||
|
|
||||||
@ -126,14 +117,9 @@ def ctc_beam_search_decoder_batch(probs_seq,
|
|||||||
results, in descending order of the confidence.
|
results, in descending order of the confidence.
|
||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
serialized = alphabet.serialize()
|
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
||||||
native_alphabet = swigwrapper.Alphabet()
|
|
||||||
err = native_alphabet.deserialize(serialized, len(serialized))
|
|
||||||
if err != 0:
|
|
||||||
raise ValueError("Error when deserializing alphabet.")
|
|
||||||
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, native_alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer)
|
|
||||||
batch_beam_results = [
|
batch_beam_results = [
|
||||||
[(res.confidence, alphabet.decode(res.tokens)) for res in beam_results]
|
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
|
||||||
for beam_results in batch_beam_results
|
for beam_results in batch_beam_results
|
||||||
]
|
]
|
||||||
return batch_beam_results
|
return batch_beam_results
|
||||||
|
@ -46,7 +46,8 @@ CTC_DECODER_FILES = [
|
|||||||
'scorer.cpp',
|
'scorer.cpp',
|
||||||
'path_trie.cpp',
|
'path_trie.cpp',
|
||||||
'decoder_utils.cpp',
|
'decoder_utils.cpp',
|
||||||
'workspace_status.cc'
|
'workspace_status.cc',
|
||||||
|
'../alphabet.cc',
|
||||||
]
|
]
|
||||||
|
|
||||||
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
|
def build_archive(srcs=[], out_name='', build_dir='temp_build/temp_build', debug=False, num_parallel=1):
|
||||||
|
@ -119,7 +119,7 @@ bool prefix_compare_external(const PathTrie *x, const PathTrie *y, const std::un
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_word_to_fst(const std::vector<int> &word,
|
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||||
fst::StdVectorFst *dictionary) {
|
fst::StdVectorFst *dictionary) {
|
||||||
if (dictionary->NumStates() == 0) {
|
if (dictionary->NumStates() == 0) {
|
||||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||||
@ -144,7 +144,7 @@ bool add_word_to_dictionary(
|
|||||||
fst::StdVectorFst *dictionary) {
|
fst::StdVectorFst *dictionary) {
|
||||||
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
|
auto characters = utf8 ? split_into_bytes(word) : split_into_codepoints(word);
|
||||||
|
|
||||||
std::vector<int> int_word;
|
std::vector<unsigned int> int_word;
|
||||||
|
|
||||||
for (auto &c : characters) {
|
for (auto &c : characters) {
|
||||||
auto int_c = char_map.find(c);
|
auto int_c = char_map.find(c);
|
||||||
|
@ -86,7 +86,7 @@ std::vector<std::string> split_into_codepoints(const std::string &str);
|
|||||||
std::vector<std::string> split_into_bytes(const std::string &str);
|
std::vector<std::string> split_into_bytes(const std::string &str);
|
||||||
|
|
||||||
// Add a word in index to the dicionary of fst
|
// Add a word in index to the dicionary of fst
|
||||||
void add_word_to_fst(const std::vector<int> &word,
|
void add_word_to_fst(const std::vector<unsigned int> &word,
|
||||||
fst::StdVectorFst *dictionary);
|
fst::StdVectorFst *dictionary);
|
||||||
|
|
||||||
// Return whether a byte is a code point boundary (not a continuation byte).
|
// Return whether a byte is a code point boundary (not a continuation byte).
|
||||||
|
@ -8,8 +8,8 @@
|
|||||||
*/
|
*/
|
||||||
struct Output {
|
struct Output {
|
||||||
double confidence;
|
double confidence;
|
||||||
std::vector<int> tokens;
|
std::vector<unsigned int> tokens;
|
||||||
std::vector<int> timesteps;
|
std::vector<unsigned int> timesteps;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // OUTPUT_H_
|
#endif // OUTPUT_H_
|
||||||
|
@ -35,7 +35,7 @@ PathTrie::~PathTrie() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) {
|
PathTrie* PathTrie::get_path_trie(unsigned int new_char, unsigned int new_timestep, float cur_log_prob_c, bool reset) {
|
||||||
auto child = children_.begin();
|
auto child = children_.begin();
|
||||||
for (; child != children_.end(); ++child) {
|
for (; child != children_.end(); ++child) {
|
||||||
if (child->first == new_char) {
|
if (child->first == new_char) {
|
||||||
@ -102,7 +102,7 @@ PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timesteps) {
|
void PathTrie::get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps) {
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (parent != nullptr) {
|
if (parent != nullptr) {
|
||||||
@ -114,8 +114,8 @@ void PathTrie::get_path_vec(std::vector<int>& output, std::vector<int>& timestep
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -124,7 +124,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
|||||||
}
|
}
|
||||||
// Recursive call: recurse back until stop condition, then append data in
|
// Recursive call: recurse back until stop condition, then append data in
|
||||||
// correct order as we walk back down the stack in the lines below.
|
// correct order as we walk back down the stack in the lines below.
|
||||||
if (!byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
if (!byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||||
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
stop = parent->get_prev_grapheme(output, timesteps, alphabet);
|
||||||
}
|
}
|
||||||
output.push_back(character);
|
output.push_back(character);
|
||||||
@ -135,7 +135,7 @@ PathTrie* PathTrie::get_prev_grapheme(std::vector<int>& output,
|
|||||||
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
if (byte_is_codepoint_boundary(alphabet.StringFromLabel(character)[0])) {
|
if (byte_is_codepoint_boundary(alphabet.DecodeSingle(character)[0])) {
|
||||||
*first_byte = (unsigned char)character + 1;
|
*first_byte = (unsigned char)character + 1;
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
@ -146,8 +146,8 @@ int PathTrie::distance_to_codepoint_boundary(unsigned char *first_byte,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
PathTrie* PathTrie::get_prev_word(std::vector<int>& output,
|
PathTrie* PathTrie::get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet)
|
const Alphabet& alphabet)
|
||||||
{
|
{
|
||||||
PathTrie* stop = this;
|
PathTrie* stop = this;
|
||||||
@ -225,7 +225,7 @@ void PathTrie::print(const Alphabet& a) {
|
|||||||
for (PathTrie* el : chain) {
|
for (PathTrie* el : chain) {
|
||||||
printf("%X ", (unsigned char)(el->character));
|
printf("%X ", (unsigned char)(el->character));
|
||||||
if (el->character != ROOT_) {
|
if (el->character != ROOT_) {
|
||||||
tr.append(a.StringFromLabel(el->character));
|
tr.append(a.DecodeSingle(el->character));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
printf("\ntimesteps:\t ");
|
printf("\ntimesteps:\t ");
|
||||||
|
@ -21,22 +21,22 @@ public:
|
|||||||
~PathTrie();
|
~PathTrie();
|
||||||
|
|
||||||
// get new prefix after appending new char
|
// get new prefix after appending new char
|
||||||
PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true);
|
PathTrie* get_path_trie(unsigned int new_char, unsigned int new_timestep, float log_prob_c, bool reset = true);
|
||||||
|
|
||||||
// get the prefix data in correct time order from root to current node
|
// get the prefix data in correct time order from root to current node
|
||||||
void get_path_vec(std::vector<int>& output, std::vector<int>& timesteps);
|
void get_path_vec(std::vector<unsigned int>& output, std::vector<unsigned int>& timesteps);
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last grapheme to current node
|
// get the prefix data in correct time order from beginning of last grapheme to current node
|
||||||
PathTrie* get_prev_grapheme(std::vector<int>& output,
|
PathTrie* get_prev_grapheme(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
// get the distance from current node to the first codepoint boundary, and the byte value at the boundary
|
||||||
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
int distance_to_codepoint_boundary(unsigned char *first_byte, const Alphabet& alphabet);
|
||||||
|
|
||||||
// get the prefix data in correct time order from beginning of last word to current node
|
// get the prefix data in correct time order from beginning of last word to current node
|
||||||
PathTrie* get_prev_word(std::vector<int>& output,
|
PathTrie* get_prev_word(std::vector<unsigned int>& output,
|
||||||
std::vector<int>& timesteps,
|
std::vector<unsigned int>& timesteps,
|
||||||
const Alphabet& alphabet);
|
const Alphabet& alphabet);
|
||||||
|
|
||||||
// update log probs
|
// update log probs
|
||||||
@ -64,8 +64,8 @@ public:
|
|||||||
float log_prob_c;
|
float log_prob_c;
|
||||||
float score;
|
float score;
|
||||||
float approx_ctc;
|
float approx_ctc;
|
||||||
int character;
|
unsigned int character;
|
||||||
int timestep;
|
unsigned int timestep;
|
||||||
PathTrie* parent;
|
PathTrie* parent;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -73,7 +73,7 @@ private:
|
|||||||
bool exists_;
|
bool exists_;
|
||||||
bool has_dictionary_;
|
bool has_dictionary_;
|
||||||
|
|
||||||
std::vector<std::pair<int, PathTrie*>> children_;
|
std::vector<std::pair<unsigned int, PathTrie*>> children_;
|
||||||
|
|
||||||
// pointer to dictionary of FST
|
// pointer to dictionary of FST
|
||||||
std::shared_ptr<FstType> dictionary_;
|
std::shared_ptr<FstType> dictionary_;
|
||||||
|
@ -65,7 +65,7 @@ void Scorer::setup_char_map()
|
|||||||
// The initial state of FST is state 0, hence the index of chars in
|
// The initial state of FST is state 0, hence the index of chars in
|
||||||
// the FST should start from 1 to avoid the conflict with the initial
|
// the FST should start from 1 to avoid the conflict with the initial
|
||||||
// state, otherwise wrong decoding results would be given.
|
// state, otherwise wrong decoding results would be given.
|
||||||
char_map_[alphabet_.StringFromLabel(i)] = i + 1;
|
char_map_[alphabet_.DecodeSingle(i)] = i + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -314,11 +314,11 @@ void Scorer::reset_params(float alpha, float beta)
|
|||||||
this->beta = beta;
|
this->beta = beta;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<int>& labels)
|
std::vector<std::string> Scorer::split_labels_into_scored_units(const std::vector<unsigned int>& labels)
|
||||||
{
|
{
|
||||||
if (labels.empty()) return {};
|
if (labels.empty()) return {};
|
||||||
|
|
||||||
std::string s = alphabet_.LabelsToString(labels);
|
std::string s = alphabet_.Decode(labels);
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
words = split_into_codepoints(s);
|
words = split_into_codepoints(s);
|
||||||
@ -339,8 +339,8 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> prefix_vec;
|
std::vector<unsigned int> prefix_vec;
|
||||||
std::vector<int> prefix_steps;
|
std::vector<unsigned int> prefix_steps;
|
||||||
|
|
||||||
if (is_utf8_mode_) {
|
if (is_utf8_mode_) {
|
||||||
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
new_node = current_node->get_prev_grapheme(prefix_vec, prefix_steps, alphabet_);
|
||||||
@ -350,7 +350,7 @@ std::vector<std::string> Scorer::make_ngram(PathTrie* prefix)
|
|||||||
current_node = new_node->parent;
|
current_node = new_node->parent;
|
||||||
|
|
||||||
// reconstruct word
|
// reconstruct word
|
||||||
std::string word = alphabet_.LabelsToString(prefix_vec);
|
std::string word = alphabet_.Decode(prefix_vec);
|
||||||
ngram.push_back(word);
|
ngram.push_back(word);
|
||||||
}
|
}
|
||||||
std::reverse(ngram.begin(), ngram.end());
|
std::reverse(ngram.begin(), ngram.end());
|
||||||
|
@ -73,7 +73,7 @@ public:
|
|||||||
|
|
||||||
// trransform the labels in index to the vector of words (word based lm) or
|
// trransform the labels in index to the vector of words (word based lm) or
|
||||||
// the vector of characters (character based lm)
|
// the vector of characters (character based lm)
|
||||||
std::vector<std::string> split_labels_into_scored_units(const std::vector<int> &labels);
|
std::vector<std::string> split_labels_into_scored_units(const std::vector<unsigned int> &labels);
|
||||||
|
|
||||||
void set_alphabet(const Alphabet& alphabet);
|
void set_alphabet(const Alphabet& alphabet);
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
%{
|
%{
|
||||||
#include "ctc_beam_search_decoder.h"
|
#include "ctc_beam_search_decoder.h"
|
||||||
#define SWIG_FILE_WITH_INIT
|
#define SWIG_FILE_WITH_INIT
|
||||||
#define SWIG_PYTHON_STRICT_BYTE_CHAR
|
|
||||||
#include "workspace_status.h"
|
#include "workspace_status.h"
|
||||||
%}
|
%}
|
||||||
|
|
||||||
@ -19,6 +18,9 @@ import_array();
|
|||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
%template(StringVector) vector<string>;
|
%template(StringVector) vector<string>;
|
||||||
|
%template(UnsignedIntVector) vector<unsigned int>;
|
||||||
|
%template(OutputVector) vector<Output>;
|
||||||
|
%template(OutputVectorVector) vector<vector<Output>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
%shared_ptr(Scorer);
|
%shared_ptr(Scorer);
|
||||||
@ -27,6 +29,7 @@ namespace std {
|
|||||||
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY2, int DIM1, int DIM2) {(const double *probs, int time_dim, int class_dim)};
|
||||||
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
%apply (double* IN_ARRAY3, int DIM1, int DIM2, int DIM3) {(const double *probs, int batch_size, int time_dim, int class_dim)};
|
||||||
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
%apply (int* IN_ARRAY1, int DIM1) {(const int *seq_lengths, int seq_lengths_size)};
|
||||||
|
%apply (unsigned int* IN_ARRAY1, int DIM1) {(const unsigned int *input, int length)};
|
||||||
|
|
||||||
%ignore Scorer::dictionary;
|
%ignore Scorer::dictionary;
|
||||||
|
|
||||||
@ -38,10 +41,6 @@ namespace std {
|
|||||||
%constant const char* __version__ = ds_version();
|
%constant const char* __version__ = ds_version();
|
||||||
%constant const char* __git_version__ = ds_git_version();
|
%constant const char* __git_version__ = ds_git_version();
|
||||||
|
|
||||||
%template(IntVector) std::vector<int>;
|
|
||||||
%template(OutputVector) std::vector<Output>;
|
|
||||||
%template(OutputVectorVector) std::vector<std::vector<Output>>;
|
|
||||||
|
|
||||||
// Import only the error code enum definitions from deepspeech.h
|
// Import only the error code enum definitions from deepspeech.h
|
||||||
// We can't just do |%ignore "";| here because it affects this file globally (even
|
// We can't just do |%ignore "";| here because it affects this file globally (even
|
||||||
// files %include'd above). That causes SWIG to lose destructor information and
|
// files %include'd above). That causes SWIG to lose destructor information and
|
||||||
|
@ -33,7 +33,7 @@ char*
|
|||||||
ModelState::decode(const DecoderState& state) const
|
ModelState::decode(const DecoderState& state) const
|
||||||
{
|
{
|
||||||
vector<Output> out = state.decode();
|
vector<Output> out = state.decode();
|
||||||
return strdup(alphabet_.LabelsToString(out[0].tokens).c_str());
|
return strdup(alphabet_.Decode(out[0].tokens).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
Metadata*
|
Metadata*
|
||||||
@ -50,7 +50,7 @@ ModelState::decode_metadata(const DecoderState& state,
|
|||||||
|
|
||||||
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
for (int j = 0; j < out[i].tokens.size(); ++j) {
|
||||||
TokenMetadata token {
|
TokenMetadata token {
|
||||||
strdup(alphabet_.StringFromLabel(out[i].tokens[j]).c_str()), // text
|
strdup(alphabet_.DecodeSingle(out[i].tokens[j]).c_str()), // text
|
||||||
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
|
static_cast<unsigned int>(out[i].timesteps[j]), // timestep
|
||||||
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
|
out[i].timesteps[j] * ((float)audio_win_step_ / sample_rate_), // start_time
|
||||||
};
|
};
|
||||||
|
@ -206,7 +206,7 @@ TFLiteModelState::init(const char* model_path)
|
|||||||
beam_width_ = (unsigned int)(*beam_width);
|
beam_width_ = (unsigned int)(*beam_width);
|
||||||
|
|
||||||
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
tflite::StringRef serialized_alphabet = tflite::GetString(interpreter_->tensor(metadata_alphabet_idx), 0);
|
||||||
err = alphabet_.deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
err = alphabet_.Deserialize(serialized_alphabet.str, serialized_alphabet.len);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
@ -119,7 +119,7 @@ TFModelState::init(const char* model_path)
|
|||||||
beam_width_ = (unsigned int)(beam_width);
|
beam_width_ = (unsigned int)(beam_width);
|
||||||
|
|
||||||
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
string serialized_alphabet = metadata_outputs[4].scalar<tensorflow::tstring>()();
|
||||||
err = alphabet_.deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
err = alphabet_.Deserialize(serialized_alphabet.data(), serialized_alphabet.size());
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
return DS_ERR_INVALID_ALPHABET;
|
return DS_ERR_INVALID_ALPHABET;
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from deepspeech_training.util.text import Alphabet
|
from ds_ctcdecoder import Alphabet
|
||||||
|
|
||||||
class TestAlphabetParsing(unittest.TestCase):
|
class TestAlphabetParsing(unittest.TestCase):
|
||||||
|
|
||||||
@ -11,12 +11,12 @@ class TestAlphabetParsing(unittest.TestCase):
|
|||||||
label_id = -1
|
label_id = -1
|
||||||
for expected_label, expected_label_id in expected:
|
for expected_label, expected_label_id in expected:
|
||||||
try:
|
try:
|
||||||
label_id = alphabet.encode(expected_label)
|
label_id = alphabet.Encode(expected_label)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(label_id, [expected_label_id])
|
self.assertEqual(label_id, [expected_label_id])
|
||||||
try:
|
try:
|
||||||
label = alphabet.decode([expected_label_id])
|
label = alphabet.Decode([expected_label_id])
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
self.assertEqual(label, expected_label)
|
self.assertEqual(label, expected_label)
|
||||||
|
@ -40,7 +40,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
|
|||||||
for i, index in enumerate(indices):
|
for i, index in enumerate(indices):
|
||||||
results[index[0]].append(values[i])
|
results[index[0]].append(values[i])
|
||||||
# List of strings
|
# List of strings
|
||||||
return [alphabet.decode(res) for res in results]
|
return [alphabet.Decode(res) for res in results]
|
||||||
|
|
||||||
|
|
||||||
def evaluate(test_csvs, create_model):
|
def evaluate(test_csvs, create_model):
|
||||||
|
@ -771,7 +771,7 @@ def export():
|
|||||||
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
outputs['metadata_feature_win_len'] = tf.constant([FLAGS.feature_win_len], name='metadata_feature_win_len')
|
||||||
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
outputs['metadata_feature_win_step'] = tf.constant([FLAGS.feature_win_step], name='metadata_feature_win_step')
|
||||||
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
outputs['metadata_beam_width'] = tf.constant([FLAGS.export_beam_width], name='metadata_beam_width')
|
||||||
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.serialize()], name='metadata_alphabet')
|
outputs['metadata_alphabet'] = tf.constant([Config.alphabet.Serialize()], name='metadata_alphabet')
|
||||||
|
|
||||||
if FLAGS.export_language:
|
if FLAGS.export_language:
|
||||||
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
outputs['metadata_language'] = tf.constant([FLAGS.export_language.encode('utf-8')], name='metadata_language')
|
||||||
|
@ -6,14 +6,15 @@ import tensorflow.compat.v1 as tfv1
|
|||||||
|
|
||||||
from attrdict import AttrDict
|
from attrdict import AttrDict
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
|
from ds_ctcdecoder import Alphabet, UTF8Alphabet
|
||||||
|
|
||||||
from .flags import FLAGS
|
from .flags import FLAGS
|
||||||
from .gpu import get_available_gpus
|
from .gpu import get_available_gpus
|
||||||
from .logging import log_error, log_warn
|
from .logging import log_error, log_warn
|
||||||
from .text import Alphabet, UTF8Alphabet
|
|
||||||
from .helpers import parse_file_size
|
from .helpers import parse_file_size
|
||||||
from .augmentations import parse_augmentations
|
from .augmentations import parse_augmentations
|
||||||
|
|
||||||
|
|
||||||
class ConfigSingleton:
|
class ConfigSingleton:
|
||||||
_config = None
|
_config = None
|
||||||
|
|
||||||
@ -115,7 +116,7 @@ def initialize_globals():
|
|||||||
c.n_hidden_3 = c.n_cell_dim
|
c.n_hidden_3 = c.n_cell_dim
|
||||||
|
|
||||||
# Units in the sixth layer = number of characters in the target language plus one
|
# Units in the sixth layer = number of characters in the target language plus one
|
||||||
c.n_hidden_6 = c.alphabet.size() + 1 # +1 for CTC blank label
|
c.n_hidden_6 = c.alphabet.GetSize() + 1 # +1 for CTC blank label
|
||||||
|
|
||||||
# Size of audio window in samples
|
# Size of audio window in samples
|
||||||
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
if (FLAGS.feature_win_len * FLAGS.audio_sample_rate) % 1000 != 0:
|
||||||
|
@ -52,12 +52,10 @@ def check_ctcdecoder_version():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
decoder_version_s = decoder_version.decode()
|
rv = semver.compare(ds_version_s, decoder_version)
|
||||||
|
|
||||||
rv = semver.compare(ds_version_s, decoder_version_s)
|
|
||||||
if rv != 0:
|
if rv != 0:
|
||||||
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
print("DeepSpeech version ({}) and CTC decoder version ({}) do not match. "
|
||||||
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version_s))
|
"Please ensure matching versions are in use.".format(ds_version_s, decoder_version))
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
return rv
|
return rv
|
||||||
|
@ -3,121 +3,6 @@ from __future__ import absolute_import, division, print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
from six.moves import range
|
|
||||||
|
|
||||||
class Alphabet(object):
|
|
||||||
def __init__(self, config_file):
|
|
||||||
self._config_file = config_file
|
|
||||||
self._label_to_str = {}
|
|
||||||
self._str_to_label = {}
|
|
||||||
self._size = 0
|
|
||||||
if config_file:
|
|
||||||
with open(config_file, 'r', encoding='utf-8') as fin:
|
|
||||||
for line in fin:
|
|
||||||
if line[0:2] == '\\#':
|
|
||||||
line = '#\n'
|
|
||||||
elif line[0] == '#':
|
|
||||||
continue
|
|
||||||
self._label_to_str[self._size] = line[:-1] # remove the line ending
|
|
||||||
self._str_to_label[line[:-1]] = self._size
|
|
||||||
self._size += 1
|
|
||||||
|
|
||||||
def _string_from_label(self, label):
|
|
||||||
return self._label_to_str[label]
|
|
||||||
|
|
||||||
def _label_from_string(self, string):
|
|
||||||
try:
|
|
||||||
return self._str_to_label[string]
|
|
||||||
except KeyError as e:
|
|
||||||
raise KeyError(
|
|
||||||
'ERROR: Your transcripts contain characters (e.g. \'{}\') which do not occur in \'{}\'! Use ' \
|
|
||||||
'util/check_characters.py to see what characters are in your [train,dev,test].csv transcripts, and ' \
|
|
||||||
'then add all these to \'{}\'.'.format(string, self._config_file, self._config_file)
|
|
||||||
).with_traceback(e.__traceback__)
|
|
||||||
|
|
||||||
def has_char(self, char):
|
|
||||||
return char in self._str_to_label
|
|
||||||
|
|
||||||
def encode(self, string):
|
|
||||||
res = []
|
|
||||||
for char in string:
|
|
||||||
res.append(self._label_from_string(char))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def decode(self, labels):
|
|
||||||
res = ''
|
|
||||||
for label in labels:
|
|
||||||
res += self._string_from_label(label)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def serialize(self):
|
|
||||||
# Serialization format is a sequence of (key, value) pairs, where key is
|
|
||||||
# a uint16_t and value is a uint16_t length followed by `length` UTF-8
|
|
||||||
# encoded bytes with the label.
|
|
||||||
res = bytearray()
|
|
||||||
|
|
||||||
# We start by writing the number of pairs in the buffer as uint16_t.
|
|
||||||
res += struct.pack('<H', self._size)
|
|
||||||
for key, value in self._label_to_str.items():
|
|
||||||
value = value.encode('utf-8')
|
|
||||||
# struct.pack only takes fixed length strings/buffers, so we have to
|
|
||||||
# construct the correct format string with the length of the encoded
|
|
||||||
# label.
|
|
||||||
res += struct.pack('<HH{}s'.format(len(value)), key, len(value), value)
|
|
||||||
return bytes(res)
|
|
||||||
|
|
||||||
def size(self):
|
|
||||||
return self._size
|
|
||||||
|
|
||||||
def config_file(self):
|
|
||||||
return self._config_file
|
|
||||||
|
|
||||||
|
|
||||||
class UTF8Alphabet(object):
|
|
||||||
@staticmethod
|
|
||||||
def _string_from_label(_):
|
|
||||||
assert False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _label_from_string(_):
|
|
||||||
assert False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def encode(string):
|
|
||||||
# 0 never happens in the data, so we can shift values by one, use 255 for
|
|
||||||
# the CTC blank, and keep the alphabet size = 256
|
|
||||||
return np.frombuffer(string.encode('utf-8'), np.uint8).astype(np.int32) - 1
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def decode(labels):
|
|
||||||
# And here we need to shift back up
|
|
||||||
return bytes(np.asarray(labels, np.uint8) + 1).decode('utf-8', errors='replace')
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def size():
|
|
||||||
return 255
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def serialize():
|
|
||||||
res = bytearray()
|
|
||||||
res += struct.pack('<h', 255)
|
|
||||||
for i in range(255):
|
|
||||||
# Note that we also shift back up in the mapping constructed here
|
|
||||||
# so that the native client sees the correct byte values when decoding.
|
|
||||||
res += struct.pack('<hh1s', i, 1, bytes([i+1]))
|
|
||||||
return bytes(res)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def deserialize(buf):
|
|
||||||
size = struct.unpack('<I', buf)[0]
|
|
||||||
assert size == 255
|
|
||||||
return UTF8Alphabet()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def config_file():
|
|
||||||
return ''
|
|
||||||
|
|
||||||
|
|
||||||
def text_to_char_array(transcript, alphabet, context=''):
|
def text_to_char_array(transcript, alphabet, context=''):
|
||||||
r"""
|
r"""
|
||||||
Given a transcript string, map characters to
|
Given a transcript string, map characters to
|
||||||
@ -125,7 +10,7 @@ def text_to_char_array(transcript, alphabet, context=''):
|
|||||||
Use a string in `context` for adding text to raised exceptions.
|
Use a string in `context` for adding text to raised exceptions.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
transcript = alphabet.encode(transcript)
|
transcript = alphabet.Encode(transcript)
|
||||||
if len(transcript) == 0:
|
if len(transcript) == 0:
|
||||||
raise ValueError('While processing {}: Found an empty transcript! '
|
raise ValueError('While processing {}: Found an empty transcript! '
|
||||||
'You must include a transcript for all training data.'
|
'You must include a transcript for all training data.'
|
||||||
|
Loading…
Reference in New Issue
Block a user