Merge pull request #1532 from mozilla/feature-caching-clean
Feature caching
This commit is contained in:
commit
5dba8e34cb
@ -25,6 +25,7 @@ from tensorflow.python.tools import freeze_graph
|
|||||||
from threading import Thread, Lock
|
from threading import Thread, Lock
|
||||||
from util.audio import audiofile_to_input_vector
|
from util.audio import audiofile_to_input_vector
|
||||||
from util.feeding import DataSet, ModelFeeder
|
from util.feeding import DataSet, ModelFeeder
|
||||||
|
from util.preprocess import preprocess
|
||||||
from util.gpu import get_available_gpus
|
from util.gpu import get_available_gpus
|
||||||
from util.shared_lib import check_cupti
|
from util.shared_lib import check_cupti
|
||||||
from util.text import sparse_tensor_value_to_texts, wer, levenshtein, Alphabet, ndarray_to_text
|
from util.text import sparse_tensor_value_to_texts, wer, levenshtein, Alphabet, ndarray_to_text
|
||||||
@ -40,6 +41,10 @@ def create_flags():
|
|||||||
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||||
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
|
||||||
|
|
||||||
|
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
|
||||||
|
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
|
||||||
|
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
|
||||||
|
|
||||||
# Cluster configuration
|
# Cluster configuration
|
||||||
# =====================
|
# =====================
|
||||||
|
|
||||||
@ -402,7 +407,7 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
|
|||||||
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
|
||||||
|
|
||||||
# Permute n_steps and batch_size
|
# Permute n_steps and batch_size
|
||||||
batch_x = tf.transpose(batch_x, [1, 0, 2])
|
batch_x = tf.transpose(batch_x, [1, 0, 2, 3])
|
||||||
# Reshape to prepare input for first layer
|
# Reshape to prepare input for first layer
|
||||||
batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
|
||||||
layers['input_reshaped'] = batch_x
|
layers['input_reshaped'] = batch_x
|
||||||
@ -1459,19 +1464,40 @@ def train(server=None):
|
|||||||
global_step = tf.Variable(0, trainable=False, name='global_step')
|
global_step = tf.Variable(0, trainable=False, name='global_step')
|
||||||
|
|
||||||
# Reading training set
|
# Reading training set
|
||||||
train_set = DataSet(FLAGS.train_files.split(','),
|
train_data = preprocess(FLAGS.train_files.split(','),
|
||||||
|
FLAGS.train_batch_size,
|
||||||
|
n_input,
|
||||||
|
n_context,
|
||||||
|
alphabet,
|
||||||
|
hdf5_cache_path=FLAGS.train_cached_features_path)
|
||||||
|
|
||||||
|
train_set = DataSet(train_data,
|
||||||
FLAGS.train_batch_size,
|
FLAGS.train_batch_size,
|
||||||
limit=FLAGS.limit_train,
|
limit=FLAGS.limit_train,
|
||||||
next_index=lambda i: COORD.get_next_index('train'))
|
next_index=lambda i: COORD.get_next_index('train'))
|
||||||
|
|
||||||
# Reading validation set
|
# Reading validation set
|
||||||
dev_set = DataSet(FLAGS.dev_files.split(','),
|
dev_data = preprocess(FLAGS.dev_files.split(','),
|
||||||
|
FLAGS.dev_batch_size,
|
||||||
|
n_input,
|
||||||
|
n_context,
|
||||||
|
alphabet,
|
||||||
|
hdf5_cache_path=FLAGS.dev_cached_features_path)
|
||||||
|
|
||||||
|
dev_set = DataSet(dev_data,
|
||||||
FLAGS.dev_batch_size,
|
FLAGS.dev_batch_size,
|
||||||
limit=FLAGS.limit_dev,
|
limit=FLAGS.limit_dev,
|
||||||
next_index=lambda i: COORD.get_next_index('dev'))
|
next_index=lambda i: COORD.get_next_index('dev'))
|
||||||
|
|
||||||
# Reading test set
|
# Reading test set
|
||||||
test_set = DataSet(FLAGS.test_files.split(','),
|
test_data = preprocess(FLAGS.test_files.split(','),
|
||||||
|
FLAGS.test_batch_size,
|
||||||
|
n_input,
|
||||||
|
n_context,
|
||||||
|
alphabet,
|
||||||
|
hdf5_cache_path=FLAGS.test_cached_features_path)
|
||||||
|
|
||||||
|
test_set = DataSet(test_data,
|
||||||
FLAGS.test_batch_size,
|
FLAGS.test_batch_size,
|
||||||
limit=FLAGS.limit_test,
|
limit=FLAGS.limit_test,
|
||||||
next_index=lambda i: COORD.get_next_index('test'))
|
next_index=lambda i: COORD.get_next_index('test'))
|
||||||
@ -1759,8 +1785,8 @@ def train(server=None):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
|
||||||
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context]
|
# Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
|
||||||
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node')
|
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
|
||||||
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
|
||||||
|
|
||||||
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
|
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)
|
||||||
|
83
evaluate.py
83
evaluate.py
@ -19,6 +19,7 @@ from multiprocessing import Pool
|
|||||||
from six.moves import zip, range
|
from six.moves import zip, range
|
||||||
from util.audio import audiofile_to_input_vector
|
from util.audio import audiofile_to_input_vector
|
||||||
from util.text import sparse_tensor_value_to_texts, text_to_char_array, Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
|
from util.text import sparse_tensor_value_to_texts, text_to_char_array, Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
|
||||||
|
from util.preprocess import pmap, preprocess
|
||||||
|
|
||||||
|
|
||||||
FLAGS = tf.app.flags.FLAGS
|
FLAGS = tf.app.flags.FLAGS
|
||||||
@ -28,88 +29,6 @@ N_FEATURES = 26
|
|||||||
N_CONTEXT = 9
|
N_CONTEXT = 9
|
||||||
|
|
||||||
|
|
||||||
def pmap(fun, iterable, threads=8):
|
|
||||||
pool = Pool(threads)
|
|
||||||
results = pool.map(fun, iterable)
|
|
||||||
pool.close()
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def process_single_file(row):
|
|
||||||
# row = index, Series
|
|
||||||
_, file = row
|
|
||||||
features = audiofile_to_input_vector(file.wav_filename, N_FEATURES, N_CONTEXT)
|
|
||||||
transcript = text_to_char_array(file.transcript, alphabet)
|
|
||||||
|
|
||||||
return features, len(features), transcript, len(transcript)
|
|
||||||
|
|
||||||
|
|
||||||
# load samples from CSV, compute features, optionally cache results on disk
|
|
||||||
def preprocess(dataset_files, batch_size, hdf5_dest_path=None):
|
|
||||||
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
|
|
||||||
|
|
||||||
if hdf5_dest_path and os.path.exists(hdf5_dest_path):
|
|
||||||
with tables.open_file(hdf5_dest_path, 'r') as file:
|
|
||||||
features = file.root.features[:]
|
|
||||||
features_len = file.root.features_len[:]
|
|
||||||
transcript = file.root.transcript[:]
|
|
||||||
transcript_len = file.root.transcript_len[:]
|
|
||||||
|
|
||||||
# features are stored flattened, so reshape into
|
|
||||||
# [n_steps, (n_input + 2*n_context*n_input)]
|
|
||||||
for i in range(len(features)):
|
|
||||||
features[i] = np.reshape(features[i], [features_len[i], -1])
|
|
||||||
|
|
||||||
in_data = list(zip(features, features_len,
|
|
||||||
transcript, transcript_len))
|
|
||||||
return pandas.DataFrame(data=in_data, columns=COLUMNS)
|
|
||||||
|
|
||||||
csv_files = dataset_files.split(',')
|
|
||||||
source_data = None
|
|
||||||
for csv in csv_files:
|
|
||||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
|
||||||
if source_data is None:
|
|
||||||
source_data = file
|
|
||||||
else:
|
|
||||||
source_data = source_data.append(file)
|
|
||||||
|
|
||||||
# discard last samples if dataset does not divide batch size evenly
|
|
||||||
if len(source_data) % batch_size != 0:
|
|
||||||
source_data = source_data[:-(len(source_data) % batch_size)]
|
|
||||||
|
|
||||||
out_data = pmap(process_single_file, source_data.iterrows())
|
|
||||||
|
|
||||||
if hdf5_dest_path:
|
|
||||||
# list of tuples -> tuple of lists
|
|
||||||
features, features_len, transcript, transcript_len = zip(*out_data)
|
|
||||||
|
|
||||||
with tables.open_file(hdf5_dest_path, 'w') as file:
|
|
||||||
features_dset = file.create_vlarray(file.root,
|
|
||||||
'features',
|
|
||||||
tables.Float32Atom(),
|
|
||||||
filters=tables.Filters(complevel=1))
|
|
||||||
# VLArray atoms need to be 1D, so flatten feature array
|
|
||||||
for f in features:
|
|
||||||
features_dset.append(np.reshape(f, -1))
|
|
||||||
|
|
||||||
features_len_dset = file.create_array(file.root,
|
|
||||||
'features_len',
|
|
||||||
features_len)
|
|
||||||
|
|
||||||
transcript_dset = file.create_vlarray(file.root,
|
|
||||||
'transcript',
|
|
||||||
tables.Int32Atom(),
|
|
||||||
filters=tables.Filters(complevel=1))
|
|
||||||
for t in transcript:
|
|
||||||
transcript_dset.append(t)
|
|
||||||
|
|
||||||
transcript_len_dset = file.create_array(file.root,
|
|
||||||
'transcript_len',
|
|
||||||
transcript_len)
|
|
||||||
|
|
||||||
return pandas.DataFrame(data=out_data, columns=COLUMNS)
|
|
||||||
|
|
||||||
|
|
||||||
def split_data(dataset, batch_size):
|
def split_data(dataset, batch_size):
|
||||||
remainder = len(dataset) % batch_size
|
remainder = len(dataset) % batch_size
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
|
@ -323,14 +323,15 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
|
|||||||
return;
|
return;
|
||||||
#endif // DS_NATIVE_MODEL
|
#endif // DS_NATIVE_MODEL
|
||||||
} else {
|
} else {
|
||||||
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, mfcc_feats_per_timestep}));
|
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
|
||||||
|
|
||||||
auto input_mapped = input.tensor<float, 3>();
|
auto input_mapped = input.flat<float>();
|
||||||
int idx = 0;
|
int i;
|
||||||
for (int i = 0; i < n_frames; i++) {
|
for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
|
||||||
for (int j = 0; j < mfcc_feats_per_timestep; j++, idx++) {
|
input_mapped(i) = aMfcc[i];
|
||||||
input_mapped(0, i, j) = aMfcc[idx];
|
}
|
||||||
}
|
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
|
||||||
|
input_mapped(i) = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
Tensor input_lengths(DT_INT32, TensorShape({1}));
|
||||||
@ -482,9 +483,8 @@ DS_CreateModel(const char* aModelPath,
|
|||||||
if (node.name() == "input_node") {
|
if (node.name() == "input_node") {
|
||||||
const auto& shape = node.attr().at("shape").shape();
|
const auto& shape = node.attr().at("shape").shape();
|
||||||
model->n_steps = shape.dim(1).size();
|
model->n_steps = shape.dim(1).size();
|
||||||
model->mfcc_feats_per_timestep = shape.dim(2).size();
|
model->n_context = (shape.dim(2).size()-1)/2;
|
||||||
// mfcc_features_per_timestep = MFCC_FEATURES * ((2*n_context) + 1)
|
model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
|
||||||
model->n_context = (model->mfcc_feats_per_timestep - MFCC_FEATURES) / (2 * MFCC_FEATURES);
|
|
||||||
} else if (node.name() == "logits_shape") {
|
} else if (node.name() == "logits_shape") {
|
||||||
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
|
||||||
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
if (!logits_shape.FromProto(node.attr().at("value").tensor())) {
|
||||||
|
@ -13,3 +13,4 @@ pyxdg
|
|||||||
bs4
|
bs4
|
||||||
six
|
six
|
||||||
requests
|
requests
|
||||||
|
tables
|
||||||
|
@ -47,11 +47,15 @@ model_source_mmap="$(dirname "${model_source}")/${model_name_mmap}"
|
|||||||
SUPPORTED_PYTHON_VERSIONS=${SUPPORTED_PYTHON_VERSIONS:-2.7.14:ucs2 2.7.14:ucs4 3.4.8:ucs4 3.5.5:ucs4 3.6.4:ucs4 3.7.0:ucs4}
|
SUPPORTED_PYTHON_VERSIONS=${SUPPORTED_PYTHON_VERSIONS:-2.7.14:ucs2 2.7.14:ucs4 3.4.8:ucs4 3.5.5:ucs4 3.6.4:ucs4 3.7.0:ucs4}
|
||||||
SUPPORTED_NODEJS_VERSIONS=${SUPPORTED_NODEJS_VERSIONS:-4.9.1 5.12.0 6.14.1 7.10.1 8.11.1 9.11.1 10.3.0}
|
SUPPORTED_NODEJS_VERSIONS=${SUPPORTED_NODEJS_VERSIONS:-4.9.1 5.12.0 6.14.1 7.10.1 8.11.1 9.11.1 10.3.0}
|
||||||
|
|
||||||
|
strip() {
|
||||||
|
echo "$(echo $1 | sed -e 's/^[[:space:]]+//' -e 's/[[:space:]]+$//')"
|
||||||
|
}
|
||||||
|
|
||||||
# This verify exact inference result
|
# This verify exact inference result
|
||||||
assert_correct_inference()
|
assert_correct_inference()
|
||||||
{
|
{
|
||||||
phrase=$1
|
phrase=$(strip "$1")
|
||||||
expected=$2
|
expected=$(strip "$2")
|
||||||
|
|
||||||
if [ -z "${phrase}" -o -z "${expected}" ]; then
|
if [ -z "${phrase}" -o -z "${expected}" ]; then
|
||||||
echo "One or more empty strings:"
|
echo "One or more empty strings:"
|
||||||
@ -158,8 +162,8 @@ assert_correct_ldc93s1_prodmodel()
|
|||||||
|
|
||||||
assert_correct_ldc93s1_somodel()
|
assert_correct_ldc93s1_somodel()
|
||||||
{
|
{
|
||||||
somodel_nolm=$1
|
somodel_nolm=$(strip "$1")
|
||||||
somodel_withlm=$2
|
somodel_withlm=$(strip "$2")
|
||||||
|
|
||||||
# We want to be able to return non zero value from the function, while not
|
# We want to be able to return non zero value from the function, while not
|
||||||
# failing the whole execution
|
# failing the whole execution
|
||||||
|
@ -1,53 +1,7 @@
|
|||||||
from __future__ import absolute_import, print_function
|
import numpy as np
|
||||||
|
|
||||||
import scipy.io.wavfile as wav
|
import scipy.io.wavfile as wav
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
class DeepSpeechDeprecationWarning(DeprecationWarning):
|
from python_speech_features import mfcc
|
||||||
pass
|
|
||||||
|
|
||||||
warnings.simplefilter('once', category=DeepSpeechDeprecationWarning)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from deepspeech import audioToInputVector
|
|
||||||
except ImportError:
|
|
||||||
warnings.warn('DeepSpeech Python bindings could not be imported, resorting to slower code to compute audio features. '
|
|
||||||
'Refer to README.md for instructions on how to install (or build) the DeepSpeech Python bindings.',
|
|
||||||
category=DeepSpeechDeprecationWarning)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from python_speech_features import mfcc
|
|
||||||
from six.moves import range
|
|
||||||
|
|
||||||
def audioToInputVector(audio, fs, numcep, numcontext):
|
|
||||||
# Get mfcc coefficients
|
|
||||||
features = mfcc(audio, samplerate=fs, numcep=numcep)
|
|
||||||
|
|
||||||
# We only keep every second feature (BiRNN stride = 2)
|
|
||||||
features = features[::2]
|
|
||||||
|
|
||||||
# One stride per time step in the input
|
|
||||||
num_strides = len(features)
|
|
||||||
|
|
||||||
# Add empty initial and final contexts
|
|
||||||
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
|
|
||||||
features = np.concatenate((empty_context, features, empty_context))
|
|
||||||
|
|
||||||
# Create a view into the array with overlapping strides of size
|
|
||||||
# numcontext (past) + 1 (present) + numcontext (future)
|
|
||||||
window_size = 2*numcontext+1
|
|
||||||
train_inputs = np.lib.stride_tricks.as_strided(
|
|
||||||
features,
|
|
||||||
(num_strides, window_size, numcep),
|
|
||||||
(features.strides[0], features.strides[0], features.strides[1]),
|
|
||||||
writeable=False)
|
|
||||||
|
|
||||||
# Flatten the second and third dimensions
|
|
||||||
train_inputs = np.reshape(train_inputs, [num_strides, -1])
|
|
||||||
|
|
||||||
# Return results
|
|
||||||
return train_inputs
|
|
||||||
|
|
||||||
|
|
||||||
def audiofile_to_input_vector(audio_filename, numcep, numcontext):
|
def audiofile_to_input_vector(audio_filename, numcep, numcontext):
|
||||||
@ -60,4 +14,14 @@ def audiofile_to_input_vector(audio_filename, numcep, numcontext):
|
|||||||
# Load wav files
|
# Load wav files
|
||||||
fs, audio = wav.read(audio_filename)
|
fs, audio = wav.read(audio_filename)
|
||||||
|
|
||||||
return audioToInputVector(audio, fs, numcep, numcontext)
|
# Get mfcc coefficients
|
||||||
|
features = mfcc(audio, samplerate=fs, numcep=numcep)
|
||||||
|
|
||||||
|
# We only keep every second feature (BiRNN stride = 2)
|
||||||
|
features = features[::2]
|
||||||
|
|
||||||
|
# Add empty initial and final contexts
|
||||||
|
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
|
||||||
|
features = np.concatenate((empty_context, features, empty_context))
|
||||||
|
|
||||||
|
return features
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import pandas
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from threading import Thread
|
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from six.moves import range
|
from six.moves import range
|
||||||
from util.audio import audiofile_to_input_vector
|
from threading import Thread
|
||||||
from util.gpu import get_available_gpus
|
from util.gpu import get_available_gpus
|
||||||
from util.text import ctc_label_dense_to_sparse, text_to_char_array
|
from util.text import ctc_label_dense_to_sparse
|
||||||
|
|
||||||
|
|
||||||
class ModelFeeder(object):
|
class ModelFeeder(object):
|
||||||
'''
|
'''
|
||||||
@ -24,7 +24,7 @@ class ModelFeeder(object):
|
|||||||
numcontext,
|
numcontext,
|
||||||
alphabet,
|
alphabet,
|
||||||
tower_feeder_count=-1,
|
tower_feeder_count=-1,
|
||||||
threads_per_queue=2):
|
threads_per_queue=4):
|
||||||
|
|
||||||
self.train = train_set
|
self.train = train_set
|
||||||
self.dev = dev_set
|
self.dev = dev_set
|
||||||
@ -35,7 +35,7 @@ class ModelFeeder(object):
|
|||||||
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
|
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
|
||||||
self.threads_per_queue = threads_per_queue
|
self.threads_per_queue = threads_per_queue
|
||||||
|
|
||||||
self.ph_x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)])
|
self.ph_x = tf.placeholder(tf.float32, [None, 2*numcontext+1, numcep])
|
||||||
self.ph_x_length = tf.placeholder(tf.int32, [])
|
self.ph_x_length = tf.placeholder(tf.int32, [])
|
||||||
self.ph_y = tf.placeholder(tf.int32, [None,])
|
self.ph_y = tf.placeholder(tf.int32, [None,])
|
||||||
self.ph_y_length = tf.placeholder(tf.int32, [])
|
self.ph_y_length = tf.placeholder(tf.int32, [])
|
||||||
@ -77,27 +77,19 @@ class ModelFeeder(object):
|
|||||||
'''
|
'''
|
||||||
return self._tower_feeders[tower_feeder_index].next_batch()
|
return self._tower_feeders[tower_feeder_index].next_batch()
|
||||||
|
|
||||||
|
|
||||||
class DataSet(object):
|
class DataSet(object):
|
||||||
'''
|
'''
|
||||||
Represents a collection of audio samples and their respective transcriptions.
|
Represents a collection of audio samples and their respective transcriptions.
|
||||||
Takes a set of CSV files produced by importers in /bin.
|
Takes a set of CSV files produced by importers in /bin.
|
||||||
'''
|
'''
|
||||||
def __init__(self, csvs, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1):
|
def __init__(self, data, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1):
|
||||||
|
self.data = data
|
||||||
|
self.data.sort_values(by="features_len", ascending=ascending, inplace=True)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.next_index = next_index
|
self.next_index = next_index
|
||||||
self.files = None
|
self.total_batches = int(ceil(len(self.data) / batch_size))
|
||||||
for csv in csvs:
|
|
||||||
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
|
||||||
if self.files is None:
|
|
||||||
self.files = file
|
|
||||||
else:
|
|
||||||
self.files = self.files.append(file)
|
|
||||||
self.files = self.files.sort_values(by="wav_filesize", ascending=ascending) \
|
|
||||||
.ix[:, ["wav_filename", "transcript"]] \
|
|
||||||
.values[skip:]
|
|
||||||
if limit > 0:
|
|
||||||
self.files = self.files[:limit]
|
|
||||||
self.total_batches = int(ceil(len(self.files) / batch_size))
|
|
||||||
|
|
||||||
class _DataSetLoader(object):
|
class _DataSetLoader(object):
|
||||||
'''
|
'''
|
||||||
@ -109,9 +101,9 @@ class _DataSetLoader(object):
|
|||||||
def __init__(self, model_feeder, data_set, alphabet):
|
def __init__(self, model_feeder, data_set, alphabet):
|
||||||
self._model_feeder = model_feeder
|
self._model_feeder = model_feeder
|
||||||
self._data_set = data_set
|
self._data_set = data_set
|
||||||
self.queue = tf.PaddingFIFOQueue(shapes=[[None, model_feeder.numcep + (2 * model_feeder.numcep * model_feeder.numcontext)], [], [None,], []],
|
self.queue = tf.PaddingFIFOQueue(shapes=[[None, 2 * model_feeder.numcontext + 1, model_feeder.numcep], [], [None,], []],
|
||||||
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
|
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
|
||||||
capacity=data_set.batch_size * 2)
|
capacity=data_set.batch_size * 8)
|
||||||
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
|
||||||
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
self._close_op = self.queue.close(cancel_pending_enqueues=True)
|
||||||
self._alphabet = alphabet
|
self._alphabet = alphabet
|
||||||
@ -138,25 +130,35 @@ class _DataSetLoader(object):
|
|||||||
'''
|
'''
|
||||||
Queue thread routine.
|
Queue thread routine.
|
||||||
'''
|
'''
|
||||||
file_count = len(self._data_set.files)
|
file_count = len(self._data_set.data)
|
||||||
index = -1
|
index = -1
|
||||||
while not coord.should_stop():
|
while not coord.should_stop():
|
||||||
index = self._data_set.next_index(index) % file_count
|
index = self._data_set.next_index(index) % file_count
|
||||||
wav_file, transcript = self._data_set.files[index]
|
features, _, transcript, transcript_len = self._data_set.data.iloc[index]
|
||||||
source = audiofile_to_input_vector(wav_file, self._model_feeder.numcep, self._model_feeder.numcontext)
|
|
||||||
source_len = len(source)
|
# One stride per time step in the input
|
||||||
target = text_to_char_array(transcript, self._alphabet)
|
num_strides = len(features) - (self._model_feeder.numcontext * 2)
|
||||||
target_len = len(target)
|
|
||||||
if source_len < target_len:
|
# Create a view into the array with overlapping strides of size
|
||||||
raise ValueError('Error: Audio file {} is too short for transcription.'.format(wav_file))
|
# numcontext (past) + 1 (present) + numcontext (future)
|
||||||
|
window_size = 2*self._model_feeder.numcontext+1
|
||||||
|
features = np.lib.stride_tricks.as_strided(
|
||||||
|
features,
|
||||||
|
(num_strides, window_size, self._model_feeder.numcep),
|
||||||
|
(features.strides[0], features.strides[0], features.strides[1]),
|
||||||
|
writeable=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.run(self._enqueue_op, feed_dict={ self._model_feeder.ph_x: source,
|
session.run(self._enqueue_op, feed_dict={
|
||||||
self._model_feeder.ph_x_length: source_len,
|
self._model_feeder.ph_x: features,
|
||||||
self._model_feeder.ph_y: target,
|
self._model_feeder.ph_x_length: num_strides,
|
||||||
self._model_feeder.ph_y_length: target_len })
|
self._model_feeder.ph_y: transcript,
|
||||||
|
self._model_feeder.ph_y_length: transcript_len
|
||||||
|
})
|
||||||
except tf.errors.CancelledError:
|
except tf.errors.CancelledError:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
class _TowerFeeder(object):
|
class _TowerFeeder(object):
|
||||||
'''
|
'''
|
||||||
Internal class that represents a switchable input queue for one tower.
|
Internal class that represents a switchable input queue for one tower.
|
||||||
|
101
util/preprocess.py
Normal file
101
util/preprocess.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import pandas
|
||||||
|
import tables
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing.dummy import Pool
|
||||||
|
from util.audio import audiofile_to_input_vector
|
||||||
|
from util.text import text_to_char_array
|
||||||
|
|
||||||
|
def pmap(fun, iterable, threads=8):
|
||||||
|
pool = Pool(threads)
|
||||||
|
results = pool.map(fun, iterable)
|
||||||
|
pool.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_file(row, numcep, numcontext, alphabet):
|
||||||
|
# row = index, Series
|
||||||
|
_, file = row
|
||||||
|
features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext)
|
||||||
|
transcript = text_to_char_array(file.transcript, alphabet)
|
||||||
|
|
||||||
|
if (2*numcontext + len(features)) < len(transcript):
|
||||||
|
raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename))
|
||||||
|
|
||||||
|
return features, len(features), transcript, len(transcript)
|
||||||
|
|
||||||
|
|
||||||
|
# load samples from CSV, compute features, optionally cache results on disk
|
||||||
|
def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None):
|
||||||
|
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
|
||||||
|
|
||||||
|
print('Preprocessing', csv_files)
|
||||||
|
|
||||||
|
if hdf5_cache_path and os.path.exists(hdf5_cache_path):
|
||||||
|
with tables.open_file(hdf5_cache_path, 'r') as file:
|
||||||
|
features = file.root.features[:]
|
||||||
|
features_len = file.root.features_len[:]
|
||||||
|
transcript = file.root.transcript[:]
|
||||||
|
transcript_len = file.root.transcript_len[:]
|
||||||
|
|
||||||
|
# features are stored flattened, so reshape into
|
||||||
|
# [n_steps, (n_input + 2*n_context*n_input)]
|
||||||
|
for i in range(len(features)):
|
||||||
|
features[i] = np.reshape(features[i], [features_len[i], -1])
|
||||||
|
|
||||||
|
in_data = list(zip(features, features_len,
|
||||||
|
transcript, transcript_len))
|
||||||
|
print('Loaded from cache at', hdf5_cache_path)
|
||||||
|
return pandas.DataFrame(data=in_data, columns=COLUMNS)
|
||||||
|
|
||||||
|
source_data = None
|
||||||
|
for csv in csv_files:
|
||||||
|
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
|
||||||
|
#FIXME: not cross-platform
|
||||||
|
csv_dir = os.path.dirname(os.path.abspath(csv))
|
||||||
|
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
|
||||||
|
if source_data is None:
|
||||||
|
source_data = file
|
||||||
|
else:
|
||||||
|
source_data = source_data.append(file)
|
||||||
|
|
||||||
|
step_fn = partial(process_single_file,
|
||||||
|
numcep=numcep,
|
||||||
|
numcontext=numcontext,
|
||||||
|
alphabet=alphabet)
|
||||||
|
out_data = pmap(step_fn, source_data.iterrows())
|
||||||
|
|
||||||
|
if hdf5_cache_path:
|
||||||
|
print('Saving to', hdf5_cache_path)
|
||||||
|
|
||||||
|
# list of tuples -> tuple of lists
|
||||||
|
features, features_len, transcript, transcript_len = zip(*out_data)
|
||||||
|
|
||||||
|
with tables.open_file(hdf5_cache_path, 'w') as file:
|
||||||
|
features_dset = file.create_vlarray(file.root,
|
||||||
|
'features',
|
||||||
|
tables.Float32Atom(),
|
||||||
|
filters=tables.Filters(complevel=1))
|
||||||
|
# VLArray atoms need to be 1D, so flatten feature array
|
||||||
|
for f in features:
|
||||||
|
features_dset.append(np.reshape(f, -1))
|
||||||
|
|
||||||
|
features_len_dset = file.create_array(file.root,
|
||||||
|
'features_len',
|
||||||
|
features_len)
|
||||||
|
|
||||||
|
transcript_dset = file.create_vlarray(file.root,
|
||||||
|
'transcript',
|
||||||
|
tables.Int32Atom(),
|
||||||
|
filters=tables.Filters(complevel=1))
|
||||||
|
for t in transcript:
|
||||||
|
transcript_dset.append(t)
|
||||||
|
|
||||||
|
transcript_len_dset = file.create_array(file.root,
|
||||||
|
'transcript_len',
|
||||||
|
transcript_len)
|
||||||
|
|
||||||
|
print('Preprocessing done')
|
||||||
|
return pandas.DataFrame(data=out_data, columns=COLUMNS)
|
Loading…
Reference in New Issue
Block a user