Merge pull request #1532 from mozilla/feature-caching-clean

Feature caching
This commit is contained in:
Reuben Morais 2018-09-17 13:36:14 -03:00 committed by GitHub
commit 5dba8e34cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 202 additions and 185 deletions

View File

@ -25,6 +25,7 @@ from tensorflow.python.tools import freeze_graph
from threading import Thread, Lock from threading import Thread, Lock
from util.audio import audiofile_to_input_vector from util.audio import audiofile_to_input_vector
from util.feeding import DataSet, ModelFeeder from util.feeding import DataSet, ModelFeeder
from util.preprocess import preprocess
from util.gpu import get_available_gpus from util.gpu import get_available_gpus
from util.shared_lib import check_cupti from util.shared_lib import check_cupti
from util.text import sparse_tensor_value_to_texts, wer, levenshtein, Alphabet, ndarray_to_text from util.text import sparse_tensor_value_to_texts, wer, levenshtein, Alphabet, ndarray_to_text
@ -40,6 +41,10 @@ def create_flags():
tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged') tf.app.flags.DEFINE_string ('test_files', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training') tf.app.flags.DEFINE_boolean ('fulltrace', False, 'if full trace debug info should be generated during training')
tf.app.flags.DEFINE_string ('train_cached_features_path', '', 'comma separated list of files specifying the dataset used for training. multiple files will get merged')
tf.app.flags.DEFINE_string ('dev_cached_features_path', '', 'comma separated list of files specifying the dataset used for validation. multiple files will get merged')
tf.app.flags.DEFINE_string ('test_cached_features_path', '', 'comma separated list of files specifying the dataset used for testing. multiple files will get merged')
# Cluster configuration # Cluster configuration
# ===================== # =====================
@ -402,7 +407,7 @@ def BiRNN(batch_x, seq_length, dropout, reuse=False, batch_size=None, n_steps=-1
# This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`. # This is done to prepare the batch for input into the first layer which expects a tensor of rank `2`.
# Permute n_steps and batch_size # Permute n_steps and batch_size
batch_x = tf.transpose(batch_x, [1, 0, 2]) batch_x = tf.transpose(batch_x, [1, 0, 2, 3])
# Reshape to prepare input for first layer # Reshape to prepare input for first layer
batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context) batch_x = tf.reshape(batch_x, [-1, n_input + 2*n_input*n_context]) # (n_steps*batch_size, n_input + 2*n_input*n_context)
layers['input_reshaped'] = batch_x layers['input_reshaped'] = batch_x
@ -1459,19 +1464,40 @@ def train(server=None):
global_step = tf.Variable(0, trainable=False, name='global_step') global_step = tf.Variable(0, trainable=False, name='global_step')
# Reading training set # Reading training set
train_set = DataSet(FLAGS.train_files.split(','), train_data = preprocess(FLAGS.train_files.split(','),
FLAGS.train_batch_size,
n_input,
n_context,
alphabet,
hdf5_cache_path=FLAGS.train_cached_features_path)
train_set = DataSet(train_data,
FLAGS.train_batch_size, FLAGS.train_batch_size,
limit=FLAGS.limit_train, limit=FLAGS.limit_train,
next_index=lambda i: COORD.get_next_index('train')) next_index=lambda i: COORD.get_next_index('train'))
# Reading validation set # Reading validation set
dev_set = DataSet(FLAGS.dev_files.split(','), dev_data = preprocess(FLAGS.dev_files.split(','),
FLAGS.dev_batch_size,
n_input,
n_context,
alphabet,
hdf5_cache_path=FLAGS.dev_cached_features_path)
dev_set = DataSet(dev_data,
FLAGS.dev_batch_size, FLAGS.dev_batch_size,
limit=FLAGS.limit_dev, limit=FLAGS.limit_dev,
next_index=lambda i: COORD.get_next_index('dev')) next_index=lambda i: COORD.get_next_index('dev'))
# Reading test set # Reading test set
test_set = DataSet(FLAGS.test_files.split(','), test_data = preprocess(FLAGS.test_files.split(','),
FLAGS.test_batch_size,
n_input,
n_context,
alphabet,
hdf5_cache_path=FLAGS.test_cached_features_path)
test_set = DataSet(test_data,
FLAGS.test_batch_size, FLAGS.test_batch_size,
limit=FLAGS.limit_test, limit=FLAGS.limit_test,
next_index=lambda i: COORD.get_next_index('test')) next_index=lambda i: COORD.get_next_index('test'))
@ -1759,8 +1785,8 @@ def train(server=None):
sys.exit(1) sys.exit(1)
def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False): def create_inference_graph(batch_size=1, n_steps=16, use_new_decoder=False):
# Input tensor will be of shape [batch_size, n_steps, n_input + 2*n_input*n_context] # Input tensor will be of shape [batch_size, n_steps, 2*n_context+1, n_input]
input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, n_input + 2*n_input*n_context], name='input_node') input_tensor = tf.placeholder(tf.float32, [batch_size, n_steps if n_steps > 0 else None, 2*n_context+1, n_input], name='input_node')
seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths') seq_length = tf.placeholder(tf.int32, [batch_size], name='input_lengths')
previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None) previous_state_c = variable_on_worker_level('previous_state_c', [batch_size, n_cell_dim], initializer=None)

