Support custom alphabet mappings
This commit is contained in:
parent
498b66867e
commit
1c4cbf1813
|
@ -24,7 +24,7 @@ from util.feeding import DataSet, ModelFeeder
|
|||
from util.gpu import get_available_gpus
|
||||
from util.shared_lib import check_cupti
|
||||
from util.spell import correction
|
||||
from util.text import sparse_tensor_value_to_texts, wer
|
||||
from util.text import sparse_tensor_value_to_texts, wer, Alphabet
|
||||
from xdg import BaseDirectory as xdg
|
||||
import numpy as np
|
||||
|
||||
|
@ -139,6 +139,8 @@ tf.app.flags.DEFINE_integer ('earlystop_nsteps', 4, 'number of steps t
|
|||
tf.app.flags.DEFINE_float ('estop_mean_thresh', 0.5, 'mean threshold for loss to determine the condition if early stopping is required')
|
||||
tf.app.flags.DEFINE_float ('estop_std_thresh', 0.5, 'standard deviation threshold for loss to determine the condition if early stopping is required')
|
||||
|
||||
tf.app.flags.DEFINE_string ('alphabet_config_path', 'data/alphabet.txt', 'path to the configuration file specifying the alphabet used by the network. See the comment in data/alphabet.txt for a description of the format.')
|
||||
|
||||
for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']:
|
||||
tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var)
|
||||
|
||||
|
@ -220,6 +222,9 @@ def initialize_globals():
|
|||
global session_config
|
||||
session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=FLAGS.log_placement)
|
||||
|
||||
global alphabet
|
||||
alphabet = Alphabet(os.path.abspath(FLAGS.alphabet_config_path))
|
||||
|
||||
# Geometric Constants
|
||||
# ===================
|
||||
|
||||
|
@ -257,7 +262,7 @@ def initialize_globals():
|
|||
|
||||
# The number of characters in the target language plus one
|
||||
global n_character
|
||||
n_character = 29 # TODO: Determine if this should be extended with other punctuation
|
||||
n_character = alphabet.size() + 1 # +1 for CTC blank label
|
||||
|
||||
# The number of units in the sixth layer
|
||||
global n_hidden_6
|
||||
|
@ -712,7 +717,7 @@ def calculate_report(results_tuple):
|
|||
items = list(zip(*results_tuple))
|
||||
mean_wer = 0.0
|
||||
for label, decoding, distance, loss in items:
|
||||
corrected = correction(decoding)
|
||||
corrected = correction(decoding, alphabet)
|
||||
sample_wer = wer(label, corrected)
|
||||
sample = Sample(label, corrected, loss, distance, sample_wer)
|
||||
samples.append(sample)
|
||||
|
@ -750,10 +755,10 @@ def collect_results(results_tuple, returns):
|
|||
# Each of the arrays within results_tuple will get extended by a batch of each available device
|
||||
for i in range(len(available_devices)):
|
||||
# Collect the labels
|
||||
results_tuple[0].extend(sparse_tensor_value_to_texts(returns[0][i]))
|
||||
results_tuple[0].extend(sparse_tensor_value_to_texts(returns[0][i], alphabet))
|
||||
|
||||
# Collect the decodings - at the moment we default to the first one
|
||||
results_tuple[1].extend(sparse_tensor_value_to_texts(returns[1][i][0]))
|
||||
results_tuple[1].extend(sparse_tensor_value_to_texts(returns[1][i][0], alphabet))
|
||||
|
||||
# Collect the distances
|
||||
results_tuple[2].extend(returns[2][i])
|
||||
|
@ -1434,6 +1439,7 @@ def train(server=None):
|
|||
test_set,
|
||||
n_input,
|
||||
n_context,
|
||||
alphabet,
|
||||
tower_feeder_count=len(available_devices))
|
||||
|
||||
# Create the optimizer
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# Each line in this file represents the Unicode codepoint (UTF-8 encoded)
|
||||
# associated with a numeric label.
|
||||
# A line that starts with # is a comment. You can escape it with \# if you wish
|
||||
# to use '#' as a label.
|
||||
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
'
|
||||
# The last (non-comment) line needs to end with a newline.
|
|
@ -5,7 +5,7 @@ load("//tensorflow:tensorflow.bzl",
|
|||
|
||||
cc_library(
|
||||
name = "deepspeech",
|
||||
srcs = ["deepspeech.cc"],
|
||||
srcs = ["deepspeech.cc", "alphabet.h"],
|
||||
hdrs = ["deepspeech.h"],
|
||||
deps = ["//tensorflow/core:core",
|
||||
":deepspeech_utils"],
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
#ifndef ALPHABET_H
|
||||
#define ALPHABET_H
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
/*
|
||||
* Loads a text file describing a mapping of labels to strings, one string per
|
||||
* line. This is used by the decoder, client and Python scripts to convert the
|
||||
* output of the decoder to a human-readable string and vice-versa.
|
||||
*/
|
||||
class Alphabet {
|
||||
public:
|
||||
Alphabet(const char *config_file) {
|
||||
std::ifstream in(config_file, std::ios::in);
|
||||
unsigned int label = 0;
|
||||
for (std::string line; std::getline(in, line);) {
|
||||
if (line.size() == 2 && line[0] == '\\' && line[1] == '#') {
|
||||
line = '#';
|
||||
} else if (line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
label_to_str_[label] = line;
|
||||
str_to_label_[line] = label;
|
||||
++label;
|
||||
}
|
||||
size_ = label;
|
||||
in.close();
|
||||
}
|
||||
|
||||
const std::string& StringFromLabel(unsigned int label) const {
|
||||
assert(label < size_);
|
||||
auto it = label_to_str_.find(label);
|
||||
if (it != label_to_str_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
// unreachable due to assert above
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int LabelFromString(const std::string& string) const {
|
||||
auto it = str_to_label_.find(string);
|
||||
if (it != str_to_label_.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
size_t GetSize() {
|
||||
return size_;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t size_;
|
||||
std::unordered_map<unsigned int, std::string> label_to_str_;
|
||||
std::unordered_map<std::string, unsigned int> str_to_label_;
|
||||
};
|
||||
|
||||
#endif //ALPHABET_H
|
|
@ -58,17 +58,19 @@ LocalDsSTT(Model& aCtx, const short* aBuffer, size_t aBufferSize,
|
|||
int
|
||||
main(int argc, char **argv)
|
||||
{
|
||||
if (argc < 3 || argc > 4) {
|
||||
if (argc < 4 || argc > 5) {
|
||||
printf("Usage: deepspeech MODEL_PATH AUDIO_PATH [-t]\n");
|
||||
printf(" MODEL_PATH\tPath to the model (protocol buffer binary file)\n");
|
||||
printf(" AUDIO_PATH\tPath to the audio file to run"
|
||||
" (any file format supported by libsox)\n");
|
||||
printf(" ALPHABET_PATH\tPath to the configuration file specifying"
|
||||
" the alphabet used by the network.\n");
|
||||
printf(" -t\t\tRun in benchmark mode, output mfcc & inference time\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Initialise DeepSpeech
|
||||
Model ctx = Model(argv[1], N_CEP, N_CONTEXT);
|
||||
Model ctx = Model(argv[1], N_CEP, N_CONTEXT, argv[3]);
|
||||
|
||||
// Initialise SOX
|
||||
assert(sox_init() == SOX_SUCCESS);
|
||||
|
|
|
@ -9,7 +9,7 @@ Fs.createReadStream(process.argv[3]).
|
|||
pipe(audioStream);
|
||||
audioStream.on('finish', () => {
|
||||
audioBuffer = audioStream.toBuffer();
|
||||
var model = new Ds.Model(process.argv[2], 26, 9);
|
||||
var model = new Ds.Model(process.argv[2], 26, 9, process.argv[4]);
|
||||
// We take half of the buffer_size because buffer is a char* while
|
||||
// LocalDsSTT() expected a short*
|
||||
console.log(model.stt(audioBuffer.slice(0, audioBuffer.length / 2), 16000));
|
||||
|
|
|
@ -6,6 +6,6 @@ import sys
|
|||
import scipy.io.wavfile as wav
|
||||
from deepspeech.model import Model
|
||||
|
||||
ds = Model(sys.argv[1], 26, 9)
|
||||
ds = Model(sys.argv[1], 26, 9, sys.argv[3])
|
||||
fs, audio = wav.read(sys.argv[2])
|
||||
print(ds.stt(audio, fs))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "deepspeech.h"
|
||||
#include "deepspeech_utils.h"
|
||||
#include "alphabet.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
|
@ -13,9 +14,11 @@ class Private {
|
|||
GraphDef graph_def;
|
||||
int ncep;
|
||||
int ncontext;
|
||||
Alphabet* alphabet;
|
||||
};
|
||||
|
||||
Model::Model(const char* aModelPath, int aNCep, int aNContext)
|
||||
Model::Model(const char* aModelPath, int aNCep, int aNContext,
|
||||
const char* aAlphabetConfigPath)
|
||||
{
|
||||
mPriv = new Private;
|
||||
|
||||
|
@ -44,6 +47,8 @@ Model::Model(const char* aModelPath, int aNCep, int aNContext)
|
|||
|
||||
mPriv->ncep = aNCep;
|
||||
mPriv->ncontext = aNContext;
|
||||
|
||||
mPriv->alphabet = new Alphabet(aAlphabetConfigPath);
|
||||
}
|
||||
|
||||
Model::~Model()
|
||||
|
@ -52,6 +57,8 @@ Model::~Model()
|
|||
mPriv->session->Close();
|
||||
}
|
||||
|
||||
delete mPriv->alphabet;
|
||||
|
||||
delete mPriv;
|
||||
}
|
||||
|
||||
|
@ -105,13 +112,24 @@ Model::infer(float* aMfcc, int aNFrames, int aFrameLen)
|
|||
// Output is an array of shape (1, n_results, result_length).
|
||||
// In this case, n_results is also equal to 1.
|
||||
auto output_mapped = outputs[0].tensor<int64, 3>();
|
||||
int length = output_mapped.dimension(2) + 1;
|
||||
char* output = (char*)malloc(sizeof(char) * length);
|
||||
for (int i = 0; i < length - 1; i++) {
|
||||
size_t output_length = output_mapped.dimension(2) + 1;
|
||||
|
||||
size_t decoded_length = 1; // add 1 for the \0
|
||||
for (int i = 0; i < output_length - 1; i++) {
|
||||
int64 character = output_mapped(0, 0, i);
|
||||
output[i] = (character == 0) ? ' ' : (character + 'a' - 1);
|
||||
const std::string& str = mPriv->alphabet->StringFromLabel(character);
|
||||
decoded_length += str.size();
|
||||
}
|
||||
output[length - 1] = '\0';
|
||||
|
||||
char* output = (char*)malloc(sizeof(char) * decoded_length);
|
||||
char* pen = output;
|
||||
for (int i = 0; i < output_length - 1; i++) {
|
||||
int64 character = output_mapped(0, 0, i);
|
||||
const std::string& str = mPriv->alphabet->StringFromLabel(character);
|
||||
strncpy(pen, str.c_str(), str.size());
|
||||
pen += str.size();
|
||||
}
|
||||
*pen = '\0';
|
||||
|
||||
return output;
|
||||
}
|
||||
|
|
|
@ -20,8 +20,11 @@ namespace DeepSpeech
|
|||
* @param aModelPath The path to the frozen model graph.
|
||||
* @param aNCep The number of cepstrum the model was trained with.
|
||||
* @param aNContext The context window the model was trained with.
|
||||
* @param aAlphabetConfigPath The path to the configuration file specifying
|
||||
* the the alphabet used by the network. See alphabet.h.
|
||||
*/
|
||||
Model(const char* aModelPath, int aNCep, int aNContext);
|
||||
Model(const char* aModelPath, int aNCep, int aNContext,
|
||||
const char* aAlphabetConfigPath);
|
||||
|
||||
/**
|
||||
* @brief Frees associated resources and destroys model object.
|
||||
|
|
|
@ -6,6 +6,6 @@ source $(dirname "$0")/tc-tests-utils.sh
|
|||
|
||||
download_material "/tmp/ds"
|
||||
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds/:$LD_LIBRARY_PATH /tmp/ds/deepspeech /tmp/${model_name} /tmp/LDC93S1.wav)
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds/:$LD_LIBRARY_PATH /tmp/ds/deepspeech /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||
|
||||
assert_correct_ldc93s1 "${phrase}"
|
||||
|
|
|
@ -20,7 +20,7 @@ pushd ${HOME}/DeepSpeech/ds/native_client/
|
|||
npm --version
|
||||
npm install ${DEEPSPEECH_ARTIFACTS_ROOT}/deepspeech-0.0.1.tgz
|
||||
npm install
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH node client.js /tmp/${model_name} /tmp/LDC93S1.wav)
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH node client.js /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||
popd
|
||||
|
||||
assert_correct_ldc93s1 "${phrase}"
|
||||
|
|
|
@ -27,3 +27,5 @@ find ${DS_ROOT_TASK}/DeepSpeech/ds/native_client/javascript/ -type f -name "deep
|
|||
|
||||
pixz -9 ${TASKCLUSTER_ARTIFACTS}/native_client.tar ${TASKCLUSTER_ARTIFACTS}/native_client.tar.xz
|
||||
rm ${TASKCLUSTER_ARTIFACTS}/native_client.tar
|
||||
|
||||
cp ${DS_ROOT_TASK}/DeepSpeech/ds/data/alphabet.txt ${TASKCLUSTER_ARTIFACTS}/
|
||||
|
|
|
@ -42,7 +42,7 @@ platform=$(python -c 'import sys; import platform; sys.stdout.write("%s_%s" % (p
|
|||
deepspeech_pkg="deepspeech-0.0.1-cp${pyver_pkg}-cp${pyver_pkg}${py_unicode_type}-${platform}.whl"
|
||||
pip install --upgrade ${DEEPSPEECH_ARTIFACTS_ROOT}/${deepspeech_pkg}
|
||||
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH python ${HOME}/DeepSpeech/ds/native_client/client.py /tmp/${model_name} /tmp/LDC93S1.wav)
|
||||
phrase=$(LD_LIBRARY_PATH=/tmp/ds-lib/:$LD_LIBRARY_PATH python ${HOME}/DeepSpeech/ds/native_client/client.py /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt)
|
||||
|
||||
assert_correct_ldc93s1 "${phrase}"
|
||||
|
||||
|
|
|
@ -61,8 +61,9 @@ download_material()
|
|||
wget ${DEEPSPEECH_MODEL} -O /tmp/${model_name}
|
||||
wget https://catalog.ldc.upenn.edu/desc/addenda/LDC93S1.wav -O /tmp/LDC93S1.wav
|
||||
wget ${DEEPSPEECH_ARTIFACTS_ROOT}/native_client.tar.xz -O - | pixz -d | tar -C ${target_dir} -xf -
|
||||
wget ${DEEPSPEECH_ARTIFACTS_ROOT}/alphabet.txt -O /tmp/alphabet.txt
|
||||
|
||||
ls -hal /tmp/${model_name} /tmp/LDC93S1.wav
|
||||
ls -hal /tmp/${model_name} /tmp/LDC93S1.wav /tmp/alphabet.txt
|
||||
}
|
||||
|
||||
install_pyenv()
|
||||
|
|
|
@ -22,6 +22,7 @@ class ModelFeeder(object):
|
|||
test_set,
|
||||
numcep,
|
||||
numcontext,
|
||||
alphabet,
|
||||
tower_feeder_count=-1,
|
||||
threads_per_queue=2):
|
||||
|
||||
|
@ -41,7 +42,7 @@ class ModelFeeder(object):
|
|||
self.ph_batch_size = tf.placeholder(tf.int32, [])
|
||||
self.ph_queue_selector = tf.placeholder(tf.int32, name='Queue_Selector')
|
||||
|
||||
self._tower_feeders = [_TowerFeeder(self, i) for i in range(self.tower_feeder_count)]
|
||||
self._tower_feeders = [_TowerFeeder(self, i, alphabet) for i in range(self.tower_feeder_count)]
|
||||
|
||||
def start_queue_threads(self, session, coord):
|
||||
'''
|
||||
|
@ -105,7 +106,7 @@ class _DataSetLoader(object):
|
|||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
||||
Keeps a DataSet reference to access its samples.
|
||||
'''
|
||||
def __init__(self, model_feeder, data_set):
|
||||
def __init__(self, model_feeder, data_set, alphabet):
|
||||
self._model_feeder = model_feeder
|
||||
self._data_set = data_set
|
||||
self.queue = tf.PaddingFIFOQueue(shapes=[[None, model_feeder.numcep + (2 * model_feeder.numcep * model_feeder.numcontext)], [], [None,], []],
|
||||
|
@ -113,6 +114,7 @@ class _DataSetLoader(object):
|
|||
capacity=data_set.batch_size * 2)
|
||||
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
||||
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
||||
self._alphabet = alphabet
|
||||
|
||||
def start_queue_threads(self, session, coord):
|
||||
'''
|
||||
|
@ -143,7 +145,7 @@ class _DataSetLoader(object):
|
|||
wav_file, transcript = self._data_set.files[index]
|
||||
source = audiofile_to_input_vector(wav_file, self._model_feeder.numcep, self._model_feeder.numcontext)
|
||||
source_len = len(source)
|
||||
target = text_to_char_array(transcript)
|
||||
target = text_to_char_array(transcript, self._alphabet)
|
||||
target_len = len(target)
|
||||
try:
|
||||
session.run(self._enqueue_op, feed_dict={ self._model_feeder.ph_x: source,
|
||||
|
@ -159,10 +161,10 @@ class _TowerFeeder(object):
|
|||
It creates, owns and combines three _DataSetLoader instances.
|
||||
Keeps a ModelFeeder reference for accessing shared settings and placeholders.
|
||||
'''
|
||||
def __init__(self, model_feeder, index):
|
||||
def __init__(self, model_feeder, index, alphabet):
|
||||
self._model_feeder = model_feeder
|
||||
self.index = index
|
||||
self._loaders = [_DataSetLoader(model_feeder, data_set) for data_set in model_feeder.sets]
|
||||
self._loaders = [_DataSetLoader(model_feeder, data_set, alphabet) for data_set in model_feeder.sets]
|
||||
self._queues = [set_queue.queue for set_queue in self._loaders]
|
||||
self._queue = tf.QueueBase.from_list(model_feeder.ph_queue_selector, self._queues)
|
||||
self._close_op = self._queue.close(cancel_pending_enqueues=True)
|
||||
|
|
|
@ -27,26 +27,26 @@ def log_probability(sentence):
|
|||
"Log base 10 probability of `sentence`, a list of words"
|
||||
return get_model().score(' '.join(sentence), bos = False, eos = False)
|
||||
|
||||
def correction(sentence):
|
||||
def correction(sentence, alphabet):
|
||||
"Most probable spelling correction for sentence."
|
||||
layer = [(0,[])]
|
||||
for word in words(sentence):
|
||||
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word) for priority, node in layer]
|
||||
layer = [(-log_probability(node + [cword]), node + [cword]) for cword in candidate_words(word, alphabet) for priority, node in layer]
|
||||
heapify(layer)
|
||||
layer = layer[:BEAM_WIDTH]
|
||||
return ' '.join(layer[0][1])
|
||||
|
||||
def candidate_words(word):
|
||||
def candidate_words(word, alphabet):
|
||||
"Generate possible spelling corrections for word."
|
||||
return (known_words([word]) or known_words(edits1(word)) or known_words(edits2(word)) or [word])
|
||||
return (known_words([word]) or known_words(edits1(word, alphabet)) or known_words(edits2(word, alphabet)) or [word])
|
||||
|
||||
def known_words(words):
|
||||
"The subset of `words` that appear in the dictionary of WORDS."
|
||||
return set(w for w in words if w in WORDS)
|
||||
|
||||
def edits1(word):
|
||||
def edits1(word, alphabet):
|
||||
"All edits that are one edit away from `word`."
|
||||
letters = 'abcdefghijklmnopqrstuvwxyz'
|
||||
letters = [alphabet.string_from_label(i) for i in range(alphabet.size())]
|
||||
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
||||
deletes = [L + R[1:] for L, R in splits if R]
|
||||
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
|
||||
|
@ -54,6 +54,6 @@ def edits1(word):
|
|||
inserts = [L + c + R for L, R in splits for c in letters]
|
||||
return set(deletes + transposes + replaces + inserts)
|
||||
|
||||
def edits2(word):
|
||||
def edits2(word, alphabet):
|
||||
"All edits that are two edits away from `word`."
|
||||
return (e2 for e1 in edits1(word) for e2 in edits1(e1))
|
||||
return (e2 for e1 in edits1(word, alphabet) for e2 in edits1(e1, alphabet))
|
||||
|
|
56
util/text.py
56
util/text.py
|
@ -5,28 +5,36 @@ import re
|
|||
from six.moves import range
|
||||
from functools import reduce
|
||||
|
||||
# Constants
|
||||
SPACE_TOKEN = '<space>'
|
||||
SPACE_INDEX = 0
|
||||
FIRST_INDEX = ord('a') - 1 # 0 is reserved to space
|
||||
class Alphabet(object):
|
||||
def __init__(self, config_file):
|
||||
self._label_to_str = []
|
||||
self._str_to_label = {}
|
||||
self._size = 0
|
||||
with open(config_file, 'r') as fin:
|
||||
for line in fin:
|
||||
if line[0:2] == '\\#':
|
||||
line = '#\n'
|
||||
elif line[0] == '#':
|
||||
continue
|
||||
self._label_to_str += line[:-1] # remove the line ending
|
||||
self._str_to_label[line[:-1]] = self._size
|
||||
self._size += 1
|
||||
|
||||
def text_to_char_array(original):
|
||||
def string_from_label(self, label):
|
||||
return self._label_to_str[label]
|
||||
|
||||
def label_from_string(self, string):
|
||||
return self._str_to_label[string]
|
||||
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def text_to_char_array(original, alphabet):
|
||||
r"""
|
||||
Given a Python string ``original``, remove unsupported characters, map characters
|
||||
to integers and return a numpy array representing the processed string.
|
||||
"""
|
||||
# Create list of sentence's words w/spaces replaced by ''
|
||||
result = original.replace(" '", "") # TODO: Deal with this properly
|
||||
result = result.replace("'", "") # TODO: Deal with this properly
|
||||
|
||||
# Tokenize into letters adding in SPACE_TOKEN where required
|
||||
result = np.hstack([SPACE_TOKEN if xt == ' ' else xt for xt in result])
|
||||
|
||||
# Map characters into indicies
|
||||
result = np.asarray([SPACE_INDEX if xt == SPACE_TOKEN else ord(xt) - FIRST_INDEX for xt in result])
|
||||
|
||||
# Add result to results
|
||||
return result
|
||||
return np.asarray([alphabet.label_from_string(c) for c in original])
|
||||
|
||||
def sparse_tuple_from(sequences, dtype=np.int32):
|
||||
r"""Creates a sparse representention of ``sequences``.
|
||||
|
@ -48,29 +56,27 @@ def sparse_tuple_from(sequences, dtype=np.int32):
|
|||
|
||||
return tf.SparseTensor(indices=indices, values=values, shape=shape)
|
||||
|
||||
def sparse_tensor_value_to_texts(value):
|
||||
def sparse_tensor_value_to_texts(value, alphabet):
|
||||
r"""
|
||||
Given a :class:`tf.SparseTensor` ``value``, return an array of Python strings
|
||||
representing its values.
|
||||
"""
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape))
|
||||
return sparse_tuple_to_texts((value.indices, value.values, value.dense_shape), alphabet)
|
||||
|
||||
def sparse_tuple_to_texts(tuple):
|
||||
def sparse_tuple_to_texts(tuple, alphabet):
|
||||
indices = tuple[0]
|
||||
values = tuple[1]
|
||||
results = [''] * tuple[2][0]
|
||||
for i in range(len(indices)):
|
||||
index = indices[i][0]
|
||||
c = values[i]
|
||||
c = ' ' if c == SPACE_INDEX else chr(c + FIRST_INDEX)
|
||||
results[index] = results[index] + c
|
||||
results[index] += alphabet.string_from_label(values[i])
|
||||
# List of strings
|
||||
return results
|
||||
|
||||
def ndarray_to_text(value):
|
||||
def ndarray_to_text(value, alphabet):
|
||||
results = ''
|
||||
for i in range(len(value)):
|
||||
results += chr(value[i] + FIRST_INDEX)
|
||||
results += alphabet.string_from_label(value[i])
|
||||
return results.replace('`', ' ')
|
||||
|
||||
def wer(original, result):
|
||||
|
|
Loading…
Reference in New Issue