Support custom alphabet mappings
This commit is contained in:
parent
498b66867e
commit
1c4cbf1813
|
@ -24,7 +24,7 @@ from util.feeding import DataSet, ModelFeeder
|
||||||
from util.gpu import get_available_gpus
|
from util.gpu import get_available_gpus
|
||||||
from util.shared_lib import check_cupti
|
from util.shared_lib import check_cupti
|
||||||
from util.spell import correction
|
from util.spell import correction
|
||||||
from util.text import sparse_tensor_value_to_texts, wer
|
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
|
||||||
from xdg import BaseDirectory as xdg
|
from xdg import BaseDirectory as xdg
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -139,6 +139,8 @@ tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps t
|
||||||
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||||
|
|
||||||
|
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||||
|
|
||||||
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
|
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
|
||||||
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var)
|
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var)
|
||||||
|
|
||||||
|
@ -220,6 +222,9 @@ def initialize_globals():
|
||||||
global session_config
|
global session_config
|
||||||
session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement)
|
session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement)
|
||||||
|
|
||||||
|
global alphabet
|
||||||
|
alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
||||||
|
|
||||||
# Geometric Constants
|
# Geometric Constants
|
||||||
# ===================
|
# ===================
|
||||||
|
|
||||||
|
@ -257,7 +262,7 @@ def initialize_globals():
|
||||||
|
|
||||||
# The number of characters in the target language plus one
|
# The number of characters in the target language plus one
|
||||||
global n_character
|
global n_character
|
||||||
n_character = 29 # TODO: Determine if this should be extended with other punctuation
|
n_character = alphabet.size() + 1 # +1 for CTC blank label
|
||||||
|
|
||||||
# The number of units in the sixth layer
|
# The number of units in the sixth layer
|
||||||
global n_hidden_6
|
global n_hidden_6
|
||||||
|
@ -712,7 +717,7 @@ def calculate_report(results_tuple):
|
||||||
items = list(zip(*results_tuple))
|
items = list(zip(*results_tuple))
|
||||||
mean_wer = 0.0
|
mean_wer = 0.0
|
||||||
for label, decoding, distance, loss in items:
|
for label, decoding, distance, loss in items:
|
||||||
corrected = correction(decoding)
|
corrected = correction(decoding, alphabet)
|
||||||
sample_wer = wer(label, corrected)
|
sample_wer = wer(label, corrected)
|
||||||
sample = Sample(label, corrected, loss, distance, sample_wer)
|
sample = Sample(label, corrected, loss, distance, sample_wer)
|
||||||
samples.append(sample)
|
samples.append(sample)
|
||||||
|
@ -750,10 +755,10 @@ def collect_results(results_tuple, returns):
|
||||||
# Each of the arrays within results_tuple will get extended by a batch of each available device
|
# Each of the arrays within results_tuple will get extended by a batch of each available device
|
||||||
for i in range(len(available_devices)):
|
for i in range(len(available_devices)):
|
||||||
# Collect the labels
|
# Collect the labels
|
||||||
results_tuple[0].extend(sparse_tensor_value_to_texts(returns[0][i]))
|
results_tuple[0].extend(sparse_tensor_value_to_texts(returns[0][i], alphabet))
|
||||||
|
|
||||||
# Collect the decodings - at the moment we default to the first one
|
# Collect the decodings - at the moment we default to the first one
|
||||||
results_tuple[1].extend(sparse_tensor_value_to_texts(returns[1][i][0]))
|
results_tuple[1].extend(sparse_tensor_value_to_texts(returns[1][i][0], alphabet))
|
||||||
|
|
||||||
# Collect the distances
|
# Collect the distances
|
||||||
results_tuple[2].extend(returns[2][i])
|
results_tuple[2].extend(returns[2][i])
|
||||||
|
@ -1434,6 +1439,7 @@ def train(server=None):
|
||||||
test_set,
|
test_set,
|
||||||
n_input,
|
n_input,
|
||||||
n_context,
|
n_context,
|
||||||
|
alphabet,
|
||||||
tower_feeder_count=len(available_devices))
|
tower_feeder_count=len(available_devices))
|
||||||
|
|
||||||
# Create the optimizer
|
# Create the optimizer
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Each line in this file represents the Unicode codepoint (UTF-8 encoded)
|
||||||
|
# associated with a numeric label.
|
||||||
|
# A line that starts with # is a comment. You can escape it with \# if you wish
|
||||||
|
# to use '#' as a label.
|
||||||
|
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
h
|
||||||
|
i
|
||||||
|
j
|
||||||
|
k
|
||||||
|
l
|
||||||
|
m
|
||||||
|
n
|
||||||
|
o
|
||||||
|
p
|
||||||
|
q
|
||||||
|
r
|
||||||
|
s
|
||||||
|
t
|
||||||
|
u
|
||||||
|
v
|
||||||
|
w
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
||||||
|
'
|
||||||
|
# The last (non-comment) line needs to end with a newline.
|
|
@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.bzl",
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "deepspeech",
|
name = "deepspeech",
|
||||||
srcs = ["deepspeech.cc"],
|
srcs = ["deepspeech.cc", "alphabet.h"],
|
||||||
hdrs = ["deepspeech.h"],
|
hdrs = ["deepspeech.h"],
|
||||||
deps = ["//tensorflow/core:core",
|
deps = ["//tensorflow/core:core",
|
||||||
":deepspeech_utils"],
|
":deepspeech_utils"],
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
#ifndef ALPHABET_H
|
||||||
|
#define ALPHABET_H
|
||||||
|
|
||||||
|
#include <cassert>
|
||||||
|
#include <fstream>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Loads a text file describing a mapping of labels to strings, one string per
|
||||||
|
* line. This is used by the decoder, client and Python scripts to convert the
|
||||||
|
* output of the decoder to a human-readable string and vice-versa.
|
||||||
|
*/
|
||||||
|
class Alphabet {
|
||||||
|
public:
|
||||||
|
Alphabet(const char *config_file) {
|
||||||
|
std::ifstream in(config_file, std::ios::in);
|
||||||
|
unsigned int label = 0;
|
||||||
|
for (std::string line; std::getline(in, line);) {
|
||||||
|
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||||
|
line = '#';
|
||||||
|
} else if (line[0] == '#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
label_to_str_[label] = line;
|
||||||
|
str_to_label_[line] = label;
|
||||||
|
++label;
|
||||||
|
}
|
||||||
|
size_ = label;
|
||||||
|
in.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string& StringFromLabel(unsigned int label) const {
|
||||||
|
assert(label < size_);
|
||||||
|
auto it = label_to_str_.find(label);
|
||||||
|
if (it != label_to_str_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
// unreachable due to assert above
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned int LabelFromString(const std::string& string) const {
|
||||||
|
auto it = str_to_label_.find(string);
|
||||||
|
if (it != str_to_label_.end()) {
|
||||||
|
return it->second;
|
||||||
|
} else {
|
||||||
|
abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t GetSize() {
|
||||||
|
return size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t size_;
|
||||||
|
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||||
|
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif //ALPHABET_H
|
|
@ -58,17 +58,19 @@ LocalDsSTT(Model& aCtx, const short* aBuffer, size_t aBufferSize,
|
||||||
int
|
int
|
||||||
main(int argc, char **argv)
|
main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
if (argc < 3 || argc > 4) {
|
if (argc < 4 || argc > 5) {
|
||||||
printf("Usage: deepspeech MODEL_PATH AUDIO_PATH [-t]\n");
|
printf("Usage: deepspeech MODEL_PATH AUDIO_PATH [-t]\n");
|
||||||
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
|
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
|
||||||
printf(" AUDIO_PATH\tPath to the audio file to run"
|
printf(" AUDIO_PATH\tPath to the audio file to run"
|
||||||
" (any file format supported by libsox)\n");
|
" (any file format supported by libsox)\n");
|
||||||
|
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
|
||||||
|
" the alphabet used by the network.\n");
|
||||||
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
|
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialise DeepSpeech
|
// Initialise DeepSpeech
|
||||||
Model ctx = Model(argv[1], N_CEP, N_CONTEXT);
|
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[3]);
|
||||||
|
|
||||||
// Initialise SOX
|
// Initialise SOX
|
||||||
assert(sox_init() == SOX_SUCCESS);
|
assert(sox_init() == SOX_SUCCESS);
|
||||||
|
|
|
@ -9,7 +9,7 @@ Fs.createReadStream(process.argv[3]).
|
||||||
pipe(audioStream);
|
pipe(audioStream);
|
||||||
audioStream.on('finish', () => {
|
audioStream.on('finish', () => {
|
||||||
audioBuffer = audioStream.toBuffer();
|
audioBuffer = audioStream.toBuffer();
|
||||||
var model = new Ds.Model(process.argv[2], 26, 9);
|
var model = new Ds.Model(process.argv[2], 26, 9, process.argv[4]);
|
||||||
// We take half of the buffer_size because buffer is a char* while
|
// We take half of the buffer_size because buffer is a char* while
|
||||||
// LocalDsSTT() expected a short*
|
// LocalDsSTT() expected a short*
|
||||||
console.log(model.stt(audioBuffer.slice(0, audioBuffer.length / 2), 16000));
|
console.log(model.stt(audioBuffer.slice(0, audioBuffer.length / 2), 16000));
|
||||||
|
|
|
@ -6,6 +6,6 @@ import sys
|
||||||
import scipy.io.wavfile as wav
|
import scipy.io.wavfile as wav
|
||||||
from deepspeech.model import Model
|
from deepspeech.model import Model
|
||||||
|
|
||||||
ds = Model(sys.argv[1], 26, 9)
|
ds = Model(sys.argv[1], 26, 9, sys.argv[3])
|
||||||
fs, audio = wav.read(sys.argv[2])
|
fs, audio = wav.read(sys.argv[2])
|
||||||
print(ds.stt(audio, fs))
|
print(ds.stt(audio, fs))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "deepspeech.h"
|
#include "deepspeech.h"
|
||||||
#include "deepspeech_utils.h"
|
#include "deepspeech_utils.h"
|
||||||
|
#include "alphabet.h"
|
||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
@ -13,9 +14,11 @@ class Private {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
int ncep;
|
int ncep;
|
||||||
int ncontext;
|
int ncontext;
|
||||||
|
Alphabet* alphabet;
|
||||||
};
|
};
|
||||||
|
|
||||||
Model::Model(const char* aModelPath, int aNCep, int aNContext)
|
Model::Model(const char* aModelPath, int aNCep, int aNContext,
|
||||||
|
const char* aAlphabetConfigPath)
|
||||||
{
|
{
|
||||||
mPriv = new Private;
|
mPriv = new Private;
|
||||||
|
|
||||||
|
@ -44,6 +47,8 @@ Model::Model(const char* aModelPath, int aNCep, int aNContext)
|
||||||
|
|
||||||
mPriv->ncep = aNCep;
|
mPriv->ncep = aNCep;
|
||||||
mPriv->ncontext = aNContext;
|
mPriv->ncontext = aNContext;
|
||||||
|
|
||||||
|
mPriv->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
Model::~Model()
|
Model::~Model()
|
||||||
|
@ -52,6 +57,8 @@ Model::~Model()
|
||||||
mPriv->session->Close();
|
mPriv->session->Close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
delete mPriv->alphabet;
|
||||||
|
|
||||||
delete mPriv;
|
delete mPriv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,13 +112,24 @@ Model::infer(float* aMfcc, int aNFrames, int aFrameLen)
|
||||||
// Output is an array of shape (1, n_results, result_length).
|
// Output is an array of shape (1, n_results, result_length).
|
||||||
// In this case, n_results is also equal to 1.
|
// In this case, n_results is also equal to 1.
|
||||||
auto output_mapped = outputs[0].tensor<int64, 3>();
|
auto output_mapped = outputs[0].tensor<int64, 3>();
|
||||||
int length = output_mapped.dimension(2) + 1;
|
size_t output_length = output_mapped.dimension(2) + 1;
|
||||||
char* output = (char*)malloc(sizeof(char) * length);
|
|
||||||
for (int i = 0; i < length - 1; i++) {
|
size_t decoded_length = 1; // add 1 for the \0
|
||||||
|
for (int i = 0; i < output_length - 1; i++) {
|
||||||
int64 character = output_mapped(0, 0, i);
|
int64 character = output_mapped(0, 0, i);
|
||||||
output[i] = (character == 0) ? ' ' : (character + 'a' - 1);
|
const std::string& str = mPriv->alphabet->StringFromLabel(character);
|
||||||
|
decoded_length += str.size();
|
||||||
}
|
}
|
||||||
output[length - 1] = '\0';
|
|
||||||
|
char* output = (char*)malloc(sizeof(char) * decoded_length);
|
||||||
|
char* pen = output;
|
||||||
|
for (int i = 0; i < output_length - 1; i++) {
|
||||||
|
int64 character = output_mapped(0, 0, i);
|
||||||
|
const std::string& str = mPriv->alphabet->StringFromLabel(character);
|
||||||
|
strncpy(pen, str.c_str(), str.size());
|
||||||
|
pen += str.size();
|
||||||
|
}
|
||||||
|
*pen = '\0';
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,11 @@ namespace DeepSpeech
|
||||||
* @param aModelPath The path to the frozen model graph.
|
* @param aModelPath The path to the frozen model graph.
|
||||||
* @param aNCep The number of cepstrum the model was trained with.
|
* @param aNCep The number of cepstrum the model was trained with.
|
||||||
* @param aNContext The context window the model was trained with.
|
* @param aNContext The context window the model was trained with.
|
||||||
|
* @param aAlphabetConfigPath The path to the configuration file specifying
|
||||||
|
* the the alphabet used by the network. See alphabet.h.
|
||||||
*/
|
*/
|
||||||
Model(const char* aModelPath, int aNCep, int aNContext);
|
Model(const char* aModelPath, int aNCep, int aNContext,
|
||||||
|
const char* aAlphabetConfigPath);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Frees associated resources and destroys model object.
|
* @brief Frees associated resources and destroys model object.
|
||||||
|
|
|
@ -6,6 +6,6 @@ source $(dirname "$0")/tc-tests-utils.sh
|
||||||
|
|
||||||
download_material "/tmp/ds"
|
download_material "/tmp/ds"
|
||||||
|
|
||||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds/:$LD_LIBRARY_PATH /tmp/ds/deepspeech /tmp/${model_name} /tmp/LDC93S1.wav)
|
phrase=$(LD_LIBRARY_PATH=/tmp/ds/:$LD_LIBRARY_PATH /tmp/ds/deepspeech /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||||
|
|
||||||
assert_correct_ldc93s1 "${phrase}"
|
assert_correct_ldc93s1 "${phrase}"
|
||||||
|
|
|
@ -20,7 +20,7 @@ pushd ${HOME}/DeepSpeech/ds/native_client/
|
||||||
npm --version
|
npm --version
|
||||||
npm install ${DEEPSPEECH_ARTIFACTS_ROOT}/deepspeech-0.0.1.tgz
|
npm install ${DEEPSPEECH_ARTIFACTS_ROOT}/deepspeech-0.0.1.tgz
|
||||||
npm install
|
npm install
|
||||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH node client.js /tmp/${model_name} /tmp/LDC93S1.wav)
|
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH node client.js /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||||
popd
|
popd
|
||||||
|
|
||||||
assert_correct_ldc93s1 "${phrase}"
|
assert_correct_ldc93s1 "${phrase}"
|
||||||
|
|
|
@ -27,3 +27,5 @@ find ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/javascript/ -type f -name "deep
|
||||||
|
|
||||||
pixz -9 ${TASKCLUSTER_ARTIFACTS}/native_client.tar ${TASKCLUSTER_ARTIFACTS}/native_client.tar.xz
|
pixz -9 ${TASKCLUSTER_ARTIFACTS}/native_client.tar ${TASKCLUSTER_ARTIFACTS}/native_client.tar.xz
|
||||||
rm ${TASKCLUSTER_ARTIFACTS}/native_client.tar
|
rm ${TASKCLUSTER_ARTIFACTS}/native_client.tar
|
||||||
|
|
||||||
|
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/alphabet.txt ${TASKCLUSTER_ARTIFACTS}/
|
||||||
|
|
|
@ -42,7 +42,7 @@ platform=$(python -c 'import sys; import platform; sys.stdout.write("%s_%s" % (p
|
||||||
deepspeech_pkg="deepspeech-0.0.1-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
|
deepspeech_pkg="deepspeech-0.0.1-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
|
||||||
pip install --upgrade ${DEEPSPEECH_ARTIFACTS_ROOT}/${deepspeech_pkg}
|
pip install --upgrade ${DEEPSPEECH_ARTIFACTS_ROOT}/${deepspeech_pkg}
|
||||||
|
|
||||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH python ${HOME}/DeepSpeech/ds/native_client/client.py /tmp/${model_name} /tmp/LDC93S1.wav)
|
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH python ${HOME}/DeepSpeech/ds/native_client/client.py /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||||
|
|
||||||
assert_correct_ldc93s1 "${phrase}"
|
assert_correct_ldc93s1 "${phrase}"
|
||||||
|
|
||||||
|
|
|
@ -61,8 +61,9 @@ download_material()
|
||||||
wget ${DEEPSPEECH_MODEL} -O /tmp/${model_name}
|
wget ${DEEPSPEECH_MODEL} -O /tmp/${model_name}
|
||||||
wget https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav -O /tmp/LDC93S1.wav
|
wget https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav -O /tmp/LDC93S1.wav
|
||||||
wget ${DEEPSPEECH_ARTIFACTS_ROOT}/native_client.tar.xz -O - | pixz -d | tar -C ${target_dir} -xf -
|
wget ${DEEPSPEECH_ARTIFACTS_ROOT}/native_client.tar.xz -O - | pixz -d | tar -C ${target_dir} -xf -
|
||||||
|
wget ${DEEPSPEECH_ARTIFACTS_ROOT}/alphabet.txt -O /tmp/alphabet.txt
|
||||||
|
|
||||||
ls -hal /tmp/${model_name} /tmp/LDC93S1.wav
|
ls -hal /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt
|
||||||
}
|
}
|
||||||
|
|
||||||
install_pyenv()
|
install_pyenv()
|
||||||
|
|
|
@ -22,6 +22,7 @@ class ModelFeeder(object):
|
||||||
test_set,
|
test_set,
|
||||||
numcep,
|
numcep,
|
||||||
numcontext,
|
numcontext,
|
||||||
|
alphabet,
|
||||||
tower_feeder_count=-1,
|
tower_feeder_count=-1,
|
||||||
threads_per_queue=2):
|
threads_per_queue=2):
|
||||||
|
|
||||||
|
@ -41,7 +42,7 @@ class ModelFeeder(object):
|
||||||
self.ph_batch_size = tf.placeholder(tf.int32, [])
|
self.ph_batch_size = tf.placeholder(tf.int32, [])
|
||||||
self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector')
|
self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector')
|
||||||
|
|
||||||
self._tower_feeders = [_TowerFeeder(self, i) for i in range(self.tower_feeder_count)]
|
self._tower_feeders = [_TowerFeeder(self, i, alphabet) for i in range(self.tower_feeder_count)]
|
||||||
|
|
||||||
def start_queue_threads(self, session, coord):
|
def start_queue_threads(self, session, coord):
|
||||||
'''
|
'''
|
||||||
|
@ -105,7 +106,7 @@ class _DataSetLoader(object):
|
||||||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
||||||
Keeps a DataSet reference to access its samples.
|
Keeps a DataSet reference to access its samples.
|
||||||
'''
|
'''
|
||||||
def __init__(self, model_feeder, data_set):
|
def __init__(self, model_feeder, data_set, alphabet):
|
||||||
self._model_feeder = model_feeder
|
self._model_feeder = model_feeder
|
||||||
self._data_set = data_set
|
self._data_set = data_set
|
||||||
self.queue = tf.PaddingFIFOQueue(shapes=[[None, model_feeder.numcep + (2 * model_feeder.numcep * model_feeder.numcontext)], [], [None,], []],
|
self.queue = tf.PaddingFIFOQueue(shapes=[[None, model_feeder.numcep + (2 * model_feeder.numcep * model_feeder.numcontext)], [], [None,], []],
|
||||||
|
@ -113,6 +114,7 @@ class _DataSetLoader(object):
|
||||||
capacity=data_set.batch_size * 2)
|
capacity=data_set.batch_size * 2)
|
||||||
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
||||||
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
||||||
|
self._alphabet = alphabet
|
||||||
|
|
||||||
def start_queue_threads(self, session, coord):
|
def start_queue_threads(self, session, coord):
|
||||||
'''
|
'''
|
||||||
|
@ -143,7 +145,7 @@ class _DataSetLoader(object):
|
||||||
wav_file, transcript = self._data_set.files[index]
|
wav_file, transcript = self._data_set.files[index]
|
||||||
source = audiofile_to_input_vector(wav_file, self._model_feeder.numcep, self._model_feeder.numcontext)
|
source = audiofile_to_input_vector(wav_file, self._model_feeder.numcep, self._model_feeder.numcontext)
|
||||||
source_len = len(source)
|
source_len = len(source)
|
||||||
target = text_to_char_array(transcript)
|
target = text_to_char_array(transcript, self._alphabet)
|
||||||
target_len = len(target)
|
target_len = len(target)
|
||||||
try:
|
try:
|
||||||
session.run(self._enqueue_op, feed_dict={ self._model_feeder.ph_x: source,
|
session.run(self._enqueue_op, feed_dict={ self._model_feeder.ph_x: source,
|
||||||
|
@ -159,10 +161,10 @@ class _TowerFeeder(object):
|
||||||
It creates, owns and combines three _DataSetLoader instances.
|
It creates, owns and combines three _DataSetLoader instances.
|
||||||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
||||||
'''
|
'''
|
||||||
def __init__(self, model_feeder, index):
|
def __init__(self, model_feeder, index, alphabet):
|
||||||
self._model_feeder = model_feeder
|
self._model_feeder = model_feeder
|
||||||
self.index = index
|
self.index = index
|
||||||
self._loaders = [_DataSetLoader(model_feeder, data_set) for data_set in model_feeder.sets]
|
self._loaders = [_DataSetLoader(model_feeder, data_set, alphabet) for data_set in model_feeder.sets]
|
||||||
self._queues = [set_queue.queue for set_queue in self._loaders]
|
self._queues = [set_queue.queue for set_queue in self._loaders]
|
||||||
self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues)
|
self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues)
|
||||||
self._close_op = self._queue.close(cancel_pending_enqueues=True)
|
self._close_op = self._queue.close(cancel_pending_enqueues=True)
|
||||||
|
|
|
@ -27,26 +27,26 @@ def log_probability(sentence):
|
||||||
"Log base 10 probability of `sentence`, a list of words"
|
"Log base 10 probability of `sentence`, a list of words"
|
||||||
return get_model().score(' '.join(sentence), bos = False, eos = False)
|
return get_model().score(' '.join(sentence), bos = False, eos = False)
|
||||||
|
|
||||||
def correction(sentence):
|
def correction(sentence, alphabet):
|
||||||
"Most probable spelling correction for sentence."
|
"Most probable spelling correction for sentence."
|
||||||
layer = [(0,[])]
|
layer = [(0,[])]
|
||||||
for word in words(sentence):
|
for word in words(sentence):
|
||||||
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word) for priority, node in layer]
|
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word, alphabet) for priority, node in layer]
|
||||||
heapify(layer)
|
heapify(layer)
|
||||||
layer = layer[:BEAM_WIDTH]
|
layer = layer[:BEAM_WIDTH]
|
||||||
return ' '.join(layer[0][1])
|
return ' '.join(layer[0][1])
|
||||||
|
|
||||||
def candidate_words(word):
|
def candidate_words(word, alphabet):
|
||||||
"Generate possible spelling corrections for word."
|
"Generate possible spelling corrections for word."
|
||||||
return (known_words([word]) or known_words(edits1(word)) or known_words(edits2(word)) or [word])
|
return (known_words([word]) or known_words(edits1(word, alphabet)) or known_words(edits2(word, alphabet)) or [word])
|
||||||
|
|
||||||
def known_words(words):
|
def known_words(words):
|
||||||
"The subset of `words` that appear in the dictionary of WORDS."
|
"The subset of `words` that appear in the dictionary of WORDS."
|
||||||
return set(w for w in words if w in WORDS)
|
return set(w for w in words if w in WORDS)
|
||||||
|
|
||||||
def edits1(word):
|
def edits1(word, alphabet):
|
||||||
"All edits that are one edit away from `word`."
|
"All edits that are one edit away from `word`."
|
||||||
letters = 'abcdefghijklmnopqrstuvwxyz'
|
letters = [alphabet.string_from_label(i) for i in range(alphabet.size())]
|
||||||
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
||||||
deletes = [L + R[1:] for L, R in splits if R]
|
deletes = [L + R[1:] for L, R in splits if R]
|
||||||
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
|
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
|
||||||
|
@ -54,6 +54,6 @@ def edits1(word):
|
||||||
inserts = [L + c + R for L, R in splits for c in letters]
|
inserts = [L + c + R for L, R in splits for c in letters]
|
||||||
return set(deletes + transposes + replaces + inserts)
|
return set(deletes + transposes + replaces + inserts)
|
||||||
|
|
||||||
def edits2(word):
|
def edits2(word, alphabet):
|
||||||
"All edits that are two edits away from `word`."
|
"All edits that are two edits away from `word`."
|
||||||
return (e2 for e1 in edits1(word) for e2 in edits1(e1))
|
return (e2 for e1 in edits1(word, alphabet) for e2 in edits1(e1, alphabet))
|
||||||
|
|
56
util/text.py
56
util/text.py
|
@ -5,28 +5,36 @@ import re
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
# Constants
|
class Alphabet(object):
|
||||||
SPACE_TOKEN = '<space>'
|
def __init__(self, config_file):
|
||||||
SPACE_INDEX = 0
|
self._label_to_str = []
|
||||||
FIRST_INDEX = ord('a') - 1 # 0 is reserved to space
|
self._str_to_label = {}
|
||||||
|
self._size = 0
|
||||||
|
with open(config_file, 'r') as fin:
|
||||||
|
for line in fin:
|
||||||
|
if line[0:2] == '\\#':
|
||||||
|
line = '#\n'
|
||||||
|
elif line[0] == '#':
|
||||||
|
continue
|
||||||
|
self._label_to_str += line[:-1] # remove the line ending
|
||||||
|
self._str_to_label[line[:-1]] = self._size
|
||||||
|
self._size += 1
|
||||||
|
|
||||||
def text_to_char_array(original):
|
def string_from_label(self, label):
|
||||||
|
return self._label_to_str[label]
|
||||||
|
|
||||||
|
def label_from_string(self, string):
|
||||||
|
return self._str_to_label[string]
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
def text_to_char_array(original, alphabet):
|
||||||
r"""
|
r"""
|
||||||
Given a Python string ``original``, remove unsupported characters, map characters
|
Given a Python string ``original``, remove unsupported characters, map characters
|
||||||
to integers and return a numpy array representing the processed string.
|
to integers and return a numpy array representing the processed string.
|
||||||
"""
|
"""
|
||||||
# Create list of sentence's words w/spaces replaced by ''
|
return np.asarray([alphabet.label_from_string(c) for c in original])
|
||||||
result = original.replace(" '", "") # TODO: Deal with this properly
|
|
||||||
result = result.replace("'", "") # TODO: Deal with this properly
|
|
||||||
|
|
||||||
# Tokenize into letters adding in SPACE_TOKEN where required
|
|
||||||
result = np.hstack([SPACE_TOKEN if xt == ' ' else xt for xt in result])
|
|
||||||
|
|
||||||
# Map characters into indicies
|
|
||||||
result = np.asarray([SPACE_INDEX if xt == SPACE_TOKEN else ord(xt) - FIRST_INDEX for xt in result])
|
|
||||||
|
|
||||||
# Add result to results
|
|
||||||
return result
|
|
||||||
|
|
||||||
def sparse_tuple_from(sequences, dtype=np.int32):
|
def sparse_tuple_from(sequences, dtype=np.int32):
|
||||||
r"""Creates a sparse representention of ``sequences``.
|
r"""Creates a sparse representention of ``sequences``.
|
||||||
|
@ -48,29 +56,27 @@ def sparse_tuple_from(sequences, dtype=np.int32):
|
||||||
|
|
||||||
return tf.SparseTensor(indices=indices, values=values, shape=shape)
|
return tf.SparseTensor(indices=indices, values=values, shape=shape)
|
||||||
|
|
||||||
def sparse_tensor_value_to_texts(value):
|
def sparse_tensor_value_to_texts(value, alphabet):
|
||||||
r"""
|
r"""
|
||||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||||
representing its values.
|
representing its values.
|
||||||
"""
|
"""
|
||||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape))
|
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||||
|
|
||||||
def sparse_tuple_to_texts(tuple):
|
def sparse_tuple_to_texts(tuple, alphabet):
|
||||||
indices = tuple[0]
|
indices = tuple[0]
|
||||||
values = tuple[1]
|
values = tuple[1]
|
||||||
results = [''] * tuple[2][0]
|
results = [''] * tuple[2][0]
|
||||||
for i in range(len(indices)):
|
for i in range(len(indices)):
|
||||||
index = indices[i][0]
|
index = indices[i][0]
|
||||||
c = values[i]
|
results[index] += alphabet.string_from_label(values[i])
|
||||||
c = ' ' if c == SPACE_INDEX else chr(c + FIRST_INDEX)
|
|
||||||
results[index] = results[index] + c
|
|
||||||
# List of strings
|
# List of strings
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def ndarray_to_text(value):
|
def ndarray_to_text(value, alphabet):
|
||||||
results = ''
|
results = ''
|
||||||
for i in range(len(value)):
|
for i in range(len(value)):
|
||||||
results += chr(value[i] + FIRST_INDEX)
|
results += alphabet.string_from_label(value[i])
|
||||||
return results.replace('`', ' ')
|
return results.replace('`', ' ')
|
||||||
|
|
||||||
def wer(original, result):
|
def wer(original, result):
|
||||||
|
|
Loading…
Reference in New Issue