View File

@ -19,6 +19,7 @@ from multiprocessing import Pool
from six.moves import zip, range from six.moves import zip, range
from util.audio import audiofile_to_input_vector from util.audio import audiofile_to_input_vector
from util.text import sparse_tensor_value_to_texts, text_to_char_array, Alphabet, ctc_label_dense_to_sparse, wer, levenshtein from util.text import sparse_tensor_value_to_texts, text_to_char_array, Alphabet, ctc_label_dense_to_sparse, wer, levenshtein
from util.preprocess import pmap, preprocess
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
@ -28,88 +29,6 @@ N_FEATURES = 26
N_CONTEXT = 9 N_CONTEXT = 9
def pmap(fun, iterable, threads=8):
pool = Pool(threads)
results = pool.map(fun, iterable)
pool.close()
return results
def process_single_file(row):
# row = index, Series
_, file = row
features = audiofile_to_input_vector(file.wav_filename, N_FEATURES, N_CONTEXT)
transcript = text_to_char_array(file.transcript, alphabet)
return features, len(features), transcript, len(transcript)
# load samples from CSV, compute features, optionally cache results on disk
def preprocess(dataset_files, batch_size, hdf5_dest_path=None):
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
if hdf5_dest_path and os.path.exists(hdf5_dest_path):
with tables.open_file(hdf5_dest_path, 'r') as file:
features = file.root.features[:]
features_len = file.root.features_len[:]
transcript = file.root.transcript[:]
transcript_len = file.root.transcript_len[:]
# features are stored flattened, so reshape into
# [n_steps, (n_input + 2*n_context*n_input)]
for i in range(len(features)):
features[i] = np.reshape(features[i], [features_len[i], -1])
in_data = list(zip(features, features_len,
transcript, transcript_len))
return pandas.DataFrame(data=in_data, columns=COLUMNS)
csv_files = dataset_files.split(',')
source_data = None
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
if source_data is None:
source_data = file
else:
source_data = source_data.append(file)
# discard last samples if dataset does not divide batch size evenly
if len(source_data) % batch_size != 0:
source_data = source_data[:-(len(source_data) % batch_size)]
out_data = pmap(process_single_file, source_data.iterrows())
if hdf5_dest_path:
# list of tuples -> tuple of lists
features, features_len, transcript, transcript_len = zip(*out_data)
with tables.open_file(hdf5_dest_path, 'w') as file:
features_dset = file.create_vlarray(file.root,
'features',
tables.Float32Atom(),
filters=tables.Filters(complevel=1))
# VLArray atoms need to be 1D, so flatten feature array
for f in features:
features_dset.append(np.reshape(f, -1))
features_len_dset = file.create_array(file.root,
'features_len',
features_len)
transcript_dset = file.create_vlarray(file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(complevel=1))
for t in transcript:
transcript_dset.append(t)
transcript_len_dset = file.create_array(file.root,
'transcript_len',
transcript_len)
return pandas.DataFrame(data=out_data, columns=COLUMNS)
def split_data(dataset, batch_size): def split_data(dataset, batch_size):
remainder = len(dataset) % batch_size remainder = len(dataset) % batch_size
if remainder != 0: if remainder != 0:

View File

@ -323,14 +323,15 @@ ModelState::infer(const float* aMfcc, unsigned int n_frames, vector<float>& logi
return; return;
#endif // DS_NATIVE_MODEL #endif // DS_NATIVE_MODEL
} else { } else {
Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, mfcc_feats_per_timestep})); Tensor input(DT_FLOAT, TensorShape({BATCH_SIZE, n_steps, 2*n_context+1, MFCC_FEATURES}));
auto input_mapped = input.tensor<float, 3>(); auto input_mapped = input.flat<float>();
int idx = 0; int i;
for (int i = 0; i < n_frames; i++) { for (i = 0; i < n_frames*mfcc_feats_per_timestep; ++i) {
for (int j = 0; j < mfcc_feats_per_timestep; j++, idx++) { input_mapped(i) = aMfcc[i];
input_mapped(0, i, j) = aMfcc[idx];
} }
for (; i < n_steps*mfcc_feats_per_timestep; ++i) {
input_mapped(i) = 0;
} }
Tensor input_lengths(DT_INT32, TensorShape({1})); Tensor input_lengths(DT_INT32, TensorShape({1}));
@ -482,9 +483,8 @@ DS_CreateModel(const char* aModelPath,
if (node.name() == "input_node") { if (node.name() == "input_node") {
const auto& shape = node.attr().at("shape").shape(); const auto& shape = node.attr().at("shape").shape();
model->n_steps = shape.dim(1).size(); model->n_steps = shape.dim(1).size();
model->mfcc_feats_per_timestep = shape.dim(2).size(); model->n_context = (shape.dim(2).size()-1)/2;
// mfcc_features_per_timestep = MFCC_FEATURES * ((2*n_context) + 1) model->mfcc_feats_per_timestep = shape.dim(2).size() * shape.dim(3).size();
model->n_context = (model->mfcc_feats_per_timestep - MFCC_FEATURES) / (2 * MFCC_FEATURES);
} else if (node.name() == "logits_shape") { } else if (node.name() == "logits_shape") {
Tensor logits_shape = Tensor(DT_INT32, TensorShape({3})); Tensor logits_shape = Tensor(DT_INT32, TensorShape({3}));
if (!logits_shape.FromProto(node.attr().at("value").tensor())) { if (!logits_shape.FromProto(node.attr().at("value").tensor())) {

View File

@ -13,3 +13,4 @@ pyxdg
bs4 bs4
six six
requests requests
tables

View File

@ -47,11 +47,15 @@ model_source_mmap="$(dirname "${model_source}")/${model_name_mmap}"
SUPPORTED_PYTHON_VERSIONS=${SUPPORTED_PYTHON_VERSIONS:-2.7.14:ucs2 2.7.14:ucs4 3.4.8:ucs4 3.5.5:ucs4 3.6.4:ucs4 3.7.0:ucs4} SUPPORTED_PYTHON_VERSIONS=${SUPPORTED_PYTHON_VERSIONS:-2.7.14:ucs2 2.7.14:ucs4 3.4.8:ucs4 3.5.5:ucs4 3.6.4:ucs4 3.7.0:ucs4}
SUPPORTED_NODEJS_VERSIONS=${SUPPORTED_NODEJS_VERSIONS:-4.9.1 5.12.0 6.14.1 7.10.1 8.11.1 9.11.1 10.3.0} SUPPORTED_NODEJS_VERSIONS=${SUPPORTED_NODEJS_VERSIONS:-4.9.1 5.12.0 6.14.1 7.10.1 8.11.1 9.11.1 10.3.0}
strip() {
echo "$(echo $1 | sed -e 's/^[[:space:]]+//' -e 's/[[:space:]]+$//')"
}
# This verify exact inference result # This verify exact inference result
assert_correct_inference() assert_correct_inference()
{ {
phrase=$1 phrase=$(strip "$1")
expected=$2 expected=$(strip "$2")
if [ -z "${phrase}" -o -z "${expected}" ]; then if [ -z "${phrase}" -o -z "${expected}" ]; then
echo "One or more empty strings:" echo "One or more empty strings:"
@ -158,8 +162,8 @@ assert_correct_ldc93s1_prodmodel()
assert_correct_ldc93s1_somodel() assert_correct_ldc93s1_somodel()
{ {
somodel_nolm=$1 somodel_nolm=$(strip "$1")
somodel_withlm=$2 somodel_withlm=$(strip "$2")
# We want to be able to return non zero value from the function, while not # We want to be able to return non zero value from the function, while not
# failing the whole execution # failing the whole execution

View File

@ -1,53 +1,7 @@
from __future__ import absolute_import, print_function import numpy as np
import scipy.io.wavfile as wav import scipy.io.wavfile as wav
import sys
import warnings
class DeepSpeechDeprecationWarning(DeprecationWarning): from python_speech_features import mfcc
pass
warnings.simplefilter('once', category=DeepSpeechDeprecationWarning)
try:
from deepspeech import audioToInputVector
except ImportError:
warnings.warn('DeepSpeech Python bindings could not be imported, resorting to slower code to compute audio features. '
'Refer to README.md for instructions on how to install (or build) the DeepSpeech Python bindings.',
category=DeepSpeechDeprecationWarning)
import numpy as np
from python_speech_features import mfcc
from six.moves import range
def audioToInputVector(audio, fs, numcep, numcontext):
# Get mfcc coefficients
features = mfcc(audio, samplerate=fs, numcep=numcep)
# We only keep every second feature (BiRNN stride = 2)
features = features[::2]
# One stride per time step in the input
num_strides = len(features)
# Add empty initial and final contexts
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
# Create a view into the array with overlapping strides of size
# numcontext (past) + 1 (present) + numcontext (future)
window_size = 2*numcontext+1
train_inputs = np.lib.stride_tricks.as_strided(
features,
(num_strides, window_size, numcep),
(features.strides[0], features.strides[0], features.strides[1]),
writeable=False)
# Flatten the second and third dimensions
train_inputs = np.reshape(train_inputs, [num_strides, -1])
# Return results
return train_inputs
def audiofile_to_input_vector(audio_filename, numcep, numcontext): def audiofile_to_input_vector(audio_filename, numcep, numcontext):
@ -60,4 +14,14 @@ def audiofile_to_input_vector(audio_filename, numcep, numcontext):
# Load wav files # Load wav files
fs, audio = wav.read(audio_filename) fs, audio = wav.read(audio_filename)
return audioToInputVector(audio, fs, numcep, numcontext) # Get mfcc coefficients
features = mfcc(audio, samplerate=fs, numcep=numcep)
# We only keep every second feature (BiRNN stride = 2)
features = features[::2]
# Add empty initial and final contexts
empty_context = np.zeros((numcontext, numcep), dtype=features.dtype)
features = np.concatenate((empty_context, features, empty_context))
return features

View File

@ -1,12 +1,12 @@
import pandas import numpy as np
import tensorflow as tf import tensorflow as tf
from threading import Thread
from math import ceil from math import ceil
from six.moves import range from six.moves import range
from util.audio import audiofile_to_input_vector from threading import Thread
from util.gpu import get_available_gpus from util.gpu import get_available_gpus
from util.text import ctc_label_dense_to_sparse, text_to_char_array from util.text import ctc_label_dense_to_sparse
class ModelFeeder(object): class ModelFeeder(object):
''' '''
@ -24,7 +24,7 @@ class ModelFeeder(object):
numcontext, numcontext,
alphabet, alphabet,
tower_feeder_count=-1, tower_feeder_count=-1,
threads_per_queue=2): threads_per_queue=4):
self.train = train_set self.train = train_set
self.dev = dev_set self.dev = dev_set
@ -35,7 +35,7 @@ class ModelFeeder(object):
self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count self.tower_feeder_count = max(len(get_available_gpus()), 1) if tower_feeder_count < 0 else tower_feeder_count
self.threads_per_queue = threads_per_queue self.threads_per_queue = threads_per_queue
self.ph_x = tf.placeholder(tf.float32, [None, numcep + (2 * numcep * numcontext)]) self.ph_x = tf.placeholder(tf.float32, [None, 2*numcontext+1, numcep])
self.ph_x_length = tf.placeholder(tf.int32, []) self.ph_x_length = tf.placeholder(tf.int32, [])
self.ph_y = tf.placeholder(tf.int32, [None,]) self.ph_y = tf.placeholder(tf.int32, [None,])
self.ph_y_length = tf.placeholder(tf.int32, []) self.ph_y_length = tf.placeholder(tf.int32, [])
@ -77,27 +77,19 @@ class ModelFeeder(object):
''' '''
return self._tower_feeders[tower_feeder_index].next_batch() return self._tower_feeders[tower_feeder_index].next_batch()
class DataSet(object): class DataSet(object):
''' '''
Represents a collection of audio samples and their respective transcriptions. Represents a collection of audio samples and their respective transcriptions.
Takes a set of CSV files produced by importers in /bin. Takes a set of CSV files produced by importers in /bin.
''' '''
def __init__(self, csvs, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1): def __init__(self, data, batch_size, skip=0, limit=0, ascending=True, next_index=lambda i: i + 1):
self.data = data
self.data.sort_values(by="features_len", ascending=ascending, inplace=True)
self.batch_size = batch_size self.batch_size = batch_size
self.next_index = next_index self.next_index = next_index
self.files = None self.total_batches = int(ceil(len(self.data) / batch_size))
for csv in csvs:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
if self.files is None:
self.files = file
else:
self.files = self.files.append(file)
self.files = self.files.sort_values(by="wav_filesize", ascending=ascending) \
.ix[:, ["wav_filename", "transcript"]] \
.values[skip:]
if limit > 0:
self.files = self.files[:limit]
self.total_batches = int(ceil(len(self.files) / batch_size))
class _DataSetLoader(object): class _DataSetLoader(object):
''' '''
@ -109,9 +101,9 @@ class _DataSetLoader(object):
def __init__(self, model_feeder, data_set, alphabet): def __init__(self, model_feeder, data_set, alphabet):
self._model_feeder = model_feeder self._model_feeder = model_feeder
self._data_set = data_set self._data_set = data_set
self.queue = tf.PaddingFIFOQueue(shapes=[[None, model_feeder.numcep + (2 * model_feeder.numcep * model_feeder.numcontext)], [], [None,], []], self.queue = tf.PaddingFIFOQueue(shapes=[[None, 2 * model_feeder.numcontext + 1, model_feeder.numcep], [], [None,], []],
dtypes=[tf.float32, tf.int32, tf.int32, tf.int32], dtypes=[tf.float32, tf.int32, tf.int32, tf.int32],
capacity=data_set.batch_size * 2) capacity=data_set.batch_size * 8)
self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length]) self._enqueue_op = self.queue.enqueue([model_feeder.ph_x, model_feeder.ph_x_length, model_feeder.ph_y, model_feeder.ph_y_length])
self._close_op = self.queue.close(cancel_pending_enqueues=True) self._close_op = self.queue.close(cancel_pending_enqueues=True)
self._alphabet = alphabet self._alphabet = alphabet
@ -138,25 +130,35 @@ class _DataSetLoader(object):
''' '''
Queue thread routine. Queue thread routine.
''' '''
file_count = len(self._data_set.files) file_count = len(self._data_set.data)
index = -1 index = -1
while not coord.should_stop(): while not coord.should_stop():
index = self._data_set.next_index(index) % file_count index = self._data_set.next_index(index) % file_count
wav_file, transcript = self._data_set.files[index] features, _, transcript, transcript_len = self._data_set.data.iloc[index]
source = audiofile_to_input_vector(wav_file, self._model_feeder.numcep, self._model_feeder.numcontext)
source_len = len(source) # One stride per time step in the input
target = text_to_char_array(transcript, self._alphabet) num_strides = len(features) - (self._model_feeder.numcontext * 2)
target_len = len(target)
if source_len < target_len: # Create a view into the array with overlapping strides of size
raise ValueError('Error: Audio file {} is too short for transcription.'.format(wav_file)) # numcontext (past) + 1 (present) + numcontext (future)
window_size = 2*self._model_feeder.numcontext+1
features = np.lib.stride_tricks.as_strided(
features,
(num_strides, window_size, self._model_feeder.numcep),
(features.strides[0], features.strides[0], features.strides[1]),
writeable=False)
try: try:
session.run(self._enqueue_op, feed_dict={ self._model_feeder.ph_x: source, session.run(self._enqueue_op, feed_dict={
self._model_feeder.ph_x_length: source_len, self._model_feeder.ph_x: features,
self._model_feeder.ph_y: target, self._model_feeder.ph_x_length: num_strides,
self._model_feeder.ph_y_length: target_len }) self._model_feeder.ph_y: transcript,
self._model_feeder.ph_y_length: transcript_len
})
except tf.errors.CancelledError: except tf.errors.CancelledError:
return return
class _TowerFeeder(object): class _TowerFeeder(object):
''' '''
Internal class that represents a switchable input queue for one tower. Internal class that represents a switchable input queue for one tower.

101
util/preprocess.py Normal file
View File

@ -0,0 +1,101 @@
import numpy as np
import os
import pandas
import tables
from functools import partial
from multiprocessing.dummy import Pool
from util.audio import audiofile_to_input_vector
from util.text import text_to_char_array
def pmap(fun, iterable, threads=8):
pool = Pool(threads)
results = pool.map(fun, iterable)
pool.close()
return results
def process_single_file(row, numcep, numcontext, alphabet):
# row = index, Series
_, file = row
features = audiofile_to_input_vector(file.wav_filename, numcep, numcontext)
transcript = text_to_char_array(file.transcript, alphabet)
if (2*numcontext + len(features)) < len(transcript):
raise ValueError('Error: Audio file {} is too short for transcription.'.format(file.wav_filename))
return features, len(features), transcript, len(transcript)
# load samples from CSV, compute features, optionally cache results on disk
def preprocess(csv_files, batch_size, numcep, numcontext, alphabet, hdf5_cache_path=None):
COLUMNS = ('features', 'features_len', 'transcript', 'transcript_len')
print('Preprocessing', csv_files)
if hdf5_cache_path and os.path.exists(hdf5_cache_path):
with tables.open_file(hdf5_cache_path, 'r') as file:
features = file.root.features[:]
features_len = file.root.features_len[:]
transcript = file.root.transcript[:]
transcript_len = file.root.transcript_len[:]
# features are stored flattened, so reshape into
# [n_steps, (n_input + 2*n_context*n_input)]
for i in range(len(features)):
features[i] = np.reshape(features[i], [features_len[i], -1])
in_data = list(zip(features, features_len,
transcript, transcript_len))
print('Loaded from cache at', hdf5_cache_path)
return pandas.DataFrame(data=in_data, columns=COLUMNS)
source_data = None
for csv in csv_files:
file = pandas.read_csv(csv, encoding='utf-8', na_filter=False)
#FIXME: not cross-platform
csv_dir = os.path.dirname(os.path.abspath(csv))
file['wav_filename'] = file['wav_filename'].str.replace(r'(^[^/])', lambda m: os.path.join(csv_dir, m.group(1)))
if source_data is None:
source_data = file
else:
source_data = source_data.append(file)
step_fn = partial(process_single_file,
numcep=numcep,
numcontext=numcontext,
alphabet=alphabet)
out_data = pmap(step_fn, source_data.iterrows())
if hdf5_cache_path:
print('Saving to', hdf5_cache_path)
# list of tuples -> tuple of lists
features, features_len, transcript, transcript_len = zip(*out_data)
with tables.open_file(hdf5_cache_path, 'w') as file:
features_dset = file.create_vlarray(file.root,
'features',
tables.Float32Atom(),
filters=tables.Filters(complevel=1))
# VLArray atoms need to be 1D, so flatten feature array
for f in features:
features_dset.append(np.reshape(f, -1))
features_len_dset = file.create_array(file.root,
'features_len',
features_len)
transcript_dset = file.create_vlarray(file.root,
'transcript',
tables.Int32Atom(),
filters=tables.Filters(complevel=1))
for t in transcript:
transcript_dset.append(t)
transcript_len_dset = file.create_array(file.root,
'transcript_len',
transcript_len)
print('Preprocessing done')
return pandas.DataFrame(data=out_data, columns=COLUMNS